Skip to content

Commit da50b41

Browse files
committed
Gemma capping is a must for big models (#31698)
* softcapping * soft cap before the mask * style * ... * super nit
1 parent 086c74e commit da50b41

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

src/transformers/models/gemma2/configuration_gemma2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class Gemma2Config(PretrainedConfig):
7878
attention_dropout (`float`, *optional*, defaults to 0.0):
7979
The dropout ratio for the attention probabilities.
8080
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
81+
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
8182
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
8283
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
8384
size of the sliding window.
@@ -116,6 +117,7 @@ def __init__(
116117
attention_bias=False,
117118
attention_dropout=0.0,
118119
final_logit_softcapping=30.0,
120+
attn_logit_softcapping=50.0,
119121
query_pre_attn_scalar=224,
120122
sliding_window=4096,
121123
**kwargs,
@@ -135,6 +137,7 @@ def __init__(
135137
self.rope_theta = rope_theta
136138
self.attention_bias = attention_bias
137139
self.attention_dropout = attention_dropout
140+
self.attn_logit_softcapping = attn_logit_softcapping
138141

139142
super().__init__(
140143
pad_token_id=pad_token_id,

src/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,11 @@ def forward(
256256

257257
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
258258

259+
if self.config.attn_logit_softcapping is not None:
260+
attn_weights = attn_weights / self.config.attn_logit_softcapping
261+
attn_weights = torch.tanh(attn_weights)
262+
attn_weights = attn_weights * self.config.attn_logit_softcapping
263+
259264
if attention_mask is not None: # no matter the length, we just slice it
260265
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
261266
attn_weights = attn_weights + causal_mask

0 commit comments

Comments
 (0)