-
Notifications
You must be signed in to change notification settings - Fork 332
Closed
Labels
Milestone
Description
MMM with adstock and saturation configurations can potentially not be saved correctly since the additional kwargs are not saved off.
Having these classes serializable and taking advantage of them in save and load methods would be a good idea.
For instance, a model that would not be able to be loaded correctly would be:
from pymc_marketing.mmm import MMM, GeometricAdstock, LogisticSaturation
mmm = MMM(
...,
adstock=GeometricAdstock(l_max=4, normalize=False, mode="Before"),
saturation=LogisticSaturation(),
...,
)This is because both the normalize and mode are not saved off.
This would involve reconstructing the adstock and saturation upon loading here:
pymc-marketing/pymc_marketing/mmm/delayed_saturated_mmm.py
Lines 637 to 659 in 2be2664
| model = cls( | |
| date_column=json.loads(idata.attrs["date_column"]), | |
| control_columns=json.loads(idata.attrs["control_columns"]), | |
| # Media Transformations | |
| channel_columns=json.loads(idata.attrs["channel_columns"]), | |
| adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]), | |
| adstock=json.loads(idata.attrs.get("adstock", "geometric")), | |
| saturation=json.loads(idata.attrs.get("saturation", "logistic")), | |
| adstock_first=json.loads(idata.attrs.get("adstock_first", True)), | |
| # Seasonality | |
| yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]), | |
| # TVP | |
| time_varying_intercept=json.loads( | |
| idata.attrs.get("time_varying_intercept", False) | |
| ), | |
| time_varying_media=json.loads( | |
| idata.attrs.get("time_varying_media", False) | |
| ), | |
| # Configurations | |
| validate_data=json.loads(idata.attrs["validate_data"]), | |
| model_config=model_config, | |
| sampler_config=json.loads(idata.attrs["sampler_config"]), | |
| ) |