Skip to content

Commit fc5f910

Browse files
authored
[Qwen3 Next] Use numerically stable rsqrt (#40848)
use numerically stable inverse
1 parent 96d3795 commit fc5f910

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/transformers/models/qwen3_next/modeling_qwen3_next.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def torch_causal_conv1d_update(
435435

436436
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
437437
"""This function is intended to align with the l2norm implementation in the FLA library."""
438-
inv_norm = 1 / torch.sqrt((x * x).sum(dim=dim, keepdim=True) + eps)
438+
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
439439
return x * inv_norm
440440

441441

src/transformers/models/qwen3_next/modular_qwen3_next.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def torch_causal_conv1d_update(
271271

272272
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
273273
"""This function is intended to align with the l2norm implementation in the FLA library."""
274-
inv_norm = 1 / torch.sqrt((x * x).sum(dim=dim, keepdim=True) + eps)
274+
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
275275
return x * inv_norm
276276

277277

0 commit comments

Comments
 (0)