Skip to content

Commit 64c1e60

Browse files
committed
Fucking gpt
1 parent 7ead617 commit 64c1e60

File tree

3 files changed

+764
-54
lines changed

3 files changed

+764
-54
lines changed

causalpy/pymc_models.py

Lines changed: 85 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -523,24 +523,30 @@ class BayesianBasisExpansionTimeSeries(PyMCModel):
523523
----------
524524
n_order : int, optional
525525
The number of Fourier components for the yearly seasonality. Defaults to 3.
526+
Only used if seasonality_component is None.
526527
n_changepoints_trend : int, optional
527528
The number of changepoints for the linear trend component. Defaults to 10.
529+
Only used if trend_component is None.
530+
prior_sigma : float, optional
531+
Prior standard deviation for the observation noise. Defaults to 5.
532+
trend_component : Optional[Any], optional
533+
A custom trend component model. If None, the default pymc-marketing LinearTrend component is used.
534+
Must have an `apply(time_data)` method that returns a PyMC tensor.
535+
seasonality_component : Optional[Any], optional
536+
A custom seasonality component model. If None, the default pymc-marketing YearlyFourier component is used.
537+
Must have an `apply(time_data)` method that returns a PyMC tensor.
528538
sample_kwargs : dict, optional
529539
A dictionary of kwargs that get unpacked and passed to the
530540
:func:`pymc.sample` function. Defaults to an empty dictionary.
531-
trend_component : Optional[Any], optional
532-
A custom trend component model. If None, the default pymc-marketing trend component is used.
533-
seasonality_component : Optional[Any], optional
534-
A custom seasonality component model. If None, the default pymc-marketing seasonality `YearlyFourier` component is used.
535541
""" # noqa: W605
536542

537543
def __init__(
538544
self,
539545
n_order: int = 3,
540546
n_changepoints_trend: int = 10,
541547
prior_sigma: float = 5,
542-
# Removed trend_component and seasonality_component for now to simplify
543-
# They can be added back if pymc-marketing is a hard dependency or via other logic
548+
trend_component: Optional[Any] = None,
549+
seasonality_component: Optional[Any] = None,
544550
sample_kwargs: Optional[Dict[str, Any]] = None,
545551
):
546552
super().__init__(sample_kwargs=sample_kwargs)
@@ -552,9 +558,74 @@ def __init__(
552558
self._first_fit_timestamp: Optional[pd.Timestamp] = None
553559
self._exog_var_names: Optional[List[str]] = None
554560

555-
# pymc-marketing components will be initialized in build_model
556-
# self._yearly_fourier = None
557-
# self._linear_trend = None
561+
# Store custom components (fix the bug where they were swapped)
562+
self._custom_trend_component = trend_component
563+
self._custom_seasonality_component = seasonality_component
564+
565+
# Initialize and validate components
566+
self._trend_component = None
567+
self._seasonality_component = None
568+
self._validate_and_initialize_components()
569+
570+
def _validate_and_initialize_components(self):
571+
"""
572+
Validate and initialize trend and seasonality components.
573+
This separates validation from model building for cleaner code.
574+
"""
575+
# Validate pymc-marketing availability if using default components
576+
if (
577+
self._custom_trend_component is None
578+
or self._custom_seasonality_component is None
579+
):
580+
try:
581+
from pymc_marketing.mmm import LinearTrend, YearlyFourier
582+
583+
self._PymcMarketingLinearTrend = LinearTrend
584+
self._PymcMarketingYearlyFourier = YearlyFourier
585+
except ImportError:
586+
raise ImportError(
587+
"pymc-marketing is required when using default trend or seasonality components. "
588+
"Please install it with `pip install pymc-marketing` or provide custom components."
589+
)
590+
591+
# Validate custom components have required methods
592+
if self._custom_trend_component is not None:
593+
if not hasattr(self._custom_trend_component, "apply"):
594+
raise ValueError(
595+
"Custom trend_component must have an 'apply' method that accepts time data "
596+
"and returns a PyMC tensor."
597+
)
598+
599+
if self._custom_seasonality_component is not None:
600+
if not hasattr(self._custom_seasonality_component, "apply"):
601+
raise ValueError(
602+
"Custom seasonality_component must have an 'apply' method that accepts time data "
603+
"and returns a PyMC tensor."
604+
)
605+
606+
def _get_trend_component(self):
607+
"""Get the trend component, creating default if needed."""
608+
if self._custom_trend_component is not None:
609+
return self._custom_trend_component
610+
611+
# Create default trend component
612+
if self._trend_component is None:
613+
self._trend_component = self._PymcMarketingLinearTrend(
614+
n_changepoints=self.n_changepoints_trend
615+
)
616+
return self._trend_component
617+
618+
def _get_seasonality_component(self):
619+
"""Get the seasonality component, creating default if needed."""
620+
if self._custom_seasonality_component is not None:
621+
return self._custom_seasonality_component
622+
623+
# Create default seasonality component
624+
if self._seasonality_component is None:
625+
self._seasonality_component = self._PymcMarketingYearlyFourier(
626+
n_order=self.n_order
627+
)
628+
return self._seasonality_component
558629

559630
def _prepare_time_and_exog_features(
560631
self,
@@ -665,9 +736,6 @@ def build_model(
665736

666737
# Get exog_names from coords["coeffs"] if X_exog_array is present
667738
exog_names_from_coords = coords.get("coeffs")
668-
# This will be further processed into a list by _prepare_time_and_exog_features
669-
# if isinstance(exog_names_from_coords, str): # Handle single coeff name
670-
# exog_names_from_coords = [exog_names_from_coords]
671739

672740
(
673741
time_for_trend,
@@ -738,44 +806,19 @@ def build_model(
738806
"t_season_data", time_for_seasonality, dims="obs_ind", mutable=True
739807
)
740808

741-
# Attempt to import and instantiate pymc_marketing components here
742-
_PymcMarketingLinearTrend = None
743-
_PymcMarketingYearlyFourier = None
744-
pymc_marketing_available = False
745-
try:
746-
from pymc_marketing.mmm import LinearTrend as PymcMLinearTrend
747-
from pymc_marketing.mmm import YearlyFourier as PymcMYearlyFourier
748-
749-
_PymcMarketingLinearTrend = PymcMLinearTrend
750-
_PymcMarketingYearlyFourier = PymcMYearlyFourier
751-
pymc_marketing_available = True
752-
except ImportError:
753-
# pymc-marketing is not available. This is handled conditionally below.
754-
pass
755-
756-
if not pymc_marketing_available:
757-
raise ImportError(
758-
"pymc-marketing is required. "
759-
"Please install it with `pip install pymc-marketing`."
760-
)
761-
762-
# Instantiate components for this specific build_model call
763-
local_yearly_fourier = _PymcMarketingYearlyFourier(n_order=self.n_order)
764-
local_linear_trend = _PymcMarketingLinearTrend(
765-
n_changepoints=self.n_changepoints_trend
766-
)
809+
# Get validated components (no more ugly imports in build_model!)
810+
trend_component_instance = self._get_trend_component()
811+
seasonality_component_instance = self._get_seasonality_component()
767812

768813
# Seasonal component
769814
season_component = pm.Deterministic(
770815
"season_component",
771-
local_yearly_fourier.apply(t_season_data), # Use local instance
816+
seasonality_component_instance.apply(t_season_data),
772817
dims="obs_ind",
773818
)
774819

775820
# Trend component
776-
trend_component_values = local_linear_trend.apply(
777-
t_trend_data
778-
) # Use local instance
821+
trend_component_values = trend_component_instance.apply(t_trend_data)
779822
trend_component = pm.Deterministic(
780823
"trend_component",
781824
trend_component_values,

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)