Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions ch04/08_deltanet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Both Qwen3-Next and Kimi Linear use a 3:1 ratio, meaning for every three transfo

## Introduction and Overview

Gated DeltaNet is a linear attention variant with inspiration from recurrent neural networks, including a gating mechanism from the [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464) paper. In a sense, Gated DeltaNet is a DeltaNet with Mamba-style gating, and DeltaNet is a linear attention mechanism.
Gated DeltaNet is a linear attention variant with inspiration from recurrent neural networks, including a gating mechanism from the [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464) paper. In a sense, Gated DeltaNet is a DeltaNet with Mamba-style gating, and DeltaNet is a linear attention mechanism.

Kimi Linear modifies the linear attention mechanism of Qwen3-Next by the Kimi Delta Attention (KDA) mechanism, which is essentially a refinement of Gated DeltaNet. Whereas Qwen3-Next applies a scalar gate (one value per attention head) to control the memory decay rate, Kimi Linear replaces it with a channel-wise gating for each feature dimension. According to the authors, this gives more control over the memory, and this, in turn, improves long-context reasoning.

Expand Down Expand Up @@ -96,7 +96,7 @@ class GatedMultiHeadAttention(nn.Module):
context = context.reshape(b, num_tokens, self.d_out)

####################################################
### NEW: Add gate
### NEW: Add gate
context = context * torch.sigmoid(gate)
####################################################
out = self.out_proj(context)
Expand All @@ -115,11 +115,11 @@ As we can see, after computing attention as usual, the model uses a separate gat

Now, what is Gated DeltaNet? Gated DeltaNet (short for *Gated Delta Network*) is Qwen3-Next's linear-attention layer, which is intended as an alternative to standard softmax attention. It was adopted from the [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464) paper as mentioned earlier.

Gated DeltaNet was originally proposed as an improved version of Mamba2, where it combines the gated decay mechanism of Mamba2 with a delta rule.
Gated DeltaNet was originally proposed as an improved version of Mamba2, where it combines the gated decay mechanism of Mamba2 with a delta rule.

Mamba is a state-space model (an alternative to transformers), a big topic that deserves separate coverage in the future.

The delta rule part refers to computing the difference (delta, Δ) between new and predicted values to update a hidden state that is used as a memory state (more on that later).
The delta rule part refers to computing the difference (delta, Δ) between new and predicted values to update a hidden state that is used as a memory state (more on that later).

(Side note: Readers with classic machine learning literature can think of this as similar to Hebbian learning inspired by biology: "Cells that fire together wire together." It's basically a precursor of the perceptron update rule and gradient descent-based learning, but without supervision.)

Expand All @@ -132,8 +132,10 @@ However, as shown in the figure above, the "gated" in the Gated DeltaNet also re
- `α` (decay gate) controls how fast the memory decays or resets over time,
- `β` (update gate) controls how strongly new inputs modify the state.

In code, a simplified version of the Gated DeltaNet depicted above (without the convolutional mixing) can be implemented as follows (the code is inspired by the [official implementation](https://github.com/huggingface/transformers/blob/0ed6d51ae8ed3f4fafca67a983b8d75bc76cd51b/src/transformers/models/qwen3_next/modular_qwen3_next.py#L835) by the Qwen3 team).

(Note that some implementations refer to the decay gate as `gk` (gate for step k), where `exp(gk)` matches the paper's $\alpha_t$. To keep this relationship explicit, the snippet below separates the log-space gate `alpha_log` from the exponentiated decay `alpha`.)

In code, a simplified version of the Gated DeltaNet depicted above (without the convolutional mixing) can be implemented as follows (the code is inspired by the [official implementation](https://github.com/huggingface/transformers/blob/0ed6d51ae8ed3f4fafca67a983b8d75bc76cd51b/src/transformers/models/qwen3_next/modular_qwen3_next.py#L835) by the Qwen3 team):

```python
import torch
Expand Down Expand Up @@ -161,7 +163,7 @@ class GatedDeltaNet(nn.Module):
### NEW: Gates for delta rule and output gating
self.W_gate = nn.Linear(d_in, d_out, bias=False)
self.W_beta = nn.Linear(d_in, d_out, bias=False)

# Note: The decay gate alpha corresponds to
# A_log + W_alpha(x) + dt_bias
self.W_alpha = nn.Linear(d_in, num_heads, bias=False)
Expand All @@ -172,7 +174,7 @@ class GatedDeltaNet(nn.Module):
# W_alpha = nn.Linear(d_in, num_heads, bias=True)
# but the bias is separate for interpretability and
# to mimic the official implementation

self.norm = nn.RMSNorm(self.head_dim, eps=1e-6)
####################################################

Expand All @@ -187,9 +189,10 @@ class GatedDeltaNet(nn.Module):
####################################################
### NEW: Compute delta rule gates
beta = torch.sigmoid(self.W_beta(x))
alpha = -self.A_log.exp().view(1, 1, -1) * F.softplus(
alpha_log = -self.A_log.exp().view(1, 1, -1) * F.softplus(
self.W_alpha(x) + self.dt_bias
)
alpha = alpha_log.exp()
gate = self.W_gate(x)
####################################################

Expand Down Expand Up @@ -223,7 +226,7 @@ class GatedDeltaNet(nn.Module):
b_t = beta[:, :, t]
a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1)

S = S * a_t.exp()
S = S * a_t
kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * b_t
S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
Expand Down Expand Up @@ -287,7 +290,7 @@ for t in range(num_tokens):
b_t = beta[:, :, t]
a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1)

S = S * a_t.exp()
S = S * a_t
kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * b_t
S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
Expand All @@ -298,7 +301,7 @@ And the gates control how that memory changes:

- α (`alpha`) regulates how much of the old memory to forget (decay).

- β (`alpha`) regulates how much the current token at time step *t* updates the memory.
- β (`beta`) regulates how much the current token at time step *t* updates the memory.

(And the final output gate, not shown in the snippet above, is similar to gated attention; it controls how much of the output is kept.)

Expand Down