Skip to content

Conversation

@nstarman
Copy link
Contributor

I haven't implemented the full logic, like the use of _FLATTEN_SENTINEL.

@nstarman
Copy link
Contributor Author

nstarman commented Oct 13, 2025

@patrick-kidger This is ~80% of the way there. I think it's a tiny bit slower than the other PR but still gets equinox into the same ballpark as straight JAX.

The same test file, with the methods on the equinox module removed, should benchmark this PR.

This PR is missing the _flatten_sentinel and wrapper params logic. I'm AFK for most of the rest of the week, so if you like this approach and want to polish this off, that'd be awesome :).

(A third alternative would be to mypyc Module and let c eliminate the overhead.)

@patrick-kidger
Copy link
Owner

I really like this approach. If there's a need for greater speed then codegen sounds like a great way to get things to be fast-by-default.

I definitely don't have time to pursue this myself but I'll happily merge this once it's ready.

Copy link
Contributor Author

@nstarman nstarman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For @patrick-kidger, a short report on current status:
I've cleaned up the PR code a lot and moved it to a separate module.
With the current PR https://github.com/user-attachments/files/22870041/overhead.ipynb.zip benchmarks great. It's about as fast as raw arrays and optimized pytrees. So that's good.
The part that I haven't been able to solve is the massive performance hit associated with getting and setting the WRAPPER_FIELD_NAMES. I've commented in all the changes necessary to use the wrapper fields.

@patrick-kidger
Copy link
Owner

Awesome! I really like the look of this PR.

The performance hit on the extra magic attributes – I think the common case is that none of these are present on the instance, and that instead they are class attributes. Perhaps this could be exploited, so that e.g. flattening check their presence and sets a boolean flag, and then unflattening skips this codepath if it is not required?

@nstarman
Copy link
Contributor Author

nstarman commented Oct 21, 2025

The last commit needs a lot of cleanup, but if the '__module__' in obj.__dict__ guard is fine for filtering the wrapper attributes, then this PR gets us much closer to JAX speeds.
I'm clocking ~9.3 µs for the Equinox module versus ~7.7 µs for the JAX pytree in the benchmark notebook. It used to be 12-15 µs. Without the wrapper stuff it's ~8 µs, so there's still room for improvement.

@nstarman
Copy link
Contributor Author

nstarman commented Nov 4, 2025

@patrick-kidger is '__module__' in obj.__dict__ sufficient a check?

@nstarman nstarman marked this pull request as ready for review November 5, 2025 01:01
@nstarman
Copy link
Contributor Author

nstarman commented Nov 5, 2025

Ok. Final approximate timings on that performance notebook:

  • Baseline is @partial(jax.tree_util.register_pytree_with_keys_class) ~ 7.7 µs
  • Module before this PR ~ 12.5 µs
  • Module after this PR ~ 10 µs
  • Module after this PR, if I delete all wrapper field related code ~ 9 µs

So this PR is a ~50% improvement ((12.5 - 10) / (12.5 - 7.7)) but there's still another 2.3 µs to go.

@nstarman nstarman changed the title perf: option 2 perf: dynamic (un)flattening generation Nov 5, 2025
@nstarman nstarman changed the title perf: dynamic (un)flattening generation perf: dynamic (un)flattening code generation Nov 5, 2025
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, many nits! But I basically really like this, it looks pretty much like exactly what we want to me.

@nstarman nstarman force-pushed the module-perf-2 branch 2 times, most recently from db187d5 to 60820e2 Compare November 15, 2025 22:48
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
@nstarman
Copy link
Contributor Author

nstarman commented Nov 17, 2025

@patrick-kidger I think the only possibly unresolved comments are #1119 (comment) and #1119 (comment), but IMO they're good to go.

@nstarman
Copy link
Contributor Author

nstarman commented Nov 18, 2025

I tested this on unxt and saw a nice µs speed boost in flattening/unflattening.
I plan to look at quax next for more performance enhancements.

@patrick-kidger
Copy link
Owner

Nice!
Given that this change occurs on the bottom of the tech stack, it would be good to also run the diffrax/optimistix/lineax tests against this too?

Just in case some of the weirder cases (overriding fields with properties etc) catch us out.

@nstarman
Copy link
Contributor Author

uv add --upgrade  git+https://github.com/nstarman/equinox/@module-perf-2 
uv run pytest test -m "not slow" > diffrax_eqx1119_test.log 2>&1

diffrax_eqx1119_test.log

There is an error in diffrax, but when I re-run the tests outside of the custom equinox branch I get the same error, so I don't think it's related.

uv add --upgrade  git+https://github.com/nstarman/equinox/@module-perf-2 
uv run --extra tests pytest tests > ~optimistix_eqx1119_test.log 2>&1

optimistix_eqx1119_test.log

uv add --optional tests pytest beartype 
uv run --extra tests pytest tests > lineax_eqx1119_test.log 2>&1

lineax_eqx1119_test.log

@johannahaffner
Copy link
Contributor

The diffrax errata were fixed in patrick-kidger/diffrax#696, and should not occur on dev, so I can confirm that these are unrelated :)

@patrick-kidger patrick-kidger merged commit 0277a30 into patrick-kidger:main Nov 21, 2025
1 check failed
@patrick-kidger
Copy link
Owner

Awesome! Merged 🎉

@nstarman nstarman deleted the module-perf-2 branch November 21, 2025 14:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants