Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions pymc_marketing/mmm/hsgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from matplotlib.figure import Figure
from pydantic import BaseModel, Field, InstanceOf, model_validator, validate_call
from pymc.distributions.shape_utils import Dims
from pymc_extras.deserialize import register_deserialization
from pymc_extras.prior import Prior, _get_transform, create_dim_handler
from pytensor.tensor import TensorLike
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -1432,3 +1433,31 @@ def create_variable(self, name: str) -> TensorVariable:
# Multiplicative centering to preserve positivity and enforce mean 1
centered_f = f / f_mean
return pm.Deterministic(name, centered_f, dims=self.dims)


# TODO: Replace this with a more robust implementation
def hsgp_from_dict(data: dict | bool):
"""Get an HSGP instance from a dictionary if passed by user."""
if isinstance(data, bool):
return data

HSGP_CLASSES = {
"HSGP": HSGP,
"SoftPlusHSGP": SoftPlusHSGP,
"HSGPPeriodic": HSGPPeriodic,
}

data = data.copy()
cls = HSGP_CLASSES[data.pop("hsgp_class")]

return cls.from_dict(data)


def _is_hsgp(data):
return (
"hsgp_class" in data
and data["hsgp_class"] in ["HSGP", "SoftPlusHSGP", "HSGPPeriodic"]
) or isinstance(data, bool)


register_deserialization(_is_hsgp, hsgp_from_dict)
28 changes: 22 additions & 6 deletions pymc_marketing/mmm/multidimensional.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@
)
from pymc_marketing.mmm.events import EventEffect
from pymc_marketing.mmm.fourier import YearlyFourier
from pymc_marketing.mmm.hsgp import HSGPBase
from pymc_marketing.mmm.hsgp import HSGPBase, hsgp_from_dict
from pymc_marketing.mmm.lift_test import (
add_cost_per_target_potentials,
add_lift_measurements_to_likelihood_from_saturation,
Expand Down Expand Up @@ -527,8 +527,22 @@ def create_idata_attrs(self) -> dict[str, str]:
attrs["control_columns"] = json.dumps(self.control_columns)
attrs["channel_columns"] = json.dumps(self.channel_columns)
attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)
attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept)
attrs["time_varying_media"] = json.dumps(self.time_varying_media)
attrs["time_varying_intercept"] = json.dumps(
self.time_varying_intercept
if not isinstance(self.time_varying_intercept, HSGPBase)
else {
**self.time_varying_intercept.to_dict(),
**{"hsgp_class": self.time_varying_intercept.__class__.__name__},
}
)
attrs["time_varying_media"] = json.dumps(
self.time_varying_media
if not isinstance(self.time_varying_media, HSGPBase)
else {
**self.time_varying_media.to_dict(),
**{"hsgp_class": self.time_varying_media.__class__.__name__},
}
)
attrs["target_column"] = self.target_column
attrs["scaling"] = json.dumps(self.scaling.model_dump(mode="json"))
attrs["dag"] = json.dumps(getattr(self, "dag", None))
Expand Down Expand Up @@ -586,11 +600,13 @@ def attrs_to_init_kwargs(cls, attrs: dict[str, str]) -> dict[str, Any]:
"saturation": saturation_from_dict(json.loads(attrs["saturation"])),
"adstock_first": json.loads(attrs.get("adstock_first", "true")),
"yearly_seasonality": json.loads(attrs["yearly_seasonality"]),
"time_varying_intercept": json.loads(
attrs.get("time_varying_intercept", "false")
"time_varying_intercept": hsgp_from_dict(
json.loads(attrs.get("time_varying_intercept", "false"))
),
"target_column": attrs["target_column"],
"time_varying_media": json.loads(attrs.get("time_varying_media", "false")),
"time_varying_media": hsgp_from_dict(
json.loads(attrs.get("time_varying_media", "false"))
),
"sampler_config": json.loads(attrs["sampler_config"]),
"dims": tuple(json.loads(attrs.get("dims", "[]"))),
"scaling": json.loads(attrs.get("scaling", "null")),
Expand Down
184 changes: 183 additions & 1 deletion tests/mmm/test_multidimensional.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections.abc import Callable

import arviz as az
Expand All @@ -25,7 +26,12 @@
from pytensor.tensor.basic import TensorVariable
from scipy.optimize import OptimizeResult

from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation, SoftPlusHSGP
from pymc_marketing.mmm import (
CovFunc,
GeometricAdstock,
LogisticSaturation,
SoftPlusHSGP,
)
from pymc_marketing.mmm.additive_effect import EventAdditiveEffect, LinearTrendEffect
from pymc_marketing.mmm.events import EventEffect, GaussianBasis, HalfGaussianBasis
from pymc_marketing.mmm.lift_test import _swap_columns_and_last_index_level
Expand Down Expand Up @@ -737,6 +743,182 @@ def test_time_varying_intercept_with_custom_hsgp_multi_dim(
assert latent_dims == hsgp_dims


@pytest.mark.parametrize(
"hsgp_dims",
[
pytest.param(
[
"date",
],
id="hsgp-dims=date",
),
pytest.param(["date", "channel"], id="hsgp-dims=date,channel"),
],
)
def test_time_varying_media_with_custom_hsgp_single_dim_save_load(
single_dim_data, hsgp_dims
):
"""
Ensure saved MMM with HSGP instance passed to time_varying_media can .save() and .load() (single-dim).
"""
X, y = single_dim_data

data = {
"m": 72,
"X_mid": 6.5,
"dims": hsgp_dims,
"transform": None,
"demeaned_basis": False,
"ls": {
"dist": "Weibull",
"kwargs": {"alpha": 0.5, "beta": 90.08328710020781},
"transform": "reciprocal",
},
"eta": {"dist": "Exponential", "kwargs": {"lam": 2.995732273553991}},
"L": 41.6,
"centered": False,
"drop_first": True,
"cov_func": CovFunc.ExpQuad,
}

hsgp = SoftPlusHSGP.from_dict(data.copy()) # .from_dict() modifies data

mmm = MMM(
date_column="date",
target_column="target",
channel_columns=["channel_1", "channel_2", "channel_3"],
adstock=GeometricAdstock(l_max=2),
saturation=LogisticSaturation(),
time_varying_media=hsgp,
)

mmm.fit(X, y)

file = "test_hsgp_media.nc"
mmm.save(file)
loaded = MMM.load(file)

assert loaded.time_varying_media.to_dict() == data

os.remove(file)


@pytest.mark.parametrize(
"hsgp_dims",
[
pytest.param(
[
"date",
],
id="hsgp-dims=date",
),
],
)
def test_time_varying_intercept_with_custom_hsgp_single_dim_save_load(
single_dim_data, hsgp_dims
):
"""
Ensure MMM with an HSGP instance passed to time_varying_intercept can .save() and .load() (single-dim).
"""
X, y = single_dim_data

data = {
"m": 72,
"X_mid": 6.5,
"dims": hsgp_dims,
"transform": None,
"demeaned_basis": False,
"ls": {
"dist": "Weibull",
"kwargs": {"alpha": 0.5, "beta": 90.08328710020781},
"transform": "reciprocal",
},
"eta": {"dist": "Exponential", "kwargs": {"lam": 2.995732273553991}},
"L": 41.6,
"centered": False,
"drop_first": True,
"cov_func": CovFunc.ExpQuad,
}

hsgp = SoftPlusHSGP.from_dict(data.copy()) # .from_dict() modifies data

mmm = MMM(
date_column="date",
target_column="target",
channel_columns=["channel_1", "channel_2", "channel_3"],
adstock=GeometricAdstock(l_max=2),
saturation=LogisticSaturation(),
time_varying_intercept=hsgp,
)

mmm.fit(X, y)

file = "test_hsgp_intercept.nc"
mmm.save(file)
loaded = MMM.load(file)

assert loaded.time_varying_intercept.to_dict() == data

os.remove(file)


@pytest.mark.parametrize(
"hsgp_dims",
[
pytest.param(["date", "country"], id="hsgp-dims=date,country"),
pytest.param(
["date", "country", "channel"], id="hsgp-dims=date,country,channel"
),
],
)
def test_time_varying_media_with_custom_hsgp_multi_dim_save_load(
df, target_column, hsgp_dims
):
"""
Ensure MMM with an HSGP instance passed to time_varying_media can .save() and .load() (multi-dim).
"""
X = df.drop(columns=[target_column])
y = df[target_column]

data = {
"m": 28,
"X_mid": 2.5,
"dims": hsgp_dims,
"transform": None,
"demeaned_basis": False,
"ls": {
"dist": "Weibull",
"kwargs": {"alpha": 0.5, "beta": 90.08328710020781},
"transform": "reciprocal",
},
"eta": {"dist": "Exponential", "kwargs": {"lam": 2.995732273553991}},
"L": 16.0,
"centered": False,
"drop_first": True,
"cov_func": CovFunc.ExpQuad,
}
hsgp = SoftPlusHSGP.from_dict(data.copy()) # .from_dict() modifies data

mmm = MMM(
date_column="date",
channel_columns=["C1", "C2"],
target_column=target_column,
dims=("country",),
adstock=GeometricAdstock(l_max=2),
saturation=LogisticSaturation(),
time_varying_media=hsgp,
)
mmm.fit(X, y)

file = "test_hsgp_intercept_multi_dim.nc"
mmm.save(file)
loaded = MMM.load(file)

assert loaded.time_varying_media.to_dict() == data

os.remove(file)


def test_sample_posterior_predictive_no_overlap_with_include_last_observations(
single_dim_data, mock_pymc_sample
):
Expand Down