import tempfile
from datetime import datetime, UTC
import sqlalchemy as sql
from celery import shared_task
from celery.utils.log import get_task_logger
from db import FLASK_DB
from app.model.lib.r_script import RScript
from app.model.orm import (
ModelingRequest,
ModelingResult,
MeasurementContext,
)
[docs]
LOGGER = get_task_logger(__name__)
@shared_task
[docs]
def process_modeling_request(modeling_request_id, measurement_context_ids, args):
db_session = FLASK_DB.session
return _process_modeling_request(db_session, modeling_request_id, measurement_context_ids, args)
[docs]
def _process_modeling_request(db_session, modeling_request_id, measurement_context_ids, args={}):
modeling_request = db_session.get(ModelingRequest, modeling_request_id)
modeling_request.state = 'in_progress'
db_session.commit()
measurement_contexts = db_session.scalars(
sql.select(MeasurementContext)
.where(MeasurementContext.id.in_(measurement_context_ids))
).all()
has_error = False
point_count = int(args.get('pointCount', '5'))
end_time = args.get('endTime', '')
with tempfile.TemporaryDirectory() as tmp_dir_name:
for measurement_context in measurement_contexts:
modeling_result = db_session.scalars(
sql.select(ModelingResult)
.where(
ModelingResult.requestId == modeling_request.id,
ModelingResult.measurementContextId == measurement_context.id,
)
).one_or_none()
if not modeling_result:
modeling_result = ModelingResult(
type=modeling_request.type,
request=modeling_request,
measurementContext=measurement_context,
)
if modeling_request.type == 'easy_linear':
modeling_result.inputs = {'pointCount': point_count}
elif modeling_request.type in ('logistic', 'baranyi_roberts'):
modeling_result.inputs = {'endTime': end_time}
db_session.add(modeling_result)
modeling_request.results.append(modeling_result)
data = measurement_context.get_df(db_session)
if modeling_request.type in ('logistic', 'baranyi_roberts') and end_time != '':
data = data[data['time'] <= float(end_time)]
# We don't need standard deviation for modeling:
data = data.drop(columns=['std'])
# Remove rows with NA values, if any
data = data.dropna()
try:
rscript = RScript(root_path=tmp_dir_name)
rscript.write_csv('input.csv', data)
if modeling_request.type == 'easy_linear':
rscript.write_json('input.json', {'pointCount': point_count})
script_name = f"scripts/modeling/{modeling_request.type}.R"
output = rscript.run(script_name)
LOGGER.info(output)
fit = rscript.read_flat_json('fit.json', discard_keys="_row")
coefficients = rscript.read_key_value_json(
'coefficients.json',
key_name="_row",
value_name="coefficients",
)
if coefficients is None or fit is None:
modeling_result.state = 'error'
modeling_result.error = 'No coefficients and/or fit were generated by the R script'
has_error = True
else:
modeling_result.coefficients = coefficients
modeling_result.rSummary = _extract_r_summary(output)
modeling_result.fit = fit
modeling_result.state = 'ready'
modeling_result.error = None
modeling_result.calculatedAt = datetime.now(UTC)
modeling_request.error = None
except Exception as e:
modeling_result.state = 'error'
modeling_result.error = 'RScript error'
LOGGER.error(e)
if has_error:
modeling_request.state = 'error'
modeling_request.error = 'One or more requests errored out'
else:
modeling_request.state = 'ready'
db_session.add(modeling_request)
db_session.commit()