import math
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from app.model.lib.conversion import (
convert_df_units,
CELL_COUNT_UNITS,
CFU_COUNT_UNITS,
METABOLITE_UNITS,
)
from app.model.lib.util import hex_to_rgba
[docs]
PLOTLY_TEMPLATE = 'plotly_white'
"List of templates can be found at plotly.com/python/templates"
[docs]
class Chart:
"""
An object that encapsulates the common properties of Plotly charts across
the site.
"""
def __init__(
self,
time_units,
cell_count_units='Cells/mL',
cfu_count_units='CFUs/mL',
metabolite_units='mM',
log_left=False,
log_right=False,
width=None,
clamp_x_data=False,
show_std=True,
):
# TODO (2025-06-25) Unused, should consider conversion, but handle
# units during modeling:
[docs]
self.time_units = time_units
[docs]
self.cell_count_units = cell_count_units
[docs]
self.cfu_count_units = cfu_count_units
[docs]
self.clamp_x_data = clamp_x_data
[docs]
self.show_std = show_std
[docs]
self.log_left = log_left
[docs]
self.log_right = log_right
[docs]
self.mixed_units_left = False
[docs]
self.mixed_units_right = False
[docs]
self.model_df_left_indices = []
[docs]
self.model_df_right_indices = []
[docs]
self.colors = px.colors.qualitative.Plotly
[docs]
def add_df(self, df, *, units, label=None, axis='left', metabolite_mass=None):
if 'std' not in df:
df['std'] = [float('nan') for _ in range(df['value'].size)]
entry = (df, units, label, metabolite_mass)
if axis == 'left':
self.data_left.append(entry)
elif axis == 'right':
self.data_right.append(entry)
else:
raise ValueError(f"Unexpected axis: {axis}")
[docs]
def add_model_df(self, df, *, units, label=None, axis='left'):
entry = (df, units, label, None)
if axis == 'left':
self.model_df_left_indices.append(len(self.data_left))
self.data_left.append(entry)
elif axis == 'right':
self.model_df_right_indices.append(len(self.data_right))
self.data_right.append(entry)
else:
raise ValueError(f"Unexpected axis: {axis}")
[docs]
def add_region(self, start_time, end_time, label, text):
self.regions.append((start_time, end_time, label, text))
[docs]
def to_html(self):
self.color_index = 0
fig = make_subplots(specs=[[{"secondary_y": True}]])
converted_data_left, left_units_label = self._convert_units(self.data_left)
converted_data_right, right_units_label = self._convert_units(self.data_right)
if left_units_label == '[mixed units]':
self.mixed_units_left = True
if right_units_label == '[mixed units]':
self.mixed_units_right = True
for (df, label) in converted_data_left:
scatter_param_list = self._get_scatter_params(df, label, log=self.log_left)
for scatter_params in scatter_param_list:
fig.add_trace(go.Scatter(**scatter_params), secondary_y=False)
for (df, label) in converted_data_right:
scatter_param_list = self._get_scatter_params(df, label, log=self.log_right)
for scatter_params in scatter_param_list:
scatter_params['line'] = {'dash': 'dot'}
fig.add_trace(go.Scatter(**scatter_params), secondary_y=True)
if self.clamp_x_data:
# Fit the x-axis of the shortest chart:
xaxis_range = self._calculate_x_range(converted_data_left + converted_data_right)
else:
xaxis_range = None
left_yaxis_range = self._calculate_y_range(
converted_data_left,
model_df_indices=self.model_df_left_indices,
log=self.log_left,
)
right_yaxis_range = self._calculate_y_range(
converted_data_right,
model_df_indices=self.model_df_right_indices,
log=self.log_right,
)
if self.regions:
# No log-transformation applied for the region-drawing:
y0, y1 = self._calculate_y_range(converted_data_left, self.model_df_left_indices)
for index, (x0, x1, label, text) in enumerate(self.regions):
fig.add_trace(
go.Scatter(
name=label,
x=[x0, x0, x1, x1, x0],
y=[y0, y0, y0, y1, y1],
opacity=0.15,
line_width=0,
fill="toself",
hovertemplate=text,
mode="text",
),
)
left_yaxis = dict(
side="left",
title_text=left_units_label,
exponentformat="power",
range=left_yaxis_range,
)
if self.log_left:
left_yaxis['type'] = 'log'
right_yaxis = dict(
side="right",
title_text=right_units_label,
exponentformat="power",
range=right_yaxis_range,
)
if self.log_right:
right_yaxis['type'] = 'log'
fig.update_layout(
showlegend=True,
template=PLOTLY_TEMPLATE,
margin=dict(l=0, r=0, t=60, b=40),
title=dict(x=0),
hovermode='x unified',
legend=dict(
yanchor="bottom",
y=1,
xanchor="left",
x=0,
orientation='h',
maxheight=0.25,
),
modebar=dict(orientation='v'),
font_family="Public Sans",
yaxis=left_yaxis,
yaxis2=right_yaxis,
xaxis=dict(
title=dict(text='Time (h)'),
range=xaxis_range,
)
)
return fig.to_html(
full_html=False,
include_plotlyjs=False,
default_width=(f"{self.width}px" if self.width is not None else '100%'),
config={
'toImageButtonOptions': {
'format': 'svg',
'filename': 'mgrowth_chart',
# Force width and height to be the same as the visible dimensions on screen
# Reference: https://github.com/plotly/plotly.js/pull/3746
'height': None,
'width': None,
},
},
)
def _convert_units(self, data):
if len(data) == 0:
return [], None
converted_units = set()
converted_data = [(df, label) for (df, _, label, _) in data]
for (df, units, label, metabolite_mass) in data:
if units in CELL_COUNT_UNITS:
result_units = convert_df_units(df, units, self.cell_count_units)
converted_units.add(result_units)
elif units in CFU_COUNT_UNITS:
result_units = convert_df_units(df, units, self.cfu_count_units)
converted_units.add(result_units)
elif units in METABOLITE_UNITS:
result_units = convert_df_units(df, units, self.metabolite_units, metabolite_mass)
converted_units.add(result_units)
else:
converted_units.add(units)
if len(converted_units) > 1 or len(converted_data) == 0:
return converted_data, '[mixed units]'
return converted_data, tuple(converted_units)[0]
def _get_scatter_params(self, df, label, log=False):
scatter_param_list = []
value = df['value']
main_scatter_params = dict(
x=df['time'],
y=value,
name=label,
fillcolor=self.colors[self.color_index % len(self.colors)]
)
self.color_index += 1
if self.show_std and 'std' in df and not df['std'].isnull().all():
# We want to clip negative error bars to 0
positive_err = df['std']
negative_err = np.clip(df['std'], max=df['value'])
# Use error bars, add them to the main trace:
if (positive_err == negative_err).all():
error_y = go.scatter.ErrorY(array=positive_err)
else:
error_y = go.scatter.ErrorY(array=positive_err, arrayminus=negative_err)
main_scatter_params['error_y'] = error_y
scatter_param_list.append(main_scatter_params)
if value.size >= 100:
# We have many points, let's hide the error bars and show a band:
scatter_param_list[0]['error_y'].thickness = 0
scatter_param_list[0]['error_y'].width = 0
# Upper bound:
scatter_param_list.append(dict(
name=f"{label} upper bound",
x=df['time'],
y=df['value'] + positive_err,
mode='lines',
line=dict(width=0),
showlegend=False,
hoverinfo='skip',
fillcolor=hex_to_rgba(scatter_param_list[0]['fillcolor'], 0.25),
))
# Lower bound:
scatter_param_list.append(dict(
name=f"{label} lower bound",
x=df['time'],
y=df['value'] - negative_err,
mode='lines',
line=dict(width=0),
showlegend=False,
hoverinfo='skip',
fill='tonexty',
fillcolor=hex_to_rgba(scatter_param_list[0]['fillcolor'], 0.25),
))
else:
scatter_param_list.append(main_scatter_params)
return scatter_param_list
def _calculate_x_range(self, data):
# With multiple charts, fit the x-axis of the shortest one:
global_max_x = math.inf
global_min_x = 0
for (i, (df, _)) in enumerate(data):
max_x = df['time'].max()
min_x = df['time'].min()
if max_x < global_max_x:
global_max_x = max_x
if min_x > global_min_x:
global_min_x = min_x
# The range of the chart is given a padding depending on the data range
# to make sure the content is visible:
padding = (global_max_x - global_min_x) * 0.05
return [global_min_x - padding, global_max_x + padding]
def _calculate_y_range(self, data, model_df_indices, log=False):
"""
Find the limit for the y axis, ignoring model dataframes, since they
might have exponentials that shoot up.
"""
global_max_y = 0
global_min_y = math.inf
global_positive_min_y = math.inf
for (i, (df, _)) in enumerate(data):
if i in model_df_indices:
# A model's data might shoot up exponentially, so we don't
# consider it for the chart range
continue
# We look for the min and max values + std in the dataframe:
lowers = []
uppers = []
entries = zip(df['value'], df['std'], df['std'])
for value, upper_std, lower_std in entries:
# For some reason, pandas might give us a None here, or it might
# give us a NaN
if upper_std is None or math.isnan(upper_std):
upper_std = 0
if lower_std is None or math.isnan(lower_std):
lower_std = 0
uppers.append(value + upper_std)
if log:
lowers.append(value - lower_std)
else:
lowers.append(np.clip(value - lower_std, min=0))
max_y = max(uppers)
min_y = min(lowers)
positive_ys = [y for y in lowers if y > 0]
if positive_ys:
positive_min_y = min(positive_ys)
else:
positive_min_y = None
if max_y > global_max_y:
global_max_y = max_y
if min_y < global_min_y:
global_min_y = min_y
if positive_min_y is not None and positive_min_y < global_positive_min_y:
global_positive_min_y = positive_min_y
# The range of the chart is given a padding depending on the data range
# to make sure the content is visible:
padding = (global_max_y - global_min_y) * 0.05
if log:
# Fifth of an order of magnitude of padding:
padding = 0.2
global_max_y = np.log10(global_max_y)
if global_min_y <= 0.0:
global_min_y = np.log10(global_positive_min_y)
else:
global_min_y = np.log10(global_min_y)
lower = global_min_y - padding
upper = global_max_y + padding
return [lower, upper]