src.core.mmm_simplified

Simplified MMM model with cleaner architecture.

Reduces complex inheritance hierarchy by using composition and focused interfaces.

Module Contents

class src.core.mmm_simplified.DataProcessor

Handles data preprocessing and validation for MMM models.

fit_transform(X: pandas.DataFrame, y: pandas.Series, channel_columns: List[str]) Tuple[pandas.DataFrame, pandas.Series]

Fit scalers and transform data.

transform(X: pandas.DataFrame, channel_columns: List[str]) pandas.DataFrame

Transform new data using fitted scalers.

inverse_transform_target(y_transformed: numpy.ndarray) numpy.ndarray

Inverse transform target variable back to original scale.

class src.core.mmm_simplified.ModelValidator

Validates model inputs and configuration.

static validate_data(X: pandas.DataFrame, y: pandas.Series, channel_columns: List[str], control_columns: List[str] | None = None, date_column: str = 'date') None

Validate input data for model fitting.

static validate_model_config(config: Dict[str, Any]) None

Validate model configuration.

class src.core.mmm_simplified.SimplifiedMMM(date_column: str, channel_columns: List[str], control_columns: List[str] | None = None, adstock_max_lag: int = 4, model_config: Dict[str, Any] | None = None, sampler_config: Dict[str, Any] | None = None, **kwargs)

Bases: src.core.base.BaseMMM

Simplified MMM implementation with cleaner architecture.

Uses composition instead of complex inheritance hierarchy for better maintainability.

fit(X: pandas.DataFrame, y: pandas.Series, chains: int = 4, draws: int = 1000, tune: int = 1000, target_accept: float = 0.95, random_seed: int | None = None, **kwargs)

Fit the MMM model to data.

Parameters:
  • X – Feature DataFrame with media channels and controls.

  • y – Target variable Series.

  • chains – Number of MCMC chains.

  • draws – Number of samples to draw.

  • tune – Number of tuning steps.

  • target_accept – Target acceptance rate.

  • random_seed – Random seed for reproducibility.

  • **kwargs – Additional sampling arguments.

predict(X: pandas.DataFrame) numpy.ndarray

Generate predictions for new data.

get_channel_contributions(X: pandas.DataFrame | None = None) numpy.ndarray

Calculate channel contributions.

save(filepath: str) None

Save fitted model to file.

classmethod load(filepath: str, **init_kwargs) SimplifiedMMM

Load saved model from file.