@@ -78,6 +78,7 @@ class Gemma2Config(PretrainedConfig):
78
78
attention_dropout (`float`, *optional*, defaults to 0.0):
79
79
The dropout ratio for the attention probabilities.
80
80
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
81
+ attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
81
82
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
82
83
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
83
84
size of the sliding window.
@@ -116,6 +117,7 @@ def __init__(
116
117
attention_bias = False ,
117
118
attention_dropout = 0.0 ,
118
119
final_logit_softcapping = 30.0 ,
120
+ attn_logit_softcapping = 50.0 ,
119
121
query_pre_attn_scalar = 224 ,
120
122
sliding_window = 4096 ,
121
123
** kwargs ,
@@ -135,6 +137,7 @@ def __init__(
135
137
self .rope_theta = rope_theta
136
138
self .attention_bias = attention_bias
137
139
self .attention_dropout = attention_dropout
140
+ self .attn_logit_softcapping = attn_logit_softcapping
138
141
139
142
super ().__init__ (
140
143
pad_token_id = pad_token_id ,
0 commit comments