Skip to content

Commit c90e961

Browse files
committed
Check for statically known rank
Parametrize tests
1 parent 169f7f5 commit c90e961

File tree

2 files changed

+63
-47
lines changed

2 files changed

+63
-47
lines changed

tensorflow_probability/python/stats/sample_stats.py

+6
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,12 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
915915
low_indices = high_indices // 2
916916
else:
917917
low_indices = tf.convert_to_tensor(low_indices)
918+
919+
indices_rank = tf.get_static_value(ps.rank(low_indices))
920+
x_rank = tf.get_static_value(ps.rank(x))
921+
if indices_rank is None or x_rank is None:
922+
raise ValueError("`indices` and `x` ranks must be statically known.")
923+
918924
# Broadcast indices together.
919925
high_indices = high_indices + tf.zeros_like(low_indices)
920926
low_indices = low_indices + tf.zeros_like(high_indices)

tensorflow_probability/python/stats/sample_stats_test.py

+57-47
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@
1515
"""Tests for Sample Stats Ops."""
1616

1717
# Dependency imports
18-
import functools
18+
import itertools
19+
1920
import numpy as np
2021
import tensorflow.compat.v1 as tf1
2122
import tensorflow.compat.v2 as tf
23+
from absl.testing import parameterized
24+
from tensorflow.python.framework.errors_impl import InvalidArgumentError
25+
2226
from tensorflow_probability.python.internal import test_util
2327
from tensorflow_probability.python.stats import sample_stats
2428

@@ -721,7 +725,8 @@ def apply_func(vector, l, h):
721725
out = np.transpose(t_out, axes=dims)
722726
return out
723727

724-
def check_gaussian_windowed(self, shape, indice_shape, axis,
728+
729+
def check_gaussian_windowed_func(self, shape, indice_shape, axis,
725730
window_func, np_func):
726731
stat_shape = np.array(shape).astype(np.int32)
727732
stat_shape[axis] = 1
@@ -753,51 +758,56 @@ def check_gaussian_windowed(self, shape, indice_shape, axis,
753758
def _make_dynamic_shape(self, x):
754759
return tf1.placeholder_with_default(x, shape=(None,)*len(x.shape))
755760

756-
def check_windowed(self, func, numpy_func):
757-
check_fn = functools.partial(self.check_gaussian_windowed,
758-
window_func=func, np_func=numpy_func)
759-
check_fn((64, 4, 8), (128, 1, 1), axis=0)
760-
check_fn((64, 4, 8), (32, 1, 1), axis=0)
761-
check_fn((64, 4, 8), (32, 4, 1), axis=0)
762-
check_fn((64, 4, 8), (32, 4, 8), axis=0)
763-
check_fn((64, 4, 8), (64, 4, 8), axis=0)
764-
check_fn((64, 4, 8), (128, 1), axis=0)
765-
check_fn((64, 4, 8), (32,), axis=0)
766-
check_fn((64, 4, 8), (32, 4), axis=0)
767-
768-
check_fn((64, 4, 8), (64, 64, 1), axis=1)
769-
check_fn((64, 4, 8), (1, 64, 1), axis=1)
770-
check_fn((64, 4, 8), (64, 2, 8), axis=1)
771-
check_fn((64, 4, 8), (64, 4, 8), axis=1)
772-
check_fn((64, 4, 8), (16,), axis=1)
773-
check_fn((64, 4, 8), (1, 64), axis=1)
774-
775-
check_fn((64, 4, 8), (64, 4, 64), axis=2)
776-
check_fn((64, 4, 8), (1, 1, 64), axis=2)
777-
check_fn((64, 4, 8), (64, 4, 4), axis=2)
778-
check_fn((64, 4, 8), (1, 1, 4), axis=2)
779-
check_fn((64, 4, 8), (64, 4, 8), axis=2)
780-
check_fn((64, 4, 8), (16,), axis=2)
781-
check_fn((64, 4, 8), (1, 4), axis=2)
782-
check_fn((64, 4, 8), (64, 4), axis=2)
783-
784-
with self.assertRaises(Exception):
785-
# Non broadcastable shapes
786-
check_fn((64, 4, 8), (4, 1, 4), axis=2)
787-
788-
with self.assertRaises(Exception):
789-
# Non broadcastable shapes
790-
check_fn((64, 4, 8), (2, 4), axis=2)
791-
792-
def test_windowed_mean(self):
793-
self.check_windowed(func=sample_stats.windowed_mean, numpy_func=np.mean)
794-
795-
def test_windowed_mean_graph(self):
796-
func = tf.function(sample_stats.windowed_mean)
797-
self.check_windowed(func=func, numpy_func=np.mean)
798-
799-
def test_windowed_variance(self):
800-
self.check_windowed(func=sample_stats.windowed_variance, numpy_func=np.var)
761+
@parameterized.named_parameters(*[(
762+
f"{np_func.__name__} shape={a} indices_shape={b} axis={axis}", a, b, axis,
763+
tf_func, np_func) for a, (b, axis), (tf_func, np_func) in
764+
itertools.product([(64, 4, 8), ],
765+
[((128, 1, 1), 0),
766+
((32, 1, 1), 0),
767+
((32, 4, 1), 0),
768+
((32, 4, 8), 0),
769+
((64, 4, 8), 0),
770+
((128, 1), 0),
771+
((32,), 0),
772+
((32, 4), 0),
773+
774+
((64, 64, 1), 1),
775+
((1, 64, 1), 1),
776+
((64, 2, 8), 1),
777+
((64, 4, 8), 1),
778+
((16,), 1),
779+
((1, 64), 1),
780+
781+
((64, 4, 64), 2),
782+
((1, 1, 64), 2),
783+
((64, 4, 4), 2),
784+
((1, 1, 4), 2),
785+
((64, 4, 8), 2),
786+
((16,), 2),
787+
((1, 4), 2),
788+
((64, 4), 2)],
789+
[
790+
(sample_stats.windowed_mean, np.mean),
791+
(sample_stats.windowed_variance, np.var)
792+
])])
793+
def test_windowed(self, shape, indice_shape, axis, window_func, np_func):
794+
self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func,
795+
np_func)
796+
797+
798+
@parameterized.named_parameters(*[(
799+
f"{np_func.__name__} shape={a} indices_shape={b} axis={axis}", a, b, axis,
800+
tf_func, np_func) for a, (b, axis), (tf_func, np_func) in
801+
itertools.product([(64, 4, 8), ],
802+
[((4, 1, 4), 2), ((2, 4), 2)],
803+
[(sample_stats.windowed_mean, np.mean),
804+
(sample_stats.windowed_variance, np.var)])])
805+
def test_non_broadcastable_shapes(self, shape, indice_shape, axis,
806+
window_func, np_func):
807+
with self.assertRaisesRegexp((IndexError, ValueError, InvalidArgumentError),
808+
'^shape mismatch|Incompatible shapes'):
809+
self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func,
810+
np_func)
801811

802812

803813
@test_util.test_all_tf_execution_regimes

0 commit comments

Comments
 (0)