Plot Diagnostics (sketch.plot_diagnostics)

This module provides functions for plotting model diagnostics, such as posterior predictions, parameter traces, posterior distributions, and model structure graphs.

Key functions

  • plot_posterior_predictions(X, y, X_test, y_test, mmm, config, output_dir, n_points=52, show_oos_r2=False)

    • Two-panel figure:

      • Top: Actuals vs posterior predictive (median) with 50%/90% HDI fans; optional OOS median + HDI and Holdout shading.

      • Bottom: Residuals over time (±2σ band).

    • Inputs: in-sample (X, y), optional OOS (X_test, y_test), fitted model mmm, config (uses date_col, target_col), output directory.

    • Saves: model_fit_predictions.png.

  • plot_model_trace(model, results_dir)

    • Parameter trace plots via ArviZ for intercept, beta_channel, alpha, lam, and present conditionally: likelihood_sigma, gamma_control, gamma_fourier.

    • Input: fitted model with fit_result (ArviZ InferenceData), output directory.

    • Saves: model_trace.png.

  • plot_posterior_distributions(idata, results_dir, filename='posterior_distributions.png')

    • Small-multiples grid of posterior distributions for all parameters in idata.posterior.

    • Inputs: ArviZ InferenceData, output directory, optional filename.

    • Saves: posterior_distributions.png (default).

  • plot_model_structure(model)

    • Returns a Graphviz graph of the PyMC model; requires Graphviz installed.

    • Input: object with a .model attribute of type pymc.Model.

    • Returns: graphviz.Digraph or None if Graphviz is unavailable.

Usage example

from src.sketch.plot_diagnostics import (
    plot_posterior_predictions,
    plot_model_trace,
    plot_posterior_distributions,
    plot_model_structure,
)

# Posterior predictive diagnostics
plot_posterior_predictions(X, y, X_test, y_test, mmm, config, output_dir=results_dir, show_oos_r2=True)

# Trace and posterior summaries
plot_model_trace(mmm, results_dir)
plot_posterior_distributions(mmm.fit_result, results_dir)

# Model structure (requires Graphviz)
graph = plot_model_structure(mmm)
if graph is not None:
    graph.render("model_structure", directory=results_dir, format="png", cleanup=True)

Notes

  • Date handling: dates are read from the columns specified in config['date_col'] for both in-sample and OOS frames.

  • Scaling: posterior predictive outputs are transformed back to original scale using in-graph scaling parameters or the target transformer as applicable.