src.core.mmm_model_v2¶
Defines the concrete DelayedSaturatedMMMv2 class with in-graph scaling.
Module Contents¶
- class src.core.mmm_model_v2.DelayedSaturatedMMMv2(*args, **kwargs)¶
Bases:
src.prepro.valid.ValidateControlColumns,src.prepro.prepro_v2.MaxAbsScaleTargetV2,src.prepro.prepro_v2.MaxAbsScaleChannelsV2,src.core.mmm_base_v2.BaseDelayedSaturatedMMMv2,src.core.mixins.mmm_analysis.MMMAnalysisMixin,src.core.mixins.mmm_predict.MMMPredictMixin,src.core.mixins.mmm_calibrate.MMMCalibrateMixinMedia Mix Model v2 with delayed adstock, logistic saturation, and in-graph scaling.
This is the v2 implementation that resolves tensor shape contamination issues by applying scaling within the PyMC model graph using
pm.Datacontainers, rather than during preprocessing. This ensures clean tensor shape handling when switching between datasets with different numbers of channels.Key differences from v1¶
Scaling parameters are computed during preprocessing but applied in-graph
Uses
pm.Datacontainers for dynamic scaling within the modelEliminates PyTensor compilation cache contamination
Maintains compatibility with analysis and prediction methods
Overview¶
Combines geometric adstock, logistic saturation, optional control variables, and optional Fourier modes for seasonality. Includes data validation and preprocessing mixins (MaxAbs scaling computation for target and channels). Provides methods for analysis, plotting, prediction, and calibration.
- param date_column:
Column name of the date variable.
- type date_column:
str
- param channel_columns:
Column names of the media channel variables.
- type channel_columns:
List[str]
- param adstock_max_lag:
Maximum lag for the adstock transformation.
- type adstock_max_lag:
int
- param model_config:
Configuration for priors and likelihood. Uses defaults if None.
- type model_config:
Optional[Dict], optional
- param sampler_config:
Configuration for the sampler. Uses defaults if None.
- type sampler_config:
Optional[Dict], optional
- param validate_data:
Whether to validate input data. Defaults to True.
- type validate_data:
bool, optional
- param control_columns:
Column names for control variables. Defaults to None.
- type control_columns:
Optional[List[str]], optional
- param **kwargs:
Additional keyword arguments passed to
BaseDelayedSaturatedMMMv2.
Notes
Target and media scaling is computed but applied in-graph
Control variables are validated but not scaled
Supports calibration via custom priors and lift tests
Examples
>>> import pandas as pd >>> from src.core.mmm_model_v2 import DelayedSaturatedMMMv2 >>> # Load data... >>> data = pd.read_csv("your_data.csv", parse_dates=["date_column_name"]) >>> X = data.drop("target_column_name", axis=1) >>> y = data["target_column_name"] >>> mmm = DelayedSaturatedMMMv2( ... date_column="date_column_name", ... channel_columns=["channel1", "channel2"], ... control_columns=["control1"], ... adstock_max_lag=4 ... ) >>> idata = mmm.fit(X, y, draws=1000, tune=1000) >>> # Use methods from mixins: >>> fig = mmm.plot_channel_contributions_grid(start=0, stop=2, num=11) >>> pred_contrib = mmm.new_spend_contributions(spend=np.array([100, 200]))
- channel_contributions_forward_pass(channel_data: numpy.ndarray) numpy.ndarray¶
Evaluates channel contributions on the original target scale.
Applies the fitted model’s transformations (adstock, saturation, beta weights) to the input channel data and then inverse-transforms the result back to the original scale of the target variable using the fitted target scaler.
This v2 implementation handles the in-graph scaling architecture properly.
- Parameters:
channel_data (np.ndarray) – Input channel data, potentially preprocessed. Shape should be (n_dates, n_channels).
- Returns:
- Estimated channel contributions in the original target scale.
Shape corresponds to (chains, draws, n_dates, n_channels).
- Return type:
np.ndarray
- Raises:
RuntimeError – If the target transformer is not available/fitted.
- format_recovered_transformation_parameters(quantile: float = 0.5) Dict[str, Dict[str, Dict[str, float]]]¶
Formats the recovered transformation parameters for each channel.
This function retrieves the quantile of the parameters for each channel from the posterior distribution and formats them into a dictionary. It assumes LogisticSaturation (param ‘lam’) and geometric_adstock (param ‘alpha’).
- Parameters:
quantile (float, optional) – The quantile to retrieve. Defaults to 0.5 (median).
- Returns:
Nested dictionary structure, e.g.:
{ "channel": { "saturation_params": {"param_name": value}, "adstock_params": {"param_name": value} } }
- Return type:
Dict[str, Dict[str, Dict[str, float]]]
- get_channel_contributions() numpy.ndarray¶
Returns the channel contributions from the fitted model.
Extracts the channel contributions from the posterior samples. These are the contributions of each channel to the target variable after applying adstock and saturation transformations.
- Returns:
Channel contributions with shape (chains, draws, dates, channels)
- Return type:
np.ndarray
- Raises:
RuntimeError – If the model has not been fitted yet
- get_errors() numpy.ndarray¶
Returns model errors/residuals if available.
This is a compatibility method for existing code that may call it.
- Returns:
Model residuals or empty array if not available
- Return type:
np.ndarray
- property channel_transformer¶
Provides a transformer-like interface for v2’s in-graph scaling.
Returns a simple object with a transform method that applies the channel scaling using the stored max-abs scaling parameters.
- property target_transformer¶
Provides a transformer-like interface for v2’s target scaling.
In v2, target scaling is handled in-graph, so this returns an identity transformer that leaves data unchanged.
- save(fname: str) None¶
Saves the model to a file using dill for better serialization support.
Uses dill instead of pickle to handle complex objects like functools.partial that may be created internally by PyMC/PyTensor during model compilation.
- Parameters:
fname – Filename to save the model to
- classmethod load(fname: str) DelayedSaturatedMMMv2¶
Loads a model from a file using dill.
Uses dill for deserialization to match the save method and handle complex objects that standard pickle cannot deserialize.
- Parameters:
fname – Filename to load the model from
- Returns:
The loaded model instance