import io
import copy
import itertools
from datetime import datetime, timedelta, UTC
from db import get_session, get_transaction
import pandas as pd
import sqlalchemy as sql
from app.model.orm import (
Bioreplicate,
Community,
Compartment,
Experiment,
ExperimentCompartment,
Measurement,
MeasurementContext,
Perturbation,
Project,
ProjectUser,
Strain,
Study,
StudyMetabolite,
StudyUser,
Taxon,
)
from app.model.lib.util import group_by_unique_name, is_non_negative_float
[docs]
def persist_submission_to_database(submission_form):
submission = submission_form.submission
user_uuid = submission.userUniqueID
errors = []
if submission_form.submission.dataFile is None:
errors.append("Data file has not been uploaded")
if errors:
return errors
with get_transaction() as db_transaction:
db_trans_session = get_session(db_transaction)
_save_project(db_trans_session, submission_form)
study = _save_study(db_trans_session, submission_form)
# First, clear out existing relationships
study.measurements = []
study.measurementContexts = []
study.measurementTechniques = []
study.strains = []
study.experimentCompartments = []
study.compartments = []
study.communities = []
study.experiments = []
study.bioreplicates = []
study.perturbations = []
study.studyMetabolites = []
_save_compartments(db_trans_session, submission_form, study)
_save_communities(db_trans_session, submission_form, study, user_uuid)
_save_experiments(db_trans_session, submission_form, study)
_save_measurement_techniques(db_trans_session, submission_form, study)
db_trans_session.flush()
_save_measurements(db_trans_session, study, submission)
for experiment in study.experiments:
_create_average_measurements(db_trans_session, study, experiment)
submission_form.save()
db_trans_session.commit()
return []
# TODO (2025-05-11) Test (separate read_data_file, operate on dfs)
#
[docs]
def validate_data_file(submission_form, data_file=None):
submission = submission_form.submission
data_file = data_file or submission.dataFile
errors = []
if not data_file:
return []
data_xls = data_file.content
sheets = pd.read_excel(io.BytesIO(data_xls), sheet_name=None)
# Validate columns:
expected_value_columns = _get_expected_column_names(submission_form)
for sheet_name, column_set in expected_value_columns.items():
if sheet_name in sheets:
df = sheets[sheet_name]
column_set = {*column_set, 'Biological Replicate', 'Compartment', 'Time'}
for missing_column in column_set.difference(set(df.columns)):
errors.append(f"{sheet_name}: Missing column {missing_column}")
else:
errors.append(f"Missing data sheet: {sheet_name}")
# Validate row keys:
expected_bioreplicates = {
bioreplicate['name']
for experiment in submission.studyDesign['experiments']
for bioreplicate in experiment['bioreplicates']
}
expected_compartments = {c['name'] for c in submission.studyDesign['compartments']}
for sheet_name in expected_value_columns.keys():
if sheet_name not in sheets:
continue
df = sheets[sheet_name]
if 'Biological Replicate' in df:
uploaded_bioreplicates = set(df['Biological Replicate'])
if missing_bioreplicates := expected_bioreplicates.difference(uploaded_bioreplicates):
bioreplicate_description = ', '.join(missing_bioreplicates)
errors.append(f"{sheet_name}: Missing biological replicate(s): {bioreplicate_description}")
if extra_bioreplicates := uploaded_bioreplicates.difference(expected_bioreplicates):
bioreplicate_description = ', '.join(extra_bioreplicates)
errors.append(f"{sheet_name}: Unexpected biological replicate(s): {bioreplicate_description}")
if 'Compartment' in df:
uploaded_compartments = set(df['Compartment'])
if missing_compartments := expected_compartments.difference(set(df['Compartment'])):
compartment_description = ', '.join(missing_compartments)
errors.append(f"{sheet_name}: Missing compartment(s): {compartment_description}")
if extra_compartments := uploaded_compartments.difference(expected_compartments):
compartment_description = ', '.join(extra_compartments)
errors.append(f"{sheet_name}: Unexpected compartment(s): {compartment_description}")
# Validate values:
for sheet_name, df in sheets.items():
if sheet_name not in expected_value_columns.keys():
continue
missing_time_rows = []
missing_values = {}
# Time must be present
if 'Time' in df:
for index, time in enumerate(df['Time']):
if not is_non_negative_float(time, isnan_check=True):
missing_time_rows.append(str(index + 1))
if missing_time_rows:
row_description = _format_row_list_error(missing_time_rows)
errors.append(f"{sheet_name}: Missing or invalid time values on row(s) {row_description}")
# For the other rows, we're looking for non-negative numbers or blanks
value_columns = expected_value_columns[sheet_name].intersection(set(df.columns))
for column in value_columns:
for index, value in enumerate(df[column]):
if not is_non_negative_float(value, isnan_check=False):
if column not in missing_values:
missing_values[column] = []
missing_values[column].append(str(index + 1))
if column in missing_values:
row_description = _format_row_list_error(missing_values[column])
errors.append(f"{sheet_name}: Invalid values in column \"{column}\" on row(s) {row_description}")
return errors
[docs]
def _save_study(db_session, submission_form):
submission = submission_form.submission
params = {
'studyId': submission_form.study_id,
'studyName': submission.studyDesign['study']['name'],
'studyDescription': submission.studyDesign['study'].get('description', ''),
'studyURL': submission.studyDesign['study'].get('url', ''),
'studyUniqueID': submission.studyUniqueID,
'projectUniqueID': submission.projectUniqueID,
'timeUnits': submission.studyDesign['timeUnits'],
}
if submission_form.type != 'update_study':
study = Study(**Study.filter_keys(params))
study.studyId = Study.generate_public_id(db_session)
study.publishableAt = datetime.now(UTC) + timedelta(hours=24)
db_session.add(StudyUser(
studyUniqueID=submission.studyUniqueID,
userUniqueID=submission.userUniqueID,
))
else:
study = db_session.get(Study, submission.studyUniqueID)
study.update(**Study.filter_keys(params))
db_session.add(study)
return study
[docs]
def _save_project(db_session, submission_form):
submission = submission_form.submission
params = {
'projectId': submission_form.project_id,
'projectName': submission.studyDesign['project']['name'],
'projectDescription': submission.studyDesign['project'].get('description', ''),
'projectUniqueID': submission.projectUniqueID,
}
if submission_form.type == 'new_project':
project = Project(**Project.filter_keys(params))
project.projectId = Project.generate_public_id(db_session)
db_session.add(ProjectUser(
projectUniqueID=submission.projectUniqueID,
userUniqueID=submission.userUniqueID,
))
else:
project = db_session.get(Project, submission.projectUniqueID)
project.update(**Project.filter_keys(params))
db_session.add(project)
return project
[docs]
def _save_compartments(db_session, submission_form, study):
submission = submission_form.submission
compartments = []
for compartment_data in submission.studyDesign['compartments']:
compartment = Compartment(**Compartment.filter_keys(compartment_data))
compartments.append(compartment)
study.compartments = compartments
db_session.add_all(compartments)
return compartments
[docs]
def _save_communities(db_session, submission_form, study, user_uuid):
submission = submission_form.submission
communities = []
identifier_cache = {}
for community_data in submission.studyDesign['communities']:
community_data = copy.deepcopy(community_data)
strain_identifiers = community_data.pop('strainIdentifiers')
community = Community(**Community.filter_keys(community_data))
community.strainIds = []
for identifier in strain_identifiers:
if identifier not in identifier_cache:
strain = _build_strain(db_session, identifier, submission, study, user_uuid)
identifier_cache[identifier] = strain
db_session.add(strain)
db_session.flush()
community.strainIds.append(identifier_cache[identifier].id)
communities.append(community)
study.communities = communities
db_session.add_all(communities)
return communities
[docs]
def _save_experiments(db_session, submission_form, study):
submission = submission_form.submission
experiments = []
communities_by_name = group_by_unique_name(study.communities)
compartments_by_name = group_by_unique_name(study.compartments)
for experiment_data in submission.studyDesign['experiments']:
experiment_data = copy.deepcopy(experiment_data)
community_name = experiment_data.pop('communityName')
compartment_names = experiment_data.pop('compartmentNames')
bioreplicates = experiment_data.pop('bioreplicates')
perturbations = experiment_data.pop('perturbations')
experiment = Experiment(
**Experiment.filter_keys(experiment_data),
community=communities_by_name[community_name],
publicId=Experiment.generate_public_id(db_session),
)
db_session.add(experiment)
for compartment_name in compartment_names:
experiment_compartment = ExperimentCompartment(
study=study,
experiment=experiment,
compartment=compartments_by_name[compartment_name],
)
db_session.add(experiment_compartment)
for bioreplicate_data in bioreplicates:
bioreplicate = Bioreplicate(
**Bioreplicate.filter_keys(bioreplicate_data),
study=study,
experiment=experiment,
)
db_session.add(bioreplicate)
for perturbation_data in perturbations:
perturbation_data = copy.deepcopy(perturbation_data)
perturbation = Perturbation(
study=study,
experiment=experiment,
startTimepoint=perturbation_data.pop('startTimepoint'),
description=perturbation_data.pop('description'),
)
name = perturbation_data.pop('removedCompartmentName', '')
if name != '':
perturbation.removedCompartmentId = compartments_by_name[name].id
name = perturbation_data.pop('addedCompartmentName', '')
if name != '':
perturbation.addedCompartmentId = compartments_by_name[name].id
name = perturbation_data.pop('oldCommunityName', '')
if name != '':
perturbation.oldCommunityId = communities_by_name[name].id
name = perturbation_data.pop('newCommunityName', '')
if name != '':
perturbation.newCommunityId = communities_by_name[name].id
db_session.add(perturbation)
experiments.append(experiment)
study.experiments = experiments
db_session.add_all(experiments)
return experiments
[docs]
def _save_measurement_techniques(db_session, submission_form, study):
submission = submission_form.submission
techniques = []
for technique in submission.build_techniques():
technique.study = study
if technique.metaboliteIds:
for chebiId in technique.metaboliteIds:
db_session.add(StudyMetabolite(
chebi_id=chebiId,
study=study,
))
techniques.append(technique)
db_session.add_all(techniques)
return techniques
[docs]
def _save_measurements(db_session, study, submission):
data_xls = submission.dataFile.content
sheets = pd.read_excel(io.BytesIO(data_xls), sheet_name=None)
if 'Growth data per community' in sheets:
df = sheets['Growth data per community']
Measurement.insert_from_csv_string(db_session, study, df.to_csv(index=False), subject_type='bioreplicate')
if 'Growth data per strain' in sheets:
df = sheets['Growth data per strain']
Measurement.insert_from_csv_string(db_session, study, df.to_csv(index=False), subject_type='strain')
if 'Growth data per metabolite' in sheets:
df = sheets['Growth data per metabolite']
Measurement.insert_from_csv_string(db_session, study, df.to_csv(index=False), subject_type='metabolite')
[docs]
def _create_average_measurements(db_session, study, experiment):
bioreplicate_ids = [b.id for b in experiment.bioreplicates]
# The averaged measurements will be parented by a custom-generated bioreplicate:
# Note: This always gets created, unfortunately
average_bioreplicate = Bioreplicate(
name=f"Average({experiment.name})",
calculationType='average',
experiment=experiment,
study=study,
)
db_session.add(average_bioreplicate)
for technique in study.measurementTechniques:
for compartment in experiment.compartments:
# We'll average values separately over techniques and compartments:
measurement_contexts = db_session.scalars(
sql.select(MeasurementContext)
.distinct()
.join(Measurement)
.where(
MeasurementContext.compartmentId == compartment.id,
MeasurementContext.bioreplicateId.in_(bioreplicate_ids),
MeasurementContext.techniqueId == technique.id,
Measurement.value.is_not(None),
)
.order_by(MeasurementContext.subjectType, MeasurementContext.subjectId)
).all()
# If there is a single context for this cluster of measurements, there is nothing to average:
if len(measurement_contexts) <= 1:
continue
if technique.subjectType == 'bioreplicate':
# A single context for a group of bioreplicates
_create_average_measurement_context(
db_session,
parent_records=(study, technique, compartment),
measurement_contexts=measurement_contexts,
average_bioreplicate=average_bioreplicate,
subject_id=average_bioreplicate.id,
subject_type='bioreplicate',
)
else:
grouped_contexts = itertools.groupby(measurement_contexts, lambda mc: (mc.subjectId, mc.subjectType))
for (subject_id, subject_type), subject_contexts in grouped_contexts:
# One context for each subject:
_create_average_measurement_context(
db_session,
parent_records=(study, technique, compartment),
measurement_contexts=list(subject_contexts),
average_bioreplicate=average_bioreplicate,
subject_id=subject_id,
subject_type=subject_type
)
[docs]
def _create_average_measurement_context(
db_session,
parent_records,
measurement_contexts,
average_bioreplicate,
subject_id,
subject_type,
):
(study, technique, compartment) = parent_records
# Collect average measurement values for the given contexts:
measurement_rows = db_session.execute(
sql.select(
Measurement.timeInSeconds,
sql.func.avg(Measurement.value),
sql.func.std(Measurement.value),
)
.where(Measurement.contextId.in_([mc.id for mc in measurement_contexts]))
.group_by(Measurement.timeInSeconds)
.order_by(Measurement.timeInSeconds)
).all()
if len(measurement_rows) == 0:
# We do not want to create unnecessary contexts
return
# Create a parent context for the individual measurements:
average_context = MeasurementContext(
study=study,
bioreplicate=average_bioreplicate,
compartment=compartment,
subjectId=subject_id,
subjectType=subject_type,
technique=technique,
calculationType='average',
)
db_session.add(average_context)
# Create individual measurements
for (time, value, std) in measurement_rows:
measurement = Measurement(
timeInSeconds=time,
value=value,
std=std,
context=average_context,
study=study,
)
db_session.add(measurement)
[docs]
def _find_custom_strain(submission, identifier):
for custom_strain_data in submission.studyDesign['custom_strains']:
if custom_strain_data['name'] == identifier:
return custom_strain_data
else:
raise IndexError(f"New strain with name {repr(identifier)} not found in submission")
[docs]
def _get_expected_column_names(submission_form):
submission = submission_form.submission
community_columns = set()
strain_columns = set()
metabolite_columns = set()
# Validate column presence:
for technique in submission.build_techniques():
if technique.subjectType == 'bioreplicate':
column = technique.csv_column_name()
community_columns.add(column)
if technique.includeStd:
community_columns.add(f"{column} STD")
elif technique.subjectType == 'strain':
for taxon in submission_form.fetch_taxa():
column = technique.csv_column_name(taxon.name)
strain_columns.add(column)
if technique.includeStd:
strain_columns.add(f"{column} STD")
for strain in submission.studyDesign['custom_strains']:
column = technique.csv_column_name(strain['name'])
strain_columns.add(column)
if technique.includeStd:
strain_columns.add(f"{column} STD")
elif technique.subjectType == 'metabolite':
for metabolite in submission_form.fetch_all_metabolites():
column = technique.csv_column_name(metabolite.name)
metabolite_columns.add(column)
if technique.includeStd:
metabolite_columns.add(f"{column} STD")
else:
raise ValueError(f"Unexpected technique subjectType: {technique.subjectType}")
return {
'Growth data per community': community_columns,
'Growth data per strain': strain_columns,
'Growth data per metabolite': metabolite_columns,
}
[docs]
def _build_strain(db_session, identifier, submission, study, user_uuid):
strain_params = {'study': study, 'userUniqueID': user_uuid}
if identifier.startswith('existing|'):
taxon_id = identifier.removeprefix('existing|')
taxon = db_session.scalars(
sql.select(Taxon)
.where(Taxon.ncbiId == taxon_id)
.limit(1)
).one()
strain_params = {
'name': taxon.name,
'NCBId': taxon.ncbiId,
'defined': True,
**strain_params,
}
elif identifier.startswith('custom|'):
identifier = identifier.removeprefix('custom|')
custom_strain_data = _find_custom_strain(submission, identifier)
strain_params = {
'name': custom_strain_data['name'],
'NCBId': custom_strain_data['species'],
'description': custom_strain_data['description'],
'defined': False,
**strain_params,
}
else:
raise ValueError(f"Strain identifier {repr(identifier)} has an unexpected prefix")
return Strain(**strain_params)