Skip to content

multidimensional MMM initialization validation #1725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 53 additions & 15 deletions pymc_marketing/mmm/multidimensional.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import warnings
from collections.abc import Sequence
from copy import deepcopy
from typing import Any, Literal
from typing import Annotated, Any, Literal

Check warning on line 22 in pymc_marketing/mmm/multidimensional.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/multidimensional.py#L22

Added line #L22 was not covered by tests

import arviz as az
import numpy as np
Expand All @@ -28,6 +28,7 @@
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from pydantic import Field, InstanceOf, validate_call

Check warning on line 31 in pymc_marketing/mmm/multidimensional.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/multidimensional.py#L31

Added line #L31 was not covered by tests
from pymc.model.fgraph import clone_model as cm
from pymc.util import RandomState
from scipy.optimize import OptimizeResult
Expand Down Expand Up @@ -108,22 +109,59 @@
_model_type: str = "MMMM (Multi-Dimensional Marketing Mix Model)"
version: str = "0.0.1"

@validate_call

Check warning on line 112 in pymc_marketing/mmm/multidimensional.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/multidimensional.py#L112

Added line #L112 was not covered by tests
def __init__(
self,
date_column: str,
channel_columns: list[str],
target_column: str,
adstock: AdstockTransformation,
saturation: SaturationTransformation,
time_varying_intercept: bool = False,
time_varying_media: bool = False,
dims: tuple | None = None,
scaling: Scaling | dict | None = None,
model_config: dict | None = None, # Ensure model_config is a dictionary
sampler_config: dict | None = None,
control_columns: list[str] | None = None,
yearly_seasonality: int | None = None,
adstock_first: bool = True,
date_column: str = Field(..., description="Column name of the date variable."),
channel_columns: list[str] = Field(
min_length=1, description="Column names of the media channel variables."
),
target_column: str = Field(..., description="The name of the target column."),
adstock: InstanceOf[AdstockTransformation] = Field(
..., description="Type of adstock transformation to apply."
),
saturation: InstanceOf[SaturationTransformation] = Field(
...,
description="The saturation transformation to apply to the channel data.",
),
time_varying_intercept: Annotated[
bool,
Field(strict=True, description="Whether to use a time-varying intercept"),
] = False,
time_varying_media: Annotated[
bool,
Field(strict=True, description="Whether to use time-varying media effects"),
] = False,
dims: tuple[str, ...] | None = Field(
None, description="Additional dimensions for the model."
),
scaling: InstanceOf[Scaling] | dict | None = Field(
None, description="Scaling configuration for the model."
),
model_config: dict | None = Field(
None, description="Configuration settings for the model."
),
sampler_config: dict | None = Field(
None, description="Configuration settings for the sampler."
),
control_columns: Annotated[
list[str] | None,
Field(
min_length=1,
description="A list of control variables to include in the model.",
),
] = None,
yearly_seasonality: Annotated[
int | None,
Field(
gt=0,
description="The number of yearly seasonalities to include in the model.",
),
] = None,
adstock_first: Annotated[
bool,
Field(strict=True, description="Apply adstock before saturation?"),
] = True,
) -> None:
"""Define the constructor method."""
# Your existing initialization logic
Expand Down
253 changes: 253 additions & 0 deletions tests/mmm/test_multidimensional.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pymc as pm
import pytest
import xarray as xr
from pydantic import ValidationError
from pymc.model_graph import fast_eval
from pytensor.tensor.basic import TensorVariable
from scipy.optimize import OptimizeResult
Expand Down Expand Up @@ -792,3 +793,255 @@ def test_multidimensional_budget_optimizer_wrapper(fit_mmm, mock_pymc_sample):
len(channels),
) # Check shape based on dims
assert isinstance(scipy_opt_result, OptimizeResult)


class TestPydanticValidation:
"""Test suite specifically for Pydantic validation in multidimensional MMM."""

def test_empty_channel_columns_raises_validation_error(self):
"""Test that empty channel_columns raises ValidationError."""
with pytest.raises(ValidationError) as exc_info:
MMM(
date_column="date",
channel_columns=[], # Empty list should fail
target_column="target",
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
)

# Check that the error message mentions the constraint
error_msg = str(exc_info.value)
assert "at least 1 item" in error_msg or "min_length" in error_msg

def test_invalid_yearly_seasonality_raises_validation_error(self):
"""Test that yearly_seasonality <= 0 raises ValidationError."""
with pytest.raises(ValidationError) as exc_info:
MMM(
date_column="date",
channel_columns=["channel_1"],
target_column="target",
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
yearly_seasonality=0, # Should be > 0
)

error_msg = str(exc_info.value)
assert "greater than 0" in error_msg

def test_negative_yearly_seasonality_raises_validation_error(self):
"""Test that negative yearly_seasonality raises ValidationError."""
with pytest.raises(ValidationError) as exc_info:
MMM(
date_column="date",
channel_columns=["channel_1"],
target_column="target",
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
yearly_seasonality=-1,
)

error_msg = str(exc_info.value)
assert "greater than 0" in error_msg

def test_invalid_adstock_type_raises_validation_error(self):
"""Test that invalid adstock type raises ValidationError."""
with pytest.raises(ValidationError) as exc_info:
MMM(
date_column="date",
channel_columns=["channel_1"],
target_column="target",
adstock="not_an_adstock", # Invalid type
saturation=LogisticSaturation(),
)

error_msg = str(exc_info.value)
assert "AdstockTransformation" in error_msg

def test_invalid_saturation_type_raises_validation_error(self):
"""Test that invalid saturation type raises ValidationError."""
with pytest.raises(ValidationError) as exc_info:
MMM(
date_column="date",
channel_columns=["channel_1"],
target_column="target",
adstock=GeometricAdstock(l_max=8),
saturation="not_a_saturation", # Invalid type
)

error_msg = str(exc_info.value)
assert "SaturationTransformation" in error_msg

def test_empty_control_columns_raises_validation_error(self):
"""Test that empty control_columns list raises ValidationError."""
with pytest.raises(ValidationError) as exc_info:
MMM(
date_column="date",
channel_columns=["channel_1"],
target_column="target",
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
control_columns=[], # Empty list should fail when not None
)

error_msg = str(exc_info.value)
assert "at least 1 item" in error_msg or "min_length" in error_msg

def test_invalid_scaling_type_raises_validation_error(self):
"""Test that invalid scaling type raises ValidationError."""
with pytest.raises(ValidationError) as exc_info:
MMM(
date_column="date",
channel_columns=["channel_1"],
target_column="target",
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
scaling="invalid_scaling", # Should be Scaling object or dict
)

error_msg = str(exc_info.value)
assert "Scaling" in error_msg or "dict" in error_msg

def test_valid_scaling_dict_accepted(self):
"""Test that valid scaling dict is accepted and converted."""
scaling_dict = {
"channel": {"method": "max", "dims": ()},
"target": {"method": "max", "dims": ()},
}
mmm = MMM(
date_column="date",
channel_columns=["channel_1"],
target_column="target",
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
scaling=scaling_dict,
)
assert isinstance(mmm.scaling, Scaling)
assert mmm.scaling.model_dump() == scaling_dict

def test_valid_scaling_object_accepted(self):
"""Test that valid Scaling object is accepted."""
scaling_obj = Scaling(
target=VariableScaling(method="max", dims=()),
channel=VariableScaling(method="max", dims=()),
)
mmm = MMM(
date_column="date",
channel_columns=["channel_1"],
target_column="target",
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
scaling=scaling_obj,
)
assert mmm.scaling == scaling_obj

def test_dims_type_validation(self):
"""Test that dims validates as tuple of strings."""
# Valid dims
mmm = MMM(
date_column="date",
channel_columns=["channel_1"],
target_column="target",
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
dims=("country", "product"),
)
assert mmm.dims == ("country", "product")

# Test with single dimension
mmm2 = MMM(
date_column="date",
channel_columns=["channel_1"],
target_column="target",
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
dims=("country",),
)
assert mmm2.dims == ("country",)

def test_invalid_boolean_types_raise_validation_error(self):
"""Test that non-boolean values for boolean fields raise ValidationError."""
with pytest.raises(ValidationError):
MMM(
date_column="date",
channel_columns=["channel_1"],
target_column="target",
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
time_varying_intercept="yes", # Should be boolean
)

def test_missing_required_fields_raise_validation_error(self):
"""Test that missing required fields raise ValidationError."""
# Missing date_column
with pytest.raises(ValidationError) as exc_info:
MMM(
channel_columns=["channel_1"],
target_column="target",
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
)
error_msg = str(exc_info.value)
assert "date_column" in error_msg

# Missing channel_columns
with pytest.raises(ValidationError) as exc_info:
MMM(
date_column="date",
target_column="target",
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
)
error_msg = str(exc_info.value)
assert "channel_columns" in error_msg

def test_all_parameters_with_valid_values(self):
"""Test initialization with all parameters set to valid values."""
mmm = MMM(
date_column="date",
channel_columns=["channel_1", "channel_2", "channel_3"],
target_column="revenue",
adstock=GeometricAdstock(l_max=10),
saturation=LogisticSaturation(),
time_varying_intercept=True,
time_varying_media=True,
dims=("country", "product"),
scaling=Scaling(
target=VariableScaling(method="mean", dims=("country",)),
channel=VariableScaling(method="max", dims=("country", "channel")),
),
model_config={"intercept": Prior("Normal", mu=0, sigma=2)},
sampler_config={"draws": 1000, "chains": 4},
control_columns=["holiday", "promotion"],
yearly_seasonality=4,
adstock_first=False,
)

# Verify all values were set correctly
assert mmm.date_column == "date"
assert mmm.channel_columns == ["channel_1", "channel_2", "channel_3"]
assert mmm.target_column == "revenue"
assert isinstance(mmm.adstock, GeometricAdstock)
assert isinstance(mmm.saturation, LogisticSaturation)
assert mmm.time_varying_intercept is True
assert mmm.time_varying_media is True
assert mmm.dims == ("country", "product")
assert isinstance(mmm.scaling, Scaling)
assert mmm.control_columns == ["holiday", "promotion"]
assert mmm.yearly_seasonality == 4
assert mmm.adstock_first is False

def test_validation_error_provides_helpful_messages(self):
"""Test that validation errors provide clear, actionable messages."""
with pytest.raises(ValidationError) as exc_info:
MMM(
date_column="date",
channel_columns="not_a_list", # Should be a list
target_column="target",
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
)

# The error should mention that channel_columns should be a list
error_msg = str(exc_info.value)
assert "channel_columns" in error_msg
assert "list" in error_msg.lower()