Skip to content

Commit c9670ab

Browse files
williambdeantwiecki
authored andcommitted
DelayedSaturatedMMM deprecations and moving files (#965)
* deprecations and moving files * Update UML Diagrams * change the imports in notebooks * push up the code / test changes. need to run * remove _get_\w*_function tests * rerun the tvp notebook * remove stale test * move away from string initialization * change the tvp media example
1 parent a3caa4b commit c9670ab

File tree

16 files changed

+393
-740
lines changed

16 files changed

+393
-740
lines changed

docs/source/notebooks/mmm/mmm_budget_allocation_example.ipynb

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
"import numpy as np\n",
6060
"import pandas as pd\n",
6161
"\n",
62-
"from pymc_marketing.mmm.delayed_saturated_mmm import MMM\n",
62+
"from pymc_marketing.mmm import MMM\n",
6363
"\n",
6464
"warnings.filterwarnings(\"ignore\")\n",
6565
"\n",
@@ -89,7 +89,7 @@
8989
"Once the model has been trained, it is easy to save for later use. An example of the \".save\" method is demonstrated below to store the model at a designated [location](https://github.com/pymc-labs/pymc-marketing/tree/main/data).\n",
9090
"\n",
9191
"## Loading a Pre-Trained Model\n",
92-
"To utilize a saved model, load it into a new instance of the DelayedSaturatedMMM class using the load method below."
92+
"To utilize a saved model, load it into a new instance of the MMM class using the load method below."
9393
]
9494
},
9595
{
@@ -1738,7 +1738,8 @@
17381738
"provenance": []
17391739
},
17401740
"kernelspec": {
1741-
"display_name": "Python 3",
1741+
"display_name": "Python 3 (ipykernel)",
1742+
"language": "python",
17421743
"name": "python3"
17431744
},
17441745
"language_info": {
@@ -1755,5 +1756,5 @@
17551756
}
17561757
},
17571758
"nbformat": 4,
1758-
"nbformat_minor": 0
1759+
"nbformat_minor": 4
17591760
}

docs/source/notebooks/mmm/mmm_example.ipynb

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
"import pymc as pm\n",
8888
"import seaborn as sns\n",
8989
"\n",
90-
"from pymc_marketing.mmm.delayed_saturated_mmm import MMM\n",
90+
"from pymc_marketing.mmm import MMM, GeometricAdstock, LogisticSaturation\n",
9191
"from pymc_marketing.mmm.transformers import geometric_adstock, logistic_saturation\n",
9292
"\n",
9393
"warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
@@ -979,15 +979,15 @@
979979
"cell_type": "markdown",
980980
"metadata": {},
981981
"source": [
982-
"We can specify the model structure using the {class}`MMM <pymc_marketing.mmm.delayed_saturated_mmm.MMM>` class. This class, handles a lot of internal boilerplate code for us such us scaling the data (see details below) and handy diagnostics and reporting plots. One great feature is that we can specify the channel priors distributions ourselves, which fundamental component of the [bayesian workflow](https://arxiv.org/abs/2011.01808) as we can incorporate our prior knowledge into the model. This is one of the most important advantages of using a bayesian approach. Let's see how we can do it.\n",
982+
"We can specify the model structure using the {class}`MMM <pymc_marketing.mmm.mmm.MMM>` class. This class, handles a lot of internal boilerplate code for us such us scaling the data (see details below) and handy diagnostics and reporting plots. One great feature is that we can specify the channel priors distributions ourselves, which fundamental component of the [bayesian workflow](https://arxiv.org/abs/2011.01808) as we can incorporate our prior knowledge into the model. This is one of the most important advantages of using a bayesian approach. Let's see how we can do it.\n",
983983
"\n",
984984
"As we do not know much more about the channels, we start with a simple heuristic: \n",
985985
"\n",
986986
"1. The channel contributions should be positive, so we can for example use a {class}`HalfNormal <pymc.distributions.continuous.HalfNormal>` distribution as prior. We need to set the `sigma` parameter per channel. The higher the `sigma`, the more \"freedom\" it has to fit the data. To specify `sigma` we can use the following point.\n",
987987
"\n",
988988
"2. We expect channels where we spend the most to have more attributed sales , before seeing the data. This is a very reasonable assumption (note that we are not imposing anything at the level of efficiency!).\n",
989989
"\n",
990-
"How to incorporate this heuristic into the model? To begin with, it is important to note that the {class}`MMM <pymc_marketing.mmm.delayed_saturated_mmm.MMM>` class scales the target and input variables through an [`MaxAbsScaler`](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MaxAbsScaler.html) transformer from [`scikit-learn`](https://scikit-learn.org/stable/), its important to specify the priors in the scaled space (i.e. between 0 and 1). One way to do it is to use the spend share as the `sigma` parameter for the `HalfNormal` distribution. We can actually add a scaling factor to take into account the support of the distribution.\n",
990+
"How to incorporate this heuristic into the model? To begin with, it is important to note that the {class}`MMM <pymc_marketing.mmm.mmm.MMM>` class scales the target and input variables through an [`MaxAbsScaler`](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MaxAbsScaler.html) transformer from [`scikit-learn`](https://scikit-learn.org/stable/), its important to specify the priors in the scaled space (i.e. between 0 and 1). One way to do it is to use the spend share as the `sigma` parameter for the `HalfNormal` distribution. We can actually add a scaling factor to take into account the support of the distribution.\n",
991991
"\n",
992992
"First, let's compute the share of spend per channel:"
993993
]
@@ -1072,7 +1072,7 @@
10721072
"source": [
10731073
"You can use the optional parameter 'model_config' to apply your own priors to the model. Each entry in the 'model_config' contains a key that corresponds to a registered distribution name in our model. The value of the key is a dictionary that describes the input parameters of that specific distribution.\n",
10741074
"\n",
1075-
"If you're unsure how to define your own priors, you can use the 'default_model_config' property of {class}`MMM <pymc_marketing.mmm.delayed_saturated_mmm.MMM>` to see the required structure."
1075+
"If you're unsure how to define your own priors, you can use the 'default_model_config' property of {class}`MMM <pymc_marketing.mmm.mmm.MMM>` to see the required structure."
10761076
]
10771077
},
10781078
{
@@ -1101,9 +1101,8 @@
11011101
"dummy_model = MMM(\n",
11021102
" date_column=\"\",\n",
11031103
" channel_columns=[\"\"],\n",
1104-
" adstock=\"geometric\",\n",
1105-
" saturation=\"logistic\",\n",
1106-
" adstock_max_lag=4,\n",
1104+
" adstock=GeometricAdstock(l_max=4),\n",
1105+
" saturation=LogisticSaturation(),\n",
11071106
")\n",
11081107
"dummy_model.default_model_config"
11091108
]
@@ -1150,14 +1149,14 @@
11501149
"cell_type": "markdown",
11511150
"metadata": {},
11521151
"source": [
1153-
"**Remark:** For the prior specification there is no right or wrong answer. It all depends on the data, the context and the assumptions you are willing to make. It is always recommended to do some prior predictive sampling and sensitivity analysis to check the impact of the priors on the posterior. We skip this here for the sake of simplicity. If you are not sure about specific priors, the {class}`MMM <pymc_marketing.mmm.delayed_saturated_mmm.MMM>` class has some default priors that you can use as a starting point."
1152+
"**Remark:** For the prior specification there is no right or wrong answer. It all depends on the data, the context and the assumptions you are willing to make. It is always recommended to do some prior predictive sampling and sensitivity analysis to check the impact of the priors on the posterior. We skip this here for the sake of simplicity. If you are not sure about specific priors, the {class}`MMM <pymc_marketing.mmm.mmm.MMM>` class has some default priors that you can use as a starting point."
11541153
]
11551154
},
11561155
{
11571156
"cell_type": "markdown",
11581157
"metadata": {},
11591158
"source": [
1160-
"Model sampler allows specifying set of parameters that will be passed to fit the same way as the `kwargs` are getting passed so far. It doesn't disable the fit kwargs, but rather extend them, to enable customizable and preservable configuration. By default the sampler_config for {class}`MMM <pymc_marketing.mmm.delayed_saturated_mmm.MMM>` is empty. But if you'd like to use it, you can define it like showed below: "
1159+
"Model sampler allows specifying set of parameters that will be passed to fit the same way as the `kwargs` are getting passed so far. It doesn't disable the fit kwargs, but rather extend them, to enable customizable and preservable configuration. By default the sampler_config for {class}`MMM <pymc_marketing.mmm.mmm.MMM>` is empty. But if you'd like to use it, you can define it like showed below: "
11611160
]
11621161
},
11631162
{
@@ -1173,7 +1172,7 @@
11731172
"cell_type": "markdown",
11741173
"metadata": {},
11751174
"source": [
1176-
"Now we are ready to use the {class}`MMM <pymc_marketing.mmm.delayed_saturated_mmm.MMM>` class to define the model."
1175+
"Now we are ready to use the {class}`MMM <pymc_marketing.mmm.mmm.MMM>` class to define the model."
11771176
]
11781177
},
11791178
{
@@ -1186,15 +1185,14 @@
11861185
" model_config=my_model_config,\n",
11871186
" sampler_config=my_sampler_config,\n",
11881187
" date_column=\"date_week\",\n",
1189-
" adstock=\"geometric\",\n",
1190-
" saturation=\"logistic\",\n",
1188+
" adstock=GeometricAdstock(l_max=8),\n",
1189+
" saturation=LogisticSaturation(),\n",
11911190
" channel_columns=[\"x1\", \"x2\"],\n",
11921191
" control_columns=[\n",
11931192
" \"event_1\",\n",
11941193
" \"event_2\",\n",
11951194
" \"t\",\n",
11961195
" ],\n",
1197-
" adstock_max_lag=8,\n",
11981196
" yearly_seasonality=2,\n",
11991197
")"
12001198
]
@@ -6348,7 +6346,7 @@
63486346
"cell_type": "markdown",
63496347
"metadata": {},
63506348
"source": [
6351-
"The {func}`fit_result <pymc_marketing.mmm.delayed_saturated_mmm.MMM.fit_result>` attribute contains the `pymc` trace object."
6349+
"The {func}`fit_result <pymc_marketing.mmm.mmm.MMM.fit_result>` attribute contains the `pymc` trace object."
63526350
]
63536351
},
63546352
{
@@ -9400,7 +9398,7 @@
94009398
"cell_type": "markdown",
94019399
"metadata": {},
94029400
"source": [
9403-
"The results look great! We therefore successfully recovered the true values from the data generation process. We have also seen how easy is to use the {class}`MMM <pymc_marketing.mmm.delayed_saturated_mmm.MMM>` class to fit media mix models! It takes over the model specification and the media transformations, while having all the flexibility of `pymc`!"
9401+
"The results look great! We therefore successfully recovered the true values from the data generation process. We have also seen how easy is to use the {class}`MMM <pymc_marketing.mmm.mmm.MMM>` class to fit media mix models! It takes over the model specification and the media transformations, while having all the flexibility of `pymc`!"
94049402
]
94059403
},
94069404
{
@@ -10443,7 +10441,7 @@
1044310441
"metadata": {
1044410442
"hide_input": false,
1044510443
"kernelspec": {
10446-
"display_name": "Python 3",
10444+
"display_name": "Python 3 (ipykernel)",
1044710445
"language": "python",
1044810446
"name": "python3"
1044910447
},

docs/source/notebooks/mmm/mmm_lift_test.ipynb

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"import pandas as pd\n",
5151
"import pymc as pm\n",
5252
"\n",
53-
"from pymc_marketing.mmm import MMM\n",
53+
"from pymc_marketing.mmm import MMM, GeometricAdstock, LogisticSaturation\n",
5454
"from pymc_marketing.mmm.transformers import logistic_saturation"
5555
]
5656
},
@@ -228,9 +228,8 @@
228228
"mmm = MMM(\n",
229229
" date_column=\"date\",\n",
230230
" channel_columns=[\"channel 1\", \"channel 2\"],\n",
231-
" adstock_max_lag=6,\n",
232-
" adstock=\"geometric\",\n",
233-
" saturation=\"logistic\",\n",
231+
" adstock=GeometricAdstock(l_max=6),\n",
232+
" saturation=LogisticSaturation(),\n",
234233
")"
235234
]
236235
},
@@ -1795,7 +1794,7 @@
17951794
],
17961795
"source": [
17971796
"%load_ext watermark\n",
1798-
"%watermark -n -u -v -iv -w -p pymc_marketing,pytensor"
1797+
"%watermark -n -u -v -iv -w -p pymc_marketing -p pytensor"
17991798
]
18001799
}
18011800
],
@@ -1815,7 +1814,7 @@
18151814
"name": "python",
18161815
"nbconvert_exporter": "python",
18171816
"pygments_lexer": "ipython3",
1818-
"version": "3.12.4"
1817+
"version": "3.10.14"
18191818
}
18201819
},
18211820
"nbformat": 4,

docs/source/notebooks/mmm/mmm_roas.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
"import seaborn as sns\n",
7070
"\n",
7171
"from pymc_marketing.hsgp_kwargs import HSGPKwargs\n",
72-
"from pymc_marketing.mmm.delayed_saturated_mmm import (\n",
72+
"from pymc_marketing.mmm import (\n",
7373
" MMM,\n",
7474
" GeometricAdstock,\n",
7575
" LogisticSaturation,\n",

docs/source/notebooks/mmm/mmm_time_varying_media_example.ipynb

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
"import pymc as pm\n",
8080
"import seaborn as sns\n",
8181
"\n",
82-
"from pymc_marketing.mmm import MMM\n",
82+
"from pymc_marketing.mmm import MMM, GeometricAdstock, MichaelisMentenSaturation\n",
8383
"from pymc_marketing.prior import Prior\n",
8484
"\n",
8585
"warnings.filterwarnings(\"ignore\")\n",
@@ -292,10 +292,9 @@
292292
" date_column=\"date_week\",\n",
293293
" channel_columns=[\"x1\", \"x2\"],\n",
294294
" control_columns=[\"event_1\", \"event_2\"],\n",
295-
" adstock_max_lag=adstock_max_lag,\n",
296295
" yearly_seasonality=yearly_seasonality,\n",
297-
" adstock=\"geometric\",\n",
298-
" saturation=\"michaelis_menten\",\n",
296+
" adstock=GeometricAdstock(l_max=adstock_max_lag),\n",
297+
" saturation=MichaelisMentenSaturation(),\n",
299298
" time_varying_media=True,\n",
300299
")"
301300
]
@@ -4443,10 +4442,9 @@
44434442
" date_column=\"date_week\",\n",
44444443
" channel_columns=[\"x1\", \"x2\"],\n",
44454444
" control_columns=[\"event_1\", \"event_2\"],\n",
4446-
" adstock_max_lag=adstock_max_lag,\n",
44474445
" yearly_seasonality=yearly_seasonality,\n",
4448-
" adstock=\"geometric\",\n",
4449-
" saturation=\"michaelis_menten\",\n",
4446+
" adstock=GeometricAdstock(l_max=adstock_max_lag),\n",
4447+
" saturation=MichaelisMentenSaturation(),\n",
44504448
")\n",
44514449
"\n",
44524450
"basic_mmm.fit(\n",
@@ -4686,10 +4684,9 @@
46864684
" date_column=\"date_week\",\n",
46874685
" channel_columns=[\"x1\", \"x2\"],\n",
46884686
" control_columns=[\"event_1\", \"event_2\"],\n",
4689-
" adstock_max_lag=adstock_max_lag,\n",
46904687
" yearly_seasonality=yearly_seasonality,\n",
4691-
" adstock=\"geometric\",\n",
4692-
" saturation=\"michaelis_menten\",\n",
4688+
" adstock=GeometricAdstock(l_max=adstock_max_lag),\n",
4689+
" saturation=MichaelisMentenSaturation(),\n",
46934690
" time_varying_media=True,\n",
46944691
")"
46954692
]
@@ -9385,10 +9382,9 @@
93859382
" date_column=\"date_week\",\n",
93869383
" channel_columns=[\"x1\", \"x2\"],\n",
93879384
" control_columns=[\"event_1\", \"event_2\"],\n",
9388-
" adstock_max_lag=adstock_max_lag,\n",
93899385
" yearly_seasonality=yearly_seasonality,\n",
9390-
" adstock=\"geometric\",\n",
9391-
" saturation=\"michaelis_menten\",\n",
9386+
" adstock=GeometricAdstock(l_max=adstock_max_lag),\n",
9387+
" saturation=MichaelisMentenSaturation(),\n",
93929388
" time_varying_media=True,\n",
93939389
")"
93949390
]
@@ -9444,7 +9440,7 @@
94449440
"name": "python",
94459441
"nbconvert_exporter": "python",
94469442
"pygments_lexer": "ipython3",
9447-
"version": "3.10.13"
9443+
"version": "3.10.14"
94489444
}
94499445
},
94509446
"nbformat": 4,

docs/source/notebooks/mmm/mmm_tvp_example.ipynb

Lines changed: 342 additions & 425 deletions
Large diffs are not rendered by default.

docs/source/uml/classes_mmm.png

-298 KB
Loading

docs/source/uml/packages_mmm.png

-17.1 KB
Loading

pymc_marketing/mmm/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""Marketing Mix Models (MMM)."""
1515

16-
from pymc_marketing.mmm import base, delayed_saturated_mmm, preprocessing, validating
16+
from pymc_marketing.mmm import base, mmm, preprocessing, validating
1717
from pymc_marketing.mmm.base import BaseValidateMMM, MMMModelBuilder
1818
from pymc_marketing.mmm.components.adstock import (
1919
AdstockTransformation,
@@ -37,8 +37,8 @@
3737
register_saturation_transformation,
3838
saturation_from_dict,
3939
)
40-
from pymc_marketing.mmm.delayed_saturated_mmm import MMM, DelayedSaturatedMMM
4140
from pymc_marketing.mmm.fourier import MonthlyFourier, YearlyFourier
41+
from pymc_marketing.mmm.mmm import MMM
4242
from pymc_marketing.mmm.preprocessing import (
4343
preprocessing_method_X,
4444
preprocessing_method_y,
@@ -49,7 +49,6 @@
4949
"AdstockTransformation",
5050
"BaseValidateMMM",
5151
"DelayedAdstock",
52-
"DelayedSaturatedMMM",
5352
"GeometricAdstock",
5453
"HillSaturation",
5554
"HillSaturationSigmoid",
@@ -71,7 +70,7 @@
7170
"register_adstock_transformation",
7271
"YearlyFourier",
7372
"base",
74-
"delayed_saturated_mmm",
73+
"mmm",
7574
"preprocessing",
7675
"preprocessing_method_X",
7776
"preprocessing_method_y",

pymc_marketing/mmm/components/adstock.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ def function(self, x, alpha):
5252
5353
"""
5454

55-
import warnings
56-
5755
import numpy as np
5856
import xarray as xr
5957
from pydantic import Field, InstanceOf, validate_call
@@ -345,40 +343,3 @@ def adstock_from_dict(data: dict) -> AdstockTransformation:
345343
if "priors" in data:
346344
data["priors"] = {k: Prior.from_json(v) for k, v in data["priors"].items()}
347345
return cls(**data)
348-
349-
350-
def _get_adstock_function(
351-
function: str | AdstockTransformation,
352-
**kwargs,
353-
) -> AdstockTransformation:
354-
"""Get an adstock function.
355-
356-
Helper for use in the MMM to get an adstock function from the if registered.
357-
"""
358-
if isinstance(function, AdstockTransformation):
359-
return function
360-
361-
elif isinstance(function, str):
362-
if function not in ADSTOCK_TRANSFORMATIONS:
363-
raise ValueError(
364-
f"Unknown adstock function: {function}. Choose from {list(ADSTOCK_TRANSFORMATIONS.keys())}"
365-
)
366-
367-
if kwargs:
368-
msg = (
369-
"The preferred method of initializing a "
370-
"lagging function is to use the class directly. "
371-
"String support will deprecate in 0.9.0."
372-
)
373-
warnings.warn(
374-
msg,
375-
DeprecationWarning,
376-
stacklevel=1,
377-
)
378-
379-
return ADSTOCK_TRANSFORMATIONS[function](**kwargs)
380-
381-
else:
382-
raise ValueError(
383-
f"Unknown adstock function: {function}. Choose from {list(ADSTOCK_TRANSFORMATIONS.keys())}"
384-
)

0 commit comments

Comments
 (0)