Skip to content

Internal change #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tree_math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
wrap,
unwrap,
)
from tree_math._src.structs import struct
from tree_math._src.vector import Vector, VectorMixin
import tree_math.numpy

Expand Down
66 changes: 66 additions & 0 deletions tree_math/_src/structs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Helpers for constructing data classes that are JAX and tree-math enabled."""

import dataclasses
import jax
from tree_math._src.vector import VectorMixin


def struct(cls):
"""Class decorator that enables JAX function transforms as well as tree math.

Decorating a class with `@struct` makes it a dataclass that is compatible
with arithmetic infix operators like `+`, `-`, `*` and `/`. The decorated
class is also a valid pytree, making it compatible with JAX function
transformations such as `jit` and `grad`.

Example usage:

```
@struct
class Point:
x: float
y: float

a = Point(0., 1.)
b = Point(1., 1.)

a + 3 * b # Point(3., 4.)

def norm_squared(pt):
return pt.x**2 + pt.y**2

jax.jit(jax.grad(norm))(b) # Point(2., 2.)
```

Args:
cls: a class, written with the same syntax as a `dataclass`.

Returns:
A wrapped version of `cls` that implements dataclass, pytree and tree_math
functionality.
"""
@property
def fields(self):
return dataclasses.fields(self)

def asdict(self):
return {field.name: getattr(self, field.name) for field in self.fields}

def astuple(self):
return tuple(getattr(self, field.name) for field in self.fields)

def tree_flatten(self):
return self.astuple(), None

@classmethod
def tree_unflatten(cls, _, children):
return cls(*children)

cls_as_struct = type(cls.__name__,
(VectorMixin, dataclasses.dataclass(cls)),
{'fields': fields,
'asdict': asdict,
'astuple': astuple,
'tree_flatten': tree_flatten,
'tree_unflatten': tree_unflatten})
return jax.tree_util.register_pytree_node_class(cls_as_struct)
94 changes: 94 additions & 0 deletions tree_math/_src/structs_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Tests for global_circulation.structs."""

from typing import Union

from absl.testing import absltest
from absl.testing import parameterized

import jax
import jax.numpy as jnp
import numpy as np
import tree_math

ArrayLike = Union[jnp.ndarray, np.ndarray, float]


@tree_math.struct
class TestStruct:
a: ArrayLike
b: ArrayLike


class StructsTest(parameterized.TestCase):

@parameterized.named_parameters(
dict(testcase_name='Scalars', x=TestStruct(1., 2.)),
dict(testcase_name='Arrays', x=TestStruct(np.eye(10), np.ones([3, 4, 5])))
)
def testFlattenUnflatten(self, x):
leaves, structure = jax.tree_flatten(x)
y = jax.tree_unflatten(structure, leaves)
np.testing.assert_allclose(x.a, y.a)
np.testing.assert_allclose(x.b, y.b)

@parameterized.named_parameters(
dict(testcase_name='Addition',
x=TestStruct(1., 2.),
y=TestStruct(3., 4.),
operation=lambda x, y: x + y,
expected=TestStruct(4., 6.)),
dict(testcase_name='Subtraction',
x=TestStruct(1., 2.),
y=TestStruct(3., 4.),
operation=lambda x, y: x - y,
expected=TestStruct(-2., -2.)),
dict(testcase_name='Multiplication',
x=TestStruct(1., 2.),
y=TestStruct(3., 4.),
operation=lambda x, y: x * y,
expected=TestStruct(3., 8.)),
dict(testcase_name='Division',
x=TestStruct(1., 2.),
y=TestStruct(3., 4.),
operation=lambda x, y: x / y,
expected=TestStruct(1 / 3, 2 / 4)),
)
def testInfixOperator(self, x, y, operation, expected):
actual = operation(x, y)
np.testing.assert_allclose(expected.a, actual.a)
np.testing.assert_allclose(expected.b, actual.b)

@parameterized.named_parameters(
dict(testcase_name='Product',
x=TestStruct(1., 2.),
operation=lambda x: x.a * x.b,
expected=TestStruct(2., 1.)),
dict(testcase_name='SquaredNorm',
x=TestStruct(1., 2.),
operation=lambda x: x.a**2 + x.b**2,
expected=TestStruct(2., 4.)),
)
def testGrad(self, x, operation, expected):
actual = jax.grad(operation)(x)
np.testing.assert_allclose(expected.a, actual.a)
np.testing.assert_allclose(expected.b, actual.b)

@parameterized.named_parameters(
dict(testcase_name='Sum',
x=TestStruct(1., 2.),
y=TestStruct(3., 4.),
operation=lambda x, y: 3 * x + 2 * y),
dict(testcase_name='Product',
x=TestStruct(1., 2.),
y=TestStruct(3., 4.),
operation=lambda x, y: x * y),
)
def testJit(self, x, y, operation):
jitted = jax.jit(operation)(x, y)
unjitted = operation(x, y)
np.testing.assert_allclose(jitted.a, unjitted.a)
np.testing.assert_allclose(jitted.b, unjitted.b)


if __name__ == '__main__':
absltest.main()
11 changes: 7 additions & 4 deletions tree_math/_src/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def g(*args3):

def broadcasting_map(func, *args):
"""Like tree_map, but scalar arguments are broadcast to all leaves."""
static_argnums = [i for i, x in enumerate(args) if not isinstance(x, VectorMixin)]
static_argnums = [
i for i, x in enumerate(args) if not isinstance(x, VectorMixin)
]
func2, vector_args = _argnums_partial(func, args, static_argnums)
for arg in args:
if not isinstance(arg, VectorMixin):
Expand Down Expand Up @@ -113,7 +115,8 @@ def dot(left, right, *, precision="highest"):
Resulting dot product (scalar).
"""
if not isinstance(left, VectorMixin) or not isinstance(right, VectorMixin):
raise TypeError("matmul arguments must both be tree_math.VectorMixin objects")
raise TypeError(
"matmul arguments must both be tree_math.VectorMixin objects")

def _vector_dot(a, b):
return jnp.dot(jnp.ravel(a), jnp.ravel(b), precision=precision)
Expand All @@ -122,6 +125,7 @@ def _vector_dot(a, b):
parts = map(_vector_dot, left_values, right_values)
return functools.reduce(operator.add, parts)


class VectorMixin:
"""A mixin class that adds a 1D vector-like behaviour to any custom pytree class."""

Expand Down Expand Up @@ -205,6 +209,7 @@ def max(self):
parts = map(jnp.max, tree_util.tree_leaves(self))
return jnp.asarray(list(parts)).max()


@tree_util.register_pytree_node_class
class Vector(VectorMixin):
"""A wrapper for treating an arbitrary pytree as a 1D vector."""
Expand All @@ -227,5 +232,3 @@ def tree_flatten(self):
@classmethod
def tree_unflatten(cls, _, args):
return cls(*args)


8 changes: 4 additions & 4 deletions tree_math/_src/vector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from absl.testing import parameterized
from jax import tree_util
import jax.numpy as jnp
import numpy as np
import tree_math as tm
from tree_math._src import test_util
import numpy as np

# pylint: disable=g-complex-comprehension

Expand Down Expand Up @@ -123,7 +123,8 @@ def test_matmul(self):
self.assertAllClose(actual, expected)

with self.assertRaisesRegex(
TypeError, "matmul arguments must both be tree_math.VectorMixin objects",
TypeError,
"matmul arguments must both be tree_math.VectorMixin objects",
):
vector1 @ jnp.ones((7,)) # pylint: disable=expression-not-assigned

Expand All @@ -149,7 +150,7 @@ def test_sum_mean_min_max(self):
self.assertTreeEqual(vector.max(), 4, check_dtypes=False)

def test_custom_class(self):

@tree_util.register_pytree_node_class
class CustomVector(tm.VectorMixin):

Expand All @@ -170,7 +171,6 @@ def tree_unflatten(cls, _, args):

v3 = v2 + v1
self.assertTreeEqual(v3, CustomVector(5, 7.0), check_dtypes=True)



if __name__ == "__main__":
Expand Down