Skip to content

Commit 0b71262

Browse files
axchtensorflower-gardener
authored andcommitted
Tidy up DistributionSlicingTest a little, to prevent spurious failures.
The difficulty is that the test asks to compute log_prob(sample) with validations on, and expects the results to be the same across sliced and non-sliced distributions. This CL removes two sources of error: - The non-sliced distribution, if a TransformedDistribution, will presumably trigger the bijector cache, whereas the sliced one will not. So we force-break the cache to push them into the same code path. - Defer even trying to compute the sliced log_prob until after the Eigen packetization consistency check, because packetization differences could lead to the sliced log_prob failing validation even though the non-sliced version passed. PiperOrigin-RevId: 387830762
1 parent f3b66cf commit 0b71262

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

tensorflow_probability/python/distributions/distribution_properties_test.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -485,20 +485,23 @@ def _test_slicing(self, data, dist_name, dist):
485485
# slicing the samples from the original.
486486
self.assertAllEqual(sliced_samples.shape, sliced_dist_samples.shape)
487487

488-
# Check that a sliced distribution can compute the log_prob of its own
489-
# samples (up to numerical validation errors).
488+
# Check that the sliced dist's log_prob agrees with slicing the original's
489+
# log_prob.
490+
# First, we make sure that the original sample we have passes the
491+
# original distribution's validations. We break the bijector cache here
492+
# because slicing will break it later too.
490493
with tfp_hps.no_tf_rank_errors():
491494
try:
492-
lp = self.evaluate(dist.log_prob(samples))
495+
lp = self.evaluate(dist.log_prob(
496+
samples + tf.constant(0, dtype=samples.dtype)))
493497
except tf.errors.InvalidArgumentError:
494498
# TODO(b/129271256): d.log_prob(d.sample()) should not fail
495499
# validate_args checks.
496-
# We only tolerate this case for the non-sliced dist.
500+
# `return` here passes the example. If we `hp.assume(False)`
501+
# instead, that would demand from Hypothesis that it find many
502+
# examples where this check (and the next one) passes;
503+
# empirically, it seems to complain that that's too hard.
497504
return
498-
sliced_lp = self.evaluate(sliced_dist.log_prob(sliced_samples))
499-
500-
# Check that the sliced dist's log_prob agrees with slicing the original's
501-
# log_prob.
502505

503506
# This `hp.assume` is suppressing array sizes that cause the sliced and
504507
# non-sliced distribution to follow different Eigen code paths. Those
@@ -518,6 +521,10 @@ def _test_slicing(self, data, dist_name, dist):
518521
hp.note('Non-packetization check {}'.format(all_non_packetized))
519522
hp.assume(all_packetized or all_non_packetized)
520523

524+
# Actually evaluate and test the sliced log_prob
525+
with tfp_hps.no_tf_rank_errors():
526+
sliced_lp = self.evaluate(sliced_dist.log_prob(sliced_samples))
527+
521528
self.assertAllClose(lp[slices], sliced_lp,
522529
atol=SLICING_LOGPROB_ATOL[dist_name],
523530
rtol=SLICING_LOGPROB_RTOL[dist_name])

0 commit comments

Comments
 (0)