Skip to content

Conversation

@Nathancgy
Copy link
Contributor

@Nathancgy Nathancgy commented Sep 7, 2025

  • Introduces DeltaFormer layer and model.
  • Adds fused Triton ops for chunked deltaformer attention: fla/ops/deltaformer/parallel.py, invcum.py, and naive.py. The pre-attention calculates the u vector, which is a substitute for v that can be directly integrated into flash attention func.
  • Kernels 1 and 2 are both softmax in this version, omitted alpha scaling factor for v in u's calculation.
  • The current similarity calculation comes in the form of dot products between k and k (as in the original paper)
  • Includes unit tests and integration updates for the new model.
  • Supports varlen.
  • Preserves backward compatibility. The optional beta defaults to all ones when not provided.

Summary by CodeRabbit

  • New Features

    • Adds the DeltaFormer model family (config, blocks, CausalLM) and a Triton-accelerated DeltaFormer attention operator with packed/variable-length support plus a naive reference and linear-algebra helpers.
  • Documentation

    • Adds/updates README news entry and models table for DeltaFormer; merges blog link into the announcement.
  • Tests

    • Adds unit tests for modeling, generation, and attention (including packed/variable-length cases).
  • Chores

    • Bumps package version to 0.3.2 and updates public exports.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 7, 2025

Caution

Review failed

The pull request is closed.

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds DeltaFormer: new attention layer, Triton and naive ops, invcum helpers, full model/config and CausalLM wrapper with Transformers registration, package exports, README updates, and tests; package version bumped to 0.3.2.

Changes

Cohort / File(s) Summary
Docs
README.md
Added 2025-09 DeltaFormer news entry and Models table row with links.
Top-level package exports & version
fla/__init__.py
Re-exported DeltaFormerAttention, DeltaFormerModel, DeltaFormerForCausalLM; bumped __version__ 0.3.1 → 0.3.2.
Layers: attention API
fla/layers/__init__.py, fla/layers/deltaformer.py
New DeltaFormerAttention layer (Q/K/V projections, optional RMSNorm, rotary embeddings, caching/varlen support) and export.
Models: registry and exports
fla/models/__init__.py, fla/models/deltaformer/__init__.py
Added DeltaFormerConfig, DeltaFormerModel, DeltaFormerForCausalLM; registered with Transformers AutoConfig/AutoModel/AutoModelForCausalLM.
Models: config
fla/models/deltaformer/configuration_deltaformer.py
New DeltaFormerConfig with many architecture, attention, fusion, and token options plus validation/warnings.
Models: implementation
fla/models/deltaformer/modeling_deltaformer.py
Implemented DeltaFormerBlock, DeltaFormerPreTrainedModel, DeltaFormerModel, and DeltaFormerForCausalLM (forward, generation, caching, fused loss hooks).
Ops: public API
fla/ops/deltaformer/__init__.py
Exported deltaformer_attn and naive_deltaformer_attn.
Ops: math kernels & helpers
fla/ops/deltaformer/parallel.py, fla/ops/deltaformer/invcum.py
Triton-accelerated deltaformer_attn (fixed/varlen, autograd, chunked forward/backward, flash-attn integration) and invcum linear-algebra forward/backward/in-place solves.
Ops: reference implementation
fla/ops/deltaformer/naive.py
Added naive_deltaformer_attn and tril_softmax reference implementations and exports.
Tests: models
tests/models/test_modeling_deltaformer.py, tests/models/test_modeling_utils.py
Added modeling/generation tests; added DeltaFormerConfig to generation-unsupported list.
Tests: ops
tests/ops/test_deltaformer.py
Added correctness and gradient tests comparing Triton kernel vs naive reference, including packed varlen paths and gradient checks.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant App as Application
    participant CausalLM as DeltaFormerForCausalLM
    participant Model as DeltaFormerModel
    participant Block as DeltaFormerBlock[n]
    participant Attn as DeltaFormerAttention
    participant Op as deltaformer_attn (Triton / naive)
    participant Inv as invcum
    participant FA as FlashAttention (optional)

    App->>CausalLM: forward(input_ids, attention_mask, ...)
    CausalLM->>Model: forward(..., use_cache)
    Model->>Block: for each layer: forward(hidden_states, mask, past_key_values)
    Block->>Attn: compute Q,K,V (+rotary)
    Attn->>Op: deltaformer_attn(q,k,v,beta, cu_seqlens?)
    alt training (requires grads)
        Op-->>Attn: u + saved tensors
    else inference (fast path)
        Op-->>Attn: u
    end
    opt triangular helpers
        Op->>Inv: invcum forward/backward
        Inv-->>Op: solves/results
    end
    Attn-->>Block: output projection
    Block-->>Model: updated hidden_states, PKV
    Model-->>CausalLM: logits / past_key_values
    CausalLM-->>App: logits / loss
Loading
sequenceDiagram
    autonumber
    participant Autograd as Autograd
    participant Op as deltaformer_attn
    participant Kern as Triton kernels
    participant Inv as invcum

    note over Op,Kern: Chunked backward (training)
    Autograd->>Op: backward(dL/du)
    Op->>Kern: backward_u_chunk / backward_qk (per-chunk)
    Kern-->>Op: partial dq, dk, dv, dbeta
    Op->>Inv: backward solves for triangular systems
    Inv-->>Op: du/dw pieces
    Op-->>Autograd: aggregated dq, dk, dv, dbeta
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120+ minutes

Possibly related PRs

Poem

I hop through tensors, kernels sing,
Q and K in spiraled ring.
Triton sparks and invcum threads,
Deltas dance across the spreads.
Tests nibble truth — the rabbit nods and grins. 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "[DeltaFormer] Add Model" is concise, a single short sentence, and directly summarizes the primary change in the PR—adding the DeltaFormer model (and associated layers/ops). It is specific enough for a teammate scanning history and does not use vague terms or file lists.

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e6f09c3 and edb3093.

📒 Files selected for processing (1)
  • README.md (2 hunks)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Nathancgy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the DeltaFormer model architecture into the library, providing a novel approach to attention mechanisms. It includes the core model components, optimized pre-attention kernels for performance, and full integration with the existing model ecosystem, enabling its use for causal language modeling tasks.

Highlights

  • New Model Introduction: Introduces the DeltaFormer layer and model, based on the paper "Understanding Transformer from the Perspective of Associative Memory" (https://arxiv.org/pdf/2505.19488).
  • Optimized Pre-Attention: Implements optimized pre-attention using fused Triton kernels for efficient computation, specifically for chunked pre-attention. This includes dedicated Triton ops for fused_chunk.py, invcum.py, and naive.py.
  • Backward Compatibility: Ensures backward compatibility for the optional beta parameter, which now defaults to all ones when not explicitly provided.
  • Integration and Testing: Includes comprehensive unit tests and integrates the new model into the existing framework, updating relevant __init__.py files and model utility lists.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the DeltaFormer model, a significant addition to the library. The implementation includes the core attention layer, model configuration, and optimized Triton kernels for performance. The code is well-structured, and the inclusion of a naive reference implementation is a great practice for testing and validation.

However, there is a critical issue that needs to be addressed: the model currently does not support key-value caching for autoregressive generation. This is a fundamental feature for causal language models, and its absence severely limits the model's practical utility. I have also identified a couple of areas for code simplification and improved clarity. Please see the detailed comments for suggestions.

Comment on lines +93 to +101
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward method accepts past_key_values and use_cache arguments, which suggests that it should support incremental decoding for generation. However, these arguments are not used, and past_key_values is returned unmodified. Without implementing the key-value cache, autoregressive generation will be extremely inefficient as all previous tokens would need to be re-processed at every step. This is a critical feature for a model intended for causal language modeling, and its absence is confirmed by DeltaFormerConfig being added to GENERATION_UNSUPPORTED in the test suite. Please implement the caching mechanism to enable efficient generation.

GENERATION_UNSUPPORTED = [
"ABCConfig", "LinearAttentionConfig", "LightNetConfig",
"Mamba2Config", "MambaConfig", "NSAConfig", "SambaConfig", "RWKV6Config", "RWKV7Config",
"DeltaFormerConfig",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Adding DeltaFormerConfig to the GENERATION_UNSUPPORTED list means that generation tests are being skipped for this new model. This is likely due to the lack of key-value caching implementation in the attention layer, which I've commented on separately. For a ForCausalLM model, supporting efficient generation is a core requirement. This test should be enabled once the caching mechanism is implemented.

Comment on lines 124 to 131
if attention_mask is not None:
# Use varlen FlashAttention path. Pre-attention currently supports fixed length only → fallback by padding.
q_full = q
k_full = k
v_full = v
beta_full = beta
else:
q_full, k_full, v_full, beta_full = q, k, v, beta

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This if/else block is redundant. Both the if and else branches perform the same assignments. You can simplify the code by removing the conditional block and keeping only the assignments.

        # Use varlen FlashAttention path. Pre-attention currently supports fixed length only → fallback by padding.
        q_full, k_full, v_full, beta_full = q, k, v, beta

Comment on lines 821 to 825
qi = k[:, i:i + C, :]
ki = q[:, i + C:, :]
lse = lses[:, i + C:]
beta_single = beta[:, i + C:]
du = backward_u_chunk(qi, ki, lse, grad_v[:, i + C:, :], fa_scale, beta_single)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable names qi and ki are misleading. qi is initialized with a slice of k, and ki with a slice of q. This is counter-intuitive and can make the code difficult to understand and maintain. Please consider renaming them to something more descriptive, for example, k_chunk and q_context, to improve readability.

Suggested change
qi = k[:, i:i + C, :]
ki = q[:, i + C:, :]
lse = lses[:, i + C:]
beta_single = beta[:, i + C:]
du = backward_u_chunk(qi, ki, lse, grad_v[:, i + C:, :], fa_scale, beta_single)
k_chunk = k[:, i:i + C, :]
q_context = q[:, i + C:, :]
lse = lses[:, i + C:]
beta_single = beta[:, i + C:]
du = backward_u_chunk(k_chunk, q_context, lse, grad_v[:, i + C:, :], fa_scale, beta_single)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Nitpick comments (22)
fla/ops/deltaformer/invcum.py (1)

16-18: Return the tensor and caution on in-place autograd.

forward_inplace mutates u without returning it. Returning u improves ergonomics and mirrors common PyTorch APIs. Also, if u.requires_grad and u is needed for backward, this write may error at runtime.

Apply:

-def forward_inplace(u, w):
-    u.copy_(forward(u, w))
+def forward_inplace(u, w):
+    u.copy_(forward(u, w))
+    return u
fla/ops/deltaformer/naive.py (3)

10-36: Safer masked softmax for all-masked rows.

Row 0 is fully masked when strict=True. Your subsequent masked_fill cleans NaNs, but we can avoid producing them altogether and cut one exp.

Apply:

-    masked = scores.masked_fill(~mask, float('-inf'))
-    max_per_row = masked.max(dim=-1, keepdim=True).values
-    exp = (masked - max_per_row).exp()
-    exp = exp.masked_fill(~mask, 0.0)
+    masked = scores.masked_fill(~mask, float('-inf'))
+    max_per_row = masked.max(dim=-1, keepdim=True).values
+    logits = (masked - max_per_row).masked_fill(~mask, 0.0)
+    exp = logits.exp()

This prevents NaN creation and removes one masked_fill.


59-66: Replace asserts with explicit validation (kept in optimized runs).

assert is stripped under -O. Raise exceptions instead.

Apply:

-    assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "q,k,v must be [B,H,T,D]"
+    if not (q.dim() == 4 and k.dim() == 4 and v.dim() == 4):
+        raise ValueError("q,k,v must be [B,H,T,D]")
     B, H, T, D = q.shape
-    assert k.shape == (B, H, T, D) and v.shape == (B, H, T, D)
+    if k.shape != (B, H, T, D) or v.shape != (B, H, T, D):
+        raise ValueError(f"Expected k,v shape {(B,H,T,D)}, got {k.shape} and {v.shape}")
     if beta is None:
         beta = q.new_ones((B, H, T))
     else:
-        assert beta.shape == (B, H, T)
+        if beta.shape != (B, H, T):
+            raise ValueError(f"beta shape must be {(B,H,T)}, got {beta.shape}")

67-81: Reference path is clear; add a tiny correctness test to guard regressions.

Add a unit test that compares delta_pre_attn_naive vs fused delta_pre_attn on small tensors for fp32/fp16/bf16 within tolerance.

If you want, I can add tests/ops/test_delta_pre_attn_equiv.py with randomized seeds and tolerances.

fla/models/deltaformer/configuration_deltaformer.py (4)

10-13: Annotate class vars to please linters.

Mark model_type and keys_to_ignore_at_inference as ClassVar[...].

Apply:

-from typing import Dict, Optional
+from typing import Dict, Optional, ClassVar
@@
-class DeltaFormerConfig(PretrainedConfig):
-    model_type = 'deltaformer'
-    keys_to_ignore_at_inference = ['past_key_values']
+class DeltaFormerConfig(PretrainedConfig):
+    model_type: ClassVar[str] = 'deltaformer'
+    keys_to_ignore_at_inference: ClassVar[list[str]] = ['past_key_values']

14-46: Add basic shape/hyperparameter validation to fail fast.

Avoid runtime crashes in attention by checking divisibility and GQA consistency.

Apply near the end of init (before super().init):

+        if hidden_size % num_heads != 0:
+            raise ValueError(f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})")
+        kvh = num_kv_heads if num_kv_heads is not None else num_heads
+        if num_heads % kvh != 0:
+            raise ValueError(f"num_heads ({num_heads}) must be divisible by num_kv_heads ({kvh})")
+        if attn_mode not in {"chunk"}:
+            raise ValueError(f"Unsupported attn_mode: {attn_mode}")

25-29: Type of elementwise_affine should be bool.

Optional[bool] suggests None is meaningful, but you never use None. Simplify to bool.


75-89: Exception messages are fine; optional: shorten per linter TRY003.

Purely stylistic; feel free to ignore if you prefer clarity.

fla/ops/deltaformer/__init__.py (1)

3-9: Clean re-exports; consider exposing tril_softmax for testing.

Optional: re-export tril_softmax for test utilities, otherwise this is fine.

fla/__init__.py (1)

37-39: Top-level re-exports OK; consider also exposing DeltaFormerConfig.

Keeps parity with other models’ configs at top-level if that’s a convention you follow.

Apply:

-from fla.models import DeltaFormerForCausalLM, DeltaFormerModel
+from fla.models import DeltaFormerConfig, DeltaFormerForCausalLM, DeltaFormerModel
@@
-    'DeltaFormerForCausalLM', 'DeltaFormerModel',
+    'DeltaFormerConfig', 'DeltaFormerForCausalLM', 'DeltaFormerModel',
tests/models/test_modeling_deltaformer.py (1)

40-58: Generation test is solid; add a varlen-specific assertion if feasible.

Given pre-attn currently ignores attention masks (fixed-length kernel), consider asserting tolerance is slightly looser when attention_mask is present, or add a small test that runs without mask to ensure exactness.

I can add an extra param case with attention_mask=None to validate the exact path side-by-side.

fla/layers/deltaformer.py (4)

80-82: Tighten the ImportError message (Ruff TRY003).

Shorten the message to satisfy the linter.

-        if flash_attn_func is None:
-            raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
+        if flash_attn_func is None:
+            raise ImportError("flash-attn is required; install via `pip install flash-attn --no-build-isolation`.")

117-120: KV grouping replication OK; consider safer API.

Optionally use repeat_interleave(dim=2) for clarity; current repeat is fine performance-wise.

-            k = repeat(k, 'b t h d -> b t (h g) d', g=self.num_kv_groups)
-            v = repeat(v, 'b t h d -> b t (h g) d', g=self.num_kv_groups)
+            k = k.repeat_interleave(self.num_kv_groups, dim=2)
+            v = v.repeat_interleave(self.num_kv_groups, dim=2)

142-156: Varlen path: rename for clarity and guard availability.

Names are “unpadded”, not “padded”; also guard when varlen kernel is missing.

-        if attention_mask is not None:
-            q_padded, (k_padded, u_padded), indices_q, cu_seqlens, max_seq_lens = unpad_input(q, (k, u), attention_mask, q_len)
-            cu_seqlens_q, cu_seqlens_k = cu_seqlens
-            max_seqlen_q, max_seqlen_k = max_seq_lens
-            o = flash_attn_varlen_func(
-                q_padded, k_padded, u_padded,
+        if attention_mask is not None:
+            if flash_attn_varlen_func is None:
+                raise ImportError("flash-attn varlen kernel not available; please upgrade flash-attn.")
+            q_unpadded, (k_unpadded, u_unpadded), indices_q, cu_seqlens, max_seq_lens = unpad_input(q, (k, u), attention_mask, q_len)
+            cu_seqlens_q, cu_seqlens_k = cu_seqlens
+            max_seqlen_q, max_seqlen_k = max_seq_lens
+            o = flash_attn_varlen_func(
+                q_unpadded, k_unpadded, u_unpadded,
                 cu_seqlens_q=cu_seqlens_q,
                 cu_seqlens_k=cu_seqlens_k,
                 max_seqlen_q=max_seqlen_q,
                 max_seqlen_k=max_seqlen_k,
                 causal=True,
                 window_size=(-1, -1)
             )
             o = pad_input(o, indices_q, batch_size, q_len)

93-101: Unused parameters (output_attentions, use_cache, kwargs).

If the interface must keep them, explicitly mark unused to silence linters.

     def forward(
         self,
         hidden_states: torch.Tensor,
         attention_mask: Optional[torch.LongTensor] = None,
         past_key_values: Optional[Cache] = None,
         output_attentions: bool = False,
         use_cache: bool = False,
         **kwargs,
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        _ = (output_attentions, use_cache, kwargs)
fla/ops/deltaformer/fused_chunk.py (5)

27-27: Consider using explicit error handling instead of assertions.

Assertions can be disabled with the -O flag in production, potentially leading to silent failures. Consider using explicit error checking for better production robustness.

-    assert B == _B and D == _D and B == __B and __C == C
+    if not (B == _B and D == _D and B == __B and __C == C):
+        raise ValueError(f"Shape mismatch: q={q.shape}, k={k.shape}, v={v.shape}, beta={beta.shape}")

232-234: Magic number should be a named constant.

The value -1e6 used for masking should be defined as a constant for better maintainability.

+MASK_VALUE = -1e6
+
 def flash_attn_kernel(
     ...
 ):
     ...
     if kv_i >= T - C:
         mask = (T - C - kv_i + rowid_block[:, None] - colid_block[None, :] < 1)
-        qk = tl.where(mask, -1e6, qk)
+        qk = tl.where(mask, MASK_VALUE, qk)

271-271: Potential numerical instability when computing normalized attention.

The division acc / rowsum[:, None] could lead to division by zero if rowsum contains zeros. While the forward kernel initializes rowsum to 1 and only adds positive values from exp2, consider adding a small epsilon for numerical stability.

-    acc = acc / rowsum[:, None]
+    acc = acc / (rowsum[:, None] + 1e-12)

819-831: Backward pass could be optimized for memory usage.

The backward implementation creates multiple temporary tensors (du, grad_v) at each chunk iteration. Consider reusing buffers to reduce memory allocations.

Would you like me to propose a memory-optimized version of the backward pass that reuses buffers across chunk iterations?


891-894: Redundant gradient checking for beta parameter.

The condition checks beta is not None and beta.requires_grad, but if beta is None, it gets replaced with ones tensor in _forward_impl. Consider simplifying the gradient check logic.

-    if k.requires_grad or q.requires_grad or v.requires_grad or (beta is not None and beta.requires_grad):
-        return _DeltaPreAttnFunction.apply(q, k, v, beta, C)
-    u, _, _ = _DeltaPreAttnFunction._forward_impl(q, k, v, beta, C, need_aux=False)
-    return u
+    # Check if any input requires gradient
+    requires_grad = k.requires_grad or q.requires_grad or v.requires_grad
+    if beta is not None:
+        requires_grad = requires_grad or beta.requires_grad
+    
+    if requires_grad:
+        return _DeltaPreAttnFunction.apply(q, k, v, beta, C)
+    else:
+        u, _, _ = _DeltaPreAttnFunction._forward_impl(q, k, v, beta, C, need_aux=False)
+        return u
fla/models/deltaformer/modeling_deltaformer.py (2)

106-108: Remove unused method parameters.

The parameters prenorm_residual_strategy and num_residuals_per_layer are not used in the method implementation.

 def _init_weights(
     self,
     module: nn.Module,
-    prenorm_residual_strategy: Optional[str] = None,
-    num_residuals_per_layer: int = 2,
 ):

152-157: Set proper stacklevel for warning.

The warning should include stacklevel=2 to show the correct location in user code.

-        warnings.warn(
-            "`DeltaFormerModel` does not support output attention weights now, "
-            "so `output_attentions` is set to `False`."
-        )
+        warnings.warn(
+            "`DeltaFormerModel` does not support output attention weights now, "
+            "so `output_attentions` is set to `False`.",
+            stacklevel=2
+        )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e6a82e9 and 566b28b.

📒 Files selected for processing (13)
  • fla/__init__.py (2 hunks)
  • fla/layers/__init__.py (2 hunks)
  • fla/layers/deltaformer.py (1 hunks)
  • fla/models/__init__.py (2 hunks)
  • fla/models/deltaformer/__init__.py (1 hunks)
  • fla/models/deltaformer/configuration_deltaformer.py (1 hunks)
  • fla/models/deltaformer/modeling_deltaformer.py (1 hunks)
  • fla/ops/deltaformer/__init__.py (1 hunks)
  • fla/ops/deltaformer/fused_chunk.py (1 hunks)
  • fla/ops/deltaformer/invcum.py (1 hunks)
  • fla/ops/deltaformer/naive.py (1 hunks)
  • tests/models/test_modeling_deltaformer.py (1 hunks)
  • tests/models/test_modeling_utils.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (10)
tests/models/test_modeling_deltaformer.py (2)
fla/models/deltaformer/configuration_deltaformer.py (1)
  • DeltaFormerConfig (10-96)
tests/models/test_modeling_base.py (2)
  • run_test_generation (67-126)
  • run_test_model_forward_backward (27-61)
fla/ops/deltaformer/__init__.py (2)
fla/ops/deltaformer/fused_chunk.py (1)
  • delta_pre_attn (878-894)
fla/ops/deltaformer/naive.py (1)
  • delta_pre_attn_naive (39-81)
fla/models/__init__.py (2)
fla/models/deltaformer/configuration_deltaformer.py (1)
  • DeltaFormerConfig (10-96)
fla/models/deltaformer/modeling_deltaformer.py (2)
  • DeltaFormerForCausalLM (208-349)
  • DeltaFormerModel (119-205)
fla/ops/deltaformer/invcum.py (1)
fla/ops/deltaformer/fused_chunk.py (3)
  • forward (16-31)
  • forward (784-798)
  • backward (801-834)
fla/__init__.py (1)
fla/models/deltaformer/modeling_deltaformer.py (2)
  • DeltaFormerForCausalLM (208-349)
  • DeltaFormerModel (119-205)
fla/layers/deltaformer.py (4)
fla/layers/utils.py (2)
  • pad_input (176-197)
  • unpad_input (101-173)
fla/modules/layernorm.py (1)
  • RMSNorm (1060-1107)
fla/ops/deltaformer/fused_chunk.py (3)
  • delta_pre_attn (878-894)
  • forward (16-31)
  • forward (784-798)
fla/models/deltaformer/modeling_deltaformer.py (3)
  • forward (61-89)
  • forward (140-205)
  • forward (289-349)
fla/layers/__init__.py (1)
fla/layers/deltaformer.py (1)
  • DeltaFormerAttention (28-163)
fla/models/deltaformer/modeling_deltaformer.py (7)
fla/layers/deltaformer.py (2)
  • DeltaFormerAttention (28-163)
  • forward (93-163)
fla/models/deltaformer/configuration_deltaformer.py (1)
  • DeltaFormerConfig (10-96)
fla/modules/fused_cross_entropy.py (1)
  • FusedCrossEntropyLoss (344-419)
fla/modules/fused_linear_cross_entropy.py (1)
  • FusedLinearCrossEntropyLoss (493-567)
fla/modules/mlp.py (1)
  • GatedMLP (26-69)
fla/modules/layernorm.py (1)
  • RMSNorm (1060-1107)
fla/models/modeling_layers.py (1)
  • GradientCheckpointingLayer (11-71)
fla/models/deltaformer/__init__.py (2)
fla/models/deltaformer/configuration_deltaformer.py (1)
  • DeltaFormerConfig (10-96)
fla/models/deltaformer/modeling_deltaformer.py (2)
  • DeltaFormerForCausalLM (208-349)
  • DeltaFormerModel (119-205)
fla/ops/deltaformer/fused_chunk.py (2)
fla/layers/deltaformer.py (1)
  • forward (93-163)
fla/ops/deltaformer/invcum.py (4)
  • forward (7-13)
  • backward (29-38)
  • backward_x (20-26)
  • forward_inplace (16-17)
🪛 Ruff (0.12.2)
fla/models/deltaformer/configuration_deltaformer.py

12-12: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


76-76: Avoid specifying long messages outside the exception class

(TRY003)


80-80: Avoid specifying long messages outside the exception class

(TRY003)


82-82: Avoid specifying long messages outside the exception class

(TRY003)


84-84: Avoid specifying long messages outside the exception class

(TRY003)

fla/ops/deltaformer/naive.py

59-59: Use of assert detected

(S101)


61-61: Use of assert detected

(S101)


65-65: Use of assert detected

(S101)

fla/layers/deltaformer.py

81-81: Avoid specifying long messages outside the exception class

(TRY003)


98-98: Unused method argument: output_attentions

(ARG002)


99-99: Unused method argument: use_cache

(ARG002)


100-100: Unused method argument: kwargs

(ARG002)


104-104: Use of assert detected

(S101)

fla/models/deltaformer/modeling_deltaformer.py

97-97: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


106-106: Unused method argument: prenorm_residual_strategy

(ARG002)


107-107: Unused method argument: num_residuals_per_layer

(ARG002)


143-143: Unused blanket noqa directive

Remove unused noqa directive

(RUF100)


153-153: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


164-164: Avoid specifying long messages outside the exception class

(TRY003)


166-166: Avoid specifying long messages outside the exception class

(TRY003)


210-210: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


246-252: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


246-252: Avoid specifying long messages outside the exception class

(TRY003)


254-254: Use raise without specifying exception name

Remove exception name

(TRY201)


265-265: Unused method argument: kwargs

(ARG002)


341-341: Consider (logits, *outputs[1:]) instead of concatenation

Replace with (logits, *outputs[1:])

(RUF005)


342-342: Consider (loss, *output) instead of concatenation

Replace with (loss, *output)

(RUF005)

fla/ops/deltaformer/fused_chunk.py

27-27: Use of assert detected

(S101)


192-192: Unused function argument: B

(ARG001)


352-352: Unused function argument: B

(ARG001)


459-459: Unused function argument: B

(ARG001)


581-581: Unused function argument: B

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-models
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (14)
fla/ops/deltaformer/invcum.py (3)

7-13: Numerical intent is clear; please confirm unit-lower-triangular assumption for w.

unitriangular=True ignores the diagonal of w. Ensure upstream constructs w as unit-lower-triangular (or at least that its diagonal is not relied upon). If not guaranteed, results may silently diverge.


20-27: Backprop solve uses strictly-upper assumption; verify ws construction.

w.tril(-1).mH with upper=True, unitriangular=True expects a strictly-upper unit-triangular system after conjugate-transpose. Please verify ws[i] is produced accordingly in fused_chunk; otherwise gradients w.r.t. RHS can be wrong.


29-38: Cast dw to the parameter’s dtype and zero its diagonal after the cast
Add .to(w.dtype) immediately after .tril(-1) to ensure mixed-precision doesn’t change the dtype, and confirm (e.g. via an assertion in fused_chunk) that dw.dtype == w.dtype and the diagonal remains zero.

fla/layers/__init__.py (2)

10-10: Public export looks good.


64-65: all update is correct.

fla/models/__init__.py (2)

7-7: DeltaFormer symbols wired correctly.


64-65: all includes DeltaFormer trio.

Good for downstream imports.

tests/models/test_modeling_deltaformer.py (1)

14-24: Param IDs aid debugging; good coverage across head dims.

fla/models/deltaformer/__init__.py (2)

12-12: Exports look good.


8-10: Wrap registrations in try/except for backward compatibility
Older transformers releases don’t support the exist_ok parameter on AutoConfig.register/AutoModel.register/AutoModelForCausalLM.register. Fallback to calling .register(...) without exist_ok inside a TypeError handler to avoid import-time failures.

tests/models/test_modeling_utils.py (2)

13-13: LGTM: gating varlen testing for DeltaFormer.

Consistent with current layer limitations.


27-27: LGTM: skip generation tests for DeltaFormer.

Matches the current generate constraints.

fla/models/deltaformer/modeling_deltaformer.py (2)

80-87: LGTM! Clean implementation of fused norm branching.

The conditional logic for fused vs non-fused normalization is well-structured, maintaining the correct residual connections in both paths.


333-333: Label shape aligns with hidden_states—no mismatch. The shifting logic mirrors other causal‐LM implementations in this repo (e.g., Rodimus, Transformer) and, with the default logits_to_keep = 0 (full sequence), both labels and hidden_states share the same sequence length.

Comment on lines 20 to 24
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
except ImportError:
flash_attn_func = None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid potential NameError; mirror sentinel for varlen import.

If the import fails, flash_attn_varlen_func is undefined. Set it to None too.

 try:
     from flash_attn import flash_attn_func, flash_attn_varlen_func
 except ImportError:
-    flash_attn_func = None
+    flash_attn_func = None
+    flash_attn_varlen_func = None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
except ImportError:
flash_attn_func = None
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
except ImportError:
flash_attn_func = None
flash_attn_varlen_func = None
🤖 Prompt for AI Agents
In fla/layers/deltaformer.py around lines 20 to 24, the except block only sets
flash_attn_func to None leaving flash_attn_varlen_func undefined on ImportError;
update the except block to also assign flash_attn_varlen_func = None so both
symbols exist as sentinels when the import fails.

Comment on lines +70 to +76
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_kv_groups = num_heads // self.num_kv_heads
self.head_dim = self.hidden_size // self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.qkv_bias = qkv_bias
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Validate head divisibility and GQA grouping early.

Prevent silent shape mismatches when hidden_size % num_heads != 0 or num_heads % num_kv_heads != 0.

-        self.hidden_size = hidden_size
-        self.num_heads = num_heads
-        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
-        self.num_kv_groups = num_heads // self.num_kv_heads
-        self.head_dim = self.hidden_size // self.num_heads
-        self.kv_dim = self.num_kv_heads * self.head_dim
+        self.hidden_size = hidden_size
+        self.num_heads = num_heads
+        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
+        if self.hidden_size % self.num_heads != 0:
+            raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads}).")
+        if self.num_heads % self.num_kv_heads != 0:
+            raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads}).")
+        self.num_kv_groups = self.num_heads // self.num_kv_heads
+        self.head_dim = self.hidden_size // self.num_heads
+        self.kv_dim = self.num_kv_heads * self.head_dim
         self.qkv_bias = qkv_bias
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_kv_groups = num_heads // self.num_kv_heads
self.head_dim = self.hidden_size // self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.qkv_bias = qkv_bias
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
if self.hidden_size % self.num_heads != 0:
raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads}).")
if self.num_heads % self.num_kv_heads != 0:
raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads}).")
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.head_dim = self.hidden_size // self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.qkv_bias = qkv_bias
🤖 Prompt for AI Agents
In fla/layers/deltaformer.py around lines 70 to 76, currently head sizes and GQA
grouping are computed without validation which can lead to silent shape
mismatches; add explicit checks before computing head_dim and kv_dim: verify
hidden_size % num_heads == 0 and raise a ValueError with a clear message if not,
and verify num_heads % num_kv_heads == 0 (after resolving default num_kv_heads)
and raise a ValueError if that fails; only after these validations compute
head_dim and kv_dim so failures occur early and with descriptive errors.

Comment on lines +246 to +254
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
else:
raise exception
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve exception chaining for better debugging.

When re-raising exceptions, use from to preserve the original exception context.

-            if 'past_key_values' in str(exception):
-                raise AttributeError(
-                    f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
-                    f"which is not supported for {self.__class__.__name__}. "
-                    f"Try another generation strategy instead. "
-                    f"For the available generation strategies, check this doc: "
-                    f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
-                )
-            else:
-                raise exception
+            if 'past_key_values' in str(exception):
+                raise AttributeError(
+                    f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
+                    f"which is not supported for {self.__class__.__name__}. "
+                    f"Try another generation strategy instead. "
+                    f"For the available generation strategies, check this doc: "
+                    f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
+                ) from exception
+            else:
+                raise
🧰 Tools
🪛 Ruff (0.12.2)

246-252: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


246-252: Avoid specifying long messages outside the exception class

(TRY003)


254-254: Use raise without specifying exception name

Remove exception name

(TRY201)

🤖 Prompt for AI Agents
In fla/models/deltaformer/modeling_deltaformer.py around lines 246 to 254, the
code currently does "else: raise exception" which loses proper exception
chaining; replace that line with a plain "raise" to re-raise the caught
exception preserving its traceback, or if you must raise a new exception, raise
it with "from exception" (e.g., raise NewError(...) from exception) so the
original context is preserved.

Comment on lines +323 to +331
if getattr(self, 'criterion', None) is None:
if self.config.fuse_linear_cross_entropy:
criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp)
elif self.config.fuse_cross_entropy:
criterion = FusedCrossEntropyLoss(inplace_backward=True)
else:
criterion = nn.CrossEntropyLoss()
else:
criterion = self.criterion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Criterion initialization should happen in init.

The criterion is lazily initialized in the forward pass, which could lead to unexpected behavior when the model is used in distributed settings or when the configuration changes after model creation.

 def __init__(self, config: DeltaFormerConfig):
     super().__init__(config)
     self.model = DeltaFormerModel(config)
     self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
-    self.criterion = None
+    
+    # Initialize criterion based on config
+    if config.fuse_linear_cross_entropy:
+        self.criterion = FusedLinearCrossEntropyLoss(use_l2warp=config.use_l2warp)
+    elif config.fuse_cross_entropy:
+        self.criterion = FusedCrossEntropyLoss(inplace_backward=True)
+    else:
+        self.criterion = nn.CrossEntropyLoss()
     
     self.post_init()

And simplify the forward method:

-            if getattr(self, 'criterion', None) is None:
-                if self.config.fuse_linear_cross_entropy:
-                    criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp)
-                elif self.config.fuse_cross_entropy:
-                    criterion = FusedCrossEntropyLoss(inplace_backward=True)
-                else:
-                    criterion = nn.CrossEntropyLoss()
-            else:
-                criterion = self.criterion
+            criterion = self.criterion
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if getattr(self, 'criterion', None) is None:
if self.config.fuse_linear_cross_entropy:
criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp)
elif self.config.fuse_cross_entropy:
criterion = FusedCrossEntropyLoss(inplace_backward=True)
else:
criterion = nn.CrossEntropyLoss()
else:
criterion = self.criterion
class DeltaFormerForCausalLM(nn.Module):
def __init__(self, config: DeltaFormerConfig):
super().__init__(config)
self.model = DeltaFormerModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize criterion based on config
if config.fuse_linear_cross_entropy:
self.criterion = FusedLinearCrossEntropyLoss(use_l2warp=config.use_l2warp)
elif config.fuse_cross_entropy:
self.criterion = FusedCrossEntropyLoss(inplace_backward=True)
else:
self.criterion = nn.CrossEntropyLoss()
self.post_init()
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
# ... (other forward logic)
# Use the pre-initialized criterion
criterion = self.criterion
# Compute logits, loss, etc.
outputs = self.lm_head(self.model(...))
loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
# ... (rest of forward)
return loss, outputs

@yzhangcs
Copy link
Member

yzhangcs commented Sep 7, 2025

@Nathancgy Hi, thanks for the PR.
Some issues before merging:

  1. Try running python -m benchmarks.benchmark_training_throughput --name deltaformer --batch_size 1 --seq_len 4096, it appears that the default block size is too big.
  2. All tests in varlen training and generation for deltaformer are skipped, are you planning to add supports for them in this PR?

@Nathancgy
Copy link
Contributor Author

@yzhangcs Hi, thanks for helping with the minor fixes!

  1. I have just added autotune for all the kernels, the issue is now fixed. The test result (python -m benchmarks.benchmark_training_throughput --name deltaformer --batch_size 1 --seq_len 4096) on a H100 gives 9882.56 token/s for the throughput
  2. I will try adding them tonight. But if they're not here by next morning, then I'll probably have to add them next weekend as I have an upcoming test week from school. Will try to get it done by eod.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (6)
fla/layers/deltaformer.py (5)

20-24: Fix incomplete flash_attn import handling.

The missing import sentinel creates a potential NameError for flash_attn_varlen_func when the import fails.

 try:
     from flash_attn import flash_attn_func, flash_attn_varlen_func
 except ImportError:
     flash_attn_func = None
+    flash_attn_varlen_func = None

70-78: Add validation for head dimension compatibility.

Without validation, incompatible head dimensions could cause silent shape errors downstream.

 self.hidden_size = hidden_size
 self.num_heads = num_heads
 self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
+if self.hidden_size % self.num_heads != 0:
+    raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})")
+if self.num_heads % self.num_kv_heads != 0:
+    raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})")
 self.num_kv_groups = num_heads // self.num_kv_heads

103-108: Replace assertion with explicit exception for runtime validation.

Using assertions for runtime validation can be stripped with -O optimization.

 if attention_mask is not None:
-    assert len(attention_mask.shape) == 2, (
-        "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
-        "for padding purposes (0 indicating padding). "
-        "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
-    )
+    if attention_mask.dim() != 2:
+        raise ValueError(
+            "attention_mask must be [batch, seq_len] with 1=valid and 0=pad; 2D pairwise masks are unsupported."
+        )

124-197: Critical: Missing key-value cache implementation for generation.

The forward method accepts caching parameters but doesn't implement proper incremental decoding. This makes autoregressive generation extremely inefficient as all previous tokens need re-processing at each step. The complex caching logic in lines 147-196 shows an attempt at implementation, but without proper cache state management, generation will fail.

# The current implementation attempts caching but lacks proper state management.
# A complete cache implementation would require:
# 1. Proper cache initialization and state tracking
# 2. Efficient key-value updates for incremental decoding
# 3. Handling of attention mask updates during generation

This aligns with DeltaFormerConfig being added to GENERATION_UNSUPPORTED in the test suite.


128-143: Pre-attention doesn't support masked computation - validate mask has no padding.

The pre-attention operation ignores attention_mask, allowing padded tokens to leak into computations and corrupt results.

 if attention_mask is not None:
+    if not torch.all(attention_mask):
+        raise NotImplementedError(
+            "DeltaFormer pre-attention does not support variable-length sequences; "
+            "use sequences without padding or implement masked pre-attention."
+        )
     _, cu_seqlens_k, _ = get_unpad_data(attention_mask)
fla/ops/deltaformer/fused_chunk.py (1)

948-962: Confusing variable names in backward implementation.

The variables qi and ki are misleadingly named - qi contains a slice of k while ki contains a slice of q. This makes the code difficult to understand and maintain.

-                qi = k[:, i:i + cb, :]
-                ki = q[:, i + cb:, :]
+                k_chunk = k[:, i:i + cb, :]
+                q_context = q[:, i + cb:, :]
                 lse_slice = lses[:, i + cb:]
                 beta_single = beta[:, i + cb:]
                 du = backward_u_chunk(
-                    qi,
-                    ki,
+                    k_chunk,
+                    q_context,
                     lse_slice,
                     grad_v[:, i + cb:, :],
                     fa_scale,
                     beta_single,
                     cu_seqlens=cu_flat,
                     row_start=i,
                 )
🧹 Nitpick comments (3)
fla/layers/deltaformer.py (2)

81-81: Consider extracting the error message to a constant.

The static analyzer suggests avoiding long messages inline. While this is a minor style issue, extracting improves maintainability.

+_FLASH_ATTN_INSTALL_MSG = "Please install Flash Attention via `pip install flash-attn --no-build-isolation` first"
+
 if flash_attn_func is None:
-    raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
+    raise ImportError(_FLASH_ATTN_INSTALL_MSG)

86-86: Consider renaming b_proj to beta_proj for clarity.

The variable name b_proj is not immediately clear compared to other projection layers. Since the layer produces beta weights, beta_proj would be more descriptive.

-self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
+self.beta_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)

And update Line 115:

-beta = rearrange(self.b_proj(hidden_states), 'b t h -> b h t')
+beta = rearrange(self.beta_proj(hidden_states), 'b t h -> b h t')
fla/ops/deltaformer/fused_chunk.py (1)

26-33: Add input tensor validation for shape compatibility.

The current shape assertions use raw assert statements. Consider adding more descriptive validation that would help debugging in production.

 def forward(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
     u: torch.Tensor,
     qk_scale: float,
     beta: torch.Tensor,
     row_start: int,
     cu_seqlens: Optional[torch.Tensor] = None,
 ):
     B, C, D = q.size()
     _B, T, _D = k.size()
     __B, __C = beta.size()
-    assert B == _B and D == _D and B == __B and __C == C
+    if not (B == _B == __B and D == _D and __C == C):
+        raise ValueError(
+            f"Shape mismatch: q{q.shape}, k{k.shape}, beta{beta.shape} must have compatible batch/head/feature dims"
+        )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c5fceb2 and c17dfec.

📒 Files selected for processing (2)
  • fla/layers/deltaformer.py (1 hunks)
  • fla/ops/deltaformer/fused_chunk.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
fla/ops/deltaformer/fused_chunk.py (2)
fla/layers/deltaformer.py (1)
  • forward (93-223)
fla/ops/deltaformer/invcum.py (4)
  • forward (7-13)
  • backward (29-38)
  • backward_x (20-26)
  • forward_inplace (16-17)
fla/layers/deltaformer.py (4)
fla/layers/utils.py (3)
  • get_unpad_data (75-98)
  • pad_input (176-197)
  • unpad_input (101-173)
fla/modules/layernorm.py (1)
  • RMSNorm (1060-1107)
fla/ops/deltaformer/fused_chunk.py (3)
  • delta_pre_attn (1043-1060)
  • forward (16-33)
  • forward (892-908)
fla/models/deltaformer/modeling_deltaformer.py (3)
  • forward (61-89)
  • forward (140-205)
  • forward (289-349)
🪛 Ruff (0.12.2)
fla/ops/deltaformer/fused_chunk.py

29-29: Use of assert detected

(S101)


217-217: Unused function argument: B

(ARG001)


406-406: Unused function argument: B

(ARG001)


536-536: Unused function argument: B

(ARG001)


680-680: Unused function argument: B

(ARG001)

fla/layers/deltaformer.py

81-81: Avoid specifying long messages outside the exception class

(TRY003)


98-98: Unused method argument: output_attentions

(ARG002)


100-100: Unused method argument: kwargs

(ARG002)


104-104: Use of assert detected

(S101)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-models
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (3)
fla/ops/deltaformer/fused_chunk.py (3)

1043-1061: Well-structured public API with proper gradient routing.

The main entry point correctly routes between gradient-aware and gradient-free execution paths, and provides clear documentation. The function properly handles optional parameters and maintains backward compatibility.


890-988: Comprehensive autograd implementation with proper gradient flow.

The _DeltaPreAttnFunction class provides a complete PyTorch autograd implementation with:

  • Proper forward pass with auxiliary data saving for backward
  • Chunked backward pass that correctly accumulates gradients
  • Proper handling of optional beta parameter
  • Support for both fixed and variable-length sequences

The implementation correctly manages tensor flattening/reshaping and uses the inverse cumulative solver for gradient computation.


188-226: Proper autotune configuration for kernel optimization.

The autotune decorator with multiple block size configurations will help optimize performance across different input sizes. The heuristics correctly detect variable-length sequences.

Comment on lines +98 to +100
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove unused parameters or implement functionality.

The static analyzer correctly identifies unused parameters. Either implement the missing functionality or remove these parameters to avoid misleading the API consumer.

For output_attentions:

-def forward(
-    self,
-    hidden_states: torch.Tensor,
-    attention_mask: Optional[torch.LongTensor] = None,
-    past_key_values: Optional[Cache] = None,
-    output_attentions: bool = False,
-    use_cache: bool = False,
-    **kwargs,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-    attentions = None
+def forward(
+    self,
+    hidden_states: torch.Tensor,
+    attention_mask: Optional[torch.LongTensor] = None,
+    past_key_values: Optional[Cache] = None,
+    output_attentions: bool = False,
+    use_cache: bool = False,
+    **kwargs,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+    if output_attentions:
+        raise NotImplementedError("DeltaFormer does not support outputting attention weights")
+    attentions = None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
raise NotImplementedError("DeltaFormer does not support outputting attention weights")
attentions = None
# ... rest of implementation ...
🧰 Tools
🪛 Ruff (0.12.2)

98-98: Unused method argument: output_attentions

(ARG002)


100-100: Unused method argument: kwargs

(ARG002)

🤖 Prompt for AI Agents
In fla/layers/deltaformer.py around lines 98 to 100, the parameters
output_attentions and use_cache are declared but not used; remove them from the
function/method signature if the layer does not support these features, or
propagate them to the underlying model calls and adjust return values to include
attentions/cache outputs when requested. Specifically, either (A) delete
output_attentions and use_cache from the signature and update all callers, or
(B) thread output_attentions and use_cache into the internal forward/transformer
call, handle the returned attention tensors and cache (modify return type to
include attentions if output_attentions is True and preserve cache behavior when
use_cache is True), and add tests covering both behaviors.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🧹 Nitpick comments (27)
fla/layers/deltaformer.py (4)

157-180: Deduplicate identical pre-attn calls; drive a single code path.

All three branches do the same, only differ by cu_seqlens. Reduce to one call.

Apply:

-        if attention_mask is not None:
-            u = delta_pre_attn(
-                # kk similarity
-                rearrange(k_full, 'b t h d -> b h t d'),
-                rearrange(k_full, 'b t h d -> b h t d'),
-                rearrange(v_full, 'b t h d -> b h t d'),
-                beta_full,
-                cu_seqlens=cu_seqlens,
-            )
-        elif cu_seqlens is not None:
-            u = delta_pre_attn(
-                rearrange(k_full, 'b t h d -> b h t d'),
-                rearrange(k_full, 'b t h d -> b h t d'),
-                rearrange(v_full, 'b t h d -> b h t d'),
-                beta_full,
-                cu_seqlens=cu_seqlens,
-            )
-        else:
-            u = delta_pre_attn(
-                rearrange(k_full, 'b t h d -> b h t d'),
-                rearrange(k_full, 'b t h d -> b h t d'),
-                rearrange(v_full, 'b t h d -> b h t d'),
-                beta_full,
-            )
+        u = delta_pre_attn(
+            # kk similarity
+            rearrange(k_full, 'b t h d -> b h t d'),
+            rearrange(k_full, 'b t h d -> b h t d'),
+            rearrange(v_full, 'b t h d -> b h t d'),
+            beta_full,
+            cu_seqlens=cu_seqlens,
+        )

36-44: Docstring is out of date with implementation.

  • Code uses varlen pre-attn (via cu_seqlens) but the note claims fixed-length only.
  • It also states K–K similarity; PR summary mentioned Q–K. Clarify here.

Proposed text:

  • “Pre-attention supports both fixed-length and varlen via cu_seqlens; padded positions are ignored.”
  • “Pre-attention currently uses K–K similarity; switching to Q–K or K–W requires changing inputs.”

148-156: Remove commented/unused variables.

q_full comments and duplicated assignments add noise.


26-26: Unused logger.

Remove or use for warnings (e.g., when falling back paths are taken).

README.md (2)

32-32: Minor: list style nit (markdownlint MD004).

Use consistent bullet markers if you enforce markdownlint; otherwise ignore.


88-89: Link and claim sanity-check.

DeltaFormer is listed as available; generation in tests is currently marked unsupported. Add a note that KV-cache generation is WIP to avoid confusion.

tests/models/test_modeling_deltaformer.py (1)

49-58: This generation test will always skip.

Because DeltaFormerConfig is in GENERATION_UNSUPPORTED, run_test_generation skips. Either remove this test for now or mark with a TODO/xfailed reason tied to cache support to avoid confusion.

fla/ops/deltaformer/invcum.py (2)

7-13: Confirm W semantics (unit-lower triangular A = I + tril(w)).

This relies on unitriangular=True and w carrying strictly-lower entries of A (diag implicitly 1). Please confirm upstream guarantees that the diagonal of w is ignored and its strictly-upper part never contains NaNs/garbage. Consider documenting this contract here.


29-38: Cast dw to w.dtype for consistency.

du follows do.dtype, while dw currently inherits it too; if w differs, grads may be mixed dtype downstream.

-    dw = torch.bmm(-du, x.mH)
-    dw = dw.tril(-1)
+    dw = torch.bmm(-du, x.mH).to(w.dtype)
+    dw = dw.tril(-1)
tests/ops/test_deltaformer.py (3)

11-23: Add a beta=None case.

Core API promises beta defaults to ones; please add a param set exercising beta=None to guard this pathway.

     [
         pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test))
         for test in [
             (2, 128, 2, 64, torch.float32),
             # Test with bfloat16
             (1, 256, 4, 64, torch.bfloat16),
             (2, 512, 4, 64, torch.bfloat16),
             (4, 1024, 4, 128, torch.bfloat16)
         ]
     ]
 )
+@pytest.mark.parametrize("beta_is_none", [False, True])
@@
-def test_delta_pre_attn(
+def test_delta_pre_attn(
     B: int,
     T: int,
     H: int,
     D: int,
-    dtype: torch.dtype
+    dtype: torch.dtype,
+    beta_is_none: bool
 ):
@@
-    beta = torch.randn((B, H, T), dtype=dtype, device=device).sigmoid().requires_grad_(True)
+    beta = None if beta_is_none else torch.randn((B, H, T), dtype=dtype, device=device).sigmoid().requires_grad_(True)

61-66: Tolerances look tight; keep CI-friendly guardrails.

If CI noise appears on different GPUs, consider warning=True or slightly higher ratios for bf16 paths only.


67-75: Varlen pack: verify memory layout explicitly.

Good approach. Add contiguous() on concatenation results to avoid stride surprises under Dynamo/AOTAutograd.

-        q_packed = torch.cat([q[i] for i in range(B)], dim=1).unsqueeze(0).detach().clone().requires_grad_(True)
+        q_packed = torch.cat([q[i] for i in range(B)], dim=1).contiguous().unsqueeze(0).detach().clone().requires_grad_(True)
-        k_packed = torch.cat([k[i] for i in range(B)], dim=1).unsqueeze(0).detach().clone().requires_grad_(True)
+        k_packed = torch.cat([k[i] for i in range(B)], dim=1).contiguous().unsqueeze(0).detach().clone().requires_grad_(True)
-        v_packed = torch.cat([v[i] for i in range(B)], dim=1).unsqueeze(0).detach().clone().requires_grad_(True)
+        v_packed = torch.cat([v[i] for i in range(B)], dim=1).contiguous().unsqueeze(0).detach().clone().requires_grad_(True)
-        beta_packed = torch.cat([beta[i] for i in range(B)], dim=1).unsqueeze(0).detach().clone().requires_grad_(True)
+        beta_packed = torch.cat([beta[i] for i in range(B)], dim=1).contiguous().unsqueeze(0).detach().clone().requires_grad_(True)
fla/ops/deltaformer/naive.py (2)

10-36: Numerical stability nit: prefer torch.where(mask, -inf, scores) without materializing ~mask.

Minor, but reduces one extra boolean tensor in tight loops; harmless to keep as-is.


76-87: Naive ref: good reference; add a quick early-exit for T==0.

-    u_list = []
+    if T == 0:
+        return vf.new_zeros(B, H, 0, D).to(orig_dtype)
+    u_list = []
fla/ops/deltaformer/fused_chunk.py (4)

971-1011: Do beta row scaling inside the kernel to save a write and keep w reusable.

Currently w = w * betai.unsqueeze(-1) allocates and writes once per chunk. Consider moving row scaling into flash_attn_kernel at the write to w time or emit a separate kernel that writes scaled w. Optional perf nit.

Also applies to: 1042-1070


127-160: Forward wrapper: add an assertion for C<=T in debug builds.

Even though callers clamp C, a defensive check helps.

     B, C, D = q.size()
     _B, T, _D = k.size()
@@
-    w = torch.empty(B, C, C, device=q.device, dtype=q.dtype)
+    assert C <= T, "C (chunk) must be <= T (context)."
+    w = torch.empty(B, C, C, device=q.device, dtype=q.dtype)

162-170: Autotune configs: consider adding 16-warp configs for large D on H100.

Optional; may help the reported throughput further on H100. Keep as a follow up if benchmarks justify.

Also applies to: 329-336, 442-449, 559-566


604-604: Ruff hints: silence unused parameters to keep CI green.

Prefix with _ or use them; applies to several kernel args and unpacked temps.

Examples:

  • Line 25: _T
  • Line 198/364/477/604: _B
  • Line 842: _T_max

Also applies to: 477-477, 364-364, 198-198, 25-25, 842-842

fla/models/deltaformer/configuration_deltaformer.py (2)

12-14: Annotate class attributes with ClassVar to satisfy Ruff and intent.

model_type and keys_to_ignore_at_inference are class-level constants. Annotate with ClassVar[...].

Apply this diff in place:

-from transformers.configuration_utils import PretrainedConfig
+from transformers.configuration_utils import PretrainedConfig
+from typing import ClassVar
@@
-class DeltaFormerConfig(PretrainedConfig):
-    model_type = 'deltaformer'
-    keys_to_ignore_at_inference = ['past_key_values']
+class DeltaFormerConfig(PretrainedConfig):
+    model_type: ClassVar[str] = 'deltaformer'
+    keys_to_ignore_at_inference: ClassVar[list[str]] = ['past_key_values']

78-87: Add stacklevel to warnings and fix grammar.

Use stacklevel=2 so the warning points callers; fix “can improves” → “can improve”.

-        if fuse_linear_cross_entropy:
-            warnings.warn(
-                "`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency "
-                "at the potential cost of reduced precision. "
-                "If you observe issues like loss divergence, consider disabling this setting."
-            )
+        if fuse_linear_cross_entropy:
+            warnings.warn(
+                "`fuse_linear_cross_entropy` is enabled, which can improve memory efficiency "
+                "at the potential cost of reduced precision. "
+                "If you observe issues like loss divergence, consider disabling this setting.",
+                stacklevel=2,
+            )
fla/models/deltaformer/modeling_deltaformer.py (7)

71-91: Fix return type annotation of DeltaFormerBlock.forward.

Function returns (hidden_states, attentions, past_key_values) but the annotation declares only two items.

-    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+    ) -> Tuple[
+        torch.FloatTensor,
+        Optional[torch.FloatTensor],
+        Optional[Union[Cache, List[torch.FloatTensor]]],
+    ]:

96-101: Annotate class-level mutables with ClassVar.

Satisfy Ruff RUF012 for _no_split_modules and _supports_cache_class.

-from transformers.utils.deprecation import deprecate_kwarg
+from transformers.utils.deprecation import deprecate_kwarg
+from typing import ClassVar
@@
-    _no_split_modules = ['DeltaFormerBlock']
-    _supports_cache_class = True
+    _no_split_modules: ClassVar[list[str]] = ['DeltaFormerBlock']
+    _supports_cache_class: ClassVar[bool] = True

105-120: Unused params in _init_weights.

prenorm_residual_strategy and num_residuals_per_layer aren’t used.

-    def _init_weights(
-        self,
-        module: nn.Module,
-        prenorm_residual_strategy: Optional[str] = None,
-        num_residuals_per_layer: int = 2,
-    ):
+    def _init_weights(self, module: nn.Module, *_: object) -> None:

142-161: Add stacklevel to warning; drop unused noqa.

Set stacklevel=2 for actionable warnings. Remove the stray # noqa.

-        attention_mask: Optional[torch.Tensor] = None,  # noqa
+        attention_mask: Optional[torch.Tensor] = None,
@@
-            warnings.warn(
+            warnings.warn(
                 "`DeltaFormerModel` does not support output attention weights now, "
                 "so `output_attentions` is set to `False`."
-            )
+            , stacklevel=2)

291-304: Default logits_to_keep to 1 to avoid slicing corner case.

With 0, x[:, -0:] == x[:, 0:] (full sequence). Use 1 for typical LM training/inference.

-        logits_to_keep: Optional[int] = 0,
+        logits_to_keep: Optional[int] = 1,

343-351: Prefer starred unpack over tuple concatenation.

Cleaner and satisfies Ruff RUF005.

-            output = (logits,) + outputs[1:]
-            return ((loss,) + output) if loss is not None else output
+            output = (logits, *outputs[1:])
+            return ((loss, *output) if loss is not None else output)

212-213: Annotate _tied_weights_keys as ClassVar.

-    _tied_weights_keys = ["lm_head.weight"]
+    _tied_weights_keys: ClassVar[list[str]] = ["lm_head.weight"]
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3a340ea and dfb54ec.

📒 Files selected for processing (15)
  • README.md (2 hunks)
  • fla/__init__.py (4 hunks)
  • fla/layers/__init__.py (2 hunks)
  • fla/layers/deltaformer.py (1 hunks)
  • fla/models/__init__.py (2 hunks)
  • fla/models/deltaformer/__init__.py (1 hunks)
  • fla/models/deltaformer/configuration_deltaformer.py (1 hunks)
  • fla/models/deltaformer/modeling_deltaformer.py (1 hunks)
  • fla/ops/deltaformer/__init__.py (1 hunks)
  • fla/ops/deltaformer/fused_chunk.py (1 hunks)
  • fla/ops/deltaformer/invcum.py (1 hunks)
  • fla/ops/deltaformer/naive.py (1 hunks)
  • tests/models/test_modeling_deltaformer.py (1 hunks)
  • tests/models/test_modeling_utils.py (1 hunks)
  • tests/ops/test_deltaformer.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (11)
fla/models/__init__.py (2)
fla/models/deltaformer/configuration_deltaformer.py (1)
  • DeltaFormerConfig (11-107)
fla/models/deltaformer/modeling_deltaformer.py (2)
  • DeltaFormerForCausalLM (210-351)
  • DeltaFormerModel (121-207)
fla/layers/__init__.py (1)
fla/layers/deltaformer.py (1)
  • DeltaFormerAttention (29-214)
tests/models/test_modeling_deltaformer.py (2)
fla/models/deltaformer/configuration_deltaformer.py (1)
  • DeltaFormerConfig (11-107)
tests/models/test_modeling_base.py (2)
  • run_test_generation (67-126)
  • run_test_model_forward_backward (27-61)
tests/ops/test_deltaformer.py (3)
fla/ops/deltaformer/fused_chunk.py (2)
  • delta_pre_attn (1073-1099)
  • backward (835-954)
fla/ops/deltaformer/naive.py (1)
  • delta_pre_attn_naive (39-87)
fla/utils.py (1)
  • assert_close (77-88)
fla/ops/deltaformer/__init__.py (2)
fla/ops/deltaformer/fused_chunk.py (1)
  • delta_pre_attn (1073-1099)
fla/ops/deltaformer/naive.py (1)
  • delta_pre_attn_naive (39-87)
fla/ops/deltaformer/invcum.py (1)
fla/ops/deltaformer/fused_chunk.py (3)
  • forward (16-31)
  • forward (802-832)
  • backward (835-954)
fla/layers/deltaformer.py (5)
fla/layers/utils.py (3)
  • get_unpad_data (75-98)
  • pad_input (176-197)
  • unpad_input (101-173)
fla/modules/layernorm.py (1)
  • RMSNorm (1060-1107)
fla/modules/rotary.py (1)
  • RotaryEmbedding (306-500)
fla/ops/deltaformer/fused_chunk.py (3)
  • delta_pre_attn (1073-1099)
  • forward (16-31)
  • forward (802-832)
fla/ops/utils/index.py (1)
  • prepare_lens_from_mask (43-44)
fla/__init__.py (2)
fla/layers/deltaformer.py (1)
  • DeltaFormerAttention (29-214)
fla/models/deltaformer/modeling_deltaformer.py (2)
  • DeltaFormerForCausalLM (210-351)
  • DeltaFormerModel (121-207)
fla/models/deltaformer/__init__.py (2)
fla/models/deltaformer/configuration_deltaformer.py (1)
  • DeltaFormerConfig (11-107)
fla/models/deltaformer/modeling_deltaformer.py (2)
  • DeltaFormerForCausalLM (210-351)
  • DeltaFormerModel (121-207)
fla/ops/deltaformer/fused_chunk.py (2)
fla/layers/deltaformer.py (1)
  • forward (105-214)
fla/ops/deltaformer/invcum.py (4)
  • forward (7-13)
  • backward (29-38)
  • backward_x (20-26)
  • forward_inplace (16-17)
fla/models/deltaformer/modeling_deltaformer.py (8)
fla/utils.py (1)
  • deprecate_kwarg (190-334)
fla/layers/deltaformer.py (2)
  • DeltaFormerAttention (29-214)
  • forward (105-214)
fla/models/deltaformer/configuration_deltaformer.py (1)
  • DeltaFormerConfig (11-107)
fla/modules/fused_cross_entropy.py (1)
  • FusedCrossEntropyLoss (344-419)
fla/modules/fused_linear_cross_entropy.py (1)
  • FusedLinearCrossEntropyLoss (493-567)
fla/modules/mlp.py (1)
  • GatedMLP (26-69)
fla/modules/layernorm.py (1)
  • RMSNorm (1060-1107)
fla/models/modeling_layers.py (1)
  • GradientCheckpointingLayer (11-71)
🪛 Ruff (0.13.1)
fla/models/deltaformer/configuration_deltaformer.py

13-13: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


79-81: Avoid specifying long messages outside the exception class

(TRY003)


83-83: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


91-91: Avoid specifying long messages outside the exception class

(TRY003)


93-93: Avoid specifying long messages outside the exception class

(TRY003)


95-95: Avoid specifying long messages outside the exception class

(TRY003)

fla/layers/deltaformer.py

91-91: Avoid specifying long messages outside the exception class

(TRY003)


110-110: Unused method argument: output_attentions

(ARG002)


111-111: Unused method argument: use_cache

(ARG002)

fla/ops/deltaformer/fused_chunk.py

25-25: Unpacked variable T is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


198-198: Unused function argument: B

(ARG001)


364-364: Unused function argument: B

(ARG001)


477-477: Unused function argument: B

(ARG001)


604-604: Unused function argument: B

(ARG001)


842-842: Unpacked variable T_max is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

fla/models/deltaformer/modeling_deltaformer.py

99-99: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


108-108: Unused method argument: prenorm_residual_strategy

(ARG002)


109-109: Unused method argument: num_residuals_per_layer

(ARG002)


145-145: Unused blanket noqa directive

Remove unused noqa directive

(RUF100)


155-155: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


166-166: Avoid specifying long messages outside the exception class

(TRY003)


168-168: Avoid specifying long messages outside the exception class

(TRY003)


212-212: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


248-254: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


248-254: Avoid specifying long messages outside the exception class

(TRY003)


256-256: Use raise without specifying exception name

Remove exception name

(TRY201)


267-267: Unused method argument: kwargs

(ARG002)


343-343: Consider (logits, *outputs[1:]) instead of concatenation

Replace with (logits, *outputs[1:])

(RUF005)


344-344: Consider (loss, *output) instead of concatenation

Replace with (loss, *output)

(RUF005)

🪛 markdownlint-cli2 (0.18.1)
README.md

32-32: Unordered list style
Expected: asterisk; Actual: dash

(MD004, ul-style)

🔇 Additional comments (17)
fla/layers/__init__.py (1)

10-10: Public export looks good.

DeltaFormerAttention is imported and added to __all__.

Also applies to: 64-65

fla/__init__.py (1)

9-10: API surface and version bump OK.

Exports for DeltaFormer and bump to 0.3.2 look consistent with added modules.

If you publish to PyPI, ensure changelog/release notes reflect the KV-cache limitation.

Also applies to: 38-40, 89-89, 111-111

fla/ops/deltaformer/__init__.py (1)

3-9: Clean public ops surface.

Exporting fused and naive pre-attn helpers is clear and sufficient.

fla/models/__init__.py (1)

7-7: Models exports wired correctly.

DeltaFormerConfig/Model/ForCausalLM are now available at package level.

Also applies to: 41-41

tests/models/test_modeling_deltaformer.py (1)

15-24: Modeling params LGTM; covers key dtypes and D.

Good coverage across use_l2warp and D∈{64,128}; Hopper gating is handled in the base.

Also applies to: 25-35

fla/layers/deltaformer.py (3)

21-25: Guard both flash-attn symbols on ImportError.

flash_attn_varlen_func is undefined if the import fails. Mirror the sentinel.

Apply:

 try:
     from flash_attn import flash_attn_func, flash_attn_varlen_func
 except ImportError:
-    flash_attn_func = None
+    flash_attn_func = None
+    flash_attn_varlen_func = None

115-121: Don’t use assert for runtime validation.

Asserts can be stripped; raise an explicit exception instead.

Apply:

-        if attention_mask is not None:
-            assert len(attention_mask.shape) == 2, (
-                "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
-                "for padding purposes (0 indicating padding). "
-                "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
-            )
+        if attention_mask is not None and attention_mask.dim() != 2:
+            raise ValueError(
+                "attention_mask must be [batch_size, seq_len] with 1=valid and 0=pad; pairwise masks are unsupported."
+            )

78-84: Validate head divisibility and GQA grouping.

Missing checks can yield silent shape bugs when hidden_size % num_heads != 0 or num_heads % num_kv_heads != 0.

Apply:

         self.hidden_size = hidden_size
         self.num_heads = num_heads
         self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
-        self.num_kv_groups = num_heads // self.num_kv_heads
-        self.head_dim = self.hidden_size // self.num_heads
+        if self.hidden_size % self.num_heads != 0:
+            raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads}).")
+        if self.num_heads % self.num_kv_heads != 0:
+            raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads}).")
+        self.num_kv_groups = num_heads // self.num_kv_heads
+        self.head_dim = self.hidden_size // self.num_heads
         self.kv_dim = self.num_kv_heads * self.head_dim
fla/ops/deltaformer/invcum.py (1)

16-18: Name these helpers for what they actually do (clarify math and call sites).

The current names are too generic and were previously flagged. Suggest more explicit names to reduce cognitive load.

-def forward(u, w):
+def solve_unit_lower(u, w):
@@
-def forward_inplace(u, w):
-    u.copy_(forward(u, w))
+def solve_unit_lower_inplace(u, w):
+    u.copy_(solve_unit_lower(u, w))
@@
-def backward_x(do, w):
+def solve_unit_upper_T_x(do, w):
@@
-def backward(do, w, x):
+def backward_wrt_inputs(do, w, x):

Follow up: update call sites (invcum.forward_inplacesolve_unit_lower_inplace, etc.).

Also applies to: 20-26, 29-38

fla/models/deltaformer/__init__.py (2)

8-10: Registry calls: confirm Transformers version compatibility.

APIs and exist_ok behavior vary across versions. Please ensure our minimum transformers version supports these exact signatures.


12-12: LGTM – clean public surface.

fla/ops/deltaformer/fused_chunk.py (2)

885-891: Rename qi/ki to match actual tensors (k-chunk vs q-tail).

Names are inverted and hard to follow. Adopt descriptive names; same pattern in fixed-length backward below.

-                        qi = k_seq[:, i0:i1, :]
-                        ki = q_seq[:, i1:L, :]
+                        k_chunk = k_seq[:, i0:i1, :]
+                        q_tail = q_seq[:, i1:L, :]
@@
-                        du_tail = backward_u_chunk(qi, ki, lse_tail, gv_seq[:, i1:L, :], fa_scale, beta_tail)
+                        du_tail = backward_u_chunk(k_chunk, q_tail, lse_tail, gv_seq[:, i1:L, :], fa_scale, beta_tail)
-                    qi = ko[b, :, i:i + Ci, :]
-                    ki = qo[b, :, i + Ci:, :]
+                    k_chunk = ko[b, :, i:i + Ci, :]
+                    q_tail = qo[b, :, i + Ci:, :]
@@
-                    du = backward_u_chunk(qi, ki, lse, grad_v_seq[:, i + Ci:, :], fa_scale, beta_single)
+                    du = backward_u_chunk(k_chunk, q_tail, lse, grad_v_seq[:, i + Ci:, :], fa_scale, beta_single)

Also applies to: 924-933


1073-1105: Public API: good shape checking; maintains inference fast path.

Please run the provided throughput benchmark once more after these minor renames/guards to ensure no regression.

fla/models/deltaformer/modeling_deltaformer.py (3)

82-88: Confirm fused RMSNorm return signature to avoid unpack errors.

self.mlp_norm(...) is unpacked into two tensors. Ensure the fused RMSNorm returns (normed, residual) when prenorm=True; otherwise this will throw.

If it returns a single tensor, gate by type/arity or switch to explicit residual add:

-        if self.config.fuse_norm:
-            hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
+        if self.config.fuse_norm:
+            out = self.mlp_norm(hidden_states, residual, True)
+            hidden_states, residual = (out if isinstance(out, tuple) else (out, out))

210-257: Preserve exception context and avoid broad rethrow.

Chain with from exception, and use bare raise in the fallback. This mirrors prior feedback.

         except AttributeError as exception:
             if 'past_key_values' in str(exception):
-                raise AttributeError(
+                raise AttributeError(
                     f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
                     f"which is not supported for {self.__class__.__name__}. "
                     f"Try another generation strategy instead. "
                     f"For the available generation strategies, check this doc: "
                     f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
-                )
+                ) from exception
             else:
-                raise exception
+                raise

325-333: Initialize criterion once in init (perf/stability).

Avoid per‑call instantiation; move selection to __init__ and reuse.

-        if labels is not None:
-            if getattr(self, 'criterion', None) is None:
-                if self.config.fuse_linear_cross_entropy:
-                    criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp)
-                elif self.config.fuse_cross_entropy:
-                    criterion = FusedCrossEntropyLoss(inplace_backward=True)
-                else:
-                    criterion = nn.CrossEntropyLoss()
-            else:
-                criterion = self.criterion
+        if labels is not None:
+            criterion = self.criterion

Additional change outside this hunk (constructor):

 class DeltaFormerForCausalLM(DeltaFormerPreTrainedModel, GenerationMixin):
@@
-        self.criterion = None
+        if config.fuse_linear_cross_entropy:
+            self.criterion = FusedLinearCrossEntropyLoss(use_l2warp=config.use_l2warp)
+        elif config.fuse_cross_entropy:
+            self.criterion = FusedCrossEntropyLoss(inplace_backward=True)
+        else:
+            self.criterion = nn.CrossEntropyLoss()
fla/models/deltaformer/configuration_deltaformer.py (1)

92-100: Attn schema validated — keep the checks.

config.attn['layers'] and config.attn['num_heads'] are actually consumed by many model implementations and passed into Attention constructors (e.g. fla/models/deltaformer/modeling_deltaformer.py, fla/models/rodimus/modeling_rodimus.py, fla/models/samba/modeling_samba.py), so the validation in fla/models/deltaformer/configuration_deltaformer.py is appropriate.

Comment on lines 90 to 91
if flash_attn_func is None:
raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Require varlen FlashAttention too (avoid NameError later).

Init only checks flash_attn_func; varlen paths call flash_attn_varlen_func unconditionally.

Apply:

-        if flash_attn_func is None:
-            raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
+        if flash_attn_func is None or flash_attn_varlen_func is None:
+            raise ImportError(
+                "DeltaFormer requires FlashAttention (both flash_attn_func and flash_attn_varlen_func). "
+                "Install via: pip install flash-attn --no-build-isolation"
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if flash_attn_func is None:
raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
if flash_attn_func is None or flash_attn_varlen_func is None:
raise ImportError(
"DeltaFormer requires FlashAttention (both flash_attn_func and flash_attn_varlen_func). "
"Install via: pip install flash-attn --no-build-isolation"
)
🧰 Tools
🪛 Ruff (0.13.1)

91-91: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In fla/layers/deltaformer.py around lines 90 to 91, the init only checks
flash_attn_func but the code later unconditionally calls flash_attn_varlen_func
causing a NameError when the varlen implementation is not present; update the
import/initialization check to validate that both flash_attn_func and
flash_attn_varlen_func are available (or set flash_attn_varlen_func to a
fallback) and raise a clear ImportError if either required function is missing,
so downstream varlen code does not reference an undefined name.

Comment on lines +105 to +114
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
attentions = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

KV cache not implemented; forward signature is misleading.

past_key_values/use_cache are accepted and returned unchanged. This breaks efficient generation.

Apply a minimal guard now; implement caching next:

     def forward(
         self,
         hidden_states: torch.Tensor,
         attention_mask: Optional[torch.LongTensor] = None,
         past_key_values: Optional[Cache] = None,
         output_attentions: bool = False,
         use_cache: bool = False,
         **kwargs,
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        if use_cache:
+            raise NotImplementedError("DeltaFormerAttention KV cache is not implemented yet.")
@@
-        return o, attentions, past_key_values
+        return o, attentions, past_key_values

Also consider removing output_attentions until supported.

Also applies to: 214-214

🧰 Tools
🪛 Ruff (0.13.1)

110-110: Unused method argument: output_attentions

(ARG002)


111-111: Unused method argument: use_cache

(ARG002)

🤖 Prompt for AI Agents
In fla/layers/deltaformer.py around lines 105-114 (and similarly at line 214),
the forward signature accepts past_key_values/use_cache/output_attentions but
does not implement KV caching; add a minimal guard: if use_cache is True or
past_key_values is not None, raise NotImplementedError (or ValueError) with a
clear message that KV cache is not yet supported and will be implemented later;
likewise, if output_attentions is True, raise NotImplementedError or ignore it
explicitly with a clear comment. This keeps the API honest now and prevents
silent breaks during generation; implement the same guard at the other affected
location (line ~214).

Comment on lines +138 to +141
if attention_mask is not None:
seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
max_seqlen = q_len + max(seqlen_offset)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: using Python max() on a tensor.

max(seqlen_offset) on a torch.Tensor will error; also ensure scalar for cache sizing.

Apply:

-            if attention_mask is not None:
-                seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
-                max_seqlen = q_len + max(seqlen_offset)
+            if attention_mask is not None:
+                seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
+                max_seqlen = q_len + int(seqlen_offset.max().item())
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if attention_mask is not None:
seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
max_seqlen = q_len + max(seqlen_offset)
if attention_mask is not None:
seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
max_seqlen = q_len + int(seqlen_offset.max().item())
🤖 Prompt for AI Agents
In fla/layers/deltaformer.py around lines 138 to 141, the code calls Python's
built-in max() on seqlen_offset (a torch.Tensor) which will raise an error and
may not produce a Python scalar for cache sizing; replace the built-in max with
a tensor-max and convert to a Python int, e.g. compute seqlen_offset =
seqlen_offset + prepare_lens_from_mask(attention_mask) -
attention_mask.shape[-1] as before, then set max_seqlen = q_len +
int(seqlen_offset.max().cpu().item()) (or use torch.max(seqlen_offset).item())
to ensure max_seqlen is a plain int for subsequent cache sizing.

Comment on lines +89 to +91
if attn is not None:
if not isinstance(attn, Dict):
raise ValueError("attn must be a dictionary")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: isinstance check against typing.Dict will raise or misbehave.

Use Mapping/dict instead of typing.Dict in isinstance. This can otherwise raise TypeError.

-from typing import Dict, Optional
+from typing import Dict, Optional
+from collections.abc import Mapping
@@
-            if not isinstance(attn, Dict):
+            if not isinstance(attn, Mapping):
                 raise ValueError("attn must be a dictionary")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if attn is not None:
if not isinstance(attn, Dict):
raise ValueError("attn must be a dictionary")
from typing import Dict, Optional
from collections.abc import Mapping
if attn is not None:
if not isinstance(attn, Mapping):
raise ValueError("attn must be a dictionary")
🧰 Tools
🪛 Ruff (0.13.1)

91-91: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In fla/models/deltaformer/configuration_deltaformer.py around lines 89 to 91,
the isinstance check uses typing.Dict which can raise TypeError; change the
check to use a concrete runtime check like isinstance(attn, Mapping) (from
collections.abc) or isinstance(attn, dict) and add the necessary import (from
collections.abc import Mapping) if using Mapping, so the validation works at
runtime without errors.

Comment on lines +258 to +289
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: bool = True,
logits_to_keep: Optional[int] = None,
**kwargs
):
# only last token for `inputs_ids` if the `past_key_values` is not empty.
if past_key_values is not None and len(past_key_values) > 0:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and len(past_key_values) == 0:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard.
model_inputs = {'input_ids': input_ids.contiguous()}

if logits_to_keep is not None:
model_inputs['logits_to_keep'] = logits_to_keep

model_inputs.update({
'past_key_values': past_key_values,
'use_cache': use_cache,
'attention_mask': attention_mask,
})
return model_inputs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix None-handling bug in prepare_inputs_for_generation.

len(past_key_values) is called when it can be None, causing a TypeError.

-        if inputs_embeds is not None and len(past_key_values) == 0:
+        if inputs_embeds is not None and (past_key_values is None or len(past_key_values) == 0):
             model_inputs = {'inputs_embeds': inputs_embeds}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: bool = True,
logits_to_keep: Optional[int] = None,
**kwargs
):
# only last token for `inputs_ids` if the `past_key_values` is not empty.
if past_key_values is not None and len(past_key_values) > 0:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and len(past_key_values) == 0:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard.
model_inputs = {'input_ids': input_ids.contiguous()}
if logits_to_keep is not None:
model_inputs['logits_to_keep'] = logits_to_keep
model_inputs.update({
'past_key_values': past_key_values,
'use_cache': use_cache,
'attention_mask': attention_mask,
})
return model_inputs
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: bool = True,
logits_to_keep: Optional[int] = None,
**kwargs
):
# only last token for `inputs_ids` if the `past_key_values` is not empty.
if past_key_values is not None and len(past_key_values) > 0:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and (past_key_values is None or len(past_key_values) == 0):
model_inputs = {'inputs_embeds': inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard.
model_inputs = {'input_ids': input_ids.contiguous()}
if logits_to_keep is not None:
model_inputs['logits_to_keep'] = logits_to_keep
model_inputs.update({
'past_key_values': past_key_values,
'use_cache': use_cache,
'attention_mask': attention_mask,
})
return model_inputs
🧰 Tools
🪛 Ruff (0.13.1)

267-267: Unused method argument: kwargs

(ARG002)

🤖 Prompt for AI Agents
In fla/models/deltaformer/modeling_deltaformer.py around lines 258 to 289, the
code calls len(past_key_values) when past_key_values can be None which raises a
TypeError; change the inputs_embeds check to verify past_key_values is None or
empty (e.g., replace "len(past_key_values) == 0" with "past_key_values is None
or len(past_key_values) == 0") and also guard input_ids.contiguous() by ensuring
input_ids is not None before calling contiguous() so you don't call methods on
None.

Comment on lines +319 to +323
hidden_states = outputs[0]
# For fused linear cross-entropy we do not materialize logits for the full sequence
logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])

loss = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Don’t drop logits in eval when fuse_linear_cross_entropy is enabled.

Currently logits is None whenever fuse_linear_cross_entropy is True, even if labels is None. Compute logits for eval/inference.

-        # For fused linear cross-entropy we do not materialize logits for the full sequence
-        logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
+        # Only skip logits materialization during training with fused linear CE
+        materialize_logits = (labels is None) or (not self.config.fuse_linear_cross_entropy)
+        logits = self.lm_head(hidden_states[:, -logits_to_keep:]) if materialize_logits else None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
hidden_states = outputs[0]
# For fused linear cross-entropy we do not materialize logits for the full sequence
logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
loss = None
hidden_states = outputs[0]
# Only skip logits materialization during training with fused linear CE
materialize_logits = (labels is None) or (not self.config.fuse_linear_cross_entropy)
logits = self.lm_head(hidden_states[:, -logits_to_keep:]) if materialize_logits else None
loss = None
🤖 Prompt for AI Agents
In fla/models/deltaformer/modeling_deltaformer.py around lines 319-323, the code
sets logits to None whenever config.fuse_linear_cross_entropy is True which
incorrectly drops logits during eval/inference; change the logic so that when
fuse_linear_cross_entropy is True you still compute logits if labels is None
(i.e., in eval/inference) but keep logits None during training when labels are
present so the fused loss can run; implement this by branching on both
self.config.fuse_linear_cross_entropy and whether labels is None and compute
logits = self.lm_head(hidden_states[:, -logits_to_keep:]) for the eval path.

Comment on lines 24 to 28
GENERATION_UNSUPPORTED = [
"ABCConfig", "LinearAttentionConfig", "LightNetConfig",
"Mamba2Config", "MambaConfig", "NSAConfig", "SambaConfig", "RWKV6Config", "RWKV7Config",
"DeltaFormerConfig",
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

DeltaFormer added to GENERATION_UNSUPPORTED – track as temporary.

This aligns with missing KV cache, but please add a TODO and issue to re-enable once caching lands.

I can open a follow-up ticket and patch tests to auto-enable when DeltaFormerAttention implements cache.

🤖 Prompt for AI Agents
In tests/models/test_modeling_utils.py around lines 24 to 28, DeltaFormerConfig
was added to GENERATION_UNSUPPORTED due to missing KV cache but lacks a TODO and
tracking issue; add a concise TODO comment immediately next to the
DeltaFormerConfig entry stating this is temporary due to missing KV cache,
create a follow-up issue (or use the provided issue number) to re-enable
generation support once DeltaFormerAttention implements KV caching, and include
the issue number or URL in the TODO so the test can be easily auto-enabled
later.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
fla/layers/deltaformer.py (1)

157-181: Consider extracting common deltaformer_attn call to reduce duplication.

All three branches call deltaformer_attn with identical arguments except for cu_seqlens. This duplication can be simplified.

Apply this diff:

-if attention_mask is not None:
-    u = deltaformer_attn(
-        # kk similarity
-        rearrange(k_full, 'b t h d -> b h t d'),
-        rearrange(k_full, 'b t h d -> b h t d'),
-        rearrange(v_full, 'b t h d -> b h t d'),
-        beta_full,
-        cu_seqlens=cu_seqlens,
-    )
-elif cu_seqlens is not None:
-    u = deltaformer_attn(
-        rearrange(k_full, 'b t h d -> b h t d'),
-        rearrange(k_full, 'b t h d -> b h t d'),
-        rearrange(v_full, 'b t h d -> b h t d'),
-        beta_full,
-        cu_seqlens=cu_seqlens,
-    )
-else:
-    u = deltaformer_attn(
-        rearrange(k_full, 'b t h d -> b h t d'),
-        rearrange(k_full, 'b t h d -> b h t d'),
-        rearrange(v_full, 'b t h d -> b h t d'),
-        beta_full,
-    )
+# Use KK similarity for pre-attention
+u = deltaformer_attn(
+    rearrange(k_full, 'b t h d -> b h t d'),
+    rearrange(k_full, 'b t h d -> b h t d'),
+    rearrange(v_full, 'b t h d -> b h t d'),
+    beta_full,
+    cu_seqlens=cu_seqlens,  # None when not using varlen
+)
fla/ops/deltaformer/naive.py (1)

59-70: Replace asserts with explicit exceptions for runtime validation.

Runtime asserts can be stripped with Python optimization and shouldn't be used for input validation.

Apply this diff:

-assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "q,k,v must be [B,H,T,D]"
+if q.dim() != 4 or k.dim() != 4 or v.dim() != 4:
+    raise ValueError("q, k, v must have 4 dimensions [B, H, T, D]")
 B, H, T, D = q.shape
-assert k.shape == (B, H, T, D) and v.shape == (B, H, T, D)
+if k.shape != (B, H, T, D) or v.shape != (B, H, T, D):
+    raise ValueError(f"Shape mismatch: q={q.shape}, k={k.shape}, v={v.shape}. All must be [{B}, {H}, {T}, {D}]")
 orig_dtype = q.dtype
 qf = q.float()
 kf = k.float()
 vf = v.float()
 if beta is None:
     betaf = torch.ones((B, H, T), dtype=torch.float32, device=q.device)
 else:
-    assert beta.shape == (B, H, T)
+    if beta.shape != (B, H, T):
+        raise ValueError(f"beta shape {beta.shape} does not match expected [{B}, {H}, {T}]")
     betaf = beta.float()
fla/ops/deltaformer/parallel.py (1)

162-169: Consider expanding autotune search space for better performance.

The autotune configuration uses a limited search space. Per PR comments, autotune was added to address performance issues with default block sizes being too large.

Consider adding more block size options to the search space:

 def _config_delta_flash_attn():
     return [
         triton.Config({'BLOCK_C': BC, 'BLOCK_T': BT}, num_stages=ns, num_warps=nw)
-        for BC in [128, 64]
-        for BT in [64, 32]
+        for BC in [256, 128, 64, 32]
+        for BT in [128, 64, 32, 16]
         for ns in [3, 2]
         for nw in [8, 4]
     ]
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dfb54ec and 0a23799.

📒 Files selected for processing (5)
  • fla/layers/deltaformer.py (1 hunks)
  • fla/ops/deltaformer/__init__.py (1 hunks)
  • fla/ops/deltaformer/naive.py (1 hunks)
  • fla/ops/deltaformer/parallel.py (1 hunks)
  • tests/ops/test_deltaformer.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/ops/test_deltaformer.py
🧰 Additional context used
🧬 Code graph analysis (3)
fla/layers/deltaformer.py (4)
fla/layers/utils.py (3)
  • get_unpad_data (75-98)
  • pad_input (176-197)
  • unpad_input (101-173)
fla/modules/layernorm.py (1)
  • RMSNorm (1060-1107)
fla/ops/deltaformer/parallel.py (3)
  • deltaformer_attn (1073-1099)
  • forward (16-31)
  • forward (802-832)
fla/ops/utils/index.py (1)
  • prepare_lens_from_mask (43-44)
fla/ops/deltaformer/parallel.py (2)
fla/layers/deltaformer.py (1)
  • forward (105-214)
fla/ops/deltaformer/invcum.py (2)
  • backward_x (20-26)
  • forward_inplace (16-17)
fla/ops/deltaformer/__init__.py (2)
fla/ops/deltaformer/naive.py (1)
  • naive_deltaformer_attn (39-87)
fla/ops/deltaformer/parallel.py (1)
  • deltaformer_attn (1073-1099)
🪛 Ruff (0.13.1)
fla/layers/deltaformer.py

91-91: Avoid specifying long messages outside the exception class

(TRY003)


110-110: Unused method argument: output_attentions

(ARG002)


111-111: Unused method argument: use_cache

(ARG002)

fla/ops/deltaformer/parallel.py

25-25: Unpacked variable T is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


198-198: Unused function argument: B

(ARG001)


364-364: Unused function argument: B

(ARG001)


477-477: Unused function argument: B

(ARG001)


604-604: Unused function argument: B

(ARG001)


842-842: Unpacked variable T_max is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🔇 Additional comments (12)
fla/layers/deltaformer.py (7)

115-120: Replace assert with explicit ValueError.

Using assert for runtime validation can be stripped by Python optimization (-O flag).

Apply this diff:

 if attention_mask is not None:
-    assert len(attention_mask.shape) == 2, (
-        "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
-        "for padding purposes (0 indicating padding). "
-        "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
-    )
+    if attention_mask.dim() != 2:
+        raise ValueError(
+            "attention_mask must be [batch, seq_len] with 1=valid and 0=pad; 2D pairwise masks are unsupported."
+        )

21-25: Set both to None to prevent NameError.

If the import fails, only flash_attn_func is set to None, leaving flash_attn_varlen_func undefined. This will cause a NameError when checking it later.

Apply this diff:

 try:
     from flash_attn import flash_attn_func, flash_attn_varlen_func
 except ImportError:
     flash_attn_func = None
+    flash_attn_varlen_func = None

78-84: Add validation for head dimension divisibility.

The code assumes hidden_size is divisible by num_heads and num_heads by num_kv_heads, but doesn't validate these assumptions. This can lead to silent errors.

Apply this diff:

 self.hidden_size = hidden_size
 self.num_heads = num_heads
 self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
+if self.hidden_size % self.num_heads != 0:
+    raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})")
+if self.num_heads % self.num_kv_heads != 0:
+    raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})")
 self.num_kv_groups = num_heads // self.num_kv_heads
 self.head_dim = self.hidden_size // self.num_heads
 self.kv_dim = self.num_kv_heads * self.head_dim

90-91: Check both flash_attn imports to avoid NameError later.

The code only checks flash_attn_func but calls flash_attn_varlen_func unconditionally in the forward pass, which will cause a NameError if varlen import failed.

Apply this diff:

-if flash_attn_func is None:
-    raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
+if flash_attn_func is None or flash_attn_varlen_func is None:
+    raise ImportError(
+        "DeltaFormer requires FlashAttention (both flash_attn_func and flash_attn_varlen_func). "
+        "Install via: pip install flash-attn --no-build-isolation"
+    )

105-114: KV cache not implemented - fail explicitly.

The method accepts past_key_values and use_cache but doesn't implement caching, which will silently break generation.

Apply this diff:

 def forward(
     self,
     hidden_states: torch.Tensor,
     attention_mask: Optional[torch.LongTensor] = None,
     past_key_values: Optional[Cache] = None,
     output_attentions: bool = False,
     use_cache: bool = False,
     **kwargs,
 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+    if use_cache:
+        raise NotImplementedError("DeltaFormerAttention KV cache is not implemented yet.")
+    if output_attentions:
+        raise NotImplementedError("DeltaFormer does not support outputting attention weights")
     attentions = None

138-141: Fix: Using Python max() on tensor will fail.

Python's built-in max() on a PyTorch tensor will raise an error. Use tensor's .max() method and convert to Python int.

Apply this diff:

 if attention_mask is not None:
     seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
-    max_seqlen = q_len + max(seqlen_offset)
+    max_seqlen = q_len + int(seqlen_offset.max().item())

145-156: ROPE cu_seqlens mismatch for attention_mask path.

ROPE receives cu_seqlens_kw but should use the mask-derived cu_seqlens when attention_mask is present for correct position indexing.

Apply this diff to compute cu_seqlens first:

-q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens_kw)
 if attention_mask is not None:
     _, cu_seqlens, _ = get_unpad_data(attention_mask)
+    q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
     # q_full = q
     k_full = k
     v_full = v
     beta_full = beta
 else:
     cu_seqlens = cu_seqlens_kw
+    q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens_kw)
     # q_full, k_full, v_full, beta_full = q, k, v, beta
     k_full, v_full, beta_full = k, v, beta
fla/ops/deltaformer/__init__.py (1)

1-9: LGTM! Clean module exports.

The module properly exports both the optimized deltaformer_attn and the reference naive_deltaformer_attn implementation with clear all definition.

fla/ops/deltaformer/naive.py (2)

10-36: LGTM! Well-implemented causal softmax utility.

The tril_softmax function correctly implements row-wise causal softmax with proper numerical stability through max subtraction and handles the strict vs non-strict causal masking appropriately.


76-87: LGTM! Correct iterative DeltaFormer computation.

The iterative computation correctly implements the DeltaFormer formula u[i] = v[i] - beta[i] * sum_{j<i} softmax(q[i] @ k[:i]^T) @ u[:i] with proper handling of the base case and accumulation of previous u values.

fla/ops/deltaformer/parallel.py (2)

1073-1099: LGTM! Clean public API with varlen support.

The deltaformer_attn function provides a clean interface supporting both fixed-length and variable-length sequences with proper gradient handling through the custom autograd function.


839-905: Varlen backward assumes BS==1 - add explicit guard.

The backward pass indexes tensors with [0, ...] which assumes batch size 1. This will silently produce incorrect gradients for BS>1.

Apply this diff:

 if getattr(ctx, 'cu_seqlens', None) is not None:
     cu = ctx.cu_seqlens
     qo, ko, vo, u_full, ws, lses, betao = ctx.saved_tensors
     BS, NH, T_max, D = ko.size()
+    if BS != 1:
+        raise NotImplementedError(
+            "Varlen deltaformer backward currently requires packed inputs with BS==1. "
+            "Got BS={BS}. Please pack sequences into a single batch dimension."
+        )

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
fla/ops/deltaformer/parallel.py (1)

17-20: Add stacklevel to warning for better diagnostics.

Static analyzer suggests adding stacklevel=2 to point the warning to the calling code rather than this module.

     warnings.warn(
         "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
-        category=ImportWarning
+        category=ImportWarning,
+        stacklevel=2
     )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0a23799 and e6f09c3.

📒 Files selected for processing (2)
  • fla/layers/deltaformer.py (1 hunks)
  • fla/ops/deltaformer/parallel.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
fla/layers/deltaformer.py (5)
fla/modules/layernorm.py (1)
  • RMSNorm (1060-1107)
fla/modules/rotary.py (1)
  • RotaryEmbedding (306-500)
fla/ops/deltaformer/parallel.py (2)
  • deltaformer_attn (941-986)
  • forward (664-694)
fla/ops/utils/index.py (1)
  • prepare_lens_from_mask (43-44)
fla/models/deltaformer/modeling_deltaformer.py (3)
  • forward (63-91)
  • forward (142-207)
  • forward (291-351)
fla/ops/deltaformer/parallel.py (3)
fla/layers/utils.py (2)
  • pad_input (176-197)
  • unpad_input (101-173)
fla/layers/deltaformer.py (1)
  • forward (96-153)
fla/ops/deltaformer/invcum.py (2)
  • backward_x (20-26)
  • forward_inplace (16-17)
🪛 Ruff (0.13.1)
fla/layers/deltaformer.py

102-102: Unused method argument: use_cache

(ARG002)

fla/ops/deltaformer/parallel.py

17-17: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


37-37: Unpacked variable T is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


673-673: Unpacked variable D is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


704-704: Unpacked variable T_max is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


951-951: Avoid specifying long messages outside the exception class

(TRY003)


953-953: Unpacked variable H is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


953-953: Unpacked variable D is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🔇 Additional comments (6)
fla/layers/deltaformer.py (4)

1-154: LGTM! Clean and comprehensive DeltaFormer attention layer implementation.

This implementation provides a solid foundation for the DeltaFormer attention mechanism with proper module structure, comprehensive docstrings, and appropriate parameter handling. The forward method correctly integrates with the deltaformer ops and handles both fixed-length and variable-length sequences appropriately.


101-103: Add explicit guard for unused use_cache parameter.

The static analyzer correctly identifies unused use_cache parameter. Based on past review feedback, this should be guarded explicitly to prevent misleading the API consumer.

Apply this guard to make the API behavior explicit:

 def forward(
     self,
     hidden_states: torch.Tensor,
     attention_mask: Optional[torch.LongTensor] = None,
     past_key_values: Optional[Cache] = None,
     output_attentions: bool = False,
     use_cache: bool = False,
     **kwargs,
 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+    if use_cache:
+        raise NotImplementedError("DeltaFormerAttention KV cache is not implemented yet.")
+    if output_attentions:
+        raise NotImplementedError("DeltaFormerAttention does not support outputting attention weights.")
     attentions = None

106-111: Replace assertion with explicit exception.

As noted in past reviews, runtime assertions can be stripped with -O and should be replaced with explicit exceptions.

-    if attention_mask is not None:
-        assert len(attention_mask.shape) == 2, (
-            "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
-            "for padding purposes (0 indicating padding). "
-            "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
-        )
+    if attention_mask is not None:
+        if attention_mask.dim() != 2:
+            raise ValueError(
+                "attention_mask must be [batch, seq_len] with 1=valid and 0=pad; 2D pairwise masks are unsupported."
+            )

129-131: Fix tensor max() usage for seqlen_offset calculation.

The code uses Python's built-in max() on a tensor which will raise an error. This needs to be fixed for proper cache sizing.

     if attention_mask is not None:
         seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
-        max_seqlen = q_len + max(seqlen_offset)
+        max_seqlen = q_len + int(seqlen_offset.max().item())
fla/ops/deltaformer/parallel.py (2)

950-951: Check for both flash attention functions in error guard.

The function check only validates flash_attn_func but the code later unconditionally calls flash_attn_varlen_func, which could cause NameError if varlen function is missing.

-    if flash_attn_func is None:
-        raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
+    if flash_attn_func is None or flash_attn_varlen_func is None:
+        raise ImportError(
+            "DeltaFormer requires FlashAttention (both flash_attn_func and flash_attn_varlen_func). "
+            "Install via: pip install flash-attn --no-build-isolation"
+        )

1-991: LGTM! Comprehensive Triton-accelerated DeltaFormer implementation.

This is a well-structured implementation of the DeltaFormer parallel attention with proper autotuning, autograd support, and both fixed-length and variable-length sequence handling. The Triton kernels are appropriately organized and the public API provides clean abstraction over the complex internals.

Comment on lines +14 to +22
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
except ImportError:
warnings.warn(
"Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
category=ImportWarning
)
flash_attn_func = None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Missing flash_attn_varlen_func in except block.

The search results confirm that flash_attn import errors are common and both functions need to be handled. The except block only sets flash_attn_func = None but leaves flash_attn_varlen_func undefined, which will cause NameError when the varlen path is used.

Apply this fix to define both sentinel values:

 try:
     from flash_attn import flash_attn_func, flash_attn_varlen_func
 except ImportError:
     warnings.warn(
         "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
         category=ImportWarning
     )
     flash_attn_func = None
+    flash_attn_varlen_func = None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
except ImportError:
warnings.warn(
"Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
category=ImportWarning
)
flash_attn_func = None
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
except ImportError:
warnings.warn(
"Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
category=ImportWarning
)
flash_attn_func = None
flash_attn_varlen_func = None
🧰 Tools
🪛 Ruff (0.13.1)

17-17: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

🤖 Prompt for AI Agents
In fla/ops/deltaformer/parallel.py around lines 14 to 22, the except ImportError
only sets flash_attn_func = None but leaves flash_attn_varlen_func undefined;
update the except block to also assign flash_attn_varlen_func = None so both
sentinels are defined when the import fails, and keep the existing warning
message intact.

@yzhangcs yzhangcs merged commit 195b74d into fla-org:main Sep 22, 2025
2 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants