Source code for modeling_tasks

import tempfile
from datetime import datetime, UTC

import sqlalchemy as sql
from celery import shared_task
from celery.utils.log import get_task_logger

from db import FLASK_DB
from app.model.lib.r_script import RScript
from app.model.orm import (
    ModelingRequest,
    ModelingResult,
    MeasurementContext,
)

[docs] LOGGER = get_task_logger(__name__)
@shared_task
[docs] def process_modeling_request(modeling_request_id, measurement_context_ids, args): db_session = FLASK_DB.session return _process_modeling_request(db_session, modeling_request_id, measurement_context_ids, args)
[docs] def _process_modeling_request(db_session, modeling_request_id, measurement_context_ids, args={}): modeling_request = db_session.get(ModelingRequest, modeling_request_id) modeling_request.state = 'in_progress' db_session.commit() measurement_contexts = db_session.scalars( sql.select(MeasurementContext) .where(MeasurementContext.id.in_(measurement_context_ids)) ).all() has_error = False point_count = int(args.get('pointCount', '5')) end_time = args.get('endTime', '') with tempfile.TemporaryDirectory() as tmp_dir_name: for measurement_context in measurement_contexts: modeling_result = db_session.scalars( sql.select(ModelingResult) .where( ModelingResult.requestId == modeling_request.id, ModelingResult.measurementContextId == measurement_context.id, ) ).one_or_none() if not modeling_result: modeling_result = ModelingResult( type=modeling_request.type, request=modeling_request, measurementContext=measurement_context, ) if modeling_request.type == 'easy_linear': modeling_result.inputs = {'pointCount': point_count} elif modeling_request.type in ('logistic', 'baranyi_roberts'): modeling_result.inputs = {'endTime': end_time} db_session.add(modeling_result) modeling_request.results.append(modeling_result) data = measurement_context.get_df(db_session) if modeling_request.type in ('logistic', 'baranyi_roberts') and end_time != '': data = data[data['time'] <= float(end_time)] # We don't need standard deviation for modeling: data = data.drop(columns=['std']) # Remove rows with NA values, if any data = data.dropna() try: rscript = RScript(root_path=tmp_dir_name) rscript.write_csv('input.csv', data) if modeling_request.type == 'easy_linear': rscript.write_json('input.json', {'pointCount': point_count}) script_name = f"scripts/modeling/{modeling_request.type}.R" output = rscript.run(script_name) LOGGER.info(output) fit = rscript.read_flat_json('fit.json', discard_keys="_row") coefficients = rscript.read_key_value_json( 'coefficients.json', key_name="_row", value_name="coefficients", ) if coefficients is None or fit is None: modeling_result.state = 'error' modeling_result.error = 'No coefficients and/or fit were generated by the R script' has_error = True else: modeling_result.coefficients = coefficients modeling_result.rSummary = _extract_r_summary(output) modeling_result.fit = fit modeling_result.state = 'ready' modeling_result.error = None modeling_result.calculatedAt = datetime.now(UTC) modeling_request.error = None except Exception as e: modeling_result.state = 'error' modeling_result.error = 'RScript error' LOGGER.error(e) if has_error: modeling_request.state = 'error' modeling_request.error = 'One or more requests errored out' else: modeling_request.state = 'ready' db_session.add(modeling_request) db_session.commit()
[docs] def _extract_r_summary(text): output_lines = [] in_summary = False for line in text.splitlines(): if '## SUMMARY START' in line: in_summary = True continue elif '## SUMMARY END' in line: in_summary = False if in_summary: output_lines.append(line) if output_lines: return "\n".join(output_lines) else: return None