Skip to content
Discussion options

You must be logged in to vote

Hey @lrvdijk! This looks great, lets try to get it working. First thing to note is that currently its not a good idea to transform instance methods e.g.

batched_model = nnx.vmap(model)

as here you are passing self in def __call__(self, ...) as a capture and this triggers the trace level error when trying to mutate Modules or Variables as NNX cannot keep track of these changes. The recommended approach is create a function that has the model as an explicit input and transform that:

@nnx.vmap(in_axes=(None, 0)) 
def forward(model, x):
  return model(x)

Here we assume you want to broadcast model. Same thing would apply for batch_rope.

Try to fix this part and we can solve the rest.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@lrvdijk
Comment options

Answer selected by lrvdijk
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants