import io
import copy
import itertools
from datetime import datetime, timedelta, time, UTC
from db import get_session, get_transaction
import pandas as pd
import sqlalchemy as sql
from app.model.orm import (
Bioreplicate,
Community,
CommunityStrain,
Compartment,
Experiment,
ExperimentCompartment,
Measurement,
MeasurementContext,
Perturbation,
Project,
ProjectUser,
Study,
StudyMetabolite,
StudyStrain,
StudyUser,
Taxon,
)
from app.model.lib.util import group_by_unique_name, is_non_negative_float, find_duplicates
from app.model.lib.conversion import convert_time
from app.model.tasks.submissions import export_submission_data
[docs]
def persist_submission_to_database(submission_form):
submission = submission_form.submission
user_uuid = submission.userUniqueID
errors = []
if submission_form.submission.dataFile is None:
errors.append("Data file has not been uploaded")
if errors:
return errors
with get_transaction() as db_transaction:
db_trans_session = get_session(db_transaction)
project = _save_project(db_trans_session, submission_form, user_uuid)
study = _save_study(db_trans_session, submission_form, user_uuid)
_clear_study(study)
_save_study_techniques(db_trans_session, submission_form, study)
_save_compartments(db_trans_session, submission_form, study)
_save_communities(db_trans_session, submission_form, study, user_uuid)
_save_experiments(db_trans_session, submission_form, study)
db_trans_session.flush()
_save_measurements(db_trans_session, study, submission_form)
for experiment in study.experiments:
_create_average_measurements(db_trans_session, study, experiment)
_finalize_submission(db_trans_session, submission_form, study, project)
db_trans_session.commit()
if study.isPublished:
export_submission_data.delay(submission.id)
return []
# TODO (2025-05-11) Test (separate read_data_file, operate on dfs)
#
[docs]
def validate_data_file(submission_form, data_file=None):
submission = submission_form.submission
data_file = data_file or submission.dataFile
errors = []
# TODO (2025-10-01) Show warnings that don't stop the upload
# process:
warnings = []
if not data_file:
return []
data_xls = data_file.content
sheets = pd.read_excel(io.BytesIO(data_xls), sheet_name=None)
# Validate columns:
community_columns, strain_columns, metabolite_columns = _get_expected_column_names(submission_form)
expected_value_columns = {*community_columns, *strain_columns, *metabolite_columns}
expected_columns = {'Biological Replicate', 'Compartment', 'Time', *expected_value_columns}
found_columns = set()
for sheet_name in sheets:
df = sheets[sheet_name]
found_columns |= set(df.columns)
missing_columns = expected_columns.difference(found_columns)
for missing_column in missing_columns:
errors.append(f"Missing column: {missing_column}")
# Validate row keys:
expected_bioreplicates = {
bioreplicate['name']
for experiment in submission.studyDesign['experiments']
for bioreplicate in experiment['bioreplicates']
}
expected_compartments = {c['name'] for c in submission.studyDesign['compartments']}
for sheet_name in sheets:
df = sheets[sheet_name]
if 'Biological Replicate' in df:
uploaded_bioreplicates = set(df['Biological Replicate'])
if missing_bioreplicates := expected_bioreplicates.difference(uploaded_bioreplicates):
bioreplicate_description = ', '.join(missing_bioreplicates)
errors.append(f"{sheet_name}: Missing biological replicate(s): {bioreplicate_description}")
if extra_bioreplicates := uploaded_bioreplicates.difference(expected_bioreplicates):
bioreplicate_description = ', '.join(extra_bioreplicates)
errors.append(f"{sheet_name}: Unexpected biological replicate(s): {bioreplicate_description}")
if 'Compartment' in df:
uploaded_compartments = set(df['Compartment'])
if extra_compartments := uploaded_compartments.difference(expected_compartments):
compartment_description = ', '.join(extra_compartments)
errors.append(f"{sheet_name}: Unexpected compartment(s): {compartment_description}")
# Validate values:
for sheet_name, df in sheets.items():
missing_time_rows = []
missing_values = {}
# Check for missing Time values:
if 'Time' in df:
for index, value in enumerate(df['Time']):
if not is_non_negative_float(value, isnan_check=True):
missing_time_rows.append(str(index + 1))
if missing_time_rows:
row_description = _format_row_list_error(missing_time_rows)
# TODO (2025-10-01) Show warnings in UI
warnings.append(f"{sheet_name}: Missing or invalid time values on row(s) {row_description}")
# For the other rows, we're looking for non-negative numbers or blanks
value_columns = expected_value_columns.intersection(set(df.columns))
for column in value_columns:
for index, value in enumerate(df[column]):
if not is_non_negative_float(value, isnan_check=False):
if column not in missing_values:
missing_values[column] = []
missing_values[column].append(str(index + 1))
if column in missing_values:
row_description = _format_row_list_error(missing_values[column])
errors.append(f"{sheet_name}: Invalid values in column \"{column}\" on row(s) {row_description}")
# Check for unique timepoints
grouped_timepoints = {}
time_data = df[['Biological Replicate', 'Compartment', 'Time']]
for _, row in time_data.iterrows():
key = (row['Biological Replicate'], row['Compartment'])
if row['Time']:
if key not in grouped_timepoints:
grouped_timepoints[key] = []
grouped_timepoints[key].append(row['Time'])
for (b, c), timepoints in grouped_timepoints.items():
if duplicates := find_duplicates(timepoints):
duplicates_description = ', '.join([str(d) for d in duplicates])
errors.append(f"{sheet_name}: Time points are not unique for bioreplicate {b}, compartment {c}: {duplicates_description}")
return errors
def _save_study(db_session, submission_form, user_uuid=None):
submission = submission_form.submission
if embargo_string := submission.studyDesign['study'].get('embargoExpiresAt', None):
embargo_date = datetime.fromisoformat(embargo_string)
embargo_datetime = datetime.combine(embargo_date, time(hour=23, minute=59, tzinfo=UTC))
else:
embargo_datetime = None
params = {
'publicId': submission_form.study_id,
'name': submission.studyDesign['study']['name'].strip(),
'description': submission.studyDesign['study'].get('description', '').strip(),
'url': submission.studyDesign['study'].get('url', '').strip(),
'authors': submission.studyDesign['study'].get('authors', []),
'authorCache': submission.studyDesign['study'].get('authorCache', ''),
'uuid': submission.studyUniqueID,
'projectUuid': submission.projectUniqueID,
'timeUnits': submission.studyDesign['timeUnits'],
'embargoExpiresAt': embargo_datetime,
}
if submission_form.study_id is None:
params['ownerUuid'] = user_uuid
study = Study(**Study.filter_keys(params))
study.publicId = Study.generate_public_id(db_session)
db_session.add(StudyUser(
studyUniqueID=submission.studyUniqueID,
userUniqueID=submission.userUniqueID,
))
else:
study = db_session.get(Study, submission_form.study_id)
study.update(**Study.filter_keys(params))
tomorrow = datetime.now(UTC) + timedelta(hours=24)
if embargo_datetime and embargo_datetime > tomorrow:
study.publishableAt = embargo_datetime
else:
study.publishableAt = tomorrow
db_session.add(study)
return study
def _save_project(db_session, submission_form, user_uuid=None):
submission = submission_form.submission
params = {
'publicId': submission_form.project_id,
'name': submission.studyDesign['project']['name'].strip(),
'description': submission.studyDesign['project'].get('description', '').strip(),
'uuid': submission.projectUniqueID,
}
if submission_form.project_id is None:
params['ownerUuid'] = user_uuid
project = Project(**Project.filter_keys(params))
project.publicId = Project.generate_public_id(db_session)
db_session.add(ProjectUser(
projectUniqueID=submission.projectUniqueID,
userUniqueID=submission.userUniqueID,
))
else:
project = db_session.get(Project, submission_form.project_id)
project.update(**Project.filter_keys(params))
db_session.add(project)
return project
def _clear_study(study):
for experiment in study.experiments:
experiment.experimentCompartments = []
experiment.perturbations = []
experiment.bioreplicates = []
study.measurements = []
study.measurementContexts = []
study.studyTechniques = []
study.strains = []
study.compartments = []
study.communities = []
study.studyMetabolites = []
def _save_compartments(db_session, submission_form, study):
submission = submission_form.submission
compartments = []
for compartment_data in submission.studyDesign['compartments']:
compartment = Compartment(**Compartment.filter_keys(compartment_data))
compartments.append(compartment)
study.compartments = compartments
db_session.add_all(compartments)
return compartments
def _save_communities(db_session, submission_form, study, user_uuid):
submission = submission_form.submission
communities = []
identifier_cache = {}
for community_data in submission.studyDesign['communities']:
community_data = copy.deepcopy(community_data)
strain_identifiers = community_data.pop('strainIdentifiers')
community = Community(**Community.filter_keys(community_data))
db_session.add(community)
db_session.flush()
for identifier in strain_identifiers:
if identifier not in identifier_cache:
if strain := _build_strain(db_session, identifier, submission, study, user_uuid):
identifier_cache[identifier] = strain
db_session.add(strain)
db_session.flush()
if identifier in identifier_cache:
community_strain = CommunityStrain(
community=community,
strain=identifier_cache[identifier],
)
db_session.add(community_strain)
communities.append(community)
study.communities = communities
# If any of the techniques have an "unknown" column, create an appropriate
# strain:
if any([st.includeUnknown for st in study.studyTechniques]):
strain = _build_strain(db_session, "unknown", submission, study, user_uuid)
db_session.add(strain)
db_session.flush()
return communities
def _save_experiments(db_session, submission_form, study):
submission = submission_form.submission
experiments = []
time_units = submission.studyDesign['timeUnits']
communities_by_name = group_by_unique_name(study.communities)
compartments_by_name = group_by_unique_name(study.compartments)
for experiment_data in submission.studyDesign['experiments']:
experiment_params = copy.deepcopy(experiment_data)
community_name = experiment_params.pop('communityName')
compartment_names = experiment_params.pop('compartmentNames')
bioreplicates = experiment_params.pop('bioreplicates')
perturbations = experiment_params.pop('perturbations')
if publicId := experiment_params.pop('publicId', None):
experiment = db_session.get(Experiment, publicId)
if experiment.studyId != study.publicId:
raise ValueError(f"Experiment with ID {publicId} does not belong to study {study.publicId}")
else:
experiment = Experiment(publicId=Experiment.generate_public_id(db_session))
experiment_data['publicId'] = experiment.publicId
experiment.update(
**Experiment.filter_keys(experiment_params),
community=communities_by_name[community_name],
)
db_session.add(experiment)
for compartment_name in compartment_names:
experiment_compartment = ExperimentCompartment(
experiment=experiment,
compartment=compartments_by_name[compartment_name],
)
db_session.add(experiment_compartment)
for bioreplicate_data in bioreplicates:
bioreplicate = Bioreplicate(
**Bioreplicate.filter_keys(bioreplicate_data),
experiment=experiment,
)
db_session.add(bioreplicate)
for perturbation_data in perturbations:
perturbation_data = copy.deepcopy(perturbation_data)
start_time = perturbation_data.pop('startTime', '0')
start_time_in_seconds = convert_time(int(start_time), time_units, 's')
end_time = perturbation_data.pop('endTime', None)
if end_time:
end_time_in_seconds = convert_time(int(end_time), time_units, 's')
else:
end_time_in_seconds = None
perturbation = Perturbation(
experiment=experiment,
startTimeInSeconds=start_time_in_seconds,
endTimeInSeconds=end_time_in_seconds,
description=perturbation_data.pop('description'),
)
name = perturbation_data.pop('removedCompartmentName', '')
if name != '':
perturbation.removedCompartmentId = compartments_by_name[name].id
name = perturbation_data.pop('addedCompartmentName', '')
if name != '':
perturbation.addedCompartmentId = compartments_by_name[name].id
name = perturbation_data.pop('oldCommunityName', '')
if name != '':
perturbation.oldCommunityId = communities_by_name[name].id
name = perturbation_data.pop('newCommunityName', '')
if name != '':
perturbation.newCommunityId = communities_by_name[name].id
db_session.add(perturbation)
experiments.append(experiment)
study.experiments = experiments
db_session.add_all(experiments)
return experiments
def _save_study_techniques(db_session, submission_form, study):
submission = submission_form.submission
techniques = []
for study_technique in submission.build_techniques():
study_technique.study = study
for measurement_technique in study_technique.measurementTechniques:
if measurement_technique.metaboliteIds:
for chebiId in measurement_technique.metaboliteIds:
db_session.add(StudyMetabolite(
chebiId=chebiId,
study=study,
))
techniques.append(study_technique)
db_session.add_all(techniques)
return techniques
def _save_measurements(db_session, study, submission_form):
submission = submission_form.submission
data_xls = submission.dataFile.content
sheets = pd.read_excel(io.BytesIO(data_xls), sheet_name=None)
for _, df in sheets.items():
Measurement.insert_from_csv_string(db_session, study, df.to_csv(index=False))
def _create_average_measurements(db_session, study, experiment):
bioreplicate_ids = [b.id for b in experiment.bioreplicates if not b.calculationType]
# The averaged measurements will be parented by a custom-generated bioreplicate:
average_bioreplicate = Bioreplicate(
name=f"Average({experiment.name})",
calculationType='average',
experiment=experiment,
)
db_session.add(average_bioreplicate)
has_measurements = False
for technique in study.measurementTechniques:
for compartment in experiment.compartments:
# We'll average values separately over techniques and compartments:
measurement_contexts = db_session.scalars(
sql.select(MeasurementContext)
.distinct()
.join(Measurement)
.where(
MeasurementContext.compartmentId == compartment.id,
MeasurementContext.bioreplicateId.in_(bioreplicate_ids),
MeasurementContext.techniqueId == technique.id,
Measurement.value.is_not(None),
)
.order_by(MeasurementContext.subjectType, MeasurementContext.subjectId)
).all()
# If there is a single context for this cluster of measurements, there is nothing to average:
if len(measurement_contexts) <= 1:
continue
# If measurement time points don't match, don't average them:
time_point_sets = set()
for measurement_context in measurement_contexts:
time_points = [m.timeInSeconds for m in measurement_context.measurements]
time_point_sets.add(frozenset(time_points))
if len(time_point_sets) > 1:
continue
if technique.subjectType == 'bioreplicate':
# A single context for a group of bioreplicates
_create_average_measurement_context(
db_session,
parent_records=(study, technique, compartment),
measurement_contexts=measurement_contexts,
average_bioreplicate=average_bioreplicate,
subject_id=average_bioreplicate.id,
subject_type='bioreplicate',
subject_name=average_bioreplicate.name,
subject_external_id=None,
)
else:
grouped_contexts = itertools.groupby(
measurement_contexts,
lambda mc: (mc.subjectId, mc.subjectType, mc.subjectName, mc.subjectExternalId),
)
for key, subject_contexts in grouped_contexts:
(
subject_id,
subject_type,
subject_name,
subject_external_id,
) = key
# One context for each subject:
_create_average_measurement_context(
db_session,
parent_records=(study, technique, compartment),
measurement_contexts=list(subject_contexts),
average_bioreplicate=average_bioreplicate,
subject_id=subject_id,
subject_type=subject_type,
subject_name=subject_name,
subject_external_id=subject_external_id,
)
has_measurements = True
if not has_measurements:
db_session.delete(average_bioreplicate)
def _create_average_measurement_context(
db_session,
parent_records,
measurement_contexts,
average_bioreplicate,
subject_id,
subject_type,
subject_name,
subject_external_id=None,
):
(study, technique, compartment) = parent_records
# Collect average measurement values for the given contexts:
measurement_rows = db_session.execute(
sql.select(
Measurement.timeInSeconds,
sql.func.avg(Measurement.value),
sql.func.std(Measurement.value),
)
.where(Measurement.contextId.in_([mc.id for mc in measurement_contexts]))
.group_by(Measurement.timeInSeconds)
.order_by(Measurement.timeInSeconds)
).all()
if len(measurement_rows) == 0:
# We do not want to create unnecessary contexts
return
# Create a parent context for the individual measurements:
average_context = MeasurementContext(
study=study,
bioreplicate=average_bioreplicate,
compartment=compartment,
subjectId=subject_id,
subjectType=subject_type,
subjectName=subject_name,
subjectExternalId=subject_external_id,
technique=technique,
calculationType='average',
)
db_session.add(average_context)
# Create individual measurements
for (t, value, std) in measurement_rows:
measurement = Measurement(
timeInSeconds=t,
value=value,
std=std,
context=average_context,
study=study,
)
db_session.add(measurement)
def _finalize_submission(db_session, submission_form, study, project):
study.lastSubmissionId = submission_form.submission.id
if study.isPublished:
submission_form.submission.publishedAt = study.publishedAt
submission_form.save()
submission_form.save_backup(study_id=study.publicId, project_id=project.publicId)
def _find_custom_strain(submission, identifier):
for custom_strain_data in submission.studyDesign['custom_strains']:
if custom_strain_data['name'] == identifier:
return custom_strain_data
else:
return None
def _get_expected_column_names(submission_form):
submission = submission_form.submission
community_columns = set()
strain_columns = set()
metabolite_columns = set()
# Validate column presence:
for index, study_technique in enumerate(submission.build_techniques()):
for measurement_technique in study_technique.measurementTechniques:
if measurement_technique.subjectType == 'bioreplicate':
column = measurement_technique.csv_column_name()
community_columns.add(column)
if study_technique.includeStd:
community_columns.add(f"{column} STD")
elif measurement_technique.subjectType == 'strain':
for taxon in submission_form.fetch_taxa():
column = measurement_technique.csv_column_name(taxon.name)
strain_columns.add(column)
if study_technique.includeStd:
strain_columns.add(f"{column} STD")
if study_technique.includeUnknown:
column = measurement_technique.csv_column_name("Unknown")
strain_columns.add(column)
if study_technique.includeStd:
strain_columns.add(f"{column} STD")
for strain in submission.studyDesign['custom_strains']:
column = measurement_technique.csv_column_name(strain['name'])
strain_columns.add(column)
if study_technique.includeStd:
strain_columns.add(f"{column} STD")
elif measurement_technique.subjectType == 'metabolite':
for metabolite in submission_form.fetch_metabolites_for_technique(index):
column = measurement_technique.csv_column_name(metabolite.name)
metabolite_columns.add(column)
if study_technique.includeStd:
metabolite_columns.add(f"{column} STD")
else:
raise ValueError(f"Unexpected measurement_technique subjectType: {measurement_technique.subjectType}")
return community_columns, strain_columns, metabolite_columns
def _build_strain(db_session, identifier, submission, study, user_uuid):
strain_params = {'study': study, 'userUniqueID': user_uuid}
if identifier.startswith('existing|'):
taxon_id = identifier.removeprefix('existing|')
taxon = db_session.scalars(
sql.select(Taxon)
.where(Taxon.ncbiId == taxon_id)
.limit(1)
).one()
strain_params = {
'name': taxon.name,
'ncbiId': taxon.ncbiId,
'defined': True,
**strain_params,
}
elif identifier.startswith('custom|'):
identifier = identifier.removeprefix('custom|')
custom_strain_data = _find_custom_strain(submission, identifier)
if custom_strain_data is None:
# Missing strain due to renames
return None
strain_params = {
'name': custom_strain_data['name'],
'ncbiId': custom_strain_data['species'],
'description': custom_strain_data['description'],
'defined': False,
**strain_params,
}
elif identifier == 'unknown':
strain_params = {
'name': "Unknown",
'ncbiId': 0,
'description': "Unknown measurements",
'defined': False,
**strain_params,
}
else:
raise ValueError(f"Strain identifier {repr(identifier)} has an unexpected prefix")
return StudyStrain(**strain_params)
def _format_row_list_error(row_list):
description = ', '.join(row_list[0:3])
if len(row_list) > 3:
description += f", and {len(row_list) - 3} more"
return description