Source code for app.model.lib.submission_process

import io
import copy
import itertools
from datetime import datetime, timedelta, time, UTC
from db import get_session, get_transaction

import pandas as pd
import sqlalchemy as sql

from app.model.orm import (
    Bioreplicate,
    Community,
    CommunityStrain,
    Compartment,
    Experiment,
    ExperimentCompartment,
    Measurement,
    MeasurementContext,
    Perturbation,
    Project,
    ProjectUser,
    Study,
    StudyMetabolite,
    StudyStrain,
    StudyUser,
    Taxon,
)
from app.model.lib.util import group_by_unique_name, is_non_negative_float, find_duplicates
from app.model.lib.conversion import convert_time
from app.model.tasks.submissions import export_submission_data


[docs] def persist_submission_to_database(submission_form): submission = submission_form.submission user_uuid = submission.userUniqueID errors = [] if submission_form.submission.dataFile is None: errors.append("Data file has not been uploaded") if errors: return errors with get_transaction() as db_transaction: db_trans_session = get_session(db_transaction) project = _save_project(db_trans_session, submission_form, user_uuid) study = _save_study(db_trans_session, submission_form, user_uuid) _clear_study(study) _save_study_techniques(db_trans_session, submission_form, study) _save_compartments(db_trans_session, submission_form, study) _save_communities(db_trans_session, submission_form, study, user_uuid) _save_experiments(db_trans_session, submission_form, study) db_trans_session.flush() _save_measurements(db_trans_session, study, submission_form) for experiment in study.experiments: _create_average_measurements(db_trans_session, study, experiment) _finalize_submission(db_trans_session, submission_form, study, project) db_trans_session.commit() if study.isPublished: export_submission_data.delay(submission.id) return []
# TODO (2025-05-11) Test (separate read_data_file, operate on dfs) #
[docs] def validate_data_file(submission_form, data_file=None): submission = submission_form.submission data_file = data_file or submission.dataFile errors = [] # TODO (2025-10-01) Show warnings that don't stop the upload # process: warnings = [] if not data_file: return [] data_xls = data_file.content sheets = pd.read_excel(io.BytesIO(data_xls), sheet_name=None) # Validate columns: community_columns, strain_columns, metabolite_columns = _get_expected_column_names(submission_form) expected_value_columns = {*community_columns, *strain_columns, *metabolite_columns} expected_columns = {'Biological Replicate', 'Compartment', 'Time', *expected_value_columns} found_columns = set() for sheet_name in sheets: df = sheets[sheet_name] found_columns |= set(df.columns) missing_columns = expected_columns.difference(found_columns) for missing_column in missing_columns: errors.append(f"Missing column: {missing_column}") # Validate row keys: expected_bioreplicates = { bioreplicate['name'] for experiment in submission.studyDesign['experiments'] for bioreplicate in experiment['bioreplicates'] } expected_compartments = {c['name'] for c in submission.studyDesign['compartments']} for sheet_name in sheets: df = sheets[sheet_name] if 'Biological Replicate' in df: uploaded_bioreplicates = set(df['Biological Replicate']) if missing_bioreplicates := expected_bioreplicates.difference(uploaded_bioreplicates): bioreplicate_description = ', '.join(missing_bioreplicates) errors.append(f"{sheet_name}: Missing biological replicate(s): {bioreplicate_description}") if extra_bioreplicates := uploaded_bioreplicates.difference(expected_bioreplicates): bioreplicate_description = ', '.join(extra_bioreplicates) errors.append(f"{sheet_name}: Unexpected biological replicate(s): {bioreplicate_description}") if 'Compartment' in df: uploaded_compartments = set(df['Compartment']) if extra_compartments := uploaded_compartments.difference(expected_compartments): compartment_description = ', '.join(extra_compartments) errors.append(f"{sheet_name}: Unexpected compartment(s): {compartment_description}") # Validate values: for sheet_name, df in sheets.items(): missing_time_rows = [] missing_values = {} # Check for missing Time values: if 'Time' in df: for index, value in enumerate(df['Time']): if not is_non_negative_float(value, isnan_check=True): missing_time_rows.append(str(index + 1)) if missing_time_rows: row_description = _format_row_list_error(missing_time_rows) # TODO (2025-10-01) Show warnings in UI warnings.append(f"{sheet_name}: Missing or invalid time values on row(s) {row_description}") # For the other rows, we're looking for non-negative numbers or blanks value_columns = expected_value_columns.intersection(set(df.columns)) for column in value_columns: for index, value in enumerate(df[column]): if not is_non_negative_float(value, isnan_check=False): if column not in missing_values: missing_values[column] = [] missing_values[column].append(str(index + 1)) if column in missing_values: row_description = _format_row_list_error(missing_values[column]) errors.append(f"{sheet_name}: Invalid values in column \"{column}\" on row(s) {row_description}") # Check for unique timepoints grouped_timepoints = {} time_data = df[['Biological Replicate', 'Compartment', 'Time']] for _, row in time_data.iterrows(): key = (row['Biological Replicate'], row['Compartment']) if row['Time']: if key not in grouped_timepoints: grouped_timepoints[key] = [] grouped_timepoints[key].append(row['Time']) for (b, c), timepoints in grouped_timepoints.items(): if duplicates := find_duplicates(timepoints): duplicates_description = ', '.join([str(d) for d in duplicates]) errors.append(f"{sheet_name}: Time points are not unique for bioreplicate {b}, compartment {c}: {duplicates_description}") return errors
def _save_study(db_session, submission_form, user_uuid=None): submission = submission_form.submission if embargo_string := submission.studyDesign['study'].get('embargoExpiresAt', None): embargo_date = datetime.fromisoformat(embargo_string) embargo_datetime = datetime.combine(embargo_date, time(hour=23, minute=59, tzinfo=UTC)) else: embargo_datetime = None params = { 'publicId': submission_form.study_id, 'name': submission.studyDesign['study']['name'].strip(), 'description': submission.studyDesign['study'].get('description', '').strip(), 'url': submission.studyDesign['study'].get('url', '').strip(), 'authors': submission.studyDesign['study'].get('authors', []), 'authorCache': submission.studyDesign['study'].get('authorCache', ''), 'uuid': submission.studyUniqueID, 'projectUuid': submission.projectUniqueID, 'timeUnits': submission.studyDesign['timeUnits'], 'embargoExpiresAt': embargo_datetime, } if submission_form.study_id is None: params['ownerUuid'] = user_uuid study = Study(**Study.filter_keys(params)) study.publicId = Study.generate_public_id(db_session) db_session.add(StudyUser( studyUniqueID=submission.studyUniqueID, userUniqueID=submission.userUniqueID, )) else: study = db_session.get(Study, submission_form.study_id) study.update(**Study.filter_keys(params)) tomorrow = datetime.now(UTC) + timedelta(hours=24) if embargo_datetime and embargo_datetime > tomorrow: study.publishableAt = embargo_datetime else: study.publishableAt = tomorrow db_session.add(study) return study def _save_project(db_session, submission_form, user_uuid=None): submission = submission_form.submission params = { 'publicId': submission_form.project_id, 'name': submission.studyDesign['project']['name'].strip(), 'description': submission.studyDesign['project'].get('description', '').strip(), 'uuid': submission.projectUniqueID, } if submission_form.project_id is None: params['ownerUuid'] = user_uuid project = Project(**Project.filter_keys(params)) project.publicId = Project.generate_public_id(db_session) db_session.add(ProjectUser( projectUniqueID=submission.projectUniqueID, userUniqueID=submission.userUniqueID, )) else: project = db_session.get(Project, submission_form.project_id) project.update(**Project.filter_keys(params)) db_session.add(project) return project def _clear_study(study): for experiment in study.experiments: experiment.experimentCompartments = [] experiment.perturbations = [] experiment.bioreplicates = [] study.measurements = [] study.measurementContexts = [] study.studyTechniques = [] study.strains = [] study.compartments = [] study.communities = [] study.studyMetabolites = [] def _save_compartments(db_session, submission_form, study): submission = submission_form.submission compartments = [] for compartment_data in submission.studyDesign['compartments']: compartment = Compartment(**Compartment.filter_keys(compartment_data)) compartments.append(compartment) study.compartments = compartments db_session.add_all(compartments) return compartments def _save_communities(db_session, submission_form, study, user_uuid): submission = submission_form.submission communities = [] identifier_cache = {} for community_data in submission.studyDesign['communities']: community_data = copy.deepcopy(community_data) strain_identifiers = community_data.pop('strainIdentifiers') community = Community(**Community.filter_keys(community_data)) db_session.add(community) db_session.flush() for identifier in strain_identifiers: if identifier not in identifier_cache: if strain := _build_strain(db_session, identifier, submission, study, user_uuid): identifier_cache[identifier] = strain db_session.add(strain) db_session.flush() if identifier in identifier_cache: community_strain = CommunityStrain( community=community, strain=identifier_cache[identifier], ) db_session.add(community_strain) communities.append(community) study.communities = communities # If any of the techniques have an "unknown" column, create an appropriate # strain: if any([st.includeUnknown for st in study.studyTechniques]): strain = _build_strain(db_session, "unknown", submission, study, user_uuid) db_session.add(strain) db_session.flush() return communities def _save_experiments(db_session, submission_form, study): submission = submission_form.submission experiments = [] time_units = submission.studyDesign['timeUnits'] communities_by_name = group_by_unique_name(study.communities) compartments_by_name = group_by_unique_name(study.compartments) for experiment_data in submission.studyDesign['experiments']: experiment_params = copy.deepcopy(experiment_data) community_name = experiment_params.pop('communityName') compartment_names = experiment_params.pop('compartmentNames') bioreplicates = experiment_params.pop('bioreplicates') perturbations = experiment_params.pop('perturbations') if publicId := experiment_params.pop('publicId', None): experiment = db_session.get(Experiment, publicId) if experiment.studyId != study.publicId: raise ValueError(f"Experiment with ID {publicId} does not belong to study {study.publicId}") else: experiment = Experiment(publicId=Experiment.generate_public_id(db_session)) experiment_data['publicId'] = experiment.publicId experiment.update( **Experiment.filter_keys(experiment_params), community=communities_by_name[community_name], ) db_session.add(experiment) for compartment_name in compartment_names: experiment_compartment = ExperimentCompartment( experiment=experiment, compartment=compartments_by_name[compartment_name], ) db_session.add(experiment_compartment) for bioreplicate_data in bioreplicates: bioreplicate = Bioreplicate( **Bioreplicate.filter_keys(bioreplicate_data), experiment=experiment, ) db_session.add(bioreplicate) for perturbation_data in perturbations: perturbation_data = copy.deepcopy(perturbation_data) start_time = perturbation_data.pop('startTime', '0') start_time_in_seconds = convert_time(int(start_time), time_units, 's') end_time = perturbation_data.pop('endTime', None) if end_time: end_time_in_seconds = convert_time(int(end_time), time_units, 's') else: end_time_in_seconds = None perturbation = Perturbation( experiment=experiment, startTimeInSeconds=start_time_in_seconds, endTimeInSeconds=end_time_in_seconds, description=perturbation_data.pop('description'), ) name = perturbation_data.pop('removedCompartmentName', '') if name != '': perturbation.removedCompartmentId = compartments_by_name[name].id name = perturbation_data.pop('addedCompartmentName', '') if name != '': perturbation.addedCompartmentId = compartments_by_name[name].id name = perturbation_data.pop('oldCommunityName', '') if name != '': perturbation.oldCommunityId = communities_by_name[name].id name = perturbation_data.pop('newCommunityName', '') if name != '': perturbation.newCommunityId = communities_by_name[name].id db_session.add(perturbation) experiments.append(experiment) study.experiments = experiments db_session.add_all(experiments) return experiments def _save_study_techniques(db_session, submission_form, study): submission = submission_form.submission techniques = [] for study_technique in submission.build_techniques(): study_technique.study = study for measurement_technique in study_technique.measurementTechniques: if measurement_technique.metaboliteIds: for chebiId in measurement_technique.metaboliteIds: db_session.add(StudyMetabolite( chebiId=chebiId, study=study, )) techniques.append(study_technique) db_session.add_all(techniques) return techniques def _save_measurements(db_session, study, submission_form): submission = submission_form.submission data_xls = submission.dataFile.content sheets = pd.read_excel(io.BytesIO(data_xls), sheet_name=None) for _, df in sheets.items(): Measurement.insert_from_csv_string(db_session, study, df.to_csv(index=False)) def _create_average_measurements(db_session, study, experiment): bioreplicate_ids = [b.id for b in experiment.bioreplicates if not b.calculationType] # The averaged measurements will be parented by a custom-generated bioreplicate: average_bioreplicate = Bioreplicate( name=f"Average({experiment.name})", calculationType='average', experiment=experiment, ) db_session.add(average_bioreplicate) has_measurements = False for technique in study.measurementTechniques: for compartment in experiment.compartments: # We'll average values separately over techniques and compartments: measurement_contexts = db_session.scalars( sql.select(MeasurementContext) .distinct() .join(Measurement) .where( MeasurementContext.compartmentId == compartment.id, MeasurementContext.bioreplicateId.in_(bioreplicate_ids), MeasurementContext.techniqueId == technique.id, Measurement.value.is_not(None), ) .order_by(MeasurementContext.subjectType, MeasurementContext.subjectId) ).all() # If there is a single context for this cluster of measurements, there is nothing to average: if len(measurement_contexts) <= 1: continue # If measurement time points don't match, don't average them: time_point_sets = set() for measurement_context in measurement_contexts: time_points = [m.timeInSeconds for m in measurement_context.measurements] time_point_sets.add(frozenset(time_points)) if len(time_point_sets) > 1: continue if technique.subjectType == 'bioreplicate': # A single context for a group of bioreplicates _create_average_measurement_context( db_session, parent_records=(study, technique, compartment), measurement_contexts=measurement_contexts, average_bioreplicate=average_bioreplicate, subject_id=average_bioreplicate.id, subject_type='bioreplicate', subject_name=average_bioreplicate.name, subject_external_id=None, ) else: grouped_contexts = itertools.groupby( measurement_contexts, lambda mc: (mc.subjectId, mc.subjectType, mc.subjectName, mc.subjectExternalId), ) for key, subject_contexts in grouped_contexts: ( subject_id, subject_type, subject_name, subject_external_id, ) = key # One context for each subject: _create_average_measurement_context( db_session, parent_records=(study, technique, compartment), measurement_contexts=list(subject_contexts), average_bioreplicate=average_bioreplicate, subject_id=subject_id, subject_type=subject_type, subject_name=subject_name, subject_external_id=subject_external_id, ) has_measurements = True if not has_measurements: db_session.delete(average_bioreplicate) def _create_average_measurement_context( db_session, parent_records, measurement_contexts, average_bioreplicate, subject_id, subject_type, subject_name, subject_external_id=None, ): (study, technique, compartment) = parent_records # Collect average measurement values for the given contexts: measurement_rows = db_session.execute( sql.select( Measurement.timeInSeconds, sql.func.avg(Measurement.value), sql.func.std(Measurement.value), ) .where(Measurement.contextId.in_([mc.id for mc in measurement_contexts])) .group_by(Measurement.timeInSeconds) .order_by(Measurement.timeInSeconds) ).all() if len(measurement_rows) == 0: # We do not want to create unnecessary contexts return # Create a parent context for the individual measurements: average_context = MeasurementContext( study=study, bioreplicate=average_bioreplicate, compartment=compartment, subjectId=subject_id, subjectType=subject_type, subjectName=subject_name, subjectExternalId=subject_external_id, technique=technique, calculationType='average', ) db_session.add(average_context) # Create individual measurements for (t, value, std) in measurement_rows: measurement = Measurement( timeInSeconds=t, value=value, std=std, context=average_context, study=study, ) db_session.add(measurement) def _finalize_submission(db_session, submission_form, study, project): study.lastSubmissionId = submission_form.submission.id if study.isPublished: submission_form.submission.publishedAt = study.publishedAt submission_form.save() submission_form.save_backup(study_id=study.publicId, project_id=project.publicId) def _find_custom_strain(submission, identifier): for custom_strain_data in submission.studyDesign['custom_strains']: if custom_strain_data['name'] == identifier: return custom_strain_data else: return None def _get_expected_column_names(submission_form): submission = submission_form.submission community_columns = set() strain_columns = set() metabolite_columns = set() # Validate column presence: for index, study_technique in enumerate(submission.build_techniques()): for measurement_technique in study_technique.measurementTechniques: if measurement_technique.subjectType == 'bioreplicate': column = measurement_technique.csv_column_name() community_columns.add(column) if study_technique.includeStd: community_columns.add(f"{column} STD") elif measurement_technique.subjectType == 'strain': for taxon in submission_form.fetch_taxa(): column = measurement_technique.csv_column_name(taxon.name) strain_columns.add(column) if study_technique.includeStd: strain_columns.add(f"{column} STD") if study_technique.includeUnknown: column = measurement_technique.csv_column_name("Unknown") strain_columns.add(column) if study_technique.includeStd: strain_columns.add(f"{column} STD") for strain in submission.studyDesign['custom_strains']: column = measurement_technique.csv_column_name(strain['name']) strain_columns.add(column) if study_technique.includeStd: strain_columns.add(f"{column} STD") elif measurement_technique.subjectType == 'metabolite': for metabolite in submission_form.fetch_metabolites_for_technique(index): column = measurement_technique.csv_column_name(metabolite.name) metabolite_columns.add(column) if study_technique.includeStd: metabolite_columns.add(f"{column} STD") else: raise ValueError(f"Unexpected measurement_technique subjectType: {measurement_technique.subjectType}") return community_columns, strain_columns, metabolite_columns def _build_strain(db_session, identifier, submission, study, user_uuid): strain_params = {'study': study, 'userUniqueID': user_uuid} if identifier.startswith('existing|'): taxon_id = identifier.removeprefix('existing|') taxon = db_session.scalars( sql.select(Taxon) .where(Taxon.ncbiId == taxon_id) .limit(1) ).one() strain_params = { 'name': taxon.name, 'ncbiId': taxon.ncbiId, 'defined': True, **strain_params, } elif identifier.startswith('custom|'): identifier = identifier.removeprefix('custom|') custom_strain_data = _find_custom_strain(submission, identifier) if custom_strain_data is None: # Missing strain due to renames return None strain_params = { 'name': custom_strain_data['name'], 'ncbiId': custom_strain_data['species'], 'description': custom_strain_data['description'], 'defined': False, **strain_params, } elif identifier == 'unknown': strain_params = { 'name': "Unknown", 'ncbiId': 0, 'description': "Unknown measurements", 'defined': False, **strain_params, } else: raise ValueError(f"Strain identifier {repr(identifier)} has an unexpected prefix") return StudyStrain(**strain_params) def _format_row_list_error(row_list): description = ', '.join(row_list[0:3]) if len(row_list) > 3: description += f", and {len(row_list) - 3} more" return description