From 4f8251a0c299899ca8f9bf95538da12f1878243e Mon Sep 17 00:00:00 2001 From: bjp Date: Fri, 23 Aug 2024 10:54:02 -0700 Subject: [PATCH] NumPy 2.0 related fixes. PiperOrigin-RevId: 666851860 --- tensorflow_probability/python/experimental/mcmc/BUILD | 2 +- .../python/experimental/mcmc/particle_filter_test.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/BUILD b/tensorflow_probability/python/experimental/mcmc/BUILD index 4056b28d6f..38164e1e11 100644 --- a/tensorflow_probability/python/experimental/mcmc/BUILD +++ b/tensorflow_probability/python/experimental/mcmc/BUILD @@ -556,7 +556,7 @@ multi_substrate_py_test( size = "large", srcs = ["particle_filter_test.py"], numpy_tags = ["notap"], - shard_count = 3, + shard_count = 5, deps = [ ":particle_filter", ":sequential_monte_carlo_kernel", diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py index 7b6ae37508..f2aece5917 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py @@ -476,19 +476,21 @@ def test_proposal_weights_dont_affect_marginal_likelihood(self): _, lps = self.evaluate( particle_filter.infer_trajectories( observation, - initial_state_prior=normal.Normal(loc=0., scale=1.), + initial_state_prior=normal.Normal(loc=self.dtype(0.), scale=1.), transition_fn=lambda _, x: normal.Normal(loc=x, scale=1.), observation_fn=lambda _, x: normal.Normal(loc=x, scale=1.), - initial_state_proposal=normal.Normal(loc=0., scale=5.), + initial_state_proposal=normal.Normal(loc=self.dtype(0.), scale=5.), proposal_fn=lambda _, x: normal.Normal(loc=x, scale=5.), num_particles=2048, seed=test_util.test_seed())) # Compare marginal likelihood against that # from the true (jointly normal) marginal distribution. - y1_marginal_dist = normal.Normal(loc=0., scale=np.sqrt(1. + 1.)) + y1_marginal_dist = normal.Normal(loc=0., + scale=np.sqrt(1. + 1.).astype(self.dtype)) y2_conditional_dist = ( - lambda y1: normal.Normal(loc=y1 / 2., scale=np.sqrt(5. / 2.))) + lambda y1: normal.Normal( + loc=y1 / self.dtype(2.), scale=np.sqrt(5. / 2.).astype(self.dtype))) true_lps = tf.stack( [y1_marginal_dist.log_prob(observation[0]), y2_conditional_dist(observation[0]).log_prob(observation[1])],