diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 77648092fb..68124f88bd 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -13,6 +13,7 @@ ### Maintenance - Moved math operations out of `Rice`, `TruncatedNormal`, `Triangular` and `ZeroInflatedNegativeBinomial` `random` methods. Math operations on values returned by `draw_values` might not broadcast well, and all the `size` aware broadcasting is left to `generate_samples`. Fixes [#3481](https://github.com/pymc-devs/pymc3/issues/3481) and [#3508](https://github.com/pymc-devs/pymc3/issues/3508) - Parallelization of population steppers (`DEMetropolis`) is now set via the `cores` argument. ([#3559](https://github.com/pymc-devs/pymc3/pull/3559)) +- SMC: stabilize covariance matrix [3573](https://github.com/pymc-devs/pymc3/pull/3573) ## PyMC3 3.7 (May 29 2019) diff --git a/pymc3/step_methods/smc_utils.py b/pymc3/step_methods/smc_utils.py index 08bad0463b..9e5bea1c7b 100644 --- a/pymc3/step_methods/smc_utils.py +++ b/pymc3/step_methods/smc_utils.py @@ -44,9 +44,11 @@ def _calc_covariance(posterior, weights): Calculate trace covariance matrix based on importance weights. """ cov = np.cov(posterior, aweights=weights.ravel(), bias=False, rowvar=0) + cov = np.atleast_2d(cov) + cov += 1e-6 * np.eye(cov.shape[0]) if np.isnan(cov).any() or np.isinf(cov).any(): raise ValueError('Sample covariances not valid! Likely "draws" is too small!') - return np.atleast_2d(cov) + return cov def _tune(acc_rate, proposed, step): diff --git a/pymc3/tests/test_step.py b/pymc3/tests/test_step.py index d67c0ef1d2..e96a4c0de8 100644 --- a/pymc3/tests/test_step.py +++ b/pymc3/tests/test_step.py @@ -247,7 +247,7 @@ class TestStepMethods: # yield test doesn't work subclassing object 0.81375848, 0.81375848, 0.81375848, - 1.91238675 + 1.91238675, ] ), Metropolis: np.array( @@ -460,206 +460,206 @@ class TestStepMethods: # yield test doesn't work subclassing object ), SMC: np.array( [ - 0.85565708, - -0.20703928, - 0.60432641, - 0.82409514, - 0.66956453, - 1.8112792, - 0.50997512, - 0.01190834, - 0.11877327, - 1.04616417, - 0.35542005, - 0.97711504, - 1.08273637, - 0.12254235, - -0.2125738, - 1.90683646, - 0.76584362, - 1.61601695, - 1.26496703, - 0.72605833, - 0.27710235, - 0.59466026, - 1.48847988, - 1.48383337, - 0.85487674, - 0.40339276, - 1.11378016, - -0.01154067, - -0.24933179, - 0.04855045, - 0.44408864, - 1.07009531, - 0.71832419, - -0.02224457, - 0.15732459, - 0.74732395, - -0.55976656, - 1.83476589, - 1.13464886, - 1.04477015, - -0.8829041, - 0.68610315, - -0.51600577, - 1.06577114, - 0.72533608, - 0.26181788, - 0.37045769, - 0.49110905, - 0.95187097, - 0.57052924, - 1.18390833, - -0.28470992, - 0.5143004, - 0.36340091, - 0.26524336, - 0.91352757, - -0.16906895, - 0.02671801, - -0.62018961, - 0.13845522, - 0.69578158, - 0.82213017, - 0.95565383, - 0.57201004, - 0.66751355, - 0.74662892, - -0.18802906, - -0.16424007, - 0.67661192, - 0.986151, - 1.11037246, - 0.53367584, - 0.81646283, - 0.69093199, - 1.30967566, - 0.58455688, - -0.10754191, - -0.66843656, - 0.61473792, - 0.11205418, - 1.50795552, - 1.61304845, - 0.97329021, - 0.80782701, - 1.83144593, - 0.34256428, - 0.49090154, - 1.85297793, - 0.44832949, - 1.35766813, - 0.48916332, - 0.410038, - -0.69870943, - 0.06616812, - -0.17685433, - -0.0487383, - 1.92862324, - 0.47539572, - 1.19401681, - 0.3670901, - 2.11504383, - 1.16863035, - 0.74908135, - 0.90147245, - 0.6291441, - 0.96889664, - 0.93871964, - 0.74575969, - 0.06810336, - 0.45469347, - 0.29787682, - 0.73557892, - -0.3388827, - -0.0991328, - 1.12325585, - 0.87397644, - -1.14737408, - -0.78658091, - 0.67716022, - 0.20961362, - 0.11759984, - 0.72748548, - -0.29959649, - -0.09436443, - 0.42100225, - 0.04656646, - 1.21211555, - 0.04060845, - 1.38031545, - 0.58429818, - 0.33843531, - 0.82207289, - 0.96509587, - 1.00370899, - 1.23734919, - -0.01960951, - 0.7721088, - 0.04627471, - -0.62058523, - 0.21093904, - -0.15935501, - 0.83237845, - 0.10157936, - -0.45885173, - 1.26206955, - 1.07601436, - 1.23736132, - 0.28618097, - -0.14328022, - -0.13158901, - 0.74308368, - 0.26291343, - 0.17504558, - 0.55601578, - 1.46900503, - 0.65131007, - 0.89596352, - 0.32536798, - -0.25504495, - 0.07563569, - 1.48775514, - 0.28519783, - 0.58513482, - -0.63672688, - 1.59324146, - 0.53826815, - 0.41792749, - 0.76583018, - 0.87290581, - 0.89110704, - 0.27282461, - -0.20300455, - 1.01058543, - 0.68072852, - -0.21073928, - 1.19114065, - 0.6372328, - 0.33444015, - 1.05599084, - 0.78372828, - 1.0127235, - -0.19460124, - 1.31807913, - 0.58658129, - -0.34218648, - 0.68725616, - 0.37484537, - 2.48875271, - -0.06424102, - 0.22162396, - -0.21623175, - 0.25998439, - 0.37801803, - -0.51312636, - -0.35024508, - 1.90460979, - 0.02214471, - -0.59132265, - 0.42870423, - 0.88951751, + 0.85565848, + -0.2070422, + 0.60432617, + 0.82409693, + 0.66956559, + 1.81128223, + 0.5099755, + 0.0119065, + 0.11877237, + 1.04616407, + 0.35541975, + 0.97711646, + 1.08273746, + 0.12254112, + -0.21257513, + 1.90683915, + 0.76584417, + 1.61601906, + 1.26496997, + 0.72605814, + 0.27710155, + 0.59465936, + 1.48848202, + 1.48383457, + 0.85487729, + 0.40339297, + 1.11378062, + -0.01154052, + -0.24933346, + 0.04855092, + 0.44408811, + 1.07009768, + 0.71832534, + -0.02224531, + 0.15732427, + 0.7473228, + -0.55976844, + 1.83476852, + 1.13464918, + 1.04477006, + -0.8829072, + 0.68610441, + -0.51600679, + 1.06577287, + 0.72533541, + 0.26181682, + 0.37045784, + 0.49110896, + 0.95187099, + 0.57052884, + 1.18390954, + -0.28471075, + 0.51430074, + 0.36340121, + 0.26524266, + 0.91352896, + -0.16906962, + 0.02671763, + -0.62019011, + 0.13845477, + 0.69578153, + 0.82213032, + 0.95565471, + 0.57200968, + 0.66751333, + 0.74663059, + -0.18802928, + -0.16424154, + 0.67661238, + 0.9861513, + 1.11037445, + 0.53367436, + 0.81646116, + 0.690932, + 1.30967756, + 0.58455721, + -0.10754287, + -0.6684397, + 0.61473599, + 0.11205459, + 1.50795626, + 1.61304945, + 0.97329075, + 0.80782601, + 1.83144756, + 0.34256431, + 0.4909023, + 1.85297991, + 0.44832968, + 1.35766865, + 0.48916414, + 0.41003811, + -0.69870992, + 0.06616797, + -0.17685457, + -0.04873934, + 1.92862499, + 0.47539711, + 1.19401841, + 0.36708951, + 2.11504567, + 1.1686311, + 0.74908099, + 0.90147251, + 0.6291452, + 0.96889866, + 0.93871978, + 0.74575847, + 0.06810142, + 0.45469276, + 0.2978768, + 0.73557954, + -0.33888277, + -0.09913398, + 1.12325616, + 0.87397745, + -1.14737571, + -0.78658184, + 0.67716005, + 0.20961373, + 0.11759896, + 0.72748602, + -0.29959812, + -0.09436507, + 0.42100139, + 0.0465658, + 1.21211627, + 0.0406079, + 1.38031654, + 0.58429982, + 0.33843332, + 0.82207419, + 0.9650973, + 1.00370894, + 1.23735049, + -0.01960991, + 0.77210838, + 0.04627416, + -0.62058637, + 0.21093913, + -0.15935478, + 0.83237714, + 0.10157911, + -0.45885337, + 1.26207038, + 1.07601429, + 1.23736173, + 0.28618205, + -0.143281, + -0.13159008, + 0.74308471, + 0.26291269, + 0.17504574, + 0.55601508, + 1.46900656, + 0.65130981, + 0.89596543, + 0.32536767, + -0.25504632, + 0.07563599, + 1.48775644, + 0.28519708, + 0.58513646, + -0.63673033, + 1.5932429, + 0.53826754, + 0.41792748, + 0.7658319, + 0.87290603, + 0.89110888, + 0.27282434, + -0.20300504, + 1.01058742, + 0.68072965, + -0.21073937, + 1.19114243, + 0.63723316, + 0.3344412, + 1.05599174, + 0.78372725, + 1.01272241, + -0.19460072, + 1.3180811, + 0.58658171, + -0.34218688, + 0.68725498, + 0.37484577, + 2.48875469, + -0.06424035, + 0.22162324, + -0.21623218, + 0.25998442, + 0.37801781, + -0.51312723, + -0.35024653, + 1.90461235, + 0.02214488, + -0.59132457, + 0.42870476, + 0.88951825, ] ), }