Plot Input (sketch.plot_input)¶
This module provides functions for plotting input data characteristics, such as metrics over time and correlation matrices.
Key functions
plot_all_metrics(input_data, output_dir, suffix)Plots, in a single vertical figure, the time series for all media volumes, media costs, extra features, and the target.
Inputs:
InputDataobject, output directory, filename suffix.Saves:
metrics_{suffix}.png.
plot_correlation_matrix(input_data, per_observation_df)Returns a Plotly heatmap Figure and a DataFrame for the correlation matrix of columns that include “volume” or the target column.
Inputs:
InputDataand a per-observation DataFrame (typically from preprocessing).Returns:
(fig: plotly.graph_objects.Figure, corr_df: pd.DataFrame).
plot_all_media_spend(input_data, per_observation_df)Returns a Plotly line chart of the target series over time (useful for quick trend inspection).
Inputs:
InputDataand per-observation DataFrame.Returns:
fig: plotly.graph_objects.Figure.
Usage example
from src.sketch.plot_input import plot_all_metrics, plot_correlation_matrix, plot_all_media_spend
# Plot all metrics to a single PNG
plot_all_metrics(input_data, output_dir=results_dir, suffix="train")
# Correlation heatmap (volumes + target)
corr_fig, corr_df = plot_correlation_matrix(input_data, per_observation_df)
corr_fig.write_image(f"{results_dir}/correlation.png") # optional
# Target over time (Plotly figure)
spend_fig = plot_all_media_spend(input_data, per_observation_df)
spend_fig.write_html(f"{results_dir}/target_over_time.html") # optional