def age_curves(fdir, df, channel, sweepvar='modelname', facet_var='output_target', age_groups_aggregates=None):
"""
Plots age curves based on model results and saves the figure to the specified directory.
Args:
fdir (str): Directory where the plot will be saved.
df (pandas.DataFrame): DataFrame that includes combined model results.
channel (str): Variable representing the y-axis data to be plotted.
sweepvar (str, optional): Variable to group the data and create multiple lines on the plot.
Defaults to 'modelname'.
facet_var (str, optional): Variable used to create subplots based on its unique values.
Defaults to 'output_target'.
age_groups_aggregates (list, optional): List of age group labels to use for aggregation.
If None, default age groups will be used.
Raises:
ValueError: If `facet_var` is set to 'ageGroup'.
Returns:
None: The function saves the generated plots to disk.
Notes:
- The function generates a multi-faceted plot with separate subplots for each unique value
of the `facet_var` and lines representing the `sweepvar` variable.
- Each subplot visualizes the age distribution of the specified `channel` variable, with
filled areas representing the confidence interval.
"""
if facet_var == 'ageGroup':
raise ValueError('Age curves designed to have age on the x-axis, ageGroup facets are not supported')
if age_groups_aggregates is None:
age_groups_aggregates = ['0-0.5', '0.5-1', '1-2', '2-5', '5-10', '10-15', '15-20', '20-100']
figure_vars = [channel] + [sweepvar, facet_var]
df, caption_txt = subset_dataframe_for_plot(df, figure_vars)
firstPlot = True
unique_facets = df[facet_var].unique()
unique_facets = sorted(unique_facets)
color_palette = sns.color_palette('colorblind', max(len(df[sweepvar].unique()), 4))
nx = round(len(unique_facets) / 2)
ny = 2
f = 1
fig = plt.figure(figsize=(10 * nx, 10 * ny))
for fi in unique_facets:
fdf = df[df[facet_var] == fi]
fdf = fdf[fdf[channel].notnull()]
ax = fig.add_subplot(ny, nx, f)
f += 1
max_ylim = []
for i, (s, sdf) in enumerate(fdf.groupby([sweepvar])):
color_key = color_selector(i, s)
ylim_up = np.max(sdf[channel]) * 1.1
if channel == 'prevalence' or channel == 'prevalence_2to10':
ylim_up = min(ylim_up, 1)
max_ylim.append(ylim_up)
xmean, ymean = get_x_y(sdf, 'ageGroup', channel, channel)
xmean = xmean['ageGroup']
merge_df = pd.merge(left=xmean, right=ymean, on='ageGroup')
merge_df['ageGroup'] = pd.Categorical(merge_df['ageGroup'], age_groups_aggregates) # FIXME returns NA for 0-5 and 0-100, perhaps wanted
merge_df = merge_df.dropna()
merge_df.sort_values('ageGroup', inplace=True)
ax.plot(merge_df['ageGroup'], merge_df[channel], '-', linewidth=0.8, label=f"{s}",
color=color_palette[color_key])
ax.fill_between(merge_df['ageGroup'], merge_df[f'{channel}_min'], merge_df[f'{channel}_max'], alpha=0.1,
color=color_palette[color_key])
if firstPlot:
lg = ax.legend(loc='upper left', bbox_to_anchor=(0,1))
firstPlot = False
ax.set_ylim(0, max(max_ylim))
ax.tick_params(axis='x', labelrotation=45)
ax.set_xlabel('Age (years)')
ax.set_ylabel(get_ylab(channel))
ax.set_title(f"{facet_var} = {fi}")
fname = f'agecurves_{channel}_by_{sweepvar}'
fname = fname.replace('modelname', 'model')
fig.savefig(os.path.join(fdir, f'{fname}.png'), bbox_extra_artists=(lg,), bbox_inches='tight')
plt.close()