Replies: 1 comment
-
In fact, instead of putting def vmap_module(module, *vmap_args, **vmap_kwargs):
graph, state = nnx.split(module)
def inner_apply(state, *args, **kwargs):
module = nnx.merge(graph, state)
return module(*args, **kwargs)
vmap_apply = jax.vmap(inner_apply, *vmap_args, **vmap_kwargs)
def apply(*args, **kwargs):
return vmap_apply(state, *args, **kwargs)
return apply
y = jax.jit(vmap_module(weights))(x) # or nnx.jit, both would work |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello, I recently update the flax version and got an error of the following
ValueError: Cannot extract graph node from different trace level, got ...
It seems that I can isolate the problem to passing by closure an nnx module that contains nnx.vmap (either for its methods or from its submodule). For example, take the example in the guide here:
this code above works just fine, however, if one tries to jit the weight, or use weight in any other function and jit that function, the above error would raise.
or
would both throw error like
This is quite inconvenient because that would require any transform (that could be in the middle of a large function) to pass the module explicitly, even though the module is not mutating anything, leading to a significant refactoring. I noticed that this error only happens after 0.10.5 update and for now I'm rolling back to 0.10.4. This could also be related to a previous discussion #4804
I'm wondering if this means that nnx.vmap internally mutate the state of the module. If so, can we have a separate "static" vmap that does not have this issue? For those who only requires a batched dimension such static version should be sufficient, and would not have the problem of unable to be passed by closure.
Beta Was this translation helpful? Give feedback.
All reactions