1010 ALL_STATE_AUX_DIM ,
1111 ALL_STATE_DIM ,
1212 ETS_SEASONAL_DIM ,
13+ OBS_STATE_AUX_DIM ,
1314 OBS_STATE_DIM ,
1415)
1516
@@ -176,12 +177,15 @@ class BayesianETS(PyMCStateSpace):
176177 def __init__ (
177178 self ,
178179 order : tuple [str , str , str ] | None = None ,
180+ endog_names : str | list [str ] | None = None ,
181+ k_endog : int = 1 ,
179182 trend : bool = True ,
180183 damped_trend : bool = False ,
181184 seasonal : bool = False ,
182185 seasonal_periods : int | None = None ,
183186 measurement_error : bool = False ,
184187 use_transformed_parameterization : bool = False ,
188+ dense_innovation_covariance : bool = False ,
185189 filter_type : str = "standard" ,
186190 verbose : bool = True ,
187191 ):
@@ -214,13 +218,26 @@ def __init__(
214218 if self .seasonal and self .seasonal_periods is None :
215219 raise ValueError ("If seasonal is True, seasonal_periods must be provided." )
216220
221+ if endog_names is not None :
222+ endog_names = list (endog_names )
223+ k_endog = len (endog_names )
224+ else :
225+ endog_names = [f"data_{ i } " for i in range (k_endog )] if k_endog > 1 else ["data" ]
226+
227+ self .endog_names = endog_names
228+
229+ if dense_innovation_covariance and k_endog == 1 :
230+ dense_innovation_covariance = False
231+
232+ self .dense_innovation_covariance = dense_innovation_covariance
233+
217234 k_states = (
218235 2
219236 + int (trend )
220237 + int (seasonal ) * (seasonal_periods if seasonal_periods is not None else 0 )
221- )
222- k_posdef = 1
223- k_endog = 1
238+ ) * k_endog
239+
240+ k_posdef = k_endog
224241
225242 super ().__init__ (
226243 k_endog ,
@@ -243,6 +260,7 @@ def param_names(self):
243260 "gamma" ,
244261 "phi" ,
245262 "sigma_state" ,
263+ "state_cov" ,
246264 "sigma_obs" ,
247265 ]
248266 if not self .trend :
@@ -256,6 +274,11 @@ def param_names(self):
256274 if not self .measurement_error :
257275 names .remove ("sigma_obs" )
258276
277+ if self .dense_innovation_covariance :
278+ names .remove ("sigma_state" )
279+ else :
280+ names .remove ("state_cov" )
281+
259282 return names
260283
261284 @property
@@ -283,27 +306,34 @@ def param_info(self) -> dict[str, dict[str, Any]]:
283306 "constraints" : "Positive" ,
284307 },
285308 "alpha" : {
286- "shape" : None ,
309+ "shape" : None if self . k_endog == 1 else ( self . k_endog ,) ,
287310 "constraints" : "0 < alpha < 1" ,
288311 },
289312 "beta" : {
290- "shape" : None ,
313+ "shape" : None if self . k_endog == 1 else ( self . k_endog ,) ,
291314 "constraints" : "0 < beta < 1"
292315 if not self .use_transformed_parameterization
293316 else "0 < beta < alpha" ,
294317 },
295318 "gamma" : {
296- "shape" : None ,
319+ "shape" : None if self . k_endog == 1 else ( self . k_endog ,) ,
297320 "constraints" : "0 < gamma< 1"
298321 if not self .use_transformed_parameterization
299322 else "0 < gamma < (1 - alpha)" ,
300323 },
301324 "phi" : {
302- "shape" : None ,
325+ "shape" : None if self . k_endog == 1 else ( self . k_endog ,) ,
303326 "constraints" : "0 < phi < 1" ,
304327 },
305328 }
306329
330+ if self .dense_innovation_covariance :
331+ del info ["sigma_state" ]
332+ info ["state_cov" ] = {
333+ "shape" : (self .k_posdef , self .k_posdef ),
334+ "constraints" : "Positive Semi-definite" ,
335+ }
336+
307337 for name in self .param_names :
308338 info [name ]["dims" ] = self .param_dims .get (name , None )
309339
@@ -317,15 +347,22 @@ def state_names(self):
317347 if self .seasonal :
318348 states += [f"L{ i } .season" for i in range (1 , self .seasonal_periods + 1 )]
319349
350+ if self .k_endog > 1 :
351+ states = [f"{ name } _{ state } " for name in self .endog_names for state in states ]
352+
320353 return states
321354
322355 @property
323356 def observed_states (self ):
324- return [ "data" ]
357+ return self . endog_names
325358
326359 @property
327360 def shock_names (self ):
328- return ["innovation" ]
361+ return (
362+ ["innovation" ]
363+ if self .k_endog == 1
364+ else [f"{ name } _innovation" for name in self .endog_names ]
365+ )
329366
330367 @property
331368 def param_dims (self ):
@@ -339,11 +376,23 @@ def param_dims(self):
339376 "seasonal_param" : (ETS_SEASONAL_DIM ,),
340377 }
341378
379+ if self .dense_innovation_covariance :
380+ del coord_map ["sigma_state" ]
381+ coord_map ["state_cov" ] = (OBS_STATE_DIM , OBS_STATE_AUX_DIM )
382+
342383 if self .k_endog == 1 :
343384 coord_map ["sigma_state" ] = None
344385 coord_map ["sigma_obs" ] = None
345386 coord_map ["initial_level" ] = None
346387 coord_map ["initial_trend" ] = None
388+ else :
389+ coord_map ["alpha" ] = (OBS_STATE_DIM ,)
390+ coord_map ["beta" ] = (OBS_STATE_DIM ,)
391+ coord_map ["gamma" ] = (OBS_STATE_DIM ,)
392+ coord_map ["phi" ] = (OBS_STATE_DIM ,)
393+ coord_map ["initial_seasonal" ] = (OBS_STATE_DIM , ETS_SEASONAL_DIM )
394+ coord_map ["seasonal_param" ] = (OBS_STATE_DIM , ETS_SEASONAL_DIM )
395+
347396 if not self .measurement_error :
348397 del coord_map ["sigma_obs" ]
349398 if not self .seasonal :
@@ -360,6 +409,8 @@ def coords(self) -> dict[str, Sequence]:
360409 return coords
361410
362411 def make_symbolic_graph (self ) -> None :
412+ k_states_each = self .k_states // self .k_endog
413+
363414 P0 = self .make_and_register_variable (
364415 "P0" , shape = (self .k_states , self .k_states ), dtype = floatX
365416 )
@@ -368,21 +419,37 @@ def make_symbolic_graph(self) -> None:
368419 initial_level = self .make_and_register_variable (
369420 "initial_level" , shape = (self .k_endog ,) if self .k_endog > 1 else (), dtype = floatX
370421 )
371- self .ssm ["initial_state" , 1 ] = initial_level
422+
423+ initial_states = [pt .zeros (k_states_each ) for _ in range (self .k_endog )]
424+ if self .k_endog == 1 :
425+ initial_states = [pt .set_subtensor (initial_states [0 ][1 ], initial_level )]
426+ else :
427+ initial_states = [
428+ pt .set_subtensor (initial_state [1 ], initial_level [i ])
429+ for i , initial_state in enumerate (initial_states )
430+ ]
372431
373432 # The shape of R can be pre-allocated, then filled with the required parameters
374- R = pt .zeros ((self .k_states , self .k_posdef ))
433+ R = pt .zeros ((self .k_states // self .k_endog , 1 ))
434+
435+ alpha = self .make_and_register_variable (
436+ "alpha" , shape = () if self .k_endog == 1 else (self .k_endog ,), dtype = floatX
437+ )
375438
376- alpha = self .make_and_register_variable ("alpha" , shape = (), dtype = floatX )
377- R = pt .set_subtensor (R [1 , 0 ], alpha ) # and l_t = ... + alpha * e_t
439+ if self .k_endog == 1 :
440+ # The R[0, 0] entry needs to be adjusted for a shift in the time indices. Consider the (A, N, N) model:
441+ # y_t = l_{t-1} + e_t
442+ # l_t = l_{t-1} + alpha * e_t
443+ R_list = [pt .set_subtensor (R [1 , 0 ], alpha )] # and l_t = ... + alpha * e_t
378444
379- # The R[0, 0] entry needs to be adjusted for a shift in the time indices. Consider the (A, N, N) model:
380- # y_t = l_{t-1} + e_t
381- # l_t = l_{t-1} + alpha * e_t
382- # We want the first equation to be in terms of time t on the RHS, because our observation equation is always
383- # y_t = Z @ x_t. Re-arranging equation 2, we get l_{t-1} = l_t - alpha * e_t --> y_t = l_t + e_t - alpha * e_t
384- # --> y_t = l_t + (1 - alpha) * e_t
385- R = pt .set_subtensor (R [0 , :], (1 - alpha ))
445+ # We want the first equation to be in terms of time t on the RHS, because our observation equation is always
446+ # y_t = Z @ x_t. Re-arranging equation 2, we get l_{t-1} = l_t - alpha * e_t --> y_t = l_t + e_t - alpha * e_t
447+ # --> y_t = l_t + (1 - alpha) * e_t
448+ R_list = [pt .set_subtensor (R [0 , :], (1 - alpha )) for R in R_list ]
449+ else :
450+ # If there are multiple endog, clone the basic R matrix and modify the appropriate entries
451+ R_list = [pt .set_subtensor (R [1 , 0 ], alpha [i ]) for i in range (self .k_endog )]
452+ R_list = [pt .set_subtensor (R [0 , :], (1 - alpha [i ])) for i , R in enumerate (R_list )]
386453
387454 # Shock and level component always exists, the base case is e_t = e_t and l_t = l_{t-1}
388455 T_base = pt .as_tensor_variable (np .array ([[0.0 , 0.0 ], [0.0 , 1.0 ]]))
@@ -391,77 +458,134 @@ def make_symbolic_graph(self) -> None:
391458 initial_trend = self .make_and_register_variable (
392459 "initial_trend" , shape = (self .k_endog ,) if self .k_endog > 1 else (), dtype = floatX
393460 )
394- self .ssm ["initial_state" , 2 ] = initial_trend
395461
396- beta = self .make_and_register_variable ("beta" , shape = (), dtype = floatX )
462+ if self .k_endog == 1 :
463+ initial_states = [pt .set_subtensor (initial_states [0 ][2 ], initial_trend )]
464+ else :
465+ initial_states = [
466+ pt .set_subtensor (initial_state [2 ], initial_trend [i ])
467+ for i , initial_state in enumerate (initial_states )
468+ ]
469+ beta = self .make_and_register_variable (
470+ "beta" , shape = () if self .k_endog == 1 else (self .k_endog ,), dtype = floatX
471+ )
397472 if self .use_transformed_parameterization :
398- R = pt .set_subtensor (R [2 , 0 ], beta )
473+ param = beta
474+ else :
475+ param = alpha * beta
476+ if self .k_endog == 1 :
477+ R_list = [pt .set_subtensor (R [2 , 0 ], param ) for R in R_list ]
399478 else :
400- R = pt .set_subtensor (R [2 , 0 ], alpha * beta )
479+ R_list = [ pt .set_subtensor (R [2 , 0 ], param [ i ]) for i , R in enumerate ( R_list )]
401480
402481 # If a trend is requested, we have the following transition equations (omitting the shocks):
403482 # l_t = l_{t-1} + b_{t-1}
404483 # b_t = b_{t-1}
405484 T_base = pt .as_tensor_variable (([0.0 , 0.0 , 0.0 ], [0.0 , 1.0 , 1.0 ], [0.0 , 0.0 , 1.0 ]))
406485
407486 if self .damped_trend :
408- phi = self .make_and_register_variable ("phi" , shape = (), dtype = floatX )
487+ phi = self .make_and_register_variable (
488+ "phi" , shape = () if self .k_endog == 1 else (self .k_endog ,), dtype = floatX
489+ )
409490 # We are always in the case where we have a trend, so we can add the dampening parameter to T_base defined
410491 # in that branch. Transition equations become:
411492 # l_t = l_{t-1} + phi * b_{t-1}
412493 # b_t = phi * b_{t-1}
413- T_base = pt .set_subtensor (T_base [1 :, 2 ], phi )
494+ if self .k_endog > 1 :
495+ T_base = [pt .set_subtensor (T_base [1 :, 2 ], phi [i ]) for i in range (self .k_endog )]
496+ else :
497+ T_base = pt .set_subtensor (T_base [1 :, 2 ], phi )
414498
415- T_components = [T_base ]
499+ T_components = (
500+ [T_base for _ in range (self .k_endog )] if not isinstance (T_base , list ) else T_base
501+ )
416502
417503 if self .seasonal :
418504 initial_seasonal = self .make_and_register_variable (
419- "initial_seasonal" , shape = (self .seasonal_periods ,), dtype = floatX
505+ "initial_seasonal" ,
506+ shape = (self .seasonal_periods ,)
507+ if self .k_endog == 1
508+ else (self .k_endog , self .seasonal_periods ),
509+ dtype = floatX ,
420510 )
421-
422- self .ssm ["initial_state" , 2 + int (self .trend ) :] = initial_seasonal
423-
424- gamma = self .make_and_register_variable ("gamma" , shape = (), dtype = floatX )
425-
426- if self .use_transformed_parameterization :
427- param = gamma
511+ if self .k_endog == 1 :
512+ initial_states = [
513+ pt .set_subtensor (initial_states [0 ][2 + int (self .trend ) :], initial_seasonal )
514+ ]
428515 else :
429- param = (1 - alpha ) * gamma
516+ initial_states = [
517+ pt .set_subtensor (initial_state [2 + int (self .trend ) :], initial_seasonal [i ])
518+ for i , initial_state in enumerate (initial_states )
519+ ]
430520
431- R = pt .set_subtensor (R [2 + int (self .trend ), 0 ], param )
521+ gamma = self .make_and_register_variable (
522+ "gamma" , shape = () if self .k_endog == 1 else (self .k_endog ,), dtype = floatX
523+ )
432524
525+ param = gamma if self .use_transformed_parameterization else (1 - alpha ) * gamma
433526 # Additional adjustment to the R[0, 0] position is required. Start from:
434527 # y_t = l_{t-1} + s_{t-m} + e_t
435528 # l_t = l_{t-1} + alpha * e_t
436529 # s_t = s_{t-m} + gamma * e_t
437530 # Solve for l_{t-1} and s_{t-m} in terms of l_t and s_t, then substitute into the observation equation:
438531 # y_t = l_t + s_t - alpha * e_t - gamma * e_t + e_t --> y_t = l_t + s_t + (1 - alpha - gamma) * e_t
439- R = pt .set_subtensor (R [0 , 0 ], R [0 , 0 ] - param )
532+
533+ if self .k_endog == 1 :
534+ R_list = [pt .set_subtensor (R [2 + int (self .trend ), 0 ], param ) for R in R_list ]
535+ R_list = [pt .set_subtensor (R [0 , 0 ], R [0 , 0 ] - param ) for R in R_list ]
536+
537+ else :
538+ R_list = [
539+ pt .set_subtensor (R [2 + int (self .trend ), 0 ], param [i ])
540+ for i , R in enumerate (R_list )
541+ ]
542+ R_list = [
543+ pt .set_subtensor (R [0 , 0 ], R [0 , 0 ] - param [i ]) for i , R in enumerate (R_list )
544+ ]
440545
441546 # The seasonal component is always going to look like a TimeFrequency structural component, see that
442547 # docstring for more details
443- T_seasonal = pt .eye (self .seasonal_periods , k = - 1 )
444- T_seasonal = pt .set_subtensor (T_seasonal [0 , - 1 ], 1.0 )
445- T_components += [T_seasonal ]
548+ T_seasonals = [pt .eye (self .seasonal_periods , k = - 1 ) for _ in range (self .k_endog )]
549+ T_seasonals = [pt .set_subtensor (T_seasonal [0 , - 1 ], 1.0 ) for T_seasonal in T_seasonals ]
550+
551+ # Organize the components so it goes T1, T_seasonal_1, T2, T_seasonal_2, etc.
552+ T_components = [
553+ matrix [i ] for i in range (self .k_endog ) for matrix in [T_components , T_seasonals ]
554+ ]
446555
447- self .ssm ["selection" ] = R
556+ x0 = pt .concatenate (initial_states , axis = 0 )
557+ R = pt .linalg .block_diag (* R_list )
558+
559+ self .ssm ["initial_state" ] = x0
560+ self .ssm ["selection" ] = pt .specify_shape (R , shape = (self .k_states , self .k_posdef ))
448561
449562 T = pt .linalg .block_diag (* T_components )
450563 self .ssm ["transition" ] = pt .specify_shape (T , (self .k_states , self .k_states ))
451564
452- Z = np .zeros ((self .k_endog , self .k_states ))
453- Z [0 , 0 ] = 1.0 # innovation
454- Z [0 , 1 ] = 1.0 # level
455- if self .seasonal :
456- Z [0 , 2 + int (self .trend )] = 1.0
565+ Zs = [np .zeros ((self .k_endog , self .k_states // self .k_endog )) for _ in range (self .k_endog )]
566+ for i , Z in enumerate (Zs ):
567+ Z [i , 0 ] = 1.0 # innovation
568+ Z [i , 1 ] = 1.0 # level
569+ if self .seasonal :
570+ Z [i , 2 + int (self .trend )] = 1.0
571+
572+ Z = pt .concatenate (Zs , axis = 1 )
573+
457574 self .ssm ["design" ] = Z
458575
459576 # Set up the state covariance matrix
460- state_cov_idx = ("state_cov" , * np .diag_indices (self .k_posdef ))
461- state_cov = self .make_and_register_variable (
462- "sigma_state" , shape = () if self .k_posdef == 1 else (self .k_posdef ,), dtype = floatX
463- )
464- self .ssm [state_cov_idx ] = state_cov ** 2
577+ if self .dense_innovation_covariance :
578+ state_cov = self .make_and_register_variable (
579+ "state_cov" , shape = (self .k_posdef , self .k_posdef ), dtype = floatX
580+ )
581+ self .ssm ["state_cov" ] = state_cov
582+
583+ else :
584+ state_cov_idx = ("state_cov" , * np .diag_indices (self .k_posdef ))
585+ state_cov = self .make_and_register_variable (
586+ "sigma_state" , shape = () if self .k_posdef == 1 else (self .k_posdef ,), dtype = floatX
587+ )
588+ self .ssm [state_cov_idx ] = state_cov ** 2
465589
466590 if self .measurement_error :
467591 obs_cov_idx = ("obs_cov" , * np .diag_indices (self .k_endog ))
0 commit comments