Skip to content

Commit 417e437

Browse files
ArthurZuckerzucchini-nlp
authored andcommitted
use rms_norm_eps for the L2Norm for Llama4 (huggingface#37418)
use `rms_norm_eps`
1 parent 7d788f9 commit 417e437

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/models/llama4/modeling_llama4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def forward(self, x):
110110

111111

112112
class Llama4TextL2Norm(torch.nn.Module):
113-
def __init__(self, dim: int = None, eps: float = 1e-6):
113+
def __init__(self, eps: float = 1e-6):
114114
super().__init__()
115115
self.eps = eps
116116

@@ -301,7 +301,7 @@ def __init__(self, config: Llama4TextConfig, layer_idx):
301301
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
302302
)
303303
if self.config.use_qk_norm and self.use_rope:
304-
self.qk_norm = Llama4TextL2Norm()
304+
self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps)
305305

306306
def forward(
307307
self,

0 commit comments

Comments
 (0)