src.core.mmm_model_v2

Defines the concrete DelayedSaturatedMMMv2 class with in-graph scaling.

Module Contents

class src.core.mmm_model_v2.DelayedSaturatedMMMv2(*args, **kwargs)

Bases: src.prepro.valid.ValidateControlColumns, src.prepro.prepro_v2.MaxAbsScaleTargetV2, src.prepro.prepro_v2.MaxAbsScaleChannelsV2, src.core.mmm_base_v2.BaseDelayedSaturatedMMMv2, src.core.mixins.mmm_analysis.MMMAnalysisMixin, src.core.mixins.mmm_predict.MMMPredictMixin, src.core.mixins.mmm_calibrate.MMMCalibrateMixin

Media Mix Model v2 with delayed adstock, logistic saturation, and in-graph scaling.

This is the v2 implementation that resolves tensor shape contamination issues by applying scaling within the PyMC model graph using pm.Data containers, rather than during preprocessing. This ensures clean tensor shape handling when switching between datasets with different numbers of channels.

Key differences from v1

  • Scaling parameters are computed during preprocessing but applied in-graph

  • Uses pm.Data containers for dynamic scaling within the model

  • Eliminates PyTensor compilation cache contamination

  • Maintains compatibility with analysis and prediction methods

Overview

Combines geometric adstock, logistic saturation, optional control variables, and optional Fourier modes for seasonality. Includes data validation and preprocessing mixins (MaxAbs scaling computation for target and channels). Provides methods for analysis, plotting, prediction, and calibration.

param date_column:

Column name of the date variable.

type date_column:

str

param channel_columns:

Column names of the media channel variables.

type channel_columns:

List[str]

param adstock_max_lag:

Maximum lag for the adstock transformation.

type adstock_max_lag:

int

param model_config:

Configuration for priors and likelihood. Uses defaults if None.

type model_config:

Optional[Dict], optional

param sampler_config:

Configuration for the sampler. Uses defaults if None.

type sampler_config:

Optional[Dict], optional

param validate_data:

Whether to validate input data. Defaults to True.

type validate_data:

bool, optional

param control_columns:

Column names for control variables. Defaults to None.

type control_columns:

Optional[List[str]], optional

param **kwargs:

Additional keyword arguments passed to BaseDelayedSaturatedMMMv2.

Notes

  • Target and media scaling is computed but applied in-graph

  • Control variables are validated but not scaled

  • Supports calibration via custom priors and lift tests

Examples

>>> import pandas as pd
>>> from src.core.mmm_model_v2 import DelayedSaturatedMMMv2
>>> # Load data...
>>> data = pd.read_csv("your_data.csv", parse_dates=["date_column_name"])
>>> X = data.drop("target_column_name", axis=1)
>>> y = data["target_column_name"]
>>> mmm = DelayedSaturatedMMMv2(
...     date_column="date_column_name",
...     channel_columns=["channel1", "channel2"],
...     control_columns=["control1"],
...     adstock_max_lag=4
... )
>>> idata = mmm.fit(X, y, draws=1000, tune=1000)
>>> # Use methods from mixins:
>>> fig = mmm.plot_channel_contributions_grid(start=0, stop=2, num=11)
>>> pred_contrib = mmm.new_spend_contributions(spend=np.array([100, 200]))
channel_contributions_forward_pass(channel_data: numpy.ndarray) numpy.ndarray

Evaluates channel contributions on the original target scale.

Applies the fitted model’s transformations (adstock, saturation, beta weights) to the input channel data and then inverse-transforms the result back to the original scale of the target variable using the fitted target scaler.

This v2 implementation handles the in-graph scaling architecture properly.

Parameters:

channel_data (np.ndarray) – Input channel data, potentially preprocessed. Shape should be (n_dates, n_channels).

Returns:

Estimated channel contributions in the original target scale.

Shape corresponds to (chains, draws, n_dates, n_channels).

Return type:

np.ndarray

Raises:

RuntimeError – If the target transformer is not available/fitted.

format_recovered_transformation_parameters(quantile: float = 0.5) Dict[str, Dict[str, Dict[str, float]]]

Formats the recovered transformation parameters for each channel.

This function retrieves the quantile of the parameters for each channel from the posterior distribution and formats them into a dictionary. It assumes LogisticSaturation (param ‘lam’) and geometric_adstock (param ‘alpha’).

Parameters:

quantile (float, optional) – The quantile to retrieve. Defaults to 0.5 (median).

Returns:

Nested dictionary structure, e.g.:

{
    "channel": {
        "saturation_params": {"param_name": value},
        "adstock_params": {"param_name": value}
    }
}

Return type:

Dict[str, Dict[str, Dict[str, float]]]

get_channel_contributions() numpy.ndarray

Returns the channel contributions from the fitted model.

Extracts the channel contributions from the posterior samples. These are the contributions of each channel to the target variable after applying adstock and saturation transformations.

Returns:

Channel contributions with shape (chains, draws, dates, channels)

Return type:

np.ndarray

Raises:

RuntimeError – If the model has not been fitted yet

get_errors() numpy.ndarray

Returns model errors/residuals if available.

This is a compatibility method for existing code that may call it.

Returns:

Model residuals or empty array if not available

Return type:

np.ndarray

property channel_transformer

Provides a transformer-like interface for v2’s in-graph scaling.

Returns a simple object with a transform method that applies the channel scaling using the stored max-abs scaling parameters.

property target_transformer

Provides a transformer-like interface for v2’s target scaling.

In v2, target scaling is handled in-graph, so this returns an identity transformer that leaves data unchanged.

save(fname: str) None

Saves the model to a file using dill for better serialization support.

Uses dill instead of pickle to handle complex objects like functools.partial that may be created internally by PyMC/PyTensor during model compilation.

Parameters:

fname – Filename to save the model to

classmethod load(fname: str) DelayedSaturatedMMMv2

Loads a model from a file using dill.

Uses dill for deserialization to match the save method and handle complex objects that standard pickle cannot deserialize.

Parameters:

fname – Filename to load the model from

Returns:

The loaded model instance