Skip to content

Commit 541c778

Browse files
pnorgaardtree-math authors
authored and
tree-math authors
committed
Create struct decorator for constructing dataclasses that are also pytrees and implement tree_math.
PiperOrigin-RevId: 446848773
1 parent 46214b5 commit 541c778

File tree

5 files changed

+172
-8
lines changed

5 files changed

+172
-8
lines changed

tree_math/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
wrap,
2121
unwrap,
2222
)
23+
from tree_math._src.structs import struct
2324
from tree_math._src.vector import Vector, VectorMixin
2425
import tree_math.numpy
2526

tree_math/_src/structs.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""Helpers for constructing data classes that are JAX and tree-math enabled."""
2+
3+
import dataclasses
4+
import jax
5+
from tree_math._src.vector import VectorMixin
6+
7+
8+
def struct(cls):
9+
"""Class decorator that enables JAX function transforms as well as tree math.
10+
11+
Decorating a class with `@struct` makes it a dataclass that is compatible
12+
with arithmetic infix operators like `+`, `-`, `*` and `/`. The decorated
13+
class is also a valid pytree, making it compatible with JAX function
14+
transformations such as `jit` and `grad`.
15+
16+
Example usage:
17+
18+
```
19+
@struct
20+
class Point:
21+
x: float
22+
y: float
23+
24+
a = Point(0., 1.)
25+
b = Point(1., 1.)
26+
27+
a + 3 * b # Point(3., 4.)
28+
29+
def norm_squared(pt):
30+
return pt.x**2 + pt.y**2
31+
32+
jax.jit(jax.grad(norm))(b) # Point(2., 2.)
33+
```
34+
35+
Args:
36+
cls: a class, written with the same syntax as a `dataclass`.
37+
38+
Returns:
39+
A wrapped version of `cls` that implements dataclass, pytree and tree_math
40+
functionality.
41+
"""
42+
@property
43+
def fields(self):
44+
return dataclasses.fields(self)
45+
46+
def asdict(self):
47+
return {field.name: getattr(self, field.name) for field in self.fields}
48+
49+
def astuple(self):
50+
return tuple(getattr(self, field.name) for field in self.fields)
51+
52+
def tree_flatten(self):
53+
return self.astuple(), None
54+
55+
@classmethod
56+
def tree_unflatten(cls, _, children):
57+
return cls(*children)
58+
59+
cls_as_struct = type(cls.__name__,
60+
(VectorMixin, dataclasses.dataclass(cls)),
61+
{'fields': fields,
62+
'asdict': asdict,
63+
'astuple': astuple,
64+
'tree_flatten': tree_flatten,
65+
'tree_unflatten': tree_unflatten})
66+
return jax.tree_util.register_pytree_node_class(cls_as_struct)

tree_math/_src/structs_test.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""Tests for global_circulation.structs."""
2+
3+
from typing import Union
4+
5+
from absl.testing import absltest
6+
from absl.testing import parameterized
7+
8+
import jax
9+
import jax.numpy as jnp
10+
import numpy as np
11+
import tree_math
12+
13+
ArrayLike = Union[jnp.ndarray, np.ndarray, float]
14+
15+
16+
@tree_math.struct
17+
class TestStruct:
18+
a: ArrayLike
19+
b: ArrayLike
20+
21+
22+
class StructsTest(parameterized.TestCase):
23+
24+
@parameterized.named_parameters(
25+
dict(testcase_name='Scalars', x=TestStruct(1., 2.)),
26+
dict(testcase_name='Arrays', x=TestStruct(np.eye(10), np.ones([3, 4, 5])))
27+
)
28+
def testFlattenUnflatten(self, x):
29+
leaves, structure = jax.tree_flatten(x)
30+
y = jax.tree_unflatten(structure, leaves)
31+
np.testing.assert_allclose(x.a, y.a)
32+
np.testing.assert_allclose(x.b, y.b)
33+
34+
@parameterized.named_parameters(
35+
dict(testcase_name='Addition',
36+
x=TestStruct(1., 2.),
37+
y=TestStruct(3., 4.),
38+
operation=lambda x, y: x + y,
39+
expected=TestStruct(4., 6.)),
40+
dict(testcase_name='Subtraction',
41+
x=TestStruct(1., 2.),
42+
y=TestStruct(3., 4.),
43+
operation=lambda x, y: x - y,
44+
expected=TestStruct(-2., -2.)),
45+
dict(testcase_name='Multiplication',
46+
x=TestStruct(1., 2.),
47+
y=TestStruct(3., 4.),
48+
operation=lambda x, y: x * y,
49+
expected=TestStruct(3., 8.)),
50+
dict(testcase_name='Division',
51+
x=TestStruct(1., 2.),
52+
y=TestStruct(3., 4.),
53+
operation=lambda x, y: x / y,
54+
expected=TestStruct(1 / 3, 2 / 4)),
55+
)
56+
def testInfixOperator(self, x, y, operation, expected):
57+
actual = operation(x, y)
58+
np.testing.assert_allclose(expected.a, actual.a)
59+
np.testing.assert_allclose(expected.b, actual.b)
60+
61+
@parameterized.named_parameters(
62+
dict(testcase_name='Product',
63+
x=TestStruct(1., 2.),
64+
operation=lambda x: x.a * x.b,
65+
expected=TestStruct(2., 1.)),
66+
dict(testcase_name='SquaredNorm',
67+
x=TestStruct(1., 2.),
68+
operation=lambda x: x.a**2 + x.b**2,
69+
expected=TestStruct(2., 4.)),
70+
)
71+
def testGrad(self, x, operation, expected):
72+
actual = jax.grad(operation)(x)
73+
np.testing.assert_allclose(expected.a, actual.a)
74+
np.testing.assert_allclose(expected.b, actual.b)
75+
76+
@parameterized.named_parameters(
77+
dict(testcase_name='Sum',
78+
x=TestStruct(1., 2.),
79+
y=TestStruct(3., 4.),
80+
operation=lambda x, y: 3 * x + 2 * y),
81+
dict(testcase_name='Product',
82+
x=TestStruct(1., 2.),
83+
y=TestStruct(3., 4.),
84+
operation=lambda x, y: x * y),
85+
)
86+
def testJit(self, x, y, operation):
87+
jitted = jax.jit(operation)(x, y)
88+
unjitted = operation(x, y)
89+
np.testing.assert_allclose(jitted.a, unjitted.a)
90+
np.testing.assert_allclose(jitted.b, unjitted.b)
91+
92+
93+
if __name__ == '__main__':
94+
absltest.main()

tree_math/_src/vector.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def g(*args3):
5353

5454
def broadcasting_map(func, *args):
5555
"""Like tree_map, but scalar arguments are broadcast to all leaves."""
56-
static_argnums = [i for i, x in enumerate(args) if not isinstance(x, VectorMixin)]
56+
static_argnums = [
57+
i for i, x in enumerate(args) if not isinstance(x, VectorMixin)
58+
]
5759
func2, vector_args = _argnums_partial(func, args, static_argnums)
5860
for arg in args:
5961
if not isinstance(arg, VectorMixin):
@@ -113,7 +115,8 @@ def dot(left, right, *, precision="highest"):
113115
Resulting dot product (scalar).
114116
"""
115117
if not isinstance(left, VectorMixin) or not isinstance(right, VectorMixin):
116-
raise TypeError("matmul arguments must both be tree_math.VectorMixin objects")
118+
raise TypeError(
119+
"matmul arguments must both be tree_math.VectorMixin objects")
117120

118121
def _vector_dot(a, b):
119122
return jnp.dot(jnp.ravel(a), jnp.ravel(b), precision=precision)
@@ -122,6 +125,7 @@ def _vector_dot(a, b):
122125
parts = map(_vector_dot, left_values, right_values)
123126
return functools.reduce(operator.add, parts)
124127

128+
125129
class VectorMixin:
126130
"""A mixin class that adds a 1D vector-like behaviour to any custom pytree class."""
127131

@@ -205,6 +209,7 @@ def max(self):
205209
parts = map(jnp.max, tree_util.tree_leaves(self))
206210
return jnp.asarray(list(parts)).max()
207211

212+
208213
@tree_util.register_pytree_node_class
209214
class Vector(VectorMixin):
210215
"""A wrapper for treating an arbitrary pytree as a 1D vector."""
@@ -227,5 +232,3 @@ def tree_flatten(self):
227232
@classmethod
228233
def tree_unflatten(cls, _, args):
229234
return cls(*args)
230-
231-

tree_math/_src/vector_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
from absl.testing import parameterized
1919
from jax import tree_util
2020
import jax.numpy as jnp
21+
import numpy as np
2122
import tree_math as tm
2223
from tree_math._src import test_util
23-
import numpy as np
2424

2525
# pylint: disable=g-complex-comprehension
2626

@@ -123,7 +123,8 @@ def test_matmul(self):
123123
self.assertAllClose(actual, expected)
124124

125125
with self.assertRaisesRegex(
126-
TypeError, "matmul arguments must both be tree_math.VectorMixin objects",
126+
TypeError,
127+
"matmul arguments must both be tree_math.VectorMixin objects",
127128
):
128129
vector1 @ jnp.ones((7,)) # pylint: disable=expression-not-assigned
129130

@@ -149,7 +150,7 @@ def test_sum_mean_min_max(self):
149150
self.assertTreeEqual(vector.max(), 4, check_dtypes=False)
150151

151152
def test_custom_class(self):
152-
153+
153154
@tree_util.register_pytree_node_class
154155
class CustomVector(tm.VectorMixin):
155156

@@ -170,7 +171,6 @@ def tree_unflatten(cls, _, args):
170171

171172
v3 = v2 + v1
172173
self.assertTreeEqual(v3, CustomVector(5, 7.0), check_dtypes=True)
173-
174174

175175

176176
if __name__ == "__main__":

0 commit comments

Comments
 (0)