@@ -523,24 +523,30 @@ class BayesianBasisExpansionTimeSeries(PyMCModel):
523
523
----------
524
524
n_order : int, optional
525
525
The number of Fourier components for the yearly seasonality. Defaults to 3.
526
+ Only used if seasonality_component is None.
526
527
n_changepoints_trend : int, optional
527
528
The number of changepoints for the linear trend component. Defaults to 10.
529
+ Only used if trend_component is None.
530
+ prior_sigma : float, optional
531
+ Prior standard deviation for the observation noise. Defaults to 5.
532
+ trend_component : Optional[Any], optional
533
+ A custom trend component model. If None, the default pymc-marketing LinearTrend component is used.
534
+ Must have an `apply(time_data)` method that returns a PyMC tensor.
535
+ seasonality_component : Optional[Any], optional
536
+ A custom seasonality component model. If None, the default pymc-marketing YearlyFourier component is used.
537
+ Must have an `apply(time_data)` method that returns a PyMC tensor.
528
538
sample_kwargs : dict, optional
529
539
A dictionary of kwargs that get unpacked and passed to the
530
540
:func:`pymc.sample` function. Defaults to an empty dictionary.
531
- trend_component : Optional[Any], optional
532
- A custom trend component model. If None, the default pymc-marketing trend component is used.
533
- seasonality_component : Optional[Any], optional
534
- A custom seasonality component model. If None, the default pymc-marketing seasonality `YearlyFourier` component is used.
535
541
""" # noqa: W605
536
542
537
543
def __init__ (
538
544
self ,
539
545
n_order : int = 3 ,
540
546
n_changepoints_trend : int = 10 ,
541
547
prior_sigma : float = 5 ,
542
- # Removed trend_component and seasonality_component for now to simplify
543
- # They can be added back if pymc-marketing is a hard dependency or via other logic
548
+ trend_component : Optional [ Any ] = None ,
549
+ seasonality_component : Optional [ Any ] = None ,
544
550
sample_kwargs : Optional [Dict [str , Any ]] = None ,
545
551
):
546
552
super ().__init__ (sample_kwargs = sample_kwargs )
@@ -552,9 +558,74 @@ def __init__(
552
558
self ._first_fit_timestamp : Optional [pd .Timestamp ] = None
553
559
self ._exog_var_names : Optional [List [str ]] = None
554
560
555
- # pymc-marketing components will be initialized in build_model
556
- # self._yearly_fourier = None
557
- # self._linear_trend = None
561
+ # Store custom components (fix the bug where they were swapped)
562
+ self ._custom_trend_component = trend_component
563
+ self ._custom_seasonality_component = seasonality_component
564
+
565
+ # Initialize and validate components
566
+ self ._trend_component = None
567
+ self ._seasonality_component = None
568
+ self ._validate_and_initialize_components ()
569
+
570
+ def _validate_and_initialize_components (self ):
571
+ """
572
+ Validate and initialize trend and seasonality components.
573
+ This separates validation from model building for cleaner code.
574
+ """
575
+ # Validate pymc-marketing availability if using default components
576
+ if (
577
+ self ._custom_trend_component is None
578
+ or self ._custom_seasonality_component is None
579
+ ):
580
+ try :
581
+ from pymc_marketing .mmm import LinearTrend , YearlyFourier
582
+
583
+ self ._PymcMarketingLinearTrend = LinearTrend
584
+ self ._PymcMarketingYearlyFourier = YearlyFourier
585
+ except ImportError :
586
+ raise ImportError (
587
+ "pymc-marketing is required when using default trend or seasonality components. "
588
+ "Please install it with `pip install pymc-marketing` or provide custom components."
589
+ )
590
+
591
+ # Validate custom components have required methods
592
+ if self ._custom_trend_component is not None :
593
+ if not hasattr (self ._custom_trend_component , "apply" ):
594
+ raise ValueError (
595
+ "Custom trend_component must have an 'apply' method that accepts time data "
596
+ "and returns a PyMC tensor."
597
+ )
598
+
599
+ if self ._custom_seasonality_component is not None :
600
+ if not hasattr (self ._custom_seasonality_component , "apply" ):
601
+ raise ValueError (
602
+ "Custom seasonality_component must have an 'apply' method that accepts time data "
603
+ "and returns a PyMC tensor."
604
+ )
605
+
606
+ def _get_trend_component (self ):
607
+ """Get the trend component, creating default if needed."""
608
+ if self ._custom_trend_component is not None :
609
+ return self ._custom_trend_component
610
+
611
+ # Create default trend component
612
+ if self ._trend_component is None :
613
+ self ._trend_component = self ._PymcMarketingLinearTrend (
614
+ n_changepoints = self .n_changepoints_trend
615
+ )
616
+ return self ._trend_component
617
+
618
+ def _get_seasonality_component (self ):
619
+ """Get the seasonality component, creating default if needed."""
620
+ if self ._custom_seasonality_component is not None :
621
+ return self ._custom_seasonality_component
622
+
623
+ # Create default seasonality component
624
+ if self ._seasonality_component is None :
625
+ self ._seasonality_component = self ._PymcMarketingYearlyFourier (
626
+ n_order = self .n_order
627
+ )
628
+ return self ._seasonality_component
558
629
559
630
def _prepare_time_and_exog_features (
560
631
self ,
@@ -665,9 +736,6 @@ def build_model(
665
736
666
737
# Get exog_names from coords["coeffs"] if X_exog_array is present
667
738
exog_names_from_coords = coords .get ("coeffs" )
668
- # This will be further processed into a list by _prepare_time_and_exog_features
669
- # if isinstance(exog_names_from_coords, str): # Handle single coeff name
670
- # exog_names_from_coords = [exog_names_from_coords]
671
739
672
740
(
673
741
time_for_trend ,
@@ -738,44 +806,19 @@ def build_model(
738
806
"t_season_data" , time_for_seasonality , dims = "obs_ind" , mutable = True
739
807
)
740
808
741
- # Attempt to import and instantiate pymc_marketing components here
742
- _PymcMarketingLinearTrend = None
743
- _PymcMarketingYearlyFourier = None
744
- pymc_marketing_available = False
745
- try :
746
- from pymc_marketing .mmm import LinearTrend as PymcMLinearTrend
747
- from pymc_marketing .mmm import YearlyFourier as PymcMYearlyFourier
748
-
749
- _PymcMarketingLinearTrend = PymcMLinearTrend
750
- _PymcMarketingYearlyFourier = PymcMYearlyFourier
751
- pymc_marketing_available = True
752
- except ImportError :
753
- # pymc-marketing is not available. This is handled conditionally below.
754
- pass
755
-
756
- if not pymc_marketing_available :
757
- raise ImportError (
758
- "pymc-marketing is required. "
759
- "Please install it with `pip install pymc-marketing`."
760
- )
761
-
762
- # Instantiate components for this specific build_model call
763
- local_yearly_fourier = _PymcMarketingYearlyFourier (n_order = self .n_order )
764
- local_linear_trend = _PymcMarketingLinearTrend (
765
- n_changepoints = self .n_changepoints_trend
766
- )
809
+ # Get validated components (no more ugly imports in build_model!)
810
+ trend_component_instance = self ._get_trend_component ()
811
+ seasonality_component_instance = self ._get_seasonality_component ()
767
812
768
813
# Seasonal component
769
814
season_component = pm .Deterministic (
770
815
"season_component" ,
771
- local_yearly_fourier .apply (t_season_data ), # Use local instance
816
+ seasonality_component_instance .apply (t_season_data ),
772
817
dims = "obs_ind" ,
773
818
)
774
819
775
820
# Trend component
776
- trend_component_values = local_linear_trend .apply (
777
- t_trend_data
778
- ) # Use local instance
821
+ trend_component_values = trend_component_instance .apply (t_trend_data )
779
822
trend_component = pm .Deterministic (
780
823
"trend_component" ,
781
824
trend_component_values ,
0 commit comments