def age_curves(fdir, df, channel, sweepvar='models', facet_var='target_output_values',
age_groups_aggregates=None, facet_name=None, exp=None, facet_layout=None,
width_per_col=8, height_per_row=8, space=0.4, file_format='png'):
"""
Plots age curves based on model results and saves the figure to the specified directory.
Args:
fdir (str): Directory where the plot will be saved.
df (pandas.DataFrame): DataFrame that includes combined model results.
channel (str): Variable representing the y-axis data to be plotted.
sweepvar (str, optional): Variable to group the data and create multiple lines on the plot.
Defaults to 'models'.
facet_var (str, optional): Variable used to create subplots based on its unique values.
Defaults to 'target_output_values'.
age_groups_aggregates (list, optional): List of age group labels to use for aggregation.
If None, default age groups will be used.
facet_name (str, optional): Friendly name for facet_var in plot titles. If None, uses facet_var.
exp (object, optional): Optional experiment object used for legend labeling.
facet_layout (str or None, optional): Layout of subplots. One of:
- 'horizontal': all facets in a single row
- 'vertical': all facets in a single column
- None: automatic grid layout (default)
width_per_col (int or float, optional): Width of each subplot column in inches. Default is 8.
height_per_row (int or float, optional): Height of each subplot row in inches. Default is 8.
space (float, optional): Space between subplots (both hspace and wspace). Default is 0.4.
file_format (str, optional): File format for saved figure. One of: 'png', 'pdf', 'jpg'. Default is 'png'.
Raises:
ValueError: If `facet_var` is set to 'ageGroup'.
Returns:
None: The function saves the generated plot to disk in the specified format.
"""
if facet_name is None:
facet_name = facet_var
if facet_var == 'ageGroup':
raise ValueError('Age curves designed to have age on the x-axis, ageGroup facets are not supported')
if age_groups_aggregates is None:
age_groups_aggregates = ['0-0.5', '0.5-1', '1-2', '2-5', '5-10', '10-15', '15-20', '20-100']
figure_vars = [channel] + [sweepvar, facet_var]
df, caption_txt = subset_dataframe_for_plot(df, figure_vars)
firstPlot = True
unique_facets = sorted_list(df[facet_var])
unique_groups = sorted_list(df[sweepvar])
# Determine facet layout
if facet_layout == 'horizontal':
nx = len(unique_facets)
ny = 1
elif facet_layout == 'vertical':
nx = 1
ny = len(unique_facets)
else:
nx = max(1, round(len(unique_facets) / 2))
ny = 2
f = 1
fig = plt.figure(figsize=(width_per_col * nx, height_per_row * ny))
for fi in unique_facets:
fdf = df[df[facet_var] == fi]
fdf = fdf[fdf[channel].notnull()]
ax = fig.add_subplot(ny, nx, f)
ax.set_title(f"{facet_var} = {fi}")
if facet_var == 'models':
ax.set_title(fi)
if facet_var == 'target_output_values':
ax.set_title(f"{get_label(facet_name)} = {fi}")
f += 1
max_ylim = []
for i, (s, sdf) in enumerate(fdf.groupby([sweepvar])):
color = color_selector(i, s, sweepvar=sweepvar, n_colors=len(unique_groups))
ylim_up = np.max(sdf[channel]) * 1.1
if channel == 'prevalence' or channel == 'prevalence_2to10':
ylim_up = min(ylim_up, 1)
max_ylim.append(ylim_up)
xmean, ymean = get_x_y(sdf, 'ageGroup', channel, channel)
xmean = xmean['ageGroup']
merge_df = pd.merge(left=xmean, right=ymean, on='ageGroup')
merge_df['ageGroup'] = pd.Categorical(merge_df['ageGroup'], age_groups_aggregates)
merge_df = merge_df.dropna()
merge_df.sort_values('ageGroup', inplace=True)
ax.plot(merge_df['ageGroup'], merge_df[channel], '-', linewidth=1, label=f"{s}", color=color)
ax.fill_between(merge_df['ageGroup'], merge_df[f'{channel}_min'], merge_df[f'{channel}_max'], alpha=0.1,
color=color)
if firstPlot:
lg_title = get_legend_title(sweepvar, exp)
lg = ax.legend(title=lg_title, loc='upper left', bbox_to_anchor=(0, 1))
firstPlot = False
ax.set_ylim(0, max(max_ylim))
ax.tick_params(axis='x', labelrotation=45)
ax.set_xlabel('Age (years)')
ax.set_ylabel(get_label(channel))
fname = f'agecurves_{channel}_{sweepvar}_{facet_var}'
fname = clean_fname(fname, sweepvar, unique_groups, facet_var, unique_facets)
fig.tight_layout()
fig.subplots_adjust(hspace=space, wspace=space)
fig.savefig(os.path.join(fdir, f'{fname}.{file_format}'), format=file_format, bbox_extra_artists=(lg,),
bbox_inches='tight')
plt.close()