Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ This repo aims at providing a collection of efficient Triton-based implementatio

## News

- **$\texttt{[2025-09]}$:** 🐻 Thrilled to announce that [GDN](fla/ops/gated_delta_rule) has been integrated into Qwen3-Next.
Check out [the PR](https://github.com/huggingface/transformers/pull/40771) and [their blog post](https://qwenlm.github.io/blog/qwen3_next/) for more infos!
- **$\texttt{[2025-09]}$:** 🌲 Add DeltaFormer implementation to `fla` ([paper](https://arxiv.org/abs/2505.19488v1)).
- **$\texttt{[2025-09]}$:** 🐻 Thrilled to announce that [GDN](fla/ops/gated_delta_rule) has been integrated into Qwen3-Next. Check out their [blog post](https://qwen.ai/blog?id=4074cca80393150c248e508aa62983f9cb7d27cd&from=research.latest-advancements-list) for more infos!
- **$\texttt{[2025-08]}$:** 🌲 Add Log-Linear Attention implementation to `fla` ([paper](https://arxiv.org/abs/2506.04761)).
- **$\texttt{[2025-08]}$:** 🎓 Add MoM implementation to `fla` ([paper](https://arxiv.org/abs/2502.13685)).
- **$\texttt{[2025-07]}$:** 🐳 Add MLA implementation to `fla` ([paper](https://arxiv.org/abs/2405.04434)).
Expand Down Expand Up @@ -86,6 +86,7 @@ Roughly sorted according to the timeline supported in `fla`. The recommended tra
| 2025 | | PaTH | [PaTH Attention: Position Encoding via Accumulating Householder Transformations](https://arxiv.org/abs/2505.16381) | | [fla](https://github.com/fla-org/flash-linear-attention/blob/main/fla/layers/path_attn.py) |
| 2025 | | MoM | [MoM: Linear Sequence Modeling with Mixture-of-Memories](https://arxiv.org/abs/2502.13685) | [official](https://github.com/OpenSparseLLMs/MoM) | [fla](https://github.com/fla-org/flash-linear-attention/blob/main/fla/layers/mom.py) |
| 2025 | | Log-Linear Attention | [Log-Linear Attention](https://arxiv.org/abs/2506.04761) | [official](https://github.com/HanGuo97/log-linear-attention) | [fla](https://github.com/fla-org/flash-linear-attention/tree/main/fla/ops/log_linear_attn) |
| 2025 | | DeltaFormer | [Understanding Transformer from the Perspective of Associative Memory](https://arxiv.org/abs/2505.19488v1) | | [fla](https://github.com/fla-org/flash-linear-attention/blob/main/fla/layers/deltaformer.py) |

## Installation

Expand Down
4 changes: 4 additions & 0 deletions fla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
BasedLinearAttention,
BitAttention,
Comba,
DeltaFormerAttention,
DeltaNet,
GatedDeltaNet,
GatedDeltaProduct,
Expand Down Expand Up @@ -34,6 +35,8 @@
BitNetModel,
CombaForCausalLM,
CombaModel,
DeltaFormerForCausalLM,
DeltaFormerModel,
DeltaNetForCausalLM,
DeltaNetModel,
GatedDeltaNetForCausalLM,
Expand Down Expand Up @@ -83,6 +86,7 @@
'BitAttention', 'BitNetForCausalLM', 'BitNetModel',
'Comba', 'CombaForCausalLM', 'CombaModel',
'DeltaNet', 'DeltaNetForCausalLM', 'DeltaNetModel',
'DeltaFormerAttention', 'DeltaFormerForCausalLM', 'DeltaFormerModel',
'GatedDeltaNet', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel',
'GatedDeltaProduct', 'GatedDeltaProductForCausalLM', 'GatedDeltaProductModel',
'GatedLinearAttention', 'GLAForCausalLM', 'GLAModel',
Expand Down
2 changes: 2 additions & 0 deletions fla/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .bitattn import BitAttention
from .comba import Comba
from .delta_net import DeltaNet
from .deltaformer import DeltaFormerAttention
from .forgetting_attn import ForgettingAttention
from .gated_deltanet import GatedDeltaNet
from .gated_deltaproduct import GatedDeltaProduct
Expand Down Expand Up @@ -60,4 +61,5 @@
'RWKV6Attention',
'RWKV7Attention',
'SlidingWindowSharedKeyAttention',
'DeltaFormerAttention',
]
153 changes: 153 additions & 0 deletions fla/layers/deltaformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Tuple

import torch
import torch.nn as nn
from einops import rearrange
from transformers.utils import logging

from fla.modules import RMSNorm, RotaryEmbedding
from fla.ops.deltaformer import deltaformer_attn
from fla.ops.utils.index import prepare_lens_from_mask

if TYPE_CHECKING:
from fla.models.utils import Cache

logger = logging.get_logger(__name__)


class DeltaFormerAttention(nn.Module):

r"""
The layer implementation for DeltaFormer,
[Understanding Transformer from the Perspective of Associative Memory]
(https://arxiv.org/pdf/2505.19488).

Notes
- DeltaFormer attention is implemented with Triton kernels in `fla.ops.deltaformer` and is tuned
for typical head dimensions (e.g., 64/128). It currently supports fixed-length inputs.
- For variable-length inputs (padding masks), the deltaformer computation falls back to using the
fixed-length path, while the second stage (softmax attention over U) uses FlashAttention's
varlen path when an attention mask is provided.
- K/V grouping (GQA) is supported natively by FlashAttention via `num_kv_heads`.
- Uses K-K similarity in deltaformer computation instead of Q-K similarity for better performance.

Args:
hidden_size (int, Optional):
The hidden size of the input. Default: 2048.
num_heads (int, Optional):
The number of attention heads. Default: 32.
num_kv_heads (int, Optional):
The number of key/value heads for grouped-query attention. If None, equals `num_heads`.
Default: None.
qkv_bias (bool, Optional):
Whether to use bias for Q/K/V projections. Default: False.
qk_norm (bool, Optional):
Whether to apply per-head RMSNorm to Q and K before attention. Default: False.
rope_theta (float, Optional):
The base frequency for rotary position embedding. Default: 10000.
max_position_embeddings (int, Optional):
The maximum position embeddings. Default: None.
layer_idx (int, Optional):
The index of the layer (used for cache compatibility). Default: None.
"""

def __init__(
self,
hidden_size: int = 2048,
num_heads: int = 32,
num_kv_heads: Optional[int] = None,
qkv_bias: bool = False,
qk_norm: bool = False,
rope_theta: float = 10000.,
max_position_embeddings: Optional[int] = None,
layer_idx: int | None = None,
):
super().__init__()

self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_kv_groups = num_heads // self.num_kv_heads
self.head_dim = self.hidden_size // self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.qkv_bias = qkv_bias
Comment on lines +72 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Validate head divisibility and GQA grouping early.

Prevent silent shape mismatches when hidden_size % num_heads != 0 or num_heads % num_kv_heads != 0.

-        self.hidden_size = hidden_size
-        self.num_heads = num_heads
-        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
-        self.num_kv_groups = num_heads // self.num_kv_heads
-        self.head_dim = self.hidden_size // self.num_heads
-        self.kv_dim = self.num_kv_heads * self.head_dim
+        self.hidden_size = hidden_size
+        self.num_heads = num_heads
+        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
+        if self.hidden_size % self.num_heads != 0:
+            raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads}).")
+        if self.num_heads % self.num_kv_heads != 0:
+            raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads}).")
+        self.num_kv_groups = self.num_heads // self.num_kv_heads
+        self.head_dim = self.hidden_size // self.num_heads
+        self.kv_dim = self.num_kv_heads * self.head_dim
         self.qkv_bias = qkv_bias
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_kv_groups = num_heads // self.num_kv_heads
self.head_dim = self.hidden_size // self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.qkv_bias = qkv_bias
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
if self.hidden_size % self.num_heads != 0:
raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads}).")
if self.num_heads % self.num_kv_heads != 0:
raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads}).")
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.head_dim = self.hidden_size // self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.qkv_bias = qkv_bias
🤖 Prompt for AI Agents
In fla/layers/deltaformer.py around lines 70 to 76, currently head sizes and GQA
grouping are computed without validation which can lead to silent shape
mismatches; add explicit checks before computing head_dim and kv_dim: verify
hidden_size % num_heads == 0 and raise a ValueError with a clear message if not,
and verify num_heads % num_kv_heads == 0 (after resolving default num_kv_heads)
and raise a ValueError if that fails; only after these validations compute
head_dim and kv_dim so failures occur early and with descriptive errors.

self.qk_norm = qk_norm
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.layer_idx = layer_idx

self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)

if qk_norm:
self.q_norm = RMSNorm(self.head_dim)
self.k_norm = RMSNorm(self.head_dim)

self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
Comment on lines +101 to +103
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove unused parameters or implement functionality.

The static analyzer correctly identifies unused parameters. Either implement the missing functionality or remove these parameters to avoid misleading the API consumer.

For output_attentions:

-def forward(
-    self,
-    hidden_states: torch.Tensor,
-    attention_mask: Optional[torch.LongTensor] = None,
-    past_key_values: Optional[Cache] = None,
-    output_attentions: bool = False,
-    use_cache: bool = False,
-    **kwargs,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-    attentions = None
+def forward(
+    self,
+    hidden_states: torch.Tensor,
+    attention_mask: Optional[torch.LongTensor] = None,
+    past_key_values: Optional[Cache] = None,
+    output_attentions: bool = False,
+    use_cache: bool = False,
+    **kwargs,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+    if output_attentions:
+        raise NotImplementedError("DeltaFormer does not support outputting attention weights")
+    attentions = None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
raise NotImplementedError("DeltaFormer does not support outputting attention weights")
attentions = None
# ... rest of implementation ...
🧰 Tools
🪛 Ruff (0.12.2)

98-98: Unused method argument: output_attentions

(ARG002)


100-100: Unused method argument: kwargs

(ARG002)

🤖 Prompt for AI Agents
In fla/layers/deltaformer.py around lines 98 to 100, the parameters
output_attentions and use_cache are declared but not used; remove them from the
function/method signature if the layer does not support these features, or
propagate them to the underlying model calls and adjust return values to include
attentions/cache outputs when requested. Specifically, either (A) delete
output_attentions and use_cache from the signature and update all callers, or
(B) thread output_attentions and use_cache into the internal forward/transformer
call, handle the returned attention tensors and cache (modify return type to
include attentions if output_attentions is True and preserve cache behavior when
use_cache is True), and add tests covering both behaviors.

) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
Comment on lines +96 to +104

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward method accepts past_key_values and use_cache arguments, which suggests that it should support incremental decoding for generation. However, these arguments are not used, and past_key_values is returned unmodified. Without implementing the key-value cache, autoregressive generation will be extremely inefficient as all previous tokens would need to be re-processed at every step. This is a critical feature for a model intended for causal language modeling, and its absence is confirmed by DeltaFormerConfig being added to GENERATION_UNSUPPORTED in the test suite. Please implement the caching mechanism to enable efficient generation.

attentions = None
Comment on lines +96 to +105
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

KV cache not implemented; forward signature is misleading.

past_key_values/use_cache are accepted and returned unchanged. This breaks efficient generation.

Apply a minimal guard now; implement caching next:

     def forward(
         self,
         hidden_states: torch.Tensor,
         attention_mask: Optional[torch.LongTensor] = None,
         past_key_values: Optional[Cache] = None,
         output_attentions: bool = False,
         use_cache: bool = False,
         **kwargs,
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        if use_cache:
+            raise NotImplementedError("DeltaFormerAttention KV cache is not implemented yet.")
@@
-        return o, attentions, past_key_values
+        return o, attentions, past_key_values

Also consider removing output_attentions until supported.

Also applies to: 214-214

🧰 Tools
🪛 Ruff (0.13.1)

110-110: Unused method argument: output_attentions

(ARG002)


111-111: Unused method argument: use_cache

(ARG002)

🤖 Prompt for AI Agents
In fla/layers/deltaformer.py around lines 105-114 (and similarly at line 214),
the forward signature accepts past_key_values/use_cache/output_attentions but
does not implement KV caching; add a minimal guard: if use_cache is True or
past_key_values is not None, raise NotImplementedError (or ValueError) with a
clear message that KV cache is not yet supported and will be implemented later;
likewise, if output_attentions is True, raise NotImplementedError or ignore it
explicitly with a clear comment. This keeps the API honest now and prevents
silent breaks during generation; implement the same guard at the other affected
location (line ~214).

if attention_mask is not None:
assert len(attention_mask.shape) == 2, (
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
"for padding purposes (0 indicating padding). "
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
)

batch_size, q_len, _ = hidden_states.size()

q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
beta = self.b_proj(hidden_states)

if self.qk_norm:
q, k = self.q_norm(q), self.k_norm(k)

cu_seqlens_kw = kwargs.get('cu_seqlens', None)
seqlen_offset, max_seqlen = 0, q_len
if past_key_values is not None:
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
max_seqlen = q_len + seqlen_offset

if attention_mask is not None:
seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
max_seqlen = q_len + max(seqlen_offset)

Comment on lines +129 to +132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: using Python max() on a tensor.

max(seqlen_offset) on a torch.Tensor will error; also ensure scalar for cache sizing.

Apply:

-            if attention_mask is not None:
-                seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
-                max_seqlen = q_len + max(seqlen_offset)
+            if attention_mask is not None:
+                seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
+                max_seqlen = q_len + int(seqlen_offset.max().item())
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if attention_mask is not None:
seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
max_seqlen = q_len + max(seqlen_offset)
if attention_mask is not None:
seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
max_seqlen = q_len + int(seqlen_offset.max().item())
🤖 Prompt for AI Agents
In fla/layers/deltaformer.py around lines 138 to 141, the code calls Python's
built-in max() on seqlen_offset (a torch.Tensor) which will raise an error and
may not produce a Python scalar for cache sizing; replace the built-in max with
a tensor-max and convert to a Python int, e.g. compute seqlen_offset =
seqlen_offset + prepare_lens_from_mask(attention_mask) -
attention_mask.shape[-1] as before, then set max_seqlen = q_len +
int(seqlen_offset.max().cpu().item()) (or use torch.max(seqlen_offset).item())
to ensure max_seqlen is a plain int for subsequent cache sizing.

if self.max_position_embeddings is not None:
max_seqlen = max(max_seqlen, self.max_position_embeddings)

q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens_kw)

o = deltaformer_attn(
q=q,
k=k,
v=v,
beta=beta,
attention_mask=attention_mask,
cu_seqlens=cu_seqlens_kw
)

o = o.reshape(batch_size, q_len, -1)
o = self.o_proj(o)

if not output_attentions:
attentions = None

return o, attentions, past_key_values
2 changes: 2 additions & 0 deletions fla/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fla.models.bitnet import BitNetConfig, BitNetForCausalLM, BitNetModel
from fla.models.comba import CombaConfig, CombaForCausalLM, CombaModel
from fla.models.delta_net import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel
from fla.models.deltaformer import DeltaFormerConfig, DeltaFormerForCausalLM, DeltaFormerModel
from fla.models.forgetting_transformer import (
ForgettingTransformerConfig,
ForgettingTransformerForCausalLM,
Expand Down Expand Up @@ -37,6 +38,7 @@
'BitNetConfig', 'BitNetForCausalLM', 'BitNetModel',
'CombaConfig', 'CombaForCausalLM', 'CombaModel',
'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel',
'DeltaFormerConfig', 'DeltaFormerForCausalLM', 'DeltaFormerModel',
'ForgettingTransformerConfig', 'ForgettingTransformerForCausalLM', 'ForgettingTransformerModel',
'GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel',
'GatedDeltaProductConfig', 'GatedDeltaProductForCausalLM', 'GatedDeltaProductModel',
Expand Down
12 changes: 12 additions & 0 deletions fla/models/deltaformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-

from transformers import AutoConfig, AutoModel, AutoModelForCausalLM

from fla.models.deltaformer.configuration_deltaformer import DeltaFormerConfig
from fla.models.deltaformer.modeling_deltaformer import DeltaFormerForCausalLM, DeltaFormerModel

AutoConfig.register(DeltaFormerConfig.model_type, DeltaFormerConfig, exist_ok=True)
AutoModel.register(DeltaFormerConfig, DeltaFormerModel, exist_ok=True)
AutoModelForCausalLM.register(DeltaFormerConfig, DeltaFormerForCausalLM, exist_ok=True)

__all__ = ['DeltaFormerConfig', 'DeltaFormerForCausalLM', 'DeltaFormerModel']
107 changes: 107 additions & 0 deletions fla/models/deltaformer/configuration_deltaformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-

from __future__ import annotations

import warnings
from typing import Dict, Optional

from transformers.configuration_utils import PretrainedConfig


class DeltaFormerConfig(PretrainedConfig):
model_type = 'deltaformer'
keys_to_ignore_at_inference = ['past_key_values']

def __init__(
self,
hidden_size: int = 2048,
hidden_ratio: Optional[int] = 4,
intermediate_size: Optional[int] = None,
num_hidden_layers: int = 24,
num_heads: int = 8,
num_kv_heads: Optional[int] = None,
attn_mode: str = "chunk",
hidden_act: str = "swish",
max_position_embeddings: int = 2048,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-6,
qkv_bias: bool = False,
qk_norm: bool = False,
rope_theta: float = 10000.,
rope_max_position_embeddings: Optional[int] = None,
attn: Optional[Dict] = None,
use_cache: bool = True,
pad_token_id: Optional[int] = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
initializer_range: float = 0.02,
fuse_norm: bool = True,
fuse_swiglu: bool = True,
fuse_cross_entropy: bool = True,
fuse_linear_cross_entropy: bool = False,
use_l2warp: bool = False,
vocab_size: int = 32000,
output_attentions: bool = False,
output_hidden_states: bool = False,
**kwargs
):
self.hidden_size = hidden_size
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.attn_mode = attn_mode
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.elementwise_affine = elementwise_affine
self.norm_eps = norm_eps
self.qkv_bias = qkv_bias
self.qk_norm = qk_norm
self.rope_theta = rope_theta
self.rope_max_position_embeddings = rope_max_position_embeddings
self.attn = attn
self.use_cache = use_cache
self.initializer_range = initializer_range

self.fuse_norm = fuse_norm
self.fuse_swiglu = fuse_swiglu
self.fuse_cross_entropy = fuse_cross_entropy
self.fuse_linear_cross_entropy = fuse_linear_cross_entropy
self.use_l2warp = use_l2warp
self.vocab_size = vocab_size

self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states

if fuse_cross_entropy and fuse_linear_cross_entropy:
raise ValueError(
"`fuse_cross_entropy` and `fuse_linear_cross_entropy` cannot be True at the same time."
)
if fuse_linear_cross_entropy:
warnings.warn(
"`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency "
"at the potential cost of reduced precision. "
"If you observe issues like loss divergence, consider disabling this setting."
)

if attn is not None:
if not isinstance(attn, Dict):
raise ValueError("attn must be a dictionary")
Comment on lines +89 to +91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: isinstance check against typing.Dict will raise or misbehave.

Use Mapping/dict instead of typing.Dict in isinstance. This can otherwise raise TypeError.

-from typing import Dict, Optional
+from typing import Dict, Optional
+from collections.abc import Mapping
@@
-            if not isinstance(attn, Dict):
+            if not isinstance(attn, Mapping):
                 raise ValueError("attn must be a dictionary")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if attn is not None:
if not isinstance(attn, Dict):
raise ValueError("attn must be a dictionary")
from typing import Dict, Optional
from collections.abc import Mapping
if attn is not None:
if not isinstance(attn, Mapping):
raise ValueError("attn must be a dictionary")
🧰 Tools
🪛 Ruff (0.13.1)

91-91: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In fla/models/deltaformer/configuration_deltaformer.py around lines 89 to 91,
the isinstance check uses typing.Dict which can raise TypeError; change the
check to use a concrete runtime check like isinstance(attn, Mapping) (from
collections.abc) or isinstance(attn, dict) and add the necessary import (from
collections.abc import Mapping) if using Mapping, so the validation works at
runtime without errors.

if 'layers' not in attn:
raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
if 'num_heads' not in attn:
raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
attn['qkv_bias'] = attn.get('qkv_bias', False)
attn['window_size'] = attn.get('window_size', None)
attn['rope_theta'] = attn.get('rope_theta', 10000.)

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
Loading
Loading