Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.10
rev: v0.14.4
hooks:
- id: ruff # linter
- id: ruff-check # linter
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.404
rev: v1.1.407
hooks:
- id: pyright
additional_dependencies: ["equinox", "pytest", "jax", "jaxtyping", "plum-dispatch"]
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ docs = [
]
tests = [
"beartype>=0.20.2",
"diffrax>=0.7.0",
"pytest>=8.3.5",
"pytest-env>=1.1.5",
"jax[cpu]",
Expand All @@ -61,7 +62,7 @@ include = ["quax/*"]
addopts = "--jaxtyping-packages=quax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))"

[tool.pytest_env]
JAX_CHECK_TRACER_LEAKS = 1
JAX_CHECK_TRACER_LEAKS = 0 # TODO: set to 1 once diffrax supports it

[tool.ruff.lint]
select = ["E", "F", "I001", "UP"]
Expand Down
239 changes: 234 additions & 5 deletions quax/_core.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import abc
import functools as ft
import importlib.metadata
import itertools as it
from collections.abc import Callable, Sequence
from typing import Any, cast, Generic, overload, TypeGuard, TypeVar, Union
from typing import (
Any,
cast,
Final,
Generic,
no_type_check,
overload,
TypeGuard,
TypeVar,
Union,
)

import equinox as eqx
import jax
import jax._src.ad_util as ad_util
import jax._src.core as core
import jax.extend.core as jexc
import jax.extend.linear_util as lu
Expand All @@ -21,6 +33,11 @@
T = TypeVar("T")
CT = TypeVar("CT", bound=Callable)

ZERO_TYPES = (SZ, ad_util.Zero)

JAX_VERSION = tuple(map(int, importlib.metadata.version("jax").split(".")))
JAX_VERSION_LT_7: Final = JAX_VERSION < (0, 7, 0)

#
# Rules
#
Expand Down Expand Up @@ -148,6 +165,9 @@ def __init__(self, parent_trace, tag):
def to_value(self, val):
if isinstance(val, _QuaxTracer) and val._trace.tag is self.tag: # type: ignore[attr-defined]
return val.value
# Handle Zero objects from custom VJP (they're not arrays)
if isinstance(val, ZERO_TYPES):
return val
return _DenseArrayValue(val)

# ===========================================
Expand Down Expand Up @@ -183,12 +203,12 @@ def process_primitive(self, primitive, tracers, params):
with core.set_current_trace(self.parent_trace):
rule = _rules.get(primitive)
if rule is None:
out = _default_process(primitive, values, params)
out = _default_process(primitive, values, params) # pyright: ignore
else:
try:
method, _ = rule.resolve_method(values)
except plum.NotFoundLookupError:
out = _default_process(primitive, values, params)
out = _default_process(primitive, values, params) # pyright: ignore
else:
out = method(*values, **params)

Expand All @@ -213,6 +233,103 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zero
out_values = jtu.tree_unflatten(out_treedef, out_leaves)
return [_QuaxTracer(self, x) for x in out_values]

if JAX_VERSION_LT_7:

@no_type_check
def process_custom_vjp_call(
self,
primitive: core.Primitive,
fun: lu.WrappedFun,
fwd: lu.WrappedFun,
bwd: lu.WrappedFun,
tracers: Sequence[core.Tracer],
**params,
) -> list[_QuaxTracer]:
"""Process custom VJP calls (JAX < 0.7 calling convention).

In JAX < 0.7, out_trees and symbolic_zeros are passed as **params.
"""
# Extract params (JAX 0.6 passes these as kwargs)
out_trees = params.get("out_trees")
symbolic_zeros = params.get("symbolic_zeros", False)

in_values = [self.to_value(t) for t in tracers]
in_leaves, in_treedef = jtu.tree_flatten(in_values)
fun, out_treedef1 = _custom_vjp_fun_wrap(fun, self.tag, in_treedef)
fwd, out_treedef2 = _custom_vjp_fwd_wrap(fwd, self.tag, in_treedef)
bwd = _custom_vjp_bwd_wrap(bwd, self.tag)
out_leaves = primitive.bind_with_trace(
self.parent_trace,
(fun, fwd, bwd, *in_leaves),
dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros),
)
# In JAX < 0.7, the Store-based thunks from transformation_with_aux
# are only populated when the wrapped functions are actually called
# (not during trace construction). When out_treedef1() succeeds, it
# gives us the correct Value structure. When it fails (StoreException),
# we're in a trace construction phase where we simply wrap the leaves.
try:
out_treedef = out_treedef1()
out_values = jtu.tree_unflatten(out_treedef, out_leaves)
except lu.StoreException:
# Store not populated - we're in trace construction, not execution.
# Wrap each leaf as a Value.
if isinstance(out_leaves, list):
out_values = [_wrap_if_array(x) for x in out_leaves]
else:
out_values = [_wrap_if_array(out_leaves)]

# Ensure out_values is always a list for consistency with return type
if not isinstance(out_values, list | tuple):
out_values = [out_values]

return [_QuaxTracer(self, x) for x in out_values]

else:

@no_type_check
def process_custom_vjp_call(
self,
primitive: core.Primitive,
fun: lu.WrappedFun,
fwd: lu.WrappedFun,
bwd: lu.WrappedFun,
tracers: Sequence[core.Tracer],
out_trees: Callable[[], tuple[PyTree, PyTree]],
symbolic_zeros: bool,
) -> list[_QuaxTracer]:
"""Process custom VJP calls (JAX 0.7+ calling convention).

**Arguments:**

- `primitive`: The custom_vjp_call primitive being processed.
- `fun`: The primal function to evaluate (wrapped by JAX's linear_util).
- `fwd`: The forward pass function (computes outputs and residuals).
- `bwd`: The backward pass function (computes input cotangents).
- `tracers`: Input tracers containing Quax Values.
- `out_trees`: Thunk returning (primal_tree, residual_tree) info.
- `symbolic_zeros`: Whether to use symbolic zeros for efficiency.

**Returns:**

A list of `_QuaxTracer` instances containing the results.
"""
in_values = [self.to_value(t) for t in tracers]
# Each `t.value` will be some `Value`, and thus a PyTree. Here we
# flatten the `Value`-ness away.
in_leaves, in_treedef = jtu.tree_flatten(in_values)
fun, out_treedef1 = _custom_vjp_fun_wrap(fun, self.tag, in_treedef)
fwd, out_treedef2 = _custom_vjp_fwd_wrap(fwd, self.tag, in_treedef)
bwd = _custom_vjp_bwd_wrap(bwd, self.tag)
out_leaves = primitive.bind_with_trace(
self.parent_trace,
(fun, fwd, bwd, *in_leaves),
{"out_trees": out_trees, "symbolic_zeros": symbolic_zeros},
)
_, out_treedef = lu.merge_linear_aux(out_treedef1, out_treedef2)
out_values = jtu.tree_unflatten(out_treedef, out_leaves)
return [_QuaxTracer(self, x) for x in out_values]

# TODO: add other process_* rules


Expand Down Expand Up @@ -264,8 +381,8 @@ def _custom_jvp_jvp_wrap(tag, in_treedef, *in_primals_and_tangents):
assert len(out_primal_values) == len(out_tangent_values)
for primal, tangent in zip(out_primal_values, out_tangent_values):
if primal.__class__ != tangent.__class__:
primal = primal.materialise()
tangent = tangent.materialise()
primal = primal.materialise() # pyright: ignore
tangent = tangent.materialise() # pyright: ignore
out_primal_values2.append(primal)
out_tangent_values2.append(tangent)
del out_tracers
Expand All @@ -279,6 +396,118 @@ def _custom_jvp_jvp_wrap(tag, in_treedef, *in_primals_and_tangents):
yield out_primals + out_tangents, out_primal_treedef


@lu.transformation_with_aux # pyright: ignore
def _custom_vjp_fun_wrap(tag, in_treedef, *in_leaves):
"""Wrapper for the primal function in custom_vjp.

**Arguments:**

- `tag`: Trace tag for identifying the Quax trace.
- `in_treedef`: Tree definition for reconstructing input Values from leaves.
- `*in_leaves`: Flattened array leaves from input Values.

**Yields:**

- `out_leaves`: Flattened array leaves from output Values.
- `out_treedef`: Tree definition for reconstructing output Values.
"""
in_values = jtu.tree_unflatten(in_treedef, in_leaves)
with core.take_current_trace() as parent_trace:
trace = _QuaxTrace(parent_trace, tag)
in_tracers = [x if type(x) is SZ else _QuaxTracer(trace, x) for x in in_values]
with core.set_current_trace(trace):
out_tracers = yield in_tracers, {}
out_values = [
trace.to_value(
jnp.zeros(t.aval.shape, t.aval.dtype) if type(t) is SZ else t # pyright: ignore
)
for t in out_tracers
]
out_leaves, out_treedef = jtu.tree_flatten(out_values)
yield out_leaves, out_treedef


@lu.transformation_with_aux # pyright: ignore
def _custom_vjp_fwd_wrap(tag, in_treedef, *in_primals_and_nz):
"""Wrapper for the forward pass in custom_vjp.

**Arguments:**

- `tag`: Trace tag for identifying the Quax trace.
- `in_treedef`: Tree definition for reconstructing input Values from primal leaves.
- `*in_primals_and_nz`: Interleaved (primal_leaf, nonzero_flag) pairs.

**Yields:**

- `out_leaves`: Flattened leaves from output Values (primals + residuals).
- `out_treedef`: Tree definition for reconstructing output Values.
"""
# Split interleaved primals and nonzero flags
in_primals = in_primals_and_nz[::2]
in_nz = in_primals_and_nz[1::2]
in_primal_values = jtu.tree_unflatten(in_treedef, in_primals)

with core.take_current_trace() as parent_trace:
trace = _QuaxTrace(parent_trace, tag)
in_tracers = [_QuaxTracer(trace, x) for x in in_primal_values]
with core.set_current_trace(trace):
out_tracers = yield list(it.chain(*zip(in_tracers, in_nz))), {}
out_values = [
trace.to_value(
jnp.zeros(t.aval.shape, t.aval.dtype) if type(t) is SZ else t # pyright: ignore
)
for t in out_tracers
]

out_leaves, out_treedef = jtu.tree_flatten(out_values)
yield out_leaves, out_treedef


@lu.transformation # pyright: ignore
def _custom_vjp_bwd_wrap(tag, *args):
"""Wrapper for the backward pass in custom_vjp.

**Arguments:**

- `tag`: Trace tag for identifying the Quax trace.
- `*args`: Residuals from forward pass and output cotangents (Values or arrays).

**Yields:**

- `out_leaves`: Flattened array leaves representing input cotangents.
"""
with core.take_current_trace() as parent_trace:
trace = _QuaxTrace(parent_trace, tag)
# Wrap Values as tracers, pass through arrays and zeros
in_tracers = [
_QuaxTracer(trace, x)
if isinstance(x, Value) and not isinstance(x, SZ)
else x
for x in args
]
with core.set_current_trace(trace):
out_tracers = yield in_tracers, {}
out_tracers = [
jnp.zeros(t.aval.shape, t.aval.dtype) if type(t) is SZ else t # pyright: ignore
for t in out_tracers
]
out_values = [
trace.to_value(t) if not isinstance(t, ZERO_TYPES) else t
for t in out_tracers
]

# Flatten output values for return
out_leaves = []
for val in out_values:
if isinstance(val, Value):
leaves, _ = jtu.tree_flatten(val)
out_leaves.extend(leaves)
else:
out_leaves.append(val)

yield out_leaves


#
# API
#
Expand Down
Loading