Skip to content

Commit 9797ea2

Browse files
authored
Implement size/ndim/__len__/repr/str/eq/hash for ShapeDtypeStruct. (jax-ml#2206)
1 parent fb7e48f commit 9797ea2

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

jax/api.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,6 +1955,30 @@ def __init__(self, shape, dtype):
19551955
self.shape = shape
19561956
self.dtype = dtype
19571957

1958+
size = property(lambda self: onp.prod(self.shape))
1959+
ndim = property(lambda self: len(self.shape))
1960+
1961+
def __len__(self):
1962+
try:
1963+
return self.shape[0]
1964+
except IndexError:
1965+
raise TypeError("len() of unsized object") # same as numpy error
1966+
1967+
def __repr__(self):
1968+
return "{}(shape={}, dtype={})".format(
1969+
type(self).__name__, self.shape, self.dtype.dtype.name)
1970+
1971+
__str__ = __repr__
1972+
1973+
def __eq__(self, other):
1974+
if not isinstance(other, ShapeDtypeStruct):
1975+
return False
1976+
else:
1977+
return (other.shape, other.dtype) == (self.shape, self.dtype)
1978+
1979+
def __hash__(self):
1980+
return hash((self.shape, self.dtype))
1981+
19581982
def eval_shape(fun, *args, **kwargs):
19591983
"""Compute the shape/dtype of ``fun(*args, **kwargs)`` without any FLOPs.
19601984

tests/api_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,26 @@ def f(pt):
891891
g = api.grad(f)(pt)
892892
self.assertIsInstance(pt, ZeroPoint)
893893

894+
@parameterized.parameters(1, 2, 3)
895+
def test_shape_dtype_struct(self, i):
896+
s = api.ShapeDtypeStruct(shape=(i, 2, 3), dtype=np.float32)
897+
self.assertEqual(s.shape, (i, 2, 3))
898+
self.assertEqual(s.dtype, np.float32)
899+
self.assertEqual(s.ndim, 3)
900+
self.assertEqual(s.size, i * 2 * 3)
901+
self.assertLen(s, i)
902+
for f in (str, repr):
903+
self.assertEqual(
904+
f(s), "ShapeDtypeStruct(shape=({}, 2, 3), dtype=float32)".format(i))
905+
906+
def test_shape_dtype_struct_scalar(self):
907+
s = api.ShapeDtypeStruct(shape=(), dtype=np.float32)
908+
self.assertEmpty(s.shape)
909+
self.assertEqual(s.size, 1)
910+
self.assertEqual(s.ndim, 0)
911+
with self.assertRaisesRegex(TypeError, "len[(][)] of unsized object"):
912+
_ = len(s)
913+
894914
def test_eval_shape(self):
895915
def fun(x, y):
896916
return np.tanh(np.dot(x, y) + 3.)

0 commit comments

Comments
 (0)