Skip to content

Commit 46214b5

Browse files
author
tree-math authors
committed
Merge pull request #7 from cgarciae:main
PiperOrigin-RevId: 418998614
2 parents 83f4360 + f192af0 commit 46214b5

File tree

4 files changed

+64
-36
lines changed

4 files changed

+64
-36
lines changed

tree_math/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
wrap,
2121
unwrap,
2222
)
23-
from tree_math._src.vector import Vector
23+
from tree_math._src.vector import Vector, VectorMixin
2424
import tree_math.numpy
2525

2626
__version__ = '0.1.0'

tree_math/_src/numpy_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_where_all_scalars(self):
4949
actual = tnp.where(True, 1, 2)
5050
self.assertTreeEqual(actual, expected, check_dtypes=False)
5151
with self.assertRaisesRegex(
52-
TypeError, "non-tree_math.Vector argument is not a scalar",
52+
TypeError, "non-tree_math.VectorMixin argument is not a scalar",
5353
):
5454
tnp.where(True, jnp.array([1, 2]), 3)
5555

tree_math/_src/vector.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,18 @@ def g(*args3):
5353

5454
def broadcasting_map(func, *args):
5555
"""Like tree_map, but scalar arguments are broadcast to all leaves."""
56-
static_argnums = [i for i, x in enumerate(args) if not isinstance(x, Vector)]
56+
static_argnums = [i for i, x in enumerate(args) if not isinstance(x, VectorMixin)]
5757
func2, vector_args = _argnums_partial(func, args, static_argnums)
5858
for arg in args:
59-
if not isinstance(arg, Vector):
59+
if not isinstance(arg, VectorMixin):
6060
shape = jnp.shape(arg)
6161
if shape:
6262
raise TypeError(
63-
f"non-tree_math.Vector argument is not a scalar: {arg!r}"
63+
f"non-tree_math.VectorMixin argument is not a scalar: {arg!r}"
6464
)
6565
if not vector_args:
6666
return func2() # result is a scalar
67-
_flatten_together(*[arg.tree for arg in vector_args]) # check shapes
67+
_flatten_together(*[arg for arg in vector_args]) # check shapes
6868
return tree_util.tree_map(func2, *vector_args)
6969

7070

@@ -112,43 +112,22 @@ def dot(left, right, *, precision="highest"):
112112
Returns:
113113
Resulting dot product (scalar).
114114
"""
115-
if not isinstance(left, Vector) or not isinstance(right, Vector):
116-
raise TypeError("matmul arguments must both be tree_math.Vector objects")
115+
if not isinstance(left, VectorMixin) or not isinstance(right, VectorMixin):
116+
raise TypeError("matmul arguments must both be tree_math.VectorMixin objects")
117117

118118
def _vector_dot(a, b):
119119
return jnp.dot(jnp.ravel(a), jnp.ravel(b), precision=precision)
120120

121-
(left_values, right_values), _ = _flatten_together(left.tree, right.tree)
121+
(left_values, right_values), _ = _flatten_together(left, right)
122122
parts = map(_vector_dot, left_values, right_values)
123123
return functools.reduce(operator.add, parts)
124124

125-
126-
@tree_util.register_pytree_node_class
127-
class Vector:
128-
"""A wrapper for treating an arbitrary pytree as a 1D vector."""
129-
130-
def __init__(self, tree):
131-
self._tree = tree
132-
133-
@property
134-
def tree(self):
135-
return self._tree
136-
137-
# TODO(shoyer): consider casting to a common dtype?
138-
139-
def __repr__(self):
140-
return f"tree_math.Vector({self._tree!r})"
141-
142-
def tree_flatten(self):
143-
return (self.tree,), None
144-
145-
@classmethod
146-
def tree_unflatten(cls, _, args):
147-
return cls(*args)
125+
class VectorMixin:
126+
"""A mixin class that adds a 1D vector-like behaviour to any custom pytree class."""
148127

149128
@property
150129
def size(self):
151-
values = tree_util.tree_leaves(self.tree)
130+
values = tree_util.tree_leaves(self)
152131
return sum(jnp.size(value) for value in values)
153132

154133
def __len__(self):
@@ -164,7 +143,7 @@ def ndim(self):
164143

165144
@property
166145
def dtype(self):
167-
values = tree_util.tree_leaves(self.tree)
146+
values = tree_util.tree_leaves(self)
168147
return jnp.result_type(*values)
169148

170149
# comparison
@@ -225,3 +204,28 @@ def min(self):
225204
def max(self):
226205
parts = map(jnp.max, tree_util.tree_leaves(self))
227206
return jnp.asarray(list(parts)).max()
207+
208+
@tree_util.register_pytree_node_class
209+
class Vector(VectorMixin):
210+
"""A wrapper for treating an arbitrary pytree as a 1D vector."""
211+
212+
def __init__(self, tree):
213+
self._tree = tree
214+
215+
@property
216+
def tree(self):
217+
return self._tree
218+
219+
# TODO(shoyer): consider casting to a common dtype?
220+
221+
def __repr__(self):
222+
return f"tree_math.Vector({self._tree!r})"
223+
224+
def tree_flatten(self):
225+
return (self._tree,), None
226+
227+
@classmethod
228+
def tree_unflatten(cls, _, args):
229+
return cls(*args)
230+
231+

tree_math/_src/vector_test.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_arithmetic_with_scalar(self):
5858
self.assertTreeEqual(vector + 1, expected, check_dtypes=True)
5959
self.assertTreeEqual(1 + vector, expected, check_dtypes=True)
6060
with self.assertRaisesRegex(
61-
TypeError, "non-tree_math.Vector argument is not a scalar",
61+
TypeError, "non-tree_math.VectorMixin argument is not a scalar",
6262
):
6363
vector + jnp.ones((3,)) # pylint: disable=expression-not-assigned
6464

@@ -123,7 +123,7 @@ def test_matmul(self):
123123
self.assertAllClose(actual, expected)
124124

125125
with self.assertRaisesRegex(
126-
TypeError, "matmul arguments must both be tree_math.Vector objects",
126+
TypeError, "matmul arguments must both be tree_math.VectorMixin objects",
127127
):
128128
vector1 @ jnp.ones((7,)) # pylint: disable=expression-not-assigned
129129

@@ -148,6 +148,30 @@ def test_sum_mean_min_max(self):
148148
self.assertTreeEqual(vector.min(), 1, check_dtypes=False)
149149
self.assertTreeEqual(vector.max(), 4, check_dtypes=False)
150150

151+
def test_custom_class(self):
152+
153+
@tree_util.register_pytree_node_class
154+
class CustomVector(tm.VectorMixin):
155+
156+
def __init__(self, a: int, b: float):
157+
self.a = a
158+
self.b = b
159+
160+
def tree_flatten(self):
161+
return (self.a, self.b), None
162+
163+
@classmethod
164+
def tree_unflatten(cls, _, args):
165+
return cls(*args)
166+
167+
v1 = CustomVector(1, 2.0)
168+
v2 = v1 + 3
169+
self.assertTreeEqual(v2, CustomVector(4, 5.0), check_dtypes=True)
170+
171+
v3 = v2 + v1
172+
self.assertTreeEqual(v3, CustomVector(5, 7.0), check_dtypes=True)
173+
174+
151175

152176
if __name__ == "__main__":
153177
absltest.main()

0 commit comments

Comments
 (0)