@@ -128,32 +128,51 @@ def ss_mod_no_exog_dt(rng):
128128
129129
130130@pytest .fixture (scope = "session" )
131- def exog_ss_mod (rng ):
132- ll = st .LevelTrendComponent ()
133- reg = st .RegressionComponent (name = "exog" , state_names = ["a" , "b" , "c" ])
134- mod = (ll + reg ).build (verbose = False )
131+ def exog_data (rng ):
132+ # simulate data
133+ df = pd .DataFrame (
134+ {
135+ "date" : pd .date_range (start = "2023-05-01" , end = "2023-05-10" , freq = "D" ),
136+ "x1" : rng .choice (2 , size = 10 , replace = True ).astype (float ),
137+ "y" : rng .normal (size = (10 ,)),
138+ }
139+ )
135140
136- return mod
141+ df .loc [[1 , 3 , 9 ], ["y" ]] = np .nan
142+ return df .set_index ("date" )
137143
138144
139145@pytest .fixture (scope = "session" )
140- def exog_pymc_mod (exog_ss_mod , rng ):
141- y = rng .normal (size = (100 , 1 )).astype (floatX )
142- X = rng .normal (size = (100 , 3 )).astype (floatX )
146+ def exog_ss_mod (exog_data ):
147+ level_trend = st .LevelTrendComponent (order = 1 , innovations_order = [0 ])
148+ exog = st .RegressionComponent (
149+ name = "exog" , # Name of this exogenous variable component
150+ k_exog = 1 , # Only one exogenous variable now
151+ innovations = False , # Typically fixed effect (no stochastic evolution)
152+ state_names = exog_data [["x1" ]].columns .tolist (),
153+ )
143154
144- with pm .Model (coords = exog_ss_mod .coords ) as m :
145- exog_data = pm .Data ("data_exog" , X )
146- initial_trend = pm .Normal ("initial_trend" , dims = ["trend_state" ])
147- P0_sigma = pm .Exponential ("P0_sigma" , 1 )
148- P0 = pm .Deterministic (
149- "P0" , pt .eye (exog_ss_mod .k_states ) * P0_sigma , dims = ["state" , "state_aux" ]
155+ combined_model = level_trend + exog
156+ return combined_model .build ()
157+
158+
159+ @pytest .fixture (scope = "session" )
160+ def exog_pymc_mod (exog_ss_mod , exog_data ):
161+ # define pymc model
162+ with pm .Model (coords = exog_ss_mod .coords ) as struct_model :
163+ P0_diag = pm .Gamma ("P0_diag" , alpha = 2 , beta = 4 , dims = ["state" ])
164+ P0 = pm .Deterministic ("P0" , pt .diag (P0_diag ), dims = ["state" , "state_aux" ])
165+
166+ initial_trend = pm .Normal ("initial_trend" , mu = [0 ], sigma = [0.005 ], dims = ["trend_state" ])
167+
168+ data_exog = pm .Data (
169+ "data_exog" , exog_data ["x1" ].values [:, None ], dims = ["time" , "exog_state" ]
150170 )
151- beta_exog = pm .Normal ("beta_exog" , dims = ["exog_state" ])
171+ beta_exog = pm .Normal ("beta_exog" , mu = 0 , sigma = 1 , dims = ["exog_state" ])
152172
153- sigma_trend = pm .Exponential ("sigma_trend" , 1 , dims = ["trend_shock" ])
154- exog_ss_mod .build_statespace_graph (y , save_kalman_filter_outputs_in_idata = True )
173+ exog_ss_mod .build_statespace_graph (exog_data ["y" ])
155174
156- return m
175+ return struct_model
157176
158177
159178@pytest .fixture (scope = "session" )
@@ -844,10 +863,14 @@ def test_forecast(filter_output, mod_name, idata_name, start, end, periods, rng,
844863 assert forecast_idx [0 ] == (t0 + delta )
845864
846865
866+ @pytest .mark .filterwarnings ("ignore:Provided data contains missing values" )
867+ @pytest .mark .filterwarnings ("ignore:The RandomType SharedVariables" )
847868@pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
848- @pytest .mark .parametrize ("start" , [None , - 1 , 10 ])
869+ @pytest .mark .filterwarnings ("ignore:Skipping `CheckAndRaise` Op" )
870+ @pytest .mark .filterwarnings ("ignore:No frequency was specific on the data's DateTimeIndex." )
871+ @pytest .mark .parametrize ("start" , [None , - 1 , 5 ])
849872def test_forecast_with_exog_data (rng , exog_ss_mod , idata_exog , start ):
850- scenario = pd .DataFrame (np .zeros ((10 , 3 )), columns = ["a" , "b" , "c " ])
873+ scenario = pd .DataFrame (np .zeros ((10 , 1 )), columns = ["x1 " ])
851874 scenario .iloc [5 , 0 ] = 1e9
852875
853876 forecast_idata = exog_ss_mod .forecast (
@@ -856,14 +879,14 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
856879
857880 components = exog_ss_mod .extract_components_from_idata (forecast_idata )
858881 level = components .forecast_latent .sel (state = "LevelTrend[level]" )
859- betas = components .forecast_latent .sel (state = ["exog[a]" , "exog[b]" , "exog[c ]" ])
882+ betas = components .forecast_latent .sel (state = ["exog[x1 ]" ])
860883
861884 scenario .index .name = "time"
862885 scenario_xr = (
863886 scenario .unstack ()
864887 .to_xarray ()
865888 .rename ({"level_0" : "state" })
866- .assign_coords (state = ["exog[a]" , "exog[b]" , "exog[c ]" ])
889+ .assign_coords (state = ["exog[x1 ]" ])
867890 )
868891
869892 regression_effect = forecast_idata .forecast_observed .isel (observed_state = 0 ) - level
@@ -872,91 +895,25 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
872895 assert_allclose (regression_effect , regression_effect_expected )
873896
874897
875- @pytest .mark .filterwarnings ("ignore:Provided data contains missing values. " )
898+ @pytest .mark .filterwarnings ("ignore:Provided data contains missing values" )
876899@pytest .mark .filterwarnings ("ignore:The RandomType SharedVariables" )
877- def test_foreacast_valid_index (rng ):
900+ @pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
901+ @pytest .mark .filterwarnings ("ignore:Skipping `CheckAndRaise` Op" )
902+ @pytest .mark .filterwarnings ("ignore:No frequency was specific on the data's DateTimeIndex." )
903+ def test_foreacast_valid_index (exog_pymc_mod , exog_ss_mod , exog_data ):
878904 # Regression test for issue reported at https://github.com/pymc-devs/pymc-extras/issues/424
879-
880- index = pd .date_range (start = "2023-05-01" , end = "2025-01-29" , freq = "D" )
881- T , k = len (index ), 2
882- data = np .zeros ((T , k ))
883- idx = rng .choice (T , size = 10 , replace = False )
884- cols = rng .choice (k , size = 10 , replace = True )
885-
886- data [idx , cols ] = 1
887-
888- df_holidays = pd .DataFrame (data , index = index , columns = ["Holiday 1" , "Holiday 2" ])
889-
890- data = rng .normal (size = (T , 1 ))
891- nan_locs = rng .choice (T , size = 10 , replace = False )
892- data [nan_locs ] = np .nan
893- y = pd .DataFrame (data , index = index , columns = ["sales" ])
894-
895- level_trend = st .LevelTrendComponent (order = 1 , innovations_order = [0 ])
896- weekly_seasonality = st .TimeSeasonality (
897- season_length = 7 ,
898- state_names = ["Sun" , "Mon" , "Tues" , "Wed" , "Thu" , "Fri" , "Sat" ],
899- innovations = True ,
900- remove_first_state = False ,
901- )
902- quarterly_seasonality = st .FrequencySeasonality (season_length = 365 , n = 2 , innovations = True )
903- ar1 = st .AutoregressiveComponent (order = 1 )
904- me = st .MeasurementError ()
905-
906- exog = st .RegressionComponent (
907- name = "exog" , # Name of this exogenous variable component
908- k_exog = 2 , # Only one exogenous variable now
909- innovations = False , # Typically fixed effect (no stochastic evolution)
910- state_names = df_holidays .columns .tolist (),
911- )
912-
913- combined_model = level_trend + weekly_seasonality + quarterly_seasonality + me + ar1 + exog
914- ss_mod = combined_model .build ()
915-
916- with pm .Model (coords = ss_mod .coords ) as struct_model :
917- P0_diag = pm .Gamma ("P0_diag" , alpha = 2 , beta = 10 , dims = ["state" ])
918- P0 = pm .Deterministic ("P0" , pt .diag (P0_diag ), dims = ["state" , "state_aux" ])
919-
920- initial_trend = pm .Normal ("initial_trend" , mu = [0 ], sigma = [0.005 ], dims = ["trend_state" ])
921- # sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=1, dims=["trend_shock"]) # Applied to the level only
922-
923- Seasonal_coefs = pm .ZeroSumNormal (
924- "Seasonal[s=7]_coefs" , sigma = 0.5 , dims = ["Seasonal[s=7]_state" ]
925- ) # DOW dev. from weekly mean
926- sigma_Seasonal = pm .Gamma (
927- "sigma_Seasonal[s=7]" , alpha = 2 , beta = 1
928- ) # How much this dev. can dev.
929-
930- Frequency_coefs = pm .Normal (
931- "Frequency[s=365, n=2]" , mu = 0 , sigma = 0.5 , dims = ["Frequency[s=365, n=2]_state" ]
932- ) # amplitudes in short-term (weekly noise culprit)
933- sigma_Frequency = pm .Gamma (
934- "sigma_Frequency[s=365, n=2]" , alpha = 2 , beta = 1
935- ) # smoothness & adaptability over time
936-
937- ar_params = pm .Laplace ("ar_params" , mu = 0 , b = 0.2 , dims = ["ar_lag" ])
938- sigma_ar = pm .Gamma ("sigma_ar" , alpha = 2 , beta = 1 )
939-
940- sigma_measurement_error = pm .HalfStudentT ("sigma_MeasurementError" , nu = 3 , sigma = 1 )
941-
942- data_exog = pm .Data ("data_exog" , df_holidays .values , dims = ["time" , "exog_state" ])
943- beta_exog = pm .Normal ("beta_exog" , mu = 0 , sigma = 1 , dims = ["exog_state" ])
944-
945- ss_mod .build_statespace_graph (y , mode = "JAX" )
946-
905+ with exog_pymc_mod :
947906 idata = pm .sample_prior_predictive ()
948907
949- post = ss_mod .sample_conditional_prior (idata )
950-
951908 # Define start date and forecast period
952- start_date , n_periods = pd .to_datetime ("2024-4-15 " ), 8
909+ start_date , n_periods = pd .to_datetime ("2023-05-05 " ), 5
953910
954911 # Extract exogenous data for the forecast period
955912 scenario = {
956913 "data_exog" : pd .DataFrame (
957- df_holidays .loc [start_date :].iloc [:n_periods ], columns = df_holidays .columns
914+ exog_data [[ "x1" ]] .loc [start_date :].iloc [:n_periods ], columns = exog_data [[ "x1" ]] .columns
958915 )
959916 }
960917
961918 # Generate the forecast
962- forecasts = ss_mod .forecast (idata .prior , scenario = scenario , use_scenario_index = True )
919+ forecasts = exog_ss_mod .forecast (idata .prior , scenario = scenario , use_scenario_index = True )
0 commit comments