import tempfile
from datetime import datetime, UTC
import sqlalchemy as sql
from sqlalchemy.orm.attributes import flag_modified
from celery import shared_task
from celery.utils.log import get_task_logger
from db import FLASK_DB
from app.model.lib.r_script import RScript
from app.model.orm import (
ModelingResult,
MeasurementContext,
WorkspaceEntry,
)
_LOGGER = get_task_logger(__name__)
@shared_task
[docs]
def process_modeling_request(modeling_result_id, *, target_type, target_id, args):
db_session = FLASK_DB.session
_process_modeling_request(db_session, modeling_result_id, target_type, target_id, args)
def _process_modeling_request(db_session, modeling_result_id, target_type, target_id, args={}):
modeling_result = db_session.get(ModelingResult, modeling_result_id)
if target_type == 'MeasurementContext':
target = db_session.get(MeasurementContext, target_id)
elif target_type == 'WorkspaceEntry':
target = db_session.get(WorkspaceEntry, target_id)
else:
raise ValueError(f"Unexpected target type: {target_type}")
modeling_type = modeling_result.type
point_count = int(args.get('pointCount', '5'))
end_time = args.get('endTime', '')
with tempfile.TemporaryDirectory() as tmp_dir_name:
inputs = {}
if modeling_type == 'easy_linear':
inputs = {'pointCount': point_count}
elif modeling_type in ('logistic', 'baranyi_roberts'):
inputs = {'endTime': end_time}
data = target.get_df(db_session)
if modeling_type in ('logistic', 'baranyi_roberts') and end_time != '':
data = data[data['time'] <= float(end_time)]
# We don't need error columns for modeling:
data = data[["time", "value"]]
# Remove rows with NA values, if any
data = data.dropna()
try:
rscript = RScript(root_path=tmp_dir_name)
rscript.write_csv('input.csv', data)
if modeling_type == 'easy_linear':
rscript.write_json('input.json', {'pointCount': point_count})
script_name = f"scripts/modeling/{modeling_type}.R"
output = rscript.run(script_name)
_LOGGER.info(output)
fit = rscript.read_flat_json('fit.json', discard_keys="_row")
coefficients = rscript.read_key_value_json(
'coefficients.json',
key_name="_row",
value_name="coefficients",
)
if coefficients is None or fit is None:
modeling_result.state = 'error'
modeling_result.error = 'No coefficients and/or fit were generated by the R script'
else:
modeling_result.update(
rSummary=_extract_r_summary(output),
params={
'inputs': inputs,
'coefficients': coefficients,
'fit': fit,
'r_version': rscript.get_r_version(),
'growthrates_version': rscript.get_growthrates_version(),
},
state='ready',
error=None,
calculatedAt=datetime.now(UTC),
)
flag_modified(modeling_result, 'params')
except Exception as e:
modeling_result.state = 'error'
modeling_result.error = 'RScript error'
_LOGGER.error(e)
db_session.add(modeling_result)
db_session.commit()
# Returning the object for testing purposes, not used
return modeling_result
def _extract_r_summary(text):
output_lines = []
in_summary = False
for line in text.splitlines():
if '## SUMMARY START' in line:
in_summary = True
continue
elif '## SUMMARY END' in line:
in_summary = False
if in_summary:
output_lines.append(line)
if output_lines:
return "\n".join(output_lines)
else:
return None