Skip to content

Align gemma3n cache sharing to gemma4#45489

Merged
Cyrilvallez merged 6 commits intomainfrom
gemma3n
Apr 22, 2026
Merged

Align gemma3n cache sharing to gemma4#45489
Cyrilvallez merged 6 commits intomainfrom
gemma3n

Conversation

@Cyrilvallez
Copy link
Copy Markdown
Member

What does this PR do?

As per the title. Bring changes from #45312 and #45336 to gemma3n

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few nits, can we inherit from Gemma3n directly then? Looks 1:1 to me



class Gemma3nTextAttention(Gemma3Attention):
@use_kernelized_func(apply_rotary_pos_emb)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it doesn't follow the normal function signature --> have you checked that it actually works when you pass use_kernels=True?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? It's how we do it everywhere, and the modeling did not change

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exchanged apply_rotary_pos_emb is not equivalent to the kernels version

def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
x (`torch.Tensor`): The tensor to embed.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
return (x * cos) + (rotate_half(x) * sin)

This applies RoPE on one tensor only (with the expected chunk split)

However, the kernels version uses

  1. repo_id="kernels-community/rotary", func_name="apply_rotary_transformers"
  2. https://huggingface.co/kernels-community/rotary/blob/main/build/torch210-cxx11-cu128-aarch64-linux/__init__.py#L19-L49

Which expects both q and k tensors at once (and in chunked implementation). Which means that they are not the same as the gemma4/3n here. This would need to be fixed in the upstream kernels to allow

  1. Single tensors
  2. Interleaved layout

I am not sure which models it affects, but when I saw this, I was pretty sure it was wrong.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhhh I see - this is not from this PR but would indeed need to be fixed asap! Opening another one to fix!

Comment thread src/transformers/models/gemma3n/modular_gemma3n.py
attention_mask: torch.Tensor | None,
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
past_key_values: Cache | None = None,
**kwargs: Unpack[TransformersKwargs],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**kwargs: Unpack[TransformersKwargs],
past_key_values: Cache | None = None,
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],

Just for BC we cannot change order

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just to make crystal clear that it's a required input. Since it's an internal module, this is fine IMO (and maybe better to break, as forgetting them would result in a silently wrong model)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then let's add at least a 🚨 to the title.

Imo, I don't see why we default to None - if it wouldn't be passed and it was needed, we would encounter a runtime error no?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indeed we could set to None independently and let it crash if not passed... Let me change it then if you think it's best!

Comment thread src/transformers/models/gemma3n/modular_gemma3n.py
Comment thread src/transformers/models/gemma4/modular_gemma4.py
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma3n, gemma4

@Cyrilvallez Cyrilvallez merged commit 08244b9 into main Apr 22, 2026
29 checks passed
@Cyrilvallez Cyrilvallez deleted the gemma3n branch April 22, 2026 03:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants