Skip to content

Commit 195b74d

Browse files
Nathancgyyzhangcs
andauthored
[DeltaFormer] Add Model (#585)
* [GDN] Deal with init on meta device * Added DeltaFormer Model * updated comment to clarify confusion * Minor fix * added autotuning to fix oom * added varlen and generation support * added rope, kk similarity, no repeating gqa, and more specific names * updated cu_seqlen on layers file * fixed shape issue for varlen * added testing ops for deltaformer (all passed) * added deltaformer testing ops (all passed) * supported varlen, rope, kk similarity for deltaformer * Add `DeltaFormer` into README * Minor fix * [deltaformer] rearranged to BTHD, more specific namings, one attention func integration --------- Co-authored-by: Yu Zhang <yzhang.cs@outlook.com>
1 parent 62fd58d commit 195b74d

File tree

15 files changed

+1918
-2
lines changed

15 files changed

+1918
-2
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ This repo aims at providing a collection of efficient Triton-based implementatio
2929

3030
## News
3131

32-
- **$\texttt{[2025-09]}$:** 🐻 Thrilled to announce that [GDN](fla/ops/gated_delta_rule) has been integrated into Qwen3-Next.
33-
Check out [the PR](https://github.com/huggingface/transformers/pull/40771) and [their blog post](https://qwenlm.github.io/blog/qwen3_next/) for more infos!
32+
- **$\texttt{[2025-09]}$:** 🌲 Add DeltaFormer implementation to `fla` ([paper](https://arxiv.org/abs/2505.19488v1)).
33+
- **$\texttt{[2025-09]}$:** 🐻 Thrilled to announce that [GDN](fla/ops/gated_delta_rule) has been integrated into Qwen3-Next. Check out their [blog post](https://qwen.ai/blog?id=4074cca80393150c248e508aa62983f9cb7d27cd&from=research.latest-advancements-list) for more infos!
3434
- **$\texttt{[2025-08]}$:** 🌲 Add Log-Linear Attention implementation to `fla` ([paper](https://arxiv.org/abs/2506.04761)).
3535
- **$\texttt{[2025-08]}$:** 🎓 Add MoM implementation to `fla` ([paper](https://arxiv.org/abs/2502.13685)).
3636
- **$\texttt{[2025-07]}$:** 🐳 Add MLA implementation to `fla` ([paper](https://arxiv.org/abs/2405.04434)).
@@ -86,6 +86,7 @@ Roughly sorted according to the timeline supported in `fla`. The recommended tra
8686
| 2025 | | PaTH | [PaTH Attention: Position Encoding via Accumulating Householder Transformations](https://arxiv.org/abs/2505.16381) | | [fla](https://github.com/fla-org/flash-linear-attention/blob/main/fla/layers/path_attn.py) |
8787
| 2025 | | MoM | [MoM: Linear Sequence Modeling with Mixture-of-Memories](https://arxiv.org/abs/2502.13685) | [official](https://github.com/OpenSparseLLMs/MoM) | [fla](https://github.com/fla-org/flash-linear-attention/blob/main/fla/layers/mom.py) |
8888
| 2025 | | Log-Linear Attention | [Log-Linear Attention](https://arxiv.org/abs/2506.04761) | [official](https://github.com/HanGuo97/log-linear-attention) | [fla](https://github.com/fla-org/flash-linear-attention/tree/main/fla/ops/log_linear_attn) |
89+
| 2025 | | DeltaFormer | [Understanding Transformer from the Perspective of Associative Memory](https://arxiv.org/abs/2505.19488v1) | | [fla](https://github.com/fla-org/flash-linear-attention/blob/main/fla/layers/deltaformer.py) |
8990

9091
## Installation
9192

fla/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
BasedLinearAttention,
77
BitAttention,
88
Comba,
9+
DeltaFormerAttention,
910
DeltaNet,
1011
GatedDeltaNet,
1112
GatedDeltaProduct,
@@ -34,6 +35,8 @@
3435
BitNetModel,
3536
CombaForCausalLM,
3637
CombaModel,
38+
DeltaFormerForCausalLM,
39+
DeltaFormerModel,
3740
DeltaNetForCausalLM,
3841
DeltaNetModel,
3942
GatedDeltaNetForCausalLM,
@@ -83,6 +86,7 @@
8386
'BitAttention', 'BitNetForCausalLM', 'BitNetModel',
8487
'Comba', 'CombaForCausalLM', 'CombaModel',
8588
'DeltaNet', 'DeltaNetForCausalLM', 'DeltaNetModel',
89+
'DeltaFormerAttention', 'DeltaFormerForCausalLM', 'DeltaFormerModel',
8690
'GatedDeltaNet', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel',
8791
'GatedDeltaProduct', 'GatedDeltaProductForCausalLM', 'GatedDeltaProductModel',
8892
'GatedLinearAttention', 'GLAForCausalLM', 'GLAModel',

fla/layers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .bitattn import BitAttention
88
from .comba import Comba
99
from .delta_net import DeltaNet
10+
from .deltaformer import DeltaFormerAttention
1011
from .forgetting_attn import ForgettingAttention
1112
from .gated_deltanet import GatedDeltaNet
1213
from .gated_deltaproduct import GatedDeltaProduct
@@ -60,4 +61,5 @@
6061
'RWKV6Attention',
6162
'RWKV7Attention',
6263
'SlidingWindowSharedKeyAttention',
64+
'DeltaFormerAttention',
6365
]

fla/layers/deltaformer.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3+
4+
from __future__ import annotations
5+
6+
from typing import TYPE_CHECKING, Optional, Tuple
7+
8+
import torch
9+
import torch.nn as nn
10+
from einops import rearrange
11+
from transformers.utils import logging
12+
13+
from fla.modules import RMSNorm, RotaryEmbedding
14+
from fla.ops.deltaformer import deltaformer_attn
15+
from fla.ops.utils.index import prepare_lens_from_mask
16+
17+
if TYPE_CHECKING:
18+
from fla.models.utils import Cache
19+
20+
logger = logging.get_logger(__name__)
21+
22+
23+
class DeltaFormerAttention(nn.Module):
24+
25+
r"""
26+
The layer implementation for DeltaFormer,
27+
[Understanding Transformer from the Perspective of Associative Memory]
28+
(https://arxiv.org/pdf/2505.19488).
29+
30+
Notes
31+
- DeltaFormer attention is implemented with Triton kernels in `fla.ops.deltaformer` and is tuned
32+
for typical head dimensions (e.g., 64/128). It currently supports fixed-length inputs.
33+
- For variable-length inputs (padding masks), the deltaformer computation falls back to using the
34+
fixed-length path, while the second stage (softmax attention over U) uses FlashAttention's
35+
varlen path when an attention mask is provided.
36+
- K/V grouping (GQA) is supported natively by FlashAttention via `num_kv_heads`.
37+
- Uses K-K similarity in deltaformer computation instead of Q-K similarity for better performance.
38+
39+
Args:
40+
hidden_size (int, Optional):
41+
The hidden size of the input. Default: 2048.
42+
num_heads (int, Optional):
43+
The number of attention heads. Default: 32.
44+
num_kv_heads (int, Optional):
45+
The number of key/value heads for grouped-query attention. If None, equals `num_heads`.
46+
Default: None.
47+
qkv_bias (bool, Optional):
48+
Whether to use bias for Q/K/V projections. Default: False.
49+
qk_norm (bool, Optional):
50+
Whether to apply per-head RMSNorm to Q and K before attention. Default: False.
51+
rope_theta (float, Optional):
52+
The base frequency for rotary position embedding. Default: 10000.
53+
max_position_embeddings (int, Optional):
54+
The maximum position embeddings. Default: None.
55+
layer_idx (int, Optional):
56+
The index of the layer (used for cache compatibility). Default: None.
57+
"""
58+
59+
def __init__(
60+
self,
61+
hidden_size: int = 2048,
62+
num_heads: int = 32,
63+
num_kv_heads: Optional[int] = None,
64+
qkv_bias: bool = False,
65+
qk_norm: bool = False,
66+
rope_theta: float = 10000.,
67+
max_position_embeddings: Optional[int] = None,
68+
layer_idx: int | None = None,
69+
):
70+
super().__init__()
71+
72+
self.hidden_size = hidden_size
73+
self.num_heads = num_heads
74+
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
75+
self.num_kv_groups = num_heads // self.num_kv_heads
76+
self.head_dim = self.hidden_size // self.num_heads
77+
self.kv_dim = self.num_kv_heads * self.head_dim
78+
self.qkv_bias = qkv_bias
79+
self.qk_norm = qk_norm
80+
self.rope_theta = rope_theta
81+
self.max_position_embeddings = max_position_embeddings
82+
self.layer_idx = layer_idx
83+
84+
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
85+
self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
86+
self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
87+
self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
88+
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
89+
90+
if qk_norm:
91+
self.q_norm = RMSNorm(self.head_dim)
92+
self.k_norm = RMSNorm(self.head_dim)
93+
94+
self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
95+
96+
def forward(
97+
self,
98+
hidden_states: torch.Tensor,
99+
attention_mask: Optional[torch.LongTensor] = None,
100+
past_key_values: Optional[Cache] = None,
101+
output_attentions: bool = False,
102+
use_cache: bool = False,
103+
**kwargs,
104+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
105+
attentions = None
106+
if attention_mask is not None:
107+
assert len(attention_mask.shape) == 2, (
108+
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
109+
"for padding purposes (0 indicating padding). "
110+
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
111+
)
112+
113+
batch_size, q_len, _ = hidden_states.size()
114+
115+
q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
116+
k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
117+
v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
118+
beta = self.b_proj(hidden_states)
119+
120+
if self.qk_norm:
121+
q, k = self.q_norm(q), self.k_norm(k)
122+
123+
cu_seqlens_kw = kwargs.get('cu_seqlens', None)
124+
seqlen_offset, max_seqlen = 0, q_len
125+
if past_key_values is not None:
126+
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
127+
max_seqlen = q_len + seqlen_offset
128+
129+
if attention_mask is not None:
130+
seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
131+
max_seqlen = q_len + max(seqlen_offset)
132+
133+
if self.max_position_embeddings is not None:
134+
max_seqlen = max(max_seqlen, self.max_position_embeddings)
135+
136+
q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens_kw)
137+
138+
o = deltaformer_attn(
139+
q=q,
140+
k=k,
141+
v=v,
142+
beta=beta,
143+
attention_mask=attention_mask,
144+
cu_seqlens=cu_seqlens_kw
145+
)
146+
147+
o = o.reshape(batch_size, q_len, -1)
148+
o = self.o_proj(o)
149+
150+
if not output_attentions:
151+
attentions = None
152+
153+
return o, attentions, past_key_values

fla/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from fla.models.bitnet import BitNetConfig, BitNetForCausalLM, BitNetModel
55
from fla.models.comba import CombaConfig, CombaForCausalLM, CombaModel
66
from fla.models.delta_net import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel
7+
from fla.models.deltaformer import DeltaFormerConfig, DeltaFormerForCausalLM, DeltaFormerModel
78
from fla.models.forgetting_transformer import (
89
ForgettingTransformerConfig,
910
ForgettingTransformerForCausalLM,
@@ -37,6 +38,7 @@
3738
'BitNetConfig', 'BitNetForCausalLM', 'BitNetModel',
3839
'CombaConfig', 'CombaForCausalLM', 'CombaModel',
3940
'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel',
41+
'DeltaFormerConfig', 'DeltaFormerForCausalLM', 'DeltaFormerModel',
4042
'ForgettingTransformerConfig', 'ForgettingTransformerForCausalLM', 'ForgettingTransformerModel',
4143
'GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel',
4244
'GatedDeltaProductConfig', 'GatedDeltaProductForCausalLM', 'GatedDeltaProductModel',

fla/models/deltaformer/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4+
5+
from fla.models.deltaformer.configuration_deltaformer import DeltaFormerConfig
6+
from fla.models.deltaformer.modeling_deltaformer import DeltaFormerForCausalLM, DeltaFormerModel
7+
8+
AutoConfig.register(DeltaFormerConfig.model_type, DeltaFormerConfig, exist_ok=True)
9+
AutoModel.register(DeltaFormerConfig, DeltaFormerModel, exist_ok=True)
10+
AutoModelForCausalLM.register(DeltaFormerConfig, DeltaFormerForCausalLM, exist_ok=True)
11+
12+
__all__ = ['DeltaFormerConfig', 'DeltaFormerForCausalLM', 'DeltaFormerModel']
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from __future__ import annotations
4+
5+
import warnings
6+
from typing import Dict, Optional
7+
8+
from transformers.configuration_utils import PretrainedConfig
9+
10+
11+
class DeltaFormerConfig(PretrainedConfig):
12+
model_type = 'deltaformer'
13+
keys_to_ignore_at_inference = ['past_key_values']
14+
15+
def __init__(
16+
self,
17+
hidden_size: int = 2048,
18+
hidden_ratio: Optional[int] = 4,
19+
intermediate_size: Optional[int] = None,
20+
num_hidden_layers: int = 24,
21+
num_heads: int = 8,
22+
num_kv_heads: Optional[int] = None,
23+
attn_mode: str = "chunk",
24+
hidden_act: str = "swish",
25+
max_position_embeddings: int = 2048,
26+
elementwise_affine: Optional[bool] = True,
27+
norm_eps: float = 1e-6,
28+
qkv_bias: bool = False,
29+
qk_norm: bool = False,
30+
rope_theta: float = 10000.,
31+
rope_max_position_embeddings: Optional[int] = None,
32+
attn: Optional[Dict] = None,
33+
use_cache: bool = True,
34+
pad_token_id: Optional[int] = None,
35+
bos_token_id: int = 1,
36+
eos_token_id: int = 2,
37+
tie_word_embeddings: bool = False,
38+
initializer_range: float = 0.02,
39+
fuse_norm: bool = True,
40+
fuse_swiglu: bool = True,
41+
fuse_cross_entropy: bool = True,
42+
fuse_linear_cross_entropy: bool = False,
43+
use_l2warp: bool = False,
44+
vocab_size: int = 32000,
45+
output_attentions: bool = False,
46+
output_hidden_states: bool = False,
47+
**kwargs
48+
):
49+
self.hidden_size = hidden_size
50+
self.hidden_ratio = hidden_ratio
51+
self.intermediate_size = intermediate_size
52+
self.num_hidden_layers = num_hidden_layers
53+
self.num_heads = num_heads
54+
self.num_kv_heads = num_kv_heads
55+
self.attn_mode = attn_mode
56+
self.hidden_act = hidden_act
57+
self.max_position_embeddings = max_position_embeddings
58+
self.elementwise_affine = elementwise_affine
59+
self.norm_eps = norm_eps
60+
self.qkv_bias = qkv_bias
61+
self.qk_norm = qk_norm
62+
self.rope_theta = rope_theta
63+
self.rope_max_position_embeddings = rope_max_position_embeddings
64+
self.attn = attn
65+
self.use_cache = use_cache
66+
self.initializer_range = initializer_range
67+
68+
self.fuse_norm = fuse_norm
69+
self.fuse_swiglu = fuse_swiglu
70+
self.fuse_cross_entropy = fuse_cross_entropy
71+
self.fuse_linear_cross_entropy = fuse_linear_cross_entropy
72+
self.use_l2warp = use_l2warp
73+
self.vocab_size = vocab_size
74+
75+
self.output_attentions = output_attentions
76+
self.output_hidden_states = output_hidden_states
77+
78+
if fuse_cross_entropy and fuse_linear_cross_entropy:
79+
raise ValueError(
80+
"`fuse_cross_entropy` and `fuse_linear_cross_entropy` cannot be True at the same time."
81+
)
82+
if fuse_linear_cross_entropy:
83+
warnings.warn(
84+
"`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency "
85+
"at the potential cost of reduced precision. "
86+
"If you observe issues like loss divergence, consider disabling this setting."
87+
)
88+
89+
if attn is not None:
90+
if not isinstance(attn, Dict):
91+
raise ValueError("attn must be a dictionary")
92+
if 'layers' not in attn:
93+
raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
94+
if 'num_heads' not in attn:
95+
raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
96+
attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
97+
attn['qkv_bias'] = attn.get('qkv_bias', False)
98+
attn['window_size'] = attn.get('window_size', None)
99+
attn['rope_theta'] = attn.get('rope_theta', 10000.)
100+
101+
super().__init__(
102+
pad_token_id=pad_token_id,
103+
bos_token_id=bos_token_id,
104+
eos_token_id=eos_token_id,
105+
tie_word_embeddings=tie_word_embeddings,
106+
**kwargs,
107+
)

0 commit comments

Comments
 (0)