pip install pytreeclassInstall development version
pip install git+https://github.com/ASEM000/pytreeclasspytreeclass is a JAX-compatible class builder to create and operate on stateful JAX PyTrees in a performant and intuitive way, by building on familiar concepts found in numpy, dataclasses, and others.
See documentation and π³ Common recipes to check if this library is a good fit for your work. If you find the package useful consider giving it a π.
import jax
import jax.numpy as jnp
import pytreeclass as tc
@tc.autoinit
class Tree(tc.TreeClass):
a: float = 1.0
b: tuple[float, float] = (2.0, 3.0)
c: jax.Array = jnp.array([4.0, 5.0, 6.0])
def __call__(self, x):
return self.a + self.b[0] + self.c + x
tree = Tree()
mask = jax.tree_map(lambda x: x > 5, tree)
tree = tree\
.at["a"].set(100.0)\
.at["b"][0].set(10.0)\
.at[mask].set(100.0)
print(tree)
# Tree(a=100.0, b=(10.0, 3.0), c=[ 4. 5. 100.])
print(tc.tree_diagram(tree))
# Tree
# βββ .a=100.0
# βββ .b:tuple
# β βββ [0]=10.0
# β βββ [1]=3.0
# βββ .c=f32[3](ΞΌ=36.33, Ο=45.02, β[4.00,100.00])
print(tc.tree_summary(tree))
# βββββββ¬βββββββ¬ββββββ¬βββββββ
# βName βType βCountβSize β
# βββββββΌβββββββΌββββββΌβββββββ€
# β.a βfloat β1 β β
# βββββββΌβββββββΌββββββΌβββββββ€
# β.b[0]βfloat β1 β β
# βββββββΌβββββββΌββββββΌβββββββ€
# β.b[1]βfloat β1 β β
# βββββββΌβββββββΌββββββΌβββββββ€
# β.c βf32[3]β3 β12.00Bβ
# βββββββΌβββββββΌββββββΌβββββββ€
# βΞ£ βTree β6 β12.00Bβ
# βββββββ΄βββββββ΄ββββββ΄βββββββ
# ** pass it to jax transformations **
# works with jit, grad, vmap, etc.
@jax.jit
@jax.grad
def sum_tree(tree: Tree, x):
return sum(tree(x))
print(sum_tree(tree, 1.0))
# Tree(a=3.0, b=(3.0, 0.0), c=[1. 1. 1.]) |
Under jax.jit jax requires states to be explicit, this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using TreeClass no need to separate the instance variables ; instead the whole instance is passed as a state.
Using the following pattern,Updating state functionally can be achieved under jax.jit
import jax
import pytreeclass as tc
class Counter(tc.TreeClass):
def __init__(self, calls: int = 0):
self.calls = calls
def increment(self):
self.calls += 1
counter = Counter() # Counter(calls=0) |
Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using .at. To achieve this we can use .at[method_name].__call__(*args,**kwargs), this functional call will return the value of this call and a new model instance with the update state.
@jax.jit
def update(counter):
value, new_counter = counter.at["increment"]()
return new_counter
for i in range(10):
counter = update(counter)
print(counter.calls) # 10 |
