1010 StandardFilter ,
1111 UnivariateFilter ,
1212)
13- from pymc_extras .statespace .filters .kalman_filter import BaseFilter
13+ from pymc_extras .statespace .filters .kalman_filter import BaseFilter , SquareRootFilter
1414from tests .statespace .utilities .shared_fixtures import ( # pylint: disable=unused-import
1515 rng ,
1616)
3030RTOL = 1e-6 if floatX .endswith ("64" ) else 1e-3
3131
3232standard_inout = initialize_filter (StandardFilter ())
33- # cholesky_inout = initialize_filter(CholeskyFilter ())
33+ cholesky_inout = initialize_filter (SquareRootFilter ())
3434univariate_inout = initialize_filter (UnivariateFilter ())
3535
3636f_standard = pytensor .function (* standard_inout , on_unused_input = "ignore" )
37- # f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore")
37+ f_cholesky = pytensor .function (* cholesky_inout , on_unused_input = "ignore" )
3838f_univariate = pytensor .function (* univariate_inout , on_unused_input = "ignore" )
3939
40- filter_funcs = [f_standard , f_univariate ]
40+ filter_funcs = [f_standard , f_cholesky , f_univariate ]
4141
4242filter_names = [
4343 "StandardFilter" ,
44+ "CholeskyFilter" ,
4445 "UnivariateFilter" ,
4546]
4647
@@ -229,8 +230,8 @@ def test_last_smoother_is_last_filtered(filter_func, output_idx, rng):
229230@pytest .mark .skipif (floatX == "float32" , reason = "Tests are too sensitive for float32" )
230231def test_filters_match_statsmodel_output (filter_func , filter_name , n_missing , rng ):
231232 fit_sm_mod , [data , a0 , P0 , c , d , T , Z , R , H , Q ] = nile_test_test_helper (rng , n_missing )
232- # if filter_name == "CholeskyFilter":
233- # P0 = np.linalg.cholesky(P0)
233+ if filter_name == "CholeskyFilter" :
234+ P0 = np .linalg .cholesky (P0 )
234235 inputs = [data , a0 , P0 , c , d , T , Z , R , H , Q ]
235236 outputs = filter_func (* inputs )
236237
@@ -278,8 +279,8 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob
278279 pytest .skip ("Univariate filter not stable at half precision without measurement error" )
279280
280281 fit_sm_mod , [data , a0 , P0 , c , d , T , Z , R , H , Q ] = nile_test_test_helper (rng , n_missing )
281- # if filter_name == "CholeskyFilter":
282- # P0 = np.linalg.cholesky(P0)
282+ if filter_name == "CholeskyFilter" :
283+ P0 = np .linalg .cholesky (P0 )
283284
284285 H *= int (obs_noise )
285286 inputs = [data , a0 , P0 , c , d , T , Z , R , H , Q ]
@@ -301,7 +302,7 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob
301302
302303@pytest .mark .parametrize (
303304 "filter" ,
304- [StandardFilter ],
305+ [StandardFilter , SquareRootFilter ],
305306 ids = ["standard" ],
306307)
307308def test_kalman_filter_jax (filter ):
0 commit comments