|
15 | 15 | """Tests for Sample Stats Ops."""
|
16 | 16 |
|
17 | 17 | # Dependency imports
|
18 |
| -import functools |
| 18 | +import itertools |
| 19 | + |
19 | 20 | import numpy as np
|
20 | 21 | import tensorflow.compat.v1 as tf1
|
21 | 22 | import tensorflow.compat.v2 as tf
|
| 23 | +from absl.testing import parameterized |
| 24 | +from tensorflow.python.framework.errors_impl import InvalidArgumentError |
| 25 | + |
22 | 26 | from tensorflow_probability.python.internal import test_util
|
23 | 27 | from tensorflow_probability.python.stats import sample_stats
|
24 | 28 |
|
@@ -721,7 +725,8 @@ def apply_func(vector, l, h):
|
721 | 725 | out = np.transpose(t_out, axes=dims)
|
722 | 726 | return out
|
723 | 727 |
|
724 |
| - def check_gaussian_windowed(self, shape, indice_shape, axis, |
| 728 | + |
| 729 | + def check_gaussian_windowed_func(self, shape, indice_shape, axis, |
725 | 730 | window_func, np_func):
|
726 | 731 | stat_shape = np.array(shape).astype(np.int32)
|
727 | 732 | stat_shape[axis] = 1
|
@@ -753,51 +758,56 @@ def check_gaussian_windowed(self, shape, indice_shape, axis,
|
753 | 758 | def _make_dynamic_shape(self, x):
|
754 | 759 | return tf1.placeholder_with_default(x, shape=(None,)*len(x.shape))
|
755 | 760 |
|
756 |
| - def check_windowed(self, func, numpy_func): |
757 |
| - check_fn = functools.partial(self.check_gaussian_windowed, |
758 |
| - window_func=func, np_func=numpy_func) |
759 |
| - check_fn((64, 4, 8), (128, 1, 1), axis=0) |
760 |
| - check_fn((64, 4, 8), (32, 1, 1), axis=0) |
761 |
| - check_fn((64, 4, 8), (32, 4, 1), axis=0) |
762 |
| - check_fn((64, 4, 8), (32, 4, 8), axis=0) |
763 |
| - check_fn((64, 4, 8), (64, 4, 8), axis=0) |
764 |
| - check_fn((64, 4, 8), (128, 1), axis=0) |
765 |
| - check_fn((64, 4, 8), (32,), axis=0) |
766 |
| - check_fn((64, 4, 8), (32, 4), axis=0) |
767 |
| - |
768 |
| - check_fn((64, 4, 8), (64, 64, 1), axis=1) |
769 |
| - check_fn((64, 4, 8), (1, 64, 1), axis=1) |
770 |
| - check_fn((64, 4, 8), (64, 2, 8), axis=1) |
771 |
| - check_fn((64, 4, 8), (64, 4, 8), axis=1) |
772 |
| - check_fn((64, 4, 8), (16,), axis=1) |
773 |
| - check_fn((64, 4, 8), (1, 64), axis=1) |
774 |
| - |
775 |
| - check_fn((64, 4, 8), (64, 4, 64), axis=2) |
776 |
| - check_fn((64, 4, 8), (1, 1, 64), axis=2) |
777 |
| - check_fn((64, 4, 8), (64, 4, 4), axis=2) |
778 |
| - check_fn((64, 4, 8), (1, 1, 4), axis=2) |
779 |
| - check_fn((64, 4, 8), (64, 4, 8), axis=2) |
780 |
| - check_fn((64, 4, 8), (16,), axis=2) |
781 |
| - check_fn((64, 4, 8), (1, 4), axis=2) |
782 |
| - check_fn((64, 4, 8), (64, 4), axis=2) |
783 |
| - |
784 |
| - with self.assertRaises(Exception): |
785 |
| - # Non broadcastable shapes |
786 |
| - check_fn((64, 4, 8), (4, 1, 4), axis=2) |
787 |
| - |
788 |
| - with self.assertRaises(Exception): |
789 |
| - # Non broadcastable shapes |
790 |
| - check_fn((64, 4, 8), (2, 4), axis=2) |
791 |
| - |
792 |
| - def test_windowed_mean(self): |
793 |
| - self.check_windowed(func=sample_stats.windowed_mean, numpy_func=np.mean) |
794 |
| - |
795 |
| - def test_windowed_mean_graph(self): |
796 |
| - func = tf.function(sample_stats.windowed_mean) |
797 |
| - self.check_windowed(func=func, numpy_func=np.mean) |
798 |
| - |
799 |
| - def test_windowed_variance(self): |
800 |
| - self.check_windowed(func=sample_stats.windowed_variance, numpy_func=np.var) |
| 761 | + @parameterized.named_parameters(*[( |
| 762 | + f"{np_func.__name__} shape={a} indices_shape={b} axis={axis}", a, b, axis, |
| 763 | + tf_func, np_func) for a, (b, axis), (tf_func, np_func) in |
| 764 | + itertools.product([(64, 4, 8), ], |
| 765 | + [((128, 1, 1), 0), |
| 766 | + ((32, 1, 1), 0), |
| 767 | + ((32, 4, 1), 0), |
| 768 | + ((32, 4, 8), 0), |
| 769 | + ((64, 4, 8), 0), |
| 770 | + ((128, 1), 0), |
| 771 | + ((32,), 0), |
| 772 | + ((32, 4), 0), |
| 773 | +
|
| 774 | + ((64, 64, 1), 1), |
| 775 | + ((1, 64, 1), 1), |
| 776 | + ((64, 2, 8), 1), |
| 777 | + ((64, 4, 8), 1), |
| 778 | + ((16,), 1), |
| 779 | + ((1, 64), 1), |
| 780 | +
|
| 781 | + ((64, 4, 64), 2), |
| 782 | + ((1, 1, 64), 2), |
| 783 | + ((64, 4, 4), 2), |
| 784 | + ((1, 1, 4), 2), |
| 785 | + ((64, 4, 8), 2), |
| 786 | + ((16,), 2), |
| 787 | + ((1, 4), 2), |
| 788 | + ((64, 4), 2)], |
| 789 | + [ |
| 790 | + (sample_stats.windowed_mean, np.mean), |
| 791 | + (sample_stats.windowed_variance, np.var) |
| 792 | + ])]) |
| 793 | + def test_windowed(self, shape, indice_shape, axis, window_func, np_func): |
| 794 | + self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func, |
| 795 | + np_func) |
| 796 | + |
| 797 | + |
| 798 | + @parameterized.named_parameters(*[( |
| 799 | + f"{np_func.__name__} shape={a} indices_shape={b} axis={axis}", a, b, axis, |
| 800 | + tf_func, np_func) for a, (b, axis), (tf_func, np_func) in |
| 801 | + itertools.product([(64, 4, 8), ], |
| 802 | + [((4, 1, 4), 2), ((2, 4), 2)], |
| 803 | + [(sample_stats.windowed_mean, np.mean), |
| 804 | + (sample_stats.windowed_variance, np.var)])]) |
| 805 | + def test_non_broadcastable_shapes(self, shape, indice_shape, axis, |
| 806 | + window_func, np_func): |
| 807 | + with self.assertRaisesRegexp((IndexError, ValueError, InvalidArgumentError), |
| 808 | + '^shape mismatch|Incompatible shapes'): |
| 809 | + self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func, |
| 810 | + np_func) |
801 | 811 |
|
802 | 812 |
|
803 | 813 | @test_util.test_all_tf_execution_regimes
|
|
0 commit comments