Skip to content

Commit 1565642

Browse files
committed
minor improvemeents
1 parent 66f32fc commit 1565642

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

pymc_marketing/mmm/delayed_saturated_mmm.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ def output_var(self):
101101
def _generate_and_preprocess_model_data( # type: ignore
102102
self, X: Union[pd.DataFrame, pd.Series], y: Union[pd.Series, np.ndarray]
103103
) -> None:
104-
"""
105-
Applies preprocessing to the data before fitting the model.
106-
if validate is True, it will check if the data is valid for the model.
104+
"""Applies preprocessing to the data before fitting the model.
105+
106+
If validate is True, it will check if the data is valid for the model.
107107
sets self.model_coords based on provided dataset
108108
109109
Parameters
@@ -390,6 +390,7 @@ def build_model(
390390
)
391391

392392
mu_var = intercept + channel_contributions.sum(axis=-1)
393+
393394
if (
394395
self.control_columns is not None
395396
and len(self.control_columns) > 0
@@ -417,6 +418,7 @@ def build_model(
417418
)
418419

419420
mu_var += control_contributions.sum(axis=-1)
421+
420422
if (
421423
hasattr(self, "fourier_columns")
422424
and self.fourier_columns is not None
@@ -494,10 +496,12 @@ def channel_contributions_forward_pass(
494496
self, channel_data: npt.NDArray[np.float_]
495497
) -> npt.NDArray[np.float_]:
496498
"""Evaluate the channel contribution for a given channel data and a fitted model, ie. the forward pass.
499+
497500
Parameters
498501
----------
499502
channel_data : array-like
500503
Input channel data. Result of all the preprocessing steps.
504+
501505
Returns
502506
-------
503507
array-like
@@ -753,7 +757,7 @@ class DelayedSaturatedMMM(
753757
from pymc_marketing.mmm import DelayedSaturatedMMM
754758
755759
data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/datasets/mmm_example.csv"
756-
data = pd.read_csv(data_url, parse_dates=['date_week'])
760+
data = pd.read_csv(data_url, parse_dates=["date_week"])
757761
758762
mmm = DelayedSaturatedMMM(
759763
date_column="date_week",
@@ -833,6 +837,7 @@ def channel_contributions_forward_pass(
833837
) -> npt.NDArray[np.float_]:
834838
"""Evaluate the channel contribution for a given channel data and a fitted model, ie. the forward pass.
835839
We return the contribution in the original scale of the target variable.
840+
836841
Parameters
837842
----------
838843
channel_data : array-like
@@ -855,7 +860,8 @@ def channel_contributions_forward_pass(
855860
def get_channel_contributions_forward_pass_grid(
856861
self, start: float, stop: float, num: int
857862
) -> DataArray:
858-
"""Generate a grid of scaled channel contributions for a given grid of share values.
863+
"""Generate a grid of scaled channel contributions for a given grid of shared values.
864+
859865
Parameters
860866
----------
861867
start : float
@@ -914,6 +920,7 @@ def plot_channel_contributions_grid(
914920
absolute_xrange : bool, optional
915921
If True, the x-axis is in absolute values (input units), otherwise it is in
916922
relative percentage values, by default False.
923+
917924
Returns
918925
-------
919926
plt.Figure

0 commit comments

Comments
 (0)