-
Hi all, I am trying to get more experience with implementing (protein) language models in Flax/Jax, and have been working on a re-implementation of Facebook's ESM-2 protein language model. ESM-2 is a fairly standard BERT-style encoder using RoPE for positional embeddings. However, I have been unsuccessful so far with implementing a stack of transformer layers using nnx.vmap and nnx.scan. When trying to run the model, I get the following error:
I am unsure what I am doing wrong. I'll walk through my (naive?) approach in implementing ESM-2, could anyone spot where I am making a mistake? ESM-2 main classThe ESM-2 model comprises an embedding layer, a stack of transformer layers, an additional layer norm, and a final logit output layer. I am using class ESM2(nnx.Module):
def __init__(
self,
alphabet: Alphabet,
num_layers: int = 33,
d_embed: int = 1024,
num_heads: int = 16,
*,
rngs: nnx.Rngs
):
# Token -> initial embedding
self.embed: nnx.Embed = nnx.Embed(len(alphabet.tokens), d_embed, rngs=rngs)
# BERT-style transformer layers with rotary positional encoding embeddings
# Split RNG generator #layers, and use vmap to generate each layer
@nnx.split_rngs(splits=num_layers)
@nnx.vmap(axis_size=num_layers)
def create_layers(rngs: nnx.Rngs):
return TransformerEncoder(d_embed, d_embed * 4, num_heads, rngs=rngs)
self.transformer_layers = create_layers(rngs)
self.layer_norm_after: nnx.LayerNorm = nnx.LayerNorm(d_embed, rngs=rngs)
self.logit_head: LogitHead = LogitHead(d_embed, len(alphabet.tokens), rngs=rngs)
def __call__(self, tokens: jax.Array):
# Token -> initial embedding
x = self.embed(tokens)
# Use nnx.scan to repeatedly apply each transformer layer
@nnx.scan
def apply_transformer_layer(x: jax.Array, layer: TransformerEncoder):
return layer(x), None
x, _ = apply_transformer_layer(x, self.transformer_layers)
x = self.layer_norm_after(x)
logits = self.logit_head(x)
return logits Transformer layer with RoPEI am trying to reuse self.rope: RoPE = RoPE(d_embed // num_heads, rngs=rngs)
attention_fn = partial(dot_product_attention, self.rope) # This my custom dot_product_attention, that accepts a RoPE module as first argument
self.attention: nnx.MultiHeadAttention = nnx.MultiHeadAttention(
num_heads=num_heads,
in_features=d_embed,
attention_fn=attention_fn,
decode=False,
rngs=rngs
)
(Full source: https://github.com/lrvdijk/flamino/blob/main/src/flamino/transformer.py) Custom def dot_product_attention(
rope_module: RoPE,
query: jax.Array,
key: jax.Array,
value: jax.Array,
... # Other parameters
):
"""Drop-in replacement for Flax's dot_product_attention, but with RoPE applied to the query and key."""
batch_rope = nnx.vmap(rope_module)
query = batch_rope(query)
key = batch_rope(key)
return nnx_dot_product_attention(
query,
key,
value,
... # other parameters
) RoPE implementationInspired by equinox's RoPE implementation, I cache computed sin/cos arrays into a global dict: class RoPE(nnx.Module):
...
def __call__(self, x: jax.Array):
assert x.ndim == 2
seq_len, embed_size = x.shape
assert embed_size == self.d_embed, "Sequence embedding dimension mismatch"
with jax.ensure_compile_time_eval():
cache_key = (embed_size, x.dtype)
# Check global cache for the given embedding size and dtype
if cache_key not in rope_sin_cos_table_cache:
sin_table, cos_table = self._compute_sin_cos_table(seq_len, x.dtype)
rope_sin_cos_table_cache[cache_key] = (sin_table, cos_table)
else:
sin_table, cos_table = rope_sin_cos_table_cache[cache_key]
# Re-compute sin/cos tables if length of the current sequence is greater
freq_seq_len = sin_table.shape[0]
if freq_seq_len < seq_len:
sin_table, cos_table = self._compute_sin_cos_table(seq_len, x.dtype)
rope_sin_cos_table_cache[cache_key] = (sin_table, cos_table)
return apply_rope(x, sin_table, cos_table) (Full source: https://github.com/lrvdijk/flamino/blob/main/src/flamino/rope.py) Wrap-upGiven the exception is raised within the function How can I fix this error? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hey @lrvdijk! This looks great, lets try to get it working. First thing to note is that currently its not a good idea to transform instance methods e.g. batched_model = nnx.vmap(model) as here you are passing @nnx.vmap(in_axes=(None, 0))
def forward(model, x):
return model(x) Here we assume you want to broadcast Try to fix this part and we can solve the rest. |
Beta Was this translation helpful? Give feedback.
Hey @lrvdijk! This looks great, lets try to get it working. First thing to note is that currently its not a good idea to transform instance methods e.g.
as here you are passing
self
indef __call__(self, ...)
as a capture and this triggers the trace level error when trying to mutate Modules or Variables as NNX cannot keep track of these changes. The recommended approach is create a function that has themodel
as an explicit input and transform that:Here we assume you want to broadcast
model
. Same thing would apply forbatch_rope
.Try to fix this part and we can solve the rest.