Align gemma3n cache sharing to gemma4#45489
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
vasqu
left a comment
There was a problem hiding this comment.
Just a few nits, can we inherit from Gemma3n directly then? Looks 1:1 to me
|
|
||
|
|
||
| class Gemma3nTextAttention(Gemma3Attention): | ||
| @use_kernelized_func(apply_rotary_pos_emb) |
There was a problem hiding this comment.
Looks like it doesn't follow the normal function signature --> have you checked that it actually works when you pass use_kernels=True?
There was a problem hiding this comment.
What do you mean? It's how we do it everywhere, and the modeling did not change
There was a problem hiding this comment.
The exchanged apply_rotary_pos_emb is not equivalent to the kernels version
transformers/src/transformers/models/gemma4/modeling_gemma4.py
Lines 745 to 764 in cd5bcad
This applies RoPE on one tensor only (with the expected chunk split)
However, the kernels version uses
Which expects both q and k tensors at once (and in chunked implementation). Which means that they are not the same as the gemma4/3n here. This would need to be fixed in the upstream kernels to allow
- Single tensors
- Interleaved layout
I am not sure which models it affects, but when I saw this, I was pretty sure it was wrong.
There was a problem hiding this comment.
Ohhhh I see - this is not from this PR but would indeed need to be fixed asap! Opening another one to fix!
| attention_mask: torch.Tensor | None, | ||
| shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]], | ||
| past_key_values: Cache | None = None, | ||
| **kwargs: Unpack[TransformersKwargs], |
There was a problem hiding this comment.
| **kwargs: Unpack[TransformersKwargs], | |
| past_key_values: Cache | None = None, | |
| shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]], |
Just for BC we cannot change order
There was a problem hiding this comment.
It's just to make crystal clear that it's a required input. Since it's an internal module, this is fine IMO (and maybe better to break, as forgetting them would result in a silently wrong model)
There was a problem hiding this comment.
Then let's add at least a 🚨 to the title.
Imo, I don't see why we default to None - if it wouldn't be passed and it was needed, we would encounter a runtime error no?
There was a problem hiding this comment.
Yes, indeed we could set to None independently and let it crash if not passed... Let me change it then if you think it's best!
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gemma3n, gemma4 |
What does this PR do?
As per the title. Bring changes from #45312 and #45336 to gemma3n