- 
                Notifications
    You must be signed in to change notification settings 
- Fork 280
[DeltaFormer] Add Model #585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| Caution Review failedThe pull request is closed. Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds DeltaFormer: new attention layer, Triton and naive ops, invcum helpers, full model/config and CausalLM wrapper with Transformers registration, package exports, README updates, and tests; package version bumped to 0.3.2. Changes
 Sequence Diagram(s)sequenceDiagram
    autonumber
    participant App as Application
    participant CausalLM as DeltaFormerForCausalLM
    participant Model as DeltaFormerModel
    participant Block as DeltaFormerBlock[n]
    participant Attn as DeltaFormerAttention
    participant Op as deltaformer_attn (Triton / naive)
    participant Inv as invcum
    participant FA as FlashAttention (optional)
    App->>CausalLM: forward(input_ids, attention_mask, ...)
    CausalLM->>Model: forward(..., use_cache)
    Model->>Block: for each layer: forward(hidden_states, mask, past_key_values)
    Block->>Attn: compute Q,K,V (+rotary)
    Attn->>Op: deltaformer_attn(q,k,v,beta, cu_seqlens?)
    alt training (requires grads)
        Op-->>Attn: u + saved tensors
    else inference (fast path)
        Op-->>Attn: u
    end
    opt triangular helpers
        Op->>Inv: invcum forward/backward
        Inv-->>Op: solves/results
    end
    Attn-->>Block: output projection
    Block-->>Model: updated hidden_states, PKV
    Model-->>CausalLM: logits / past_key_values
    CausalLM-->>App: logits / loss
sequenceDiagram
    autonumber
    participant Autograd as Autograd
    participant Op as deltaformer_attn
    participant Kern as Triton kernels
    participant Inv as invcum
    note over Op,Kern: Chunked backward (training)
    Autograd->>Op: backward(dL/du)
    Op->>Kern: backward_u_chunk / backward_qk (per-chunk)
    Kern-->>Op: partial dq, dk, dv, dbeta
    Op->>Inv: backward solves for triangular systems
    Inv-->>Op: du/dw pieces
    Op-->>Autograd: aggregated dq, dk, dv, dbeta
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120+ minutes Possibly related PRs
 Poem
 Pre-merge checks and finishing touches❌ Failed checks (1 warning)
 ✅ Passed checks (2 passed)
 📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @Nathancgy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request integrates the DeltaFormer model architecture into the library, providing a novel approach to attention mechanisms. It includes the core model components, optimized pre-attention kernels for performance, and full integration with the existing model ecosystem, enabling its use for causal language modeling tasks.
Highlights
- New Model Introduction: Introduces the DeltaFormer layer and model, based on the paper "Understanding Transformer from the Perspective of Associative Memory" (https://arxiv.org/pdf/2505.19488).
- Optimized Pre-Attention: Implements optimized pre-attention using fused Triton kernels for efficient computation, specifically for chunked pre-attention. This includes dedicated Triton ops for fused_chunk.py,invcum.py, andnaive.py.
- Backward Compatibility: Ensures backward compatibility for the optional betaparameter, which now defaults to all ones when not explicitly provided.
- Integration and Testing: Includes comprehensive unit tests and integrates the new model into the existing framework, updating relevant __init__.pyfiles and model utility lists.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description | 
|---|---|---|
| Code Review | /gemini review | Performs a code review for the current pull request in its current state. | 
| Pull Request Summary | /gemini summary | Provides a summary of the current pull request in its current state. | 
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. | 
| Help | /gemini help | Displays a list of available commands. | 
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
- 
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩ 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the DeltaFormer model, a significant addition to the library. The implementation includes the core attention layer, model configuration, and optimized Triton kernels for performance. The code is well-structured, and the inclusion of a naive reference implementation is a great practice for testing and validation.
However, there is a critical issue that needs to be addressed: the model currently does not support key-value caching for autoregressive generation. This is a fundamental feature for causal language models, and its absence severely limits the model's practical utility. I have also identified a couple of areas for code simplification and improved clarity. Please see the detailed comments for suggestions.
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| attention_mask: Optional[torch.LongTensor] = None, | ||
| past_key_values: Optional[Cache] = None, | ||
| output_attentions: bool = False, | ||
| use_cache: bool = False, | ||
| **kwargs, | ||
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward method accepts past_key_values and use_cache arguments, which suggests that it should support incremental decoding for generation. However, these arguments are not used, and past_key_values is returned unmodified. Without implementing the key-value cache, autoregressive generation will be extremely inefficient as all previous tokens would need to be re-processed at every step. This is a critical feature for a model intended for causal language modeling, and its absence is confirmed by DeltaFormerConfig being added to GENERATION_UNSUPPORTED in the test suite. Please implement the caching mechanism to enable efficient generation.
| GENERATION_UNSUPPORTED = [ | ||
| "ABCConfig", "LinearAttentionConfig", "LightNetConfig", | ||
| "Mamba2Config", "MambaConfig", "NSAConfig", "SambaConfig", "RWKV6Config", "RWKV7Config", | ||
| "DeltaFormerConfig", | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding DeltaFormerConfig to the GENERATION_UNSUPPORTED list means that generation tests are being skipped for this new model. This is likely due to the lack of key-value caching implementation in the attention layer, which I've commented on separately. For a ForCausalLM model, supporting efficient generation is a core requirement. This test should be enabled once the caching mechanism is implemented.
        
          
                fla/layers/deltaformer.py
              
                Outdated
          
        
      | if attention_mask is not None: | ||
| # Use varlen FlashAttention path. Pre-attention currently supports fixed length only → fallback by padding. | ||
| q_full = q | ||
| k_full = k | ||
| v_full = v | ||
| beta_full = beta | ||
| else: | ||
| q_full, k_full, v_full, beta_full = q, k, v, beta | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if/else block is redundant. Both the if and else branches perform the same assignments. You can simplify the code by removing the conditional block and keeping only the assignments.
        # Use varlen FlashAttention path. Pre-attention currently supports fixed length only → fallback by padding.
        q_full, k_full, v_full, beta_full = q, k, v, beta        
          
                fla/ops/deltaformer/fused_chunk.py
              
                Outdated
          
        
      | qi = k[:, i:i + C, :] | ||
| ki = q[:, i + C:, :] | ||
| lse = lses[:, i + C:] | ||
| beta_single = beta[:, i + C:] | ||
| du = backward_u_chunk(qi, ki, lse, grad_v[:, i + C:, :], fa_scale, beta_single) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable names qi and ki are misleading. qi is initialized with a slice of k, and ki with a slice of q. This is counter-intuitive and can make the code difficult to understand and maintain. Please consider renaming them to something more descriptive, for example, k_chunk and q_context, to improve readability.
| qi = k[:, i:i + C, :] | |
| ki = q[:, i + C:, :] | |
| lse = lses[:, i + C:] | |
| beta_single = beta[:, i + C:] | |
| du = backward_u_chunk(qi, ki, lse, grad_v[:, i + C:, :], fa_scale, beta_single) | |
| k_chunk = k[:, i:i + C, :] | |
| q_context = q[:, i + C:, :] | |
| lse = lses[:, i + C:] | |
| beta_single = beta[:, i + C:] | |
| du = backward_u_chunk(k_chunk, q_context, lse, grad_v[:, i + C:, :], fa_scale, beta_single) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🧹 Nitpick comments (22)
fla/ops/deltaformer/invcum.py (1)
16-18: Return the tensor and caution on in-place autograd.
forward_inplacemutatesuwithout returning it. Returninguimproves ergonomics and mirrors common PyTorch APIs. Also, ifu.requires_gradanduis needed for backward, this write may error at runtime.Apply:
-def forward_inplace(u, w): - u.copy_(forward(u, w)) +def forward_inplace(u, w): + u.copy_(forward(u, w)) + return ufla/ops/deltaformer/naive.py (3)
10-36: Safer masked softmax for all-masked rows.Row 0 is fully masked when
strict=True. Your subsequent masked_fill cleans NaNs, but we can avoid producing them altogether and cut one exp.Apply:
- masked = scores.masked_fill(~mask, float('-inf')) - max_per_row = masked.max(dim=-1, keepdim=True).values - exp = (masked - max_per_row).exp() - exp = exp.masked_fill(~mask, 0.0) + masked = scores.masked_fill(~mask, float('-inf')) + max_per_row = masked.max(dim=-1, keepdim=True).values + logits = (masked - max_per_row).masked_fill(~mask, 0.0) + exp = logits.exp()This prevents NaN creation and removes one masked_fill.
59-66: Replace asserts with explicit validation (kept in optimized runs).
assertis stripped under-O. Raise exceptions instead.Apply:
- assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "q,k,v must be [B,H,T,D]" + if not (q.dim() == 4 and k.dim() == 4 and v.dim() == 4): + raise ValueError("q,k,v must be [B,H,T,D]") B, H, T, D = q.shape - assert k.shape == (B, H, T, D) and v.shape == (B, H, T, D) + if k.shape != (B, H, T, D) or v.shape != (B, H, T, D): + raise ValueError(f"Expected k,v shape {(B,H,T,D)}, got {k.shape} and {v.shape}") if beta is None: beta = q.new_ones((B, H, T)) else: - assert beta.shape == (B, H, T) + if beta.shape != (B, H, T): + raise ValueError(f"beta shape must be {(B,H,T)}, got {beta.shape}")
67-81: Reference path is clear; add a tiny correctness test to guard regressions.Add a unit test that compares
delta_pre_attn_naivevs fuseddelta_pre_attnon small tensors for fp32/fp16/bf16 within tolerance.If you want, I can add
tests/ops/test_delta_pre_attn_equiv.pywith randomized seeds and tolerances.fla/models/deltaformer/configuration_deltaformer.py (4)
10-13: Annotate class vars to please linters.Mark
model_typeandkeys_to_ignore_at_inferenceasClassVar[...].Apply:
-from typing import Dict, Optional +from typing import Dict, Optional, ClassVar @@ -class DeltaFormerConfig(PretrainedConfig): - model_type = 'deltaformer' - keys_to_ignore_at_inference = ['past_key_values'] +class DeltaFormerConfig(PretrainedConfig): + model_type: ClassVar[str] = 'deltaformer' + keys_to_ignore_at_inference: ClassVar[list[str]] = ['past_key_values']
14-46: Add basic shape/hyperparameter validation to fail fast.Avoid runtime crashes in attention by checking divisibility and GQA consistency.
Apply near the end of init (before super().init):
+ if hidden_size % num_heads != 0: + raise ValueError(f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})") + kvh = num_kv_heads if num_kv_heads is not None else num_heads + if num_heads % kvh != 0: + raise ValueError(f"num_heads ({num_heads}) must be divisible by num_kv_heads ({kvh})") + if attn_mode not in {"chunk"}: + raise ValueError(f"Unsupported attn_mode: {attn_mode}")
25-29: Type of elementwise_affine should be bool.
Optional[bool]suggests None is meaningful, but you never use None. Simplify tobool.
75-89: Exception messages are fine; optional: shorten per linter TRY003.Purely stylistic; feel free to ignore if you prefer clarity.
fla/ops/deltaformer/__init__.py (1)
3-9: Clean re-exports; consider exposing tril_softmax for testing.Optional: re-export
tril_softmaxfor test utilities, otherwise this is fine.fla/__init__.py (1)
37-39: Top-level re-exports OK; consider also exposing DeltaFormerConfig.Keeps parity with other models’ configs at top-level if that’s a convention you follow.
Apply:
-from fla.models import DeltaFormerForCausalLM, DeltaFormerModel +from fla.models import DeltaFormerConfig, DeltaFormerForCausalLM, DeltaFormerModel @@ - 'DeltaFormerForCausalLM', 'DeltaFormerModel', + 'DeltaFormerConfig', 'DeltaFormerForCausalLM', 'DeltaFormerModel',tests/models/test_modeling_deltaformer.py (1)
40-58: Generation test is solid; add a varlen-specific assertion if feasible.Given pre-attn currently ignores attention masks (fixed-length kernel), consider asserting tolerance is slightly looser when
attention_maskis present, or add a small test that runs without mask to ensure exactness.I can add an extra param case with
attention_mask=Noneto validate the exact path side-by-side.fla/layers/deltaformer.py (4)
80-82: Tighten the ImportError message (Ruff TRY003).Shorten the message to satisfy the linter.
- if flash_attn_func is None: - raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first") + if flash_attn_func is None: + raise ImportError("flash-attn is required; install via `pip install flash-attn --no-build-isolation`.")
117-120: KV grouping replication OK; consider safer API.Optionally use
repeat_interleave(dim=2)for clarity; currentrepeatis fine performance-wise.- k = repeat(k, 'b t h d -> b t (h g) d', g=self.num_kv_groups) - v = repeat(v, 'b t h d -> b t (h g) d', g=self.num_kv_groups) + k = k.repeat_interleave(self.num_kv_groups, dim=2) + v = v.repeat_interleave(self.num_kv_groups, dim=2)
142-156: Varlen path: rename for clarity and guard availability.Names are “unpadded”, not “padded”; also guard when varlen kernel is missing.
- if attention_mask is not None: - q_padded, (k_padded, u_padded), indices_q, cu_seqlens, max_seq_lens = unpad_input(q, (k, u), attention_mask, q_len) - cu_seqlens_q, cu_seqlens_k = cu_seqlens - max_seqlen_q, max_seqlen_k = max_seq_lens - o = flash_attn_varlen_func( - q_padded, k_padded, u_padded, + if attention_mask is not None: + if flash_attn_varlen_func is None: + raise ImportError("flash-attn varlen kernel not available; please upgrade flash-attn.") + q_unpadded, (k_unpadded, u_unpadded), indices_q, cu_seqlens, max_seq_lens = unpad_input(q, (k, u), attention_mask, q_len) + cu_seqlens_q, cu_seqlens_k = cu_seqlens + max_seqlen_q, max_seqlen_k = max_seq_lens + o = flash_attn_varlen_func( + q_unpadded, k_unpadded, u_unpadded, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, causal=True, window_size=(-1, -1) ) o = pad_input(o, indices_q, batch_size, q_len)
93-101: Unused parameters (output_attentions,use_cache,kwargs).If the interface must keep them, explicitly mark unused to silence linters.
def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + _ = (output_attentions, use_cache, kwargs)fla/ops/deltaformer/fused_chunk.py (5)
27-27: Consider using explicit error handling instead of assertions.Assertions can be disabled with the
-Oflag in production, potentially leading to silent failures. Consider using explicit error checking for better production robustness.- assert B == _B and D == _D and B == __B and __C == C + if not (B == _B and D == _D and B == __B and __C == C): + raise ValueError(f"Shape mismatch: q={q.shape}, k={k.shape}, v={v.shape}, beta={beta.shape}")
232-234: Magic number should be a named constant.The value
-1e6used for masking should be defined as a constant for better maintainability.+MASK_VALUE = -1e6 + def flash_attn_kernel( ... ): ... if kv_i >= T - C: mask = (T - C - kv_i + rowid_block[:, None] - colid_block[None, :] < 1) - qk = tl.where(mask, -1e6, qk) + qk = tl.where(mask, MASK_VALUE, qk)
271-271: Potential numerical instability when computing normalized attention.The division
acc / rowsum[:, None]could lead to division by zero ifrowsumcontains zeros. While the forward kernel initializesrowsumto 1 and only adds positive values fromexp2, consider adding a small epsilon for numerical stability.- acc = acc / rowsum[:, None] + acc = acc / (rowsum[:, None] + 1e-12)
819-831: Backward pass could be optimized for memory usage.The backward implementation creates multiple temporary tensors (
du,grad_v) at each chunk iteration. Consider reusing buffers to reduce memory allocations.Would you like me to propose a memory-optimized version of the backward pass that reuses buffers across chunk iterations?
891-894: Redundant gradient checking for beta parameter.The condition checks
beta is not None and beta.requires_grad, but if beta is None, it gets replaced with ones tensor in_forward_impl. Consider simplifying the gradient check logic.- if k.requires_grad or q.requires_grad or v.requires_grad or (beta is not None and beta.requires_grad): - return _DeltaPreAttnFunction.apply(q, k, v, beta, C) - u, _, _ = _DeltaPreAttnFunction._forward_impl(q, k, v, beta, C, need_aux=False) - return u + # Check if any input requires gradient + requires_grad = k.requires_grad or q.requires_grad or v.requires_grad + if beta is not None: + requires_grad = requires_grad or beta.requires_grad + + if requires_grad: + return _DeltaPreAttnFunction.apply(q, k, v, beta, C) + else: + u, _, _ = _DeltaPreAttnFunction._forward_impl(q, k, v, beta, C, need_aux=False) + return ufla/models/deltaformer/modeling_deltaformer.py (2)
106-108: Remove unused method parameters.The parameters
prenorm_residual_strategyandnum_residuals_per_layerare not used in the method implementation.def _init_weights( self, module: nn.Module, - prenorm_residual_strategy: Optional[str] = None, - num_residuals_per_layer: int = 2, ):
152-157: Set proper stacklevel for warning.The warning should include
stacklevel=2to show the correct location in user code.- warnings.warn( - "`DeltaFormerModel` does not support output attention weights now, " - "so `output_attentions` is set to `False`." - ) + warnings.warn( + "`DeltaFormerModel` does not support output attention weights now, " + "so `output_attentions` is set to `False`.", + stacklevel=2 + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
- fla/__init__.py(2 hunks)
- fla/layers/__init__.py(2 hunks)
- fla/layers/deltaformer.py(1 hunks)
- fla/models/__init__.py(2 hunks)
- fla/models/deltaformer/__init__.py(1 hunks)
- fla/models/deltaformer/configuration_deltaformer.py(1 hunks)
- fla/models/deltaformer/modeling_deltaformer.py(1 hunks)
- fla/ops/deltaformer/__init__.py(1 hunks)
- fla/ops/deltaformer/fused_chunk.py(1 hunks)
- fla/ops/deltaformer/invcum.py(1 hunks)
- fla/ops/deltaformer/naive.py(1 hunks)
- tests/models/test_modeling_deltaformer.py(1 hunks)
- tests/models/test_modeling_utils.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (10)
tests/models/test_modeling_deltaformer.py (2)
fla/models/deltaformer/configuration_deltaformer.py (1)
DeltaFormerConfig(10-96)tests/models/test_modeling_base.py (2)
run_test_generation(67-126)
run_test_model_forward_backward(27-61)
fla/ops/deltaformer/__init__.py (2)
fla/ops/deltaformer/fused_chunk.py (1)
delta_pre_attn(878-894)fla/ops/deltaformer/naive.py (1)
delta_pre_attn_naive(39-81)
fla/models/__init__.py (2)
fla/models/deltaformer/configuration_deltaformer.py (1)
DeltaFormerConfig(10-96)fla/models/deltaformer/modeling_deltaformer.py (2)
DeltaFormerForCausalLM(208-349)
DeltaFormerModel(119-205)
fla/ops/deltaformer/invcum.py (1)
fla/ops/deltaformer/fused_chunk.py (3)
forward(16-31)
forward(784-798)
backward(801-834)
fla/__init__.py (1)
fla/models/deltaformer/modeling_deltaformer.py (2)
DeltaFormerForCausalLM(208-349)
DeltaFormerModel(119-205)
fla/layers/deltaformer.py (4)
fla/layers/utils.py (2)
pad_input(176-197)
unpad_input(101-173)fla/modules/layernorm.py (1)
RMSNorm(1060-1107)fla/ops/deltaformer/fused_chunk.py (3)
delta_pre_attn(878-894)
forward(16-31)
forward(784-798)fla/models/deltaformer/modeling_deltaformer.py (3)
forward(61-89)
forward(140-205)
forward(289-349)
fla/layers/__init__.py (1)
fla/layers/deltaformer.py (1)
DeltaFormerAttention(28-163)
fla/models/deltaformer/modeling_deltaformer.py (7)
fla/layers/deltaformer.py (2)
DeltaFormerAttention(28-163)
forward(93-163)fla/models/deltaformer/configuration_deltaformer.py (1)
DeltaFormerConfig(10-96)fla/modules/fused_cross_entropy.py (1)
FusedCrossEntropyLoss(344-419)fla/modules/fused_linear_cross_entropy.py (1)
FusedLinearCrossEntropyLoss(493-567)fla/modules/mlp.py (1)
GatedMLP(26-69)fla/modules/layernorm.py (1)
RMSNorm(1060-1107)fla/models/modeling_layers.py (1)
GradientCheckpointingLayer(11-71)
fla/models/deltaformer/__init__.py (2)
fla/models/deltaformer/configuration_deltaformer.py (1)
DeltaFormerConfig(10-96)fla/models/deltaformer/modeling_deltaformer.py (2)
DeltaFormerForCausalLM(208-349)
DeltaFormerModel(119-205)
fla/ops/deltaformer/fused_chunk.py (2)
fla/layers/deltaformer.py (1)
forward(93-163)fla/ops/deltaformer/invcum.py (4)
forward(7-13)
backward(29-38)
backward_x(20-26)
forward_inplace(16-17)
🪛 Ruff (0.12.2)
fla/models/deltaformer/configuration_deltaformer.py
12-12: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
76-76: Avoid specifying long messages outside the exception class
(TRY003)
80-80: Avoid specifying long messages outside the exception class
(TRY003)
82-82: Avoid specifying long messages outside the exception class
(TRY003)
84-84: Avoid specifying long messages outside the exception class
(TRY003)
fla/ops/deltaformer/naive.py
59-59: Use of assert detected
(S101)
61-61: Use of assert detected
(S101)
65-65: Use of assert detected
(S101)
fla/layers/deltaformer.py
81-81: Avoid specifying long messages outside the exception class
(TRY003)
98-98: Unused method argument: output_attentions
(ARG002)
99-99: Unused method argument: use_cache
(ARG002)
100-100: Unused method argument: kwargs
(ARG002)
104-104: Use of assert detected
(S101)
fla/models/deltaformer/modeling_deltaformer.py
97-97: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
106-106: Unused method argument: prenorm_residual_strategy
(ARG002)
107-107: Unused method argument: num_residuals_per_layer
(ARG002)
143-143: Unused blanket noqa directive
Remove unused noqa directive
(RUF100)
153-153: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
164-164: Avoid specifying long messages outside the exception class
(TRY003)
166-166: Avoid specifying long messages outside the exception class
(TRY003)
210-210: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
246-252: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
246-252: Avoid specifying long messages outside the exception class
(TRY003)
254-254: Use raise without specifying exception name
Remove exception name
(TRY201)
265-265: Unused method argument: kwargs
(ARG002)
341-341: Consider (logits, *outputs[1:]) instead of concatenation
Replace with (logits, *outputs[1:])
(RUF005)
342-342: Consider (loss, *output) instead of concatenation
Replace with (loss, *output)
(RUF005)
fla/ops/deltaformer/fused_chunk.py
27-27: Use of assert detected
(S101)
192-192: Unused function argument: B
(ARG001)
352-352: Unused function argument: B
(ARG001)
459-459: Unused function argument: B
(ARG001)
581-581: Unused function argument: B
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test H100 (PyTorch 2.7) / test-models
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (14)
fla/ops/deltaformer/invcum.py (3)
7-13: Numerical intent is clear; please confirm unit-lower-triangular assumption for w.
unitriangular=Trueignores the diagonal ofw. Ensure upstream constructswas unit-lower-triangular (or at least that its diagonal is not relied upon). If not guaranteed, results may silently diverge.
20-27: Backprop solve uses strictly-upper assumption; verify ws construction.
w.tril(-1).mHwithupper=True, unitriangular=Trueexpects a strictly-upper unit-triangular system after conjugate-transpose. Please verifyws[i]is produced accordingly in fused_chunk; otherwise gradients w.r.t. RHS can be wrong.
29-38: Castdwto the parameter’s dtype and zero its diagonal after the cast
Add.to(w.dtype)immediately after.tril(-1)to ensure mixed-precision doesn’t change the dtype, and confirm (e.g. via an assertion infused_chunk) thatdw.dtype == w.dtypeand the diagonal remains zero.fla/layers/__init__.py (2)
10-10: Public export looks good.
64-65: all update is correct.fla/models/__init__.py (2)
7-7: DeltaFormer symbols wired correctly.
64-65: all includes DeltaFormer trio.Good for downstream imports.
tests/models/test_modeling_deltaformer.py (1)
14-24: Param IDs aid debugging; good coverage across head dims.fla/models/deltaformer/__init__.py (2)
12-12: Exports look good.
8-10: Wrap registrations intry/exceptfor backward compatibility
Oldertransformersreleases don’t support theexist_okparameter onAutoConfig.register/AutoModel.register/AutoModelForCausalLM.register. Fallback to calling.register(...)withoutexist_okinside aTypeErrorhandler to avoid import-time failures.tests/models/test_modeling_utils.py (2)
13-13: LGTM: gating varlen testing for DeltaFormer.Consistent with current layer limitations.
27-27: LGTM: skip generation tests for DeltaFormer.Matches the current
generateconstraints.fla/models/deltaformer/modeling_deltaformer.py (2)
80-87: LGTM! Clean implementation of fused norm branching.The conditional logic for fused vs non-fused normalization is well-structured, maintaining the correct residual connections in both paths.
333-333: Label shape aligns with hidden_states—no mismatch. The shifting logic mirrors other causal‐LM implementations in this repo (e.g., Rodimus, Transformer) and, with the default logits_to_keep = 0 (full sequence), both labels and hidden_states share the same sequence length.
        
          
                fla/layers/deltaformer.py
              
                Outdated
          
        
      | try: | ||
| from flash_attn import flash_attn_func, flash_attn_varlen_func | ||
| except ImportError: | ||
| flash_attn_func = None | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid potential NameError; mirror sentinel for varlen import.
If the import fails, flash_attn_varlen_func is undefined. Set it to None too.
 try:
     from flash_attn import flash_attn_func, flash_attn_varlen_func
 except ImportError:
-    flash_attn_func = None
+    flash_attn_func = None
+    flash_attn_varlen_func = None📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| try: | |
| from flash_attn import flash_attn_func, flash_attn_varlen_func | |
| except ImportError: | |
| flash_attn_func = None | |
| try: | |
| from flash_attn import flash_attn_func, flash_attn_varlen_func | |
| except ImportError: | |
| flash_attn_func = None | |
| flash_attn_varlen_func = None | 
🤖 Prompt for AI Agents
In fla/layers/deltaformer.py around lines 20 to 24, the except block only sets
flash_attn_func to None leaving flash_attn_varlen_func undefined on ImportError;
update the except block to also assign flash_attn_varlen_func = None so both
symbols exist as sentinels when the import fails.
| self.hidden_size = hidden_size | ||
| self.num_heads = num_heads | ||
| self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads | ||
| self.num_kv_groups = num_heads // self.num_kv_heads | ||
| self.head_dim = self.hidden_size // self.num_heads | ||
| self.kv_dim = self.num_kv_heads * self.head_dim | ||
| self.qkv_bias = qkv_bias | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate head divisibility and GQA grouping early.
Prevent silent shape mismatches when hidden_size % num_heads != 0 or num_heads % num_kv_heads != 0.
-        self.hidden_size = hidden_size
-        self.num_heads = num_heads
-        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
-        self.num_kv_groups = num_heads // self.num_kv_heads
-        self.head_dim = self.hidden_size // self.num_heads
-        self.kv_dim = self.num_kv_heads * self.head_dim
+        self.hidden_size = hidden_size
+        self.num_heads = num_heads
+        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
+        if self.hidden_size % self.num_heads != 0:
+            raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads}).")
+        if self.num_heads % self.num_kv_heads != 0:
+            raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads}).")
+        self.num_kv_groups = self.num_heads // self.num_kv_heads
+        self.head_dim = self.hidden_size // self.num_heads
+        self.kv_dim = self.num_kv_heads * self.head_dim
         self.qkv_bias = qkv_bias📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self.hidden_size = hidden_size | |
| self.num_heads = num_heads | |
| self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads | |
| self.num_kv_groups = num_heads // self.num_kv_heads | |
| self.head_dim = self.hidden_size // self.num_heads | |
| self.kv_dim = self.num_kv_heads * self.head_dim | |
| self.qkv_bias = qkv_bias | |
| self.hidden_size = hidden_size | |
| self.num_heads = num_heads | |
| self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads | |
| if self.hidden_size % self.num_heads != 0: | |
| raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads}).") | |
| if self.num_heads % self.num_kv_heads != 0: | |
| raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads}).") | |
| self.num_kv_groups = self.num_heads // self.num_kv_heads | |
| self.head_dim = self.hidden_size // self.num_heads | |
| self.kv_dim = self.num_kv_heads * self.head_dim | |
| self.qkv_bias = qkv_bias | 
🤖 Prompt for AI Agents
In fla/layers/deltaformer.py around lines 70 to 76, currently head sizes and GQA
grouping are computed without validation which can lead to silent shape
mismatches; add explicit checks before computing head_dim and kv_dim: verify
hidden_size % num_heads == 0 and raise a ValueError with a clear message if not,
and verify num_heads % num_kv_heads == 0 (after resolving default num_kv_heads)
and raise a ValueError if that fails; only after these validations compute
head_dim and kv_dim so failures occur early and with descriptive errors.
| raise AttributeError( | ||
| f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " | ||
| f"which is not supported for {self.__class__.__name__}. " | ||
| f"Try another generation strategy instead. " | ||
| f"For the available generation strategies, check this doc: " | ||
| f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" | ||
| ) | ||
| else: | ||
| raise exception | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Improve exception chaining for better debugging.
When re-raising exceptions, use from to preserve the original exception context.
-            if 'past_key_values' in str(exception):
-                raise AttributeError(
-                    f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
-                    f"which is not supported for {self.__class__.__name__}. "
-                    f"Try another generation strategy instead. "
-                    f"For the available generation strategies, check this doc: "
-                    f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
-                )
-            else:
-                raise exception
+            if 'past_key_values' in str(exception):
+                raise AttributeError(
+                    f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
+                    f"which is not supported for {self.__class__.__name__}. "
+                    f"Try another generation strategy instead. "
+                    f"For the available generation strategies, check this doc: "
+                    f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
+                ) from exception
+            else:
+                raise🧰 Tools
🪛 Ruff (0.12.2)
246-252: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
246-252: Avoid specifying long messages outside the exception class
(TRY003)
254-254: Use raise without specifying exception name
Remove exception name
(TRY201)
🤖 Prompt for AI Agents
In fla/models/deltaformer/modeling_deltaformer.py around lines 246 to 254, the
code currently does "else: raise exception" which loses proper exception
chaining; replace that line with a plain "raise" to re-raise the caught
exception preserving its traceback, or if you must raise a new exception, raise
it with "from exception" (e.g., raise NewError(...) from exception) so the
original context is preserved.
| if getattr(self, 'criterion', None) is None: | ||
| if self.config.fuse_linear_cross_entropy: | ||
| criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp) | ||
| elif self.config.fuse_cross_entropy: | ||
| criterion = FusedCrossEntropyLoss(inplace_backward=True) | ||
| else: | ||
| criterion = nn.CrossEntropyLoss() | ||
| else: | ||
| criterion = self.criterion | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Criterion initialization should happen in init.
The criterion is lazily initialized in the forward pass, which could lead to unexpected behavior when the model is used in distributed settings or when the configuration changes after model creation.
 def __init__(self, config: DeltaFormerConfig):
     super().__init__(config)
     self.model = DeltaFormerModel(config)
     self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
-    self.criterion = None
+    
+    # Initialize criterion based on config
+    if config.fuse_linear_cross_entropy:
+        self.criterion = FusedLinearCrossEntropyLoss(use_l2warp=config.use_l2warp)
+    elif config.fuse_cross_entropy:
+        self.criterion = FusedCrossEntropyLoss(inplace_backward=True)
+    else:
+        self.criterion = nn.CrossEntropyLoss()
     
     self.post_init()And simplify the forward method:
-            if getattr(self, 'criterion', None) is None:
-                if self.config.fuse_linear_cross_entropy:
-                    criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp)
-                elif self.config.fuse_cross_entropy:
-                    criterion = FusedCrossEntropyLoss(inplace_backward=True)
-                else:
-                    criterion = nn.CrossEntropyLoss()
-            else:
-                criterion = self.criterion
+            criterion = self.criterion📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if getattr(self, 'criterion', None) is None: | |
| if self.config.fuse_linear_cross_entropy: | |
| criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp) | |
| elif self.config.fuse_cross_entropy: | |
| criterion = FusedCrossEntropyLoss(inplace_backward=True) | |
| else: | |
| criterion = nn.CrossEntropyLoss() | |
| else: | |
| criterion = self.criterion | |
| class DeltaFormerForCausalLM(nn.Module): | |
| def __init__(self, config: DeltaFormerConfig): | |
| super().__init__(config) | |
| self.model = DeltaFormerModel(config) | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| # Initialize criterion based on config | |
| if config.fuse_linear_cross_entropy: | |
| self.criterion = FusedLinearCrossEntropyLoss(use_l2warp=config.use_l2warp) | |
| elif config.fuse_cross_entropy: | |
| self.criterion = FusedCrossEntropyLoss(inplace_backward=True) | |
| else: | |
| self.criterion = nn.CrossEntropyLoss() | |
| self.post_init() | |
| def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): | |
| # ... (other forward logic) | |
| # Use the pre-initialized criterion | |
| criterion = self.criterion | |
| # Compute logits, loss, etc. | |
| outputs = self.lm_head(self.model(...)) | |
| loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1)) | |
| # ... (rest of forward) | |
| return loss, outputs | 
| @Nathancgy Hi, thanks for the PR. 
 | 
| @yzhangcs Hi, thanks for helping with the minor fixes! 
 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (6)
fla/layers/deltaformer.py (5)
20-24: Fix incomplete flash_attn import handling.The missing import sentinel creates a potential
NameErrorforflash_attn_varlen_funcwhen the import fails.try: from flash_attn import flash_attn_func, flash_attn_varlen_func except ImportError: flash_attn_func = None + flash_attn_varlen_func = None
70-78: Add validation for head dimension compatibility.Without validation, incompatible head dimensions could cause silent shape errors downstream.
self.hidden_size = hidden_size self.num_heads = num_heads self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads +if self.hidden_size % self.num_heads != 0: + raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})") +if self.num_heads % self.num_kv_heads != 0: + raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})") self.num_kv_groups = num_heads // self.num_kv_heads
103-108: Replace assertion with explicit exception for runtime validation.Using assertions for runtime validation can be stripped with
-Ooptimization.if attention_mask is not None: - assert len(attention_mask.shape) == 2, ( - "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " - "for padding purposes (0 indicating padding). " - "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." - ) + if attention_mask.dim() != 2: + raise ValueError( + "attention_mask must be [batch, seq_len] with 1=valid and 0=pad; 2D pairwise masks are unsupported." + )
124-197: Critical: Missing key-value cache implementation for generation.The forward method accepts caching parameters but doesn't implement proper incremental decoding. This makes autoregressive generation extremely inefficient as all previous tokens need re-processing at each step. The complex caching logic in lines 147-196 shows an attempt at implementation, but without proper cache state management, generation will fail.
# The current implementation attempts caching but lacks proper state management. # A complete cache implementation would require: # 1. Proper cache initialization and state tracking # 2. Efficient key-value updates for incremental decoding # 3. Handling of attention mask updates during generationThis aligns with
DeltaFormerConfigbeing added toGENERATION_UNSUPPORTEDin the test suite.
128-143: Pre-attention doesn't support masked computation - validate mask has no padding.The pre-attention operation ignores
attention_mask, allowing padded tokens to leak into computations and corrupt results.if attention_mask is not None: + if not torch.all(attention_mask): + raise NotImplementedError( + "DeltaFormer pre-attention does not support variable-length sequences; " + "use sequences without padding or implement masked pre-attention." + ) _, cu_seqlens_k, _ = get_unpad_data(attention_mask)fla/ops/deltaformer/fused_chunk.py (1)
948-962: Confusing variable names in backward implementation.The variables
qiandkiare misleadingly named -qicontains a slice ofkwhilekicontains a slice ofq. This makes the code difficult to understand and maintain.- qi = k[:, i:i + cb, :] - ki = q[:, i + cb:, :] + k_chunk = k[:, i:i + cb, :] + q_context = q[:, i + cb:, :] lse_slice = lses[:, i + cb:] beta_single = beta[:, i + cb:] du = backward_u_chunk( - qi, - ki, + k_chunk, + q_context, lse_slice, grad_v[:, i + cb:, :], fa_scale, beta_single, cu_seqlens=cu_flat, row_start=i, )
🧹 Nitpick comments (3)
fla/layers/deltaformer.py (2)
81-81: Consider extracting the error message to a constant.The static analyzer suggests avoiding long messages inline. While this is a minor style issue, extracting improves maintainability.
+_FLASH_ATTN_INSTALL_MSG = "Please install Flash Attention via `pip install flash-attn --no-build-isolation` first" + if flash_attn_func is None: - raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first") + raise ImportError(_FLASH_ATTN_INSTALL_MSG)
86-86: Consider renamingb_projtobeta_projfor clarity.The variable name
b_projis not immediately clear compared to other projection layers. Since the layer produces beta weights,beta_projwould be more descriptive.-self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True) +self.beta_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)And update Line 115:
-beta = rearrange(self.b_proj(hidden_states), 'b t h -> b h t') +beta = rearrange(self.beta_proj(hidden_states), 'b t h -> b h t')fla/ops/deltaformer/fused_chunk.py (1)
26-33: Add input tensor validation for shape compatibility.The current shape assertions use raw
assertstatements. Consider adding more descriptive validation that would help debugging in production.def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, u: torch.Tensor, qk_scale: float, beta: torch.Tensor, row_start: int, cu_seqlens: Optional[torch.Tensor] = None, ): B, C, D = q.size() _B, T, _D = k.size() __B, __C = beta.size() - assert B == _B and D == _D and B == __B and __C == C + if not (B == _B == __B and D == _D and __C == C): + raise ValueError( + f"Shape mismatch: q{q.shape}, k{k.shape}, beta{beta.shape} must have compatible batch/head/feature dims" + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
- fla/layers/deltaformer.py(1 hunks)
- fla/ops/deltaformer/fused_chunk.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
fla/ops/deltaformer/fused_chunk.py (2)
fla/layers/deltaformer.py (1)
forward(93-223)fla/ops/deltaformer/invcum.py (4)
forward(7-13)
backward(29-38)
backward_x(20-26)
forward_inplace(16-17)
fla/layers/deltaformer.py (4)
fla/layers/utils.py (3)
get_unpad_data(75-98)
pad_input(176-197)
unpad_input(101-173)fla/modules/layernorm.py (1)
RMSNorm(1060-1107)fla/ops/deltaformer/fused_chunk.py (3)
delta_pre_attn(1043-1060)
forward(16-33)
forward(892-908)fla/models/deltaformer/modeling_deltaformer.py (3)
forward(61-89)
forward(140-205)
forward(289-349)
🪛 Ruff (0.12.2)
fla/ops/deltaformer/fused_chunk.py
29-29: Use of assert detected
(S101)
217-217: Unused function argument: B
(ARG001)
406-406: Unused function argument: B
(ARG001)
536-536: Unused function argument: B
(ARG001)
680-680: Unused function argument: B
(ARG001)
fla/layers/deltaformer.py
81-81: Avoid specifying long messages outside the exception class
(TRY003)
98-98: Unused method argument: output_attentions
(ARG002)
100-100: Unused method argument: kwargs
(ARG002)
104-104: Use of assert detected
(S101)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test H100 (PyTorch 2.7) / test-models
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (3)
fla/ops/deltaformer/fused_chunk.py (3)
1043-1061: Well-structured public API with proper gradient routing.The main entry point correctly routes between gradient-aware and gradient-free execution paths, and provides clear documentation. The function properly handles optional parameters and maintains backward compatibility.
890-988: Comprehensive autograd implementation with proper gradient flow.The
_DeltaPreAttnFunctionclass provides a complete PyTorch autograd implementation with:
- Proper forward pass with auxiliary data saving for backward
- Chunked backward pass that correctly accumulates gradients
- Proper handling of optional beta parameter
- Support for both fixed and variable-length sequences
The implementation correctly manages tensor flattening/reshaping and uses the inverse cumulative solver for gradient computation.
188-226: Proper autotune configuration for kernel optimization.The autotune decorator with multiple block size configurations will help optimize performance across different input sizes. The heuristics correctly detect variable-length sequences.
| output_attentions: bool = False, | ||
| use_cache: bool = False, | ||
| **kwargs, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Remove unused parameters or implement functionality.
The static analyzer correctly identifies unused parameters. Either implement the missing functionality or remove these parameters to avoid misleading the API consumer.
For output_attentions:
-def forward(
-    self,
-    hidden_states: torch.Tensor,
-    attention_mask: Optional[torch.LongTensor] = None,
-    past_key_values: Optional[Cache] = None,
-    output_attentions: bool = False,
-    use_cache: bool = False,
-    **kwargs,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-    attentions = None
+def forward(
+    self,
+    hidden_states: torch.Tensor,
+    attention_mask: Optional[torch.LongTensor] = None,
+    past_key_values: Optional[Cache] = None,
+    output_attentions: bool = False,
+    use_cache: bool = False,
+    **kwargs,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+    if output_attentions:
+        raise NotImplementedError("DeltaFormer does not support outputting attention weights")
+    attentions = None📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| output_attentions: bool = False, | |
| use_cache: bool = False, | |
| **kwargs, | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[Cache] = None, | |
| output_attentions: bool = False, | |
| use_cache: bool = False, | |
| **kwargs, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
| if output_attentions: | |
| raise NotImplementedError("DeltaFormer does not support outputting attention weights") | |
| attentions = None | |
| # ... rest of implementation ... | 
🧰 Tools
🪛 Ruff (0.12.2)
98-98: Unused method argument: output_attentions
(ARG002)
100-100: Unused method argument: kwargs
(ARG002)
🤖 Prompt for AI Agents
In fla/layers/deltaformer.py around lines 98 to 100, the parameters
output_attentions and use_cache are declared but not used; remove them from the
function/method signature if the layer does not support these features, or
propagate them to the underlying model calls and adjust return values to include
attentions/cache outputs when requested. Specifically, either (A) delete
output_attentions and use_cache from the signature and update all callers, or
(B) thread output_attentions and use_cache into the internal forward/transformer
call, handle the returned attention tensors and cache (modify return type to
include attentions if output_attentions is True and preserve cache behavior when
use_cache is True), and add tests covering both behaviors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 9
🧹 Nitpick comments (27)
fla/layers/deltaformer.py (4)
157-180: Deduplicate identical pre-attn calls; drive a single code path.All three branches do the same, only differ by
cu_seqlens. Reduce to one call.Apply:
- if attention_mask is not None: - u = delta_pre_attn( - # kk similarity - rearrange(k_full, 'b t h d -> b h t d'), - rearrange(k_full, 'b t h d -> b h t d'), - rearrange(v_full, 'b t h d -> b h t d'), - beta_full, - cu_seqlens=cu_seqlens, - ) - elif cu_seqlens is not None: - u = delta_pre_attn( - rearrange(k_full, 'b t h d -> b h t d'), - rearrange(k_full, 'b t h d -> b h t d'), - rearrange(v_full, 'b t h d -> b h t d'), - beta_full, - cu_seqlens=cu_seqlens, - ) - else: - u = delta_pre_attn( - rearrange(k_full, 'b t h d -> b h t d'), - rearrange(k_full, 'b t h d -> b h t d'), - rearrange(v_full, 'b t h d -> b h t d'), - beta_full, - ) + u = delta_pre_attn( + # kk similarity + rearrange(k_full, 'b t h d -> b h t d'), + rearrange(k_full, 'b t h d -> b h t d'), + rearrange(v_full, 'b t h d -> b h t d'), + beta_full, + cu_seqlens=cu_seqlens, + )
36-44: Docstring is out of date with implementation.
- Code uses varlen pre-attn (via
cu_seqlens) but the note claims fixed-length only.- It also states K–K similarity; PR summary mentioned Q–K. Clarify here.
Proposed text:
- “Pre-attention supports both fixed-length and varlen via cu_seqlens; padded positions are ignored.”
- “Pre-attention currently uses K–K similarity; switching to Q–K or K–W requires changing inputs.”
148-156: Remove commented/unused variables.
q_fullcomments and duplicated assignments add noise.
26-26: Unused logger.Remove or use for warnings (e.g., when falling back paths are taken).
README.md (2)
32-32: Minor: list style nit (markdownlint MD004).Use consistent bullet markers if you enforce markdownlint; otherwise ignore.
88-89: Link and claim sanity-check.DeltaFormer is listed as available; generation in tests is currently marked unsupported. Add a note that KV-cache generation is WIP to avoid confusion.
tests/models/test_modeling_deltaformer.py (1)
49-58: This generation test will always skip.Because
DeltaFormerConfigis inGENERATION_UNSUPPORTED,run_test_generationskips. Either remove this test for now or mark with a TODO/xfailed reason tied to cache support to avoid confusion.fla/ops/deltaformer/invcum.py (2)
7-13: Confirm W semantics (unit-lower triangular A = I + tril(w)).This relies on
unitriangular=Trueandwcarrying strictly-lower entries of A (diag implicitly 1). Please confirm upstream guarantees that the diagonal ofwis ignored and its strictly-upper part never contains NaNs/garbage. Consider documenting this contract here.
29-38: Castdwtow.dtypefor consistency.
dufollowsdo.dtype, whiledwcurrently inherits it too; ifwdiffers, grads may be mixed dtype downstream.- dw = torch.bmm(-du, x.mH) - dw = dw.tril(-1) + dw = torch.bmm(-du, x.mH).to(w.dtype) + dw = dw.tril(-1)tests/ops/test_deltaformer.py (3)
11-23: Add a beta=None case.Core API promises beta defaults to ones; please add a param set exercising
beta=Noneto guard this pathway.[ pytest.param(*test, id="B{}-T{}-H{}-D{}-{}".format(*test)) for test in [ (2, 128, 2, 64, torch.float32), # Test with bfloat16 (1, 256, 4, 64, torch.bfloat16), (2, 512, 4, 64, torch.bfloat16), (4, 1024, 4, 128, torch.bfloat16) ] ] ) +@pytest.mark.parametrize("beta_is_none", [False, True]) @@ -def test_delta_pre_attn( +def test_delta_pre_attn( B: int, T: int, H: int, D: int, - dtype: torch.dtype + dtype: torch.dtype, + beta_is_none: bool ): @@ - beta = torch.randn((B, H, T), dtype=dtype, device=device).sigmoid().requires_grad_(True) + beta = None if beta_is_none else torch.randn((B, H, T), dtype=dtype, device=device).sigmoid().requires_grad_(True)
61-66: Tolerances look tight; keep CI-friendly guardrails.If CI noise appears on different GPUs, consider
warning=Trueor slightly higher ratios for bf16 paths only.
67-75: Varlen pack: verify memory layout explicitly.Good approach. Add
contiguous()on concatenation results to avoid stride surprises under Dynamo/AOTAutograd.- q_packed = torch.cat([q[i] for i in range(B)], dim=1).unsqueeze(0).detach().clone().requires_grad_(True) + q_packed = torch.cat([q[i] for i in range(B)], dim=1).contiguous().unsqueeze(0).detach().clone().requires_grad_(True) - k_packed = torch.cat([k[i] for i in range(B)], dim=1).unsqueeze(0).detach().clone().requires_grad_(True) + k_packed = torch.cat([k[i] for i in range(B)], dim=1).contiguous().unsqueeze(0).detach().clone().requires_grad_(True) - v_packed = torch.cat([v[i] for i in range(B)], dim=1).unsqueeze(0).detach().clone().requires_grad_(True) + v_packed = torch.cat([v[i] for i in range(B)], dim=1).contiguous().unsqueeze(0).detach().clone().requires_grad_(True) - beta_packed = torch.cat([beta[i] for i in range(B)], dim=1).unsqueeze(0).detach().clone().requires_grad_(True) + beta_packed = torch.cat([beta[i] for i in range(B)], dim=1).contiguous().unsqueeze(0).detach().clone().requires_grad_(True)fla/ops/deltaformer/naive.py (2)
10-36: Numerical stability nit: prefertorch.where(mask, -inf, scores)without materializing~mask.Minor, but reduces one extra boolean tensor in tight loops; harmless to keep as-is.
76-87: Naive ref: good reference; add a quick early-exit for T==0.- u_list = [] + if T == 0: + return vf.new_zeros(B, H, 0, D).to(orig_dtype) + u_list = []fla/ops/deltaformer/fused_chunk.py (4)
971-1011: Dobetarow scaling inside the kernel to save a write and keepwreusable.Currently
w = w * betai.unsqueeze(-1)allocates and writes once per chunk. Consider moving row scaling intoflash_attn_kernelat the write towtime or emit a separate kernel that writes scaledw. Optional perf nit.Also applies to: 1042-1070
127-160: Forward wrapper: add an assertion forC<=Tin debug builds.Even though callers clamp
C, a defensive check helps.B, C, D = q.size() _B, T, _D = k.size() @@ - w = torch.empty(B, C, C, device=q.device, dtype=q.dtype) + assert C <= T, "C (chunk) must be <= T (context)." + w = torch.empty(B, C, C, device=q.device, dtype=q.dtype)
162-170: Autotune configs: consider adding 16-warp configs for large D on H100.Optional; may help the reported throughput further on H100. Keep as a follow up if benchmarks justify.
Also applies to: 329-336, 442-449, 559-566
604-604: Ruff hints: silence unused parameters to keep CI green.Prefix with
_or use them; applies to several kernel args and unpacked temps.Examples:
- Line 25:
_T- Line 198/364/477/604:
_B- Line 842:
_T_maxAlso applies to: 477-477, 364-364, 198-198, 25-25, 842-842
fla/models/deltaformer/configuration_deltaformer.py (2)
12-14: Annotate class attributes with ClassVar to satisfy Ruff and intent.
model_typeandkeys_to_ignore_at_inferenceare class-level constants. Annotate withClassVar[...].Apply this diff in place:
-from transformers.configuration_utils import PretrainedConfig +from transformers.configuration_utils import PretrainedConfig +from typing import ClassVar @@ -class DeltaFormerConfig(PretrainedConfig): - model_type = 'deltaformer' - keys_to_ignore_at_inference = ['past_key_values'] +class DeltaFormerConfig(PretrainedConfig): + model_type: ClassVar[str] = 'deltaformer' + keys_to_ignore_at_inference: ClassVar[list[str]] = ['past_key_values']
78-87: Add stacklevel to warnings and fix grammar.Use
stacklevel=2so the warning points callers; fix “can improves” → “can improve”.- if fuse_linear_cross_entropy: - warnings.warn( - "`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency " - "at the potential cost of reduced precision. " - "If you observe issues like loss divergence, consider disabling this setting." - ) + if fuse_linear_cross_entropy: + warnings.warn( + "`fuse_linear_cross_entropy` is enabled, which can improve memory efficiency " + "at the potential cost of reduced precision. " + "If you observe issues like loss divergence, consider disabling this setting.", + stacklevel=2, + )fla/models/deltaformer/modeling_deltaformer.py (7)
71-91: Fix return type annotation of DeltaFormerBlock.forward.Function returns
(hidden_states, attentions, past_key_values)but the annotation declares only two items.- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> Tuple[ + torch.FloatTensor, + Optional[torch.FloatTensor], + Optional[Union[Cache, List[torch.FloatTensor]]], + ]:
96-101: Annotate class-level mutables with ClassVar.Satisfy Ruff RUF012 for
_no_split_modulesand_supports_cache_class.-from transformers.utils.deprecation import deprecate_kwarg +from transformers.utils.deprecation import deprecate_kwarg +from typing import ClassVar @@ - _no_split_modules = ['DeltaFormerBlock'] - _supports_cache_class = True + _no_split_modules: ClassVar[list[str]] = ['DeltaFormerBlock'] + _supports_cache_class: ClassVar[bool] = True
105-120: Unused params in _init_weights.
prenorm_residual_strategyandnum_residuals_per_layeraren’t used.- def _init_weights( - self, - module: nn.Module, - prenorm_residual_strategy: Optional[str] = None, - num_residuals_per_layer: int = 2, - ): + def _init_weights(self, module: nn.Module, *_: object) -> None:
142-161: Add stacklevel to warning; drop unused noqa.Set
stacklevel=2for actionable warnings. Remove the stray# noqa.- attention_mask: Optional[torch.Tensor] = None, # noqa + attention_mask: Optional[torch.Tensor] = None, @@ - warnings.warn( + warnings.warn( "`DeltaFormerModel` does not support output attention weights now, " "so `output_attentions` is set to `False`." - ) + , stacklevel=2)
291-304: Default logits_to_keep to 1 to avoid slicing corner case.With 0,
x[:, -0:]==x[:, 0:](full sequence). Use 1 for typical LM training/inference.- logits_to_keep: Optional[int] = 0, + logits_to_keep: Optional[int] = 1,
343-351: Prefer starred unpack over tuple concatenation.Cleaner and satisfies Ruff RUF005.
- output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output + output = (logits, *outputs[1:]) + return ((loss, *output) if loss is not None else output)
212-213: Annotate _tied_weights_keys as ClassVar.- _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys: ClassVar[list[str]] = ["lm_head.weight"]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (15)
- README.md(2 hunks)
- fla/__init__.py(4 hunks)
- fla/layers/__init__.py(2 hunks)
- fla/layers/deltaformer.py(1 hunks)
- fla/models/__init__.py(2 hunks)
- fla/models/deltaformer/__init__.py(1 hunks)
- fla/models/deltaformer/configuration_deltaformer.py(1 hunks)
- fla/models/deltaformer/modeling_deltaformer.py(1 hunks)
- fla/ops/deltaformer/__init__.py(1 hunks)
- fla/ops/deltaformer/fused_chunk.py(1 hunks)
- fla/ops/deltaformer/invcum.py(1 hunks)
- fla/ops/deltaformer/naive.py(1 hunks)
- tests/models/test_modeling_deltaformer.py(1 hunks)
- tests/models/test_modeling_utils.py(1 hunks)
- tests/ops/test_deltaformer.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (11)
fla/models/__init__.py (2)
fla/models/deltaformer/configuration_deltaformer.py (1)
DeltaFormerConfig(11-107)fla/models/deltaformer/modeling_deltaformer.py (2)
DeltaFormerForCausalLM(210-351)
DeltaFormerModel(121-207)
fla/layers/__init__.py (1)
fla/layers/deltaformer.py (1)
DeltaFormerAttention(29-214)
tests/models/test_modeling_deltaformer.py (2)
fla/models/deltaformer/configuration_deltaformer.py (1)
DeltaFormerConfig(11-107)tests/models/test_modeling_base.py (2)
run_test_generation(67-126)
run_test_model_forward_backward(27-61)
tests/ops/test_deltaformer.py (3)
fla/ops/deltaformer/fused_chunk.py (2)
delta_pre_attn(1073-1099)
backward(835-954)fla/ops/deltaformer/naive.py (1)
delta_pre_attn_naive(39-87)fla/utils.py (1)
assert_close(77-88)
fla/ops/deltaformer/__init__.py (2)
fla/ops/deltaformer/fused_chunk.py (1)
delta_pre_attn(1073-1099)fla/ops/deltaformer/naive.py (1)
delta_pre_attn_naive(39-87)
fla/ops/deltaformer/invcum.py (1)
fla/ops/deltaformer/fused_chunk.py (3)
forward(16-31)
forward(802-832)
backward(835-954)
fla/layers/deltaformer.py (5)
fla/layers/utils.py (3)
get_unpad_data(75-98)
pad_input(176-197)
unpad_input(101-173)fla/modules/layernorm.py (1)
RMSNorm(1060-1107)fla/modules/rotary.py (1)
RotaryEmbedding(306-500)fla/ops/deltaformer/fused_chunk.py (3)
delta_pre_attn(1073-1099)
forward(16-31)
forward(802-832)fla/ops/utils/index.py (1)
prepare_lens_from_mask(43-44)
fla/__init__.py (2)
fla/layers/deltaformer.py (1)
DeltaFormerAttention(29-214)fla/models/deltaformer/modeling_deltaformer.py (2)
DeltaFormerForCausalLM(210-351)
DeltaFormerModel(121-207)
fla/models/deltaformer/__init__.py (2)
fla/models/deltaformer/configuration_deltaformer.py (1)
DeltaFormerConfig(11-107)fla/models/deltaformer/modeling_deltaformer.py (2)
DeltaFormerForCausalLM(210-351)
DeltaFormerModel(121-207)
fla/ops/deltaformer/fused_chunk.py (2)
fla/layers/deltaformer.py (1)
forward(105-214)fla/ops/deltaformer/invcum.py (4)
forward(7-13)
backward(29-38)
backward_x(20-26)
forward_inplace(16-17)
fla/models/deltaformer/modeling_deltaformer.py (8)
fla/utils.py (1)
deprecate_kwarg(190-334)fla/layers/deltaformer.py (2)
DeltaFormerAttention(29-214)
forward(105-214)fla/models/deltaformer/configuration_deltaformer.py (1)
DeltaFormerConfig(11-107)fla/modules/fused_cross_entropy.py (1)
FusedCrossEntropyLoss(344-419)fla/modules/fused_linear_cross_entropy.py (1)
FusedLinearCrossEntropyLoss(493-567)fla/modules/mlp.py (1)
GatedMLP(26-69)fla/modules/layernorm.py (1)
RMSNorm(1060-1107)fla/models/modeling_layers.py (1)
GradientCheckpointingLayer(11-71)
🪛 Ruff (0.13.1)
fla/models/deltaformer/configuration_deltaformer.py
13-13: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
79-81: Avoid specifying long messages outside the exception class
(TRY003)
83-83: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
91-91: Avoid specifying long messages outside the exception class
(TRY003)
93-93: Avoid specifying long messages outside the exception class
(TRY003)
95-95: Avoid specifying long messages outside the exception class
(TRY003)
fla/layers/deltaformer.py
91-91: Avoid specifying long messages outside the exception class
(TRY003)
110-110: Unused method argument: output_attentions
(ARG002)
111-111: Unused method argument: use_cache
(ARG002)
fla/ops/deltaformer/fused_chunk.py
25-25: Unpacked variable T is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
198-198: Unused function argument: B
(ARG001)
364-364: Unused function argument: B
(ARG001)
477-477: Unused function argument: B
(ARG001)
604-604: Unused function argument: B
(ARG001)
842-842: Unpacked variable T_max is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
fla/models/deltaformer/modeling_deltaformer.py
99-99: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
108-108: Unused method argument: prenorm_residual_strategy
(ARG002)
109-109: Unused method argument: num_residuals_per_layer
(ARG002)
145-145: Unused blanket noqa directive
Remove unused noqa directive
(RUF100)
155-155: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
166-166: Avoid specifying long messages outside the exception class
(TRY003)
168-168: Avoid specifying long messages outside the exception class
(TRY003)
212-212: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
248-254: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
248-254: Avoid specifying long messages outside the exception class
(TRY003)
256-256: Use raise without specifying exception name
Remove exception name
(TRY201)
267-267: Unused method argument: kwargs
(ARG002)
343-343: Consider (logits, *outputs[1:]) instead of concatenation
Replace with (logits, *outputs[1:])
(RUF005)
344-344: Consider (loss, *output) instead of concatenation
Replace with (loss, *output)
(RUF005)
🪛 markdownlint-cli2 (0.18.1)
README.md
32-32: Unordered list style
Expected: asterisk; Actual: dash
(MD004, ul-style)
🔇 Additional comments (17)
fla/layers/__init__.py (1)
10-10: Public export looks good.DeltaFormerAttention is imported and added to
__all__.Also applies to: 64-65
fla/__init__.py (1)
9-10: API surface and version bump OK.Exports for DeltaFormer and bump to 0.3.2 look consistent with added modules.
If you publish to PyPI, ensure changelog/release notes reflect the KV-cache limitation.
Also applies to: 38-40, 89-89, 111-111
fla/ops/deltaformer/__init__.py (1)
3-9: Clean public ops surface.Exporting fused and naive pre-attn helpers is clear and sufficient.
fla/models/__init__.py (1)
7-7: Models exports wired correctly.DeltaFormerConfig/Model/ForCausalLM are now available at package level.
Also applies to: 41-41
tests/models/test_modeling_deltaformer.py (1)
15-24: Modeling params LGTM; covers key dtypes and D.Good coverage across use_l2warp and D∈{64,128}; Hopper gating is handled in the base.
Also applies to: 25-35
fla/layers/deltaformer.py (3)
21-25: Guard both flash-attn symbols on ImportError.
flash_attn_varlen_funcis undefined if the import fails. Mirror the sentinel.Apply:
try: from flash_attn import flash_attn_func, flash_attn_varlen_func except ImportError: - flash_attn_func = None + flash_attn_func = None + flash_attn_varlen_func = None
115-121: Don’t useassertfor runtime validation.Asserts can be stripped; raise an explicit exception instead.
Apply:
- if attention_mask is not None: - assert len(attention_mask.shape) == 2, ( - "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " - "for padding purposes (0 indicating padding). " - "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." - ) + if attention_mask is not None and attention_mask.dim() != 2: + raise ValueError( + "attention_mask must be [batch_size, seq_len] with 1=valid and 0=pad; pairwise masks are unsupported." + )
78-84: Validate head divisibility and GQA grouping.Missing checks can yield silent shape bugs when
hidden_size % num_heads != 0ornum_heads % num_kv_heads != 0.Apply:
self.hidden_size = hidden_size self.num_heads = num_heads self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads - self.num_kv_groups = num_heads // self.num_kv_heads - self.head_dim = self.hidden_size // self.num_heads + if self.hidden_size % self.num_heads != 0: + raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads}).") + if self.num_heads % self.num_kv_heads != 0: + raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads}).") + self.num_kv_groups = num_heads // self.num_kv_heads + self.head_dim = self.hidden_size // self.num_heads self.kv_dim = self.num_kv_heads * self.head_dimfla/ops/deltaformer/invcum.py (1)
16-18: Name these helpers for what they actually do (clarify math and call sites).The current names are too generic and were previously flagged. Suggest more explicit names to reduce cognitive load.
-def forward(u, w): +def solve_unit_lower(u, w): @@ -def forward_inplace(u, w): - u.copy_(forward(u, w)) +def solve_unit_lower_inplace(u, w): + u.copy_(solve_unit_lower(u, w)) @@ -def backward_x(do, w): +def solve_unit_upper_T_x(do, w): @@ -def backward(do, w, x): +def backward_wrt_inputs(do, w, x):Follow up: update call sites (
invcum.forward_inplace→solve_unit_lower_inplace, etc.).Also applies to: 20-26, 29-38
fla/models/deltaformer/__init__.py (2)
8-10: Registry calls: confirm Transformers version compatibility.APIs and
exist_okbehavior vary across versions. Please ensure our minimumtransformersversion supports these exact signatures.
12-12: LGTM – clean public surface.fla/ops/deltaformer/fused_chunk.py (2)
885-891: Renameqi/kito match actual tensors (k-chunk vs q-tail).Names are inverted and hard to follow. Adopt descriptive names; same pattern in fixed-length backward below.
- qi = k_seq[:, i0:i1, :] - ki = q_seq[:, i1:L, :] + k_chunk = k_seq[:, i0:i1, :] + q_tail = q_seq[:, i1:L, :] @@ - du_tail = backward_u_chunk(qi, ki, lse_tail, gv_seq[:, i1:L, :], fa_scale, beta_tail) + du_tail = backward_u_chunk(k_chunk, q_tail, lse_tail, gv_seq[:, i1:L, :], fa_scale, beta_tail)- qi = ko[b, :, i:i + Ci, :] - ki = qo[b, :, i + Ci:, :] + k_chunk = ko[b, :, i:i + Ci, :] + q_tail = qo[b, :, i + Ci:, :] @@ - du = backward_u_chunk(qi, ki, lse, grad_v_seq[:, i + Ci:, :], fa_scale, beta_single) + du = backward_u_chunk(k_chunk, q_tail, lse, grad_v_seq[:, i + Ci:, :], fa_scale, beta_single)Also applies to: 924-933
1073-1105: Public API: good shape checking; maintains inference fast path.Please run the provided throughput benchmark once more after these minor renames/guards to ensure no regression.
fla/models/deltaformer/modeling_deltaformer.py (3)
82-88: Confirm fused RMSNorm return signature to avoid unpack errors.
self.mlp_norm(...)is unpacked into two tensors. Ensure the fused RMSNorm returns(normed, residual)whenprenorm=True; otherwise this will throw.If it returns a single tensor, gate by type/arity or switch to explicit residual add:
- if self.config.fuse_norm: - hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + if self.config.fuse_norm: + out = self.mlp_norm(hidden_states, residual, True) + hidden_states, residual = (out if isinstance(out, tuple) else (out, out))
210-257: Preserve exception context and avoid broad rethrow.Chain with
from exception, and use bareraisein the fallback. This mirrors prior feedback.except AttributeError as exception: if 'past_key_values' in str(exception): - raise AttributeError( + raise AttributeError( f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " f"which is not supported for {self.__class__.__name__}. " f"Try another generation strategy instead. " f"For the available generation strategies, check this doc: " f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" - ) + ) from exception else: - raise exception + raise
325-333: Initialize criterion once in init (perf/stability).Avoid per‑call instantiation; move selection to
__init__and reuse.- if labels is not None: - if getattr(self, 'criterion', None) is None: - if self.config.fuse_linear_cross_entropy: - criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp) - elif self.config.fuse_cross_entropy: - criterion = FusedCrossEntropyLoss(inplace_backward=True) - else: - criterion = nn.CrossEntropyLoss() - else: - criterion = self.criterion + if labels is not None: + criterion = self.criterionAdditional change outside this hunk (constructor):
class DeltaFormerForCausalLM(DeltaFormerPreTrainedModel, GenerationMixin): @@ - self.criterion = None + if config.fuse_linear_cross_entropy: + self.criterion = FusedLinearCrossEntropyLoss(use_l2warp=config.use_l2warp) + elif config.fuse_cross_entropy: + self.criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + self.criterion = nn.CrossEntropyLoss()fla/models/deltaformer/configuration_deltaformer.py (1)
92-100: Attn schema validated — keep the checks.config.attn['layers'] and config.attn['num_heads'] are actually consumed by many model implementations and passed into Attention constructors (e.g. fla/models/deltaformer/modeling_deltaformer.py, fla/models/rodimus/modeling_rodimus.py, fla/models/samba/modeling_samba.py), so the validation in fla/models/deltaformer/configuration_deltaformer.py is appropriate.
        
          
                fla/layers/deltaformer.py
              
                Outdated
          
        
      | if flash_attn_func is None: | ||
| raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first") | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Require varlen FlashAttention too (avoid NameError later).
Init only checks flash_attn_func; varlen paths call flash_attn_varlen_func unconditionally.
Apply:
-        if flash_attn_func is None:
-            raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
+        if flash_attn_func is None or flash_attn_varlen_func is None:
+            raise ImportError(
+                "DeltaFormer requires FlashAttention (both flash_attn_func and flash_attn_varlen_func). "
+                "Install via: pip install flash-attn --no-build-isolation"
+            )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if flash_attn_func is None: | |
| raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first") | |
| if flash_attn_func is None or flash_attn_varlen_func is None: | |
| raise ImportError( | |
| "DeltaFormer requires FlashAttention (both flash_attn_func and flash_attn_varlen_func). " | |
| "Install via: pip install flash-attn --no-build-isolation" | |
| ) | 
🧰 Tools
🪛 Ruff (0.13.1)
91-91: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In fla/layers/deltaformer.py around lines 90 to 91, the init only checks
flash_attn_func but the code later unconditionally calls flash_attn_varlen_func
causing a NameError when the varlen implementation is not present; update the
import/initialization check to validate that both flash_attn_func and
flash_attn_varlen_func are available (or set flash_attn_varlen_func to a
fallback) and raise a clear ImportError if either required function is missing,
so downstream varlen code does not reference an undefined name.
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| attention_mask: Optional[torch.LongTensor] = None, | ||
| past_key_values: Optional[Cache] = None, | ||
| output_attentions: bool = False, | ||
| use_cache: bool = False, | ||
| **kwargs, | ||
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
| attentions = None | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
KV cache not implemented; forward signature is misleading.
past_key_values/use_cache are accepted and returned unchanged. This breaks efficient generation.
Apply a minimal guard now; implement caching next:
     def forward(
         self,
         hidden_states: torch.Tensor,
         attention_mask: Optional[torch.LongTensor] = None,
         past_key_values: Optional[Cache] = None,
         output_attentions: bool = False,
         use_cache: bool = False,
         **kwargs,
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        if use_cache:
+            raise NotImplementedError("DeltaFormerAttention KV cache is not implemented yet.")
@@
-        return o, attentions, past_key_values
+        return o, attentions, past_key_valuesAlso consider removing output_attentions until supported.
Also applies to: 214-214
🧰 Tools
🪛 Ruff (0.13.1)
110-110: Unused method argument: output_attentions
(ARG002)
111-111: Unused method argument: use_cache
(ARG002)
🤖 Prompt for AI Agents
In fla/layers/deltaformer.py around lines 105-114 (and similarly at line 214),
the forward signature accepts past_key_values/use_cache/output_attentions but
does not implement KV caching; add a minimal guard: if use_cache is True or
past_key_values is not None, raise NotImplementedError (or ValueError) with a
clear message that KV cache is not yet supported and will be implemented later;
likewise, if output_attentions is True, raise NotImplementedError or ignore it
explicitly with a clear comment. This keeps the API honest now and prevents
silent breaks during generation; implement the same guard at the other affected
location (line ~214).
| if attention_mask is not None: | ||
| seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1] | ||
| max_seqlen = q_len + max(seqlen_offset) | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: using Python max() on a tensor.
max(seqlen_offset) on a torch.Tensor will error; also ensure scalar for cache sizing.
Apply:
-            if attention_mask is not None:
-                seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
-                max_seqlen = q_len + max(seqlen_offset)
+            if attention_mask is not None:
+                seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
+                max_seqlen = q_len + int(seqlen_offset.max().item())📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if attention_mask is not None: | |
| seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1] | |
| max_seqlen = q_len + max(seqlen_offset) | |
| if attention_mask is not None: | |
| seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1] | |
| max_seqlen = q_len + int(seqlen_offset.max().item()) | 
🤖 Prompt for AI Agents
In fla/layers/deltaformer.py around lines 138 to 141, the code calls Python's
built-in max() on seqlen_offset (a torch.Tensor) which will raise an error and
may not produce a Python scalar for cache sizing; replace the built-in max with
a tensor-max and convert to a Python int, e.g. compute seqlen_offset =
seqlen_offset + prepare_lens_from_mask(attention_mask) -
attention_mask.shape[-1] as before, then set max_seqlen = q_len +
int(seqlen_offset.max().cpu().item()) (or use torch.max(seqlen_offset).item())
to ensure max_seqlen is a plain int for subsequent cache sizing.
| if attn is not None: | ||
| if not isinstance(attn, Dict): | ||
| raise ValueError("attn must be a dictionary") | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: isinstance check against typing.Dict will raise or misbehave.
Use Mapping/dict instead of typing.Dict in isinstance. This can otherwise raise TypeError.
-from typing import Dict, Optional
+from typing import Dict, Optional
+from collections.abc import Mapping
@@
-            if not isinstance(attn, Dict):
+            if not isinstance(attn, Mapping):
                 raise ValueError("attn must be a dictionary")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if attn is not None: | |
| if not isinstance(attn, Dict): | |
| raise ValueError("attn must be a dictionary") | |
| from typing import Dict, Optional | |
| from collections.abc import Mapping | |
| if attn is not None: | |
| if not isinstance(attn, Mapping): | |
| raise ValueError("attn must be a dictionary") | 
🧰 Tools
🪛 Ruff (0.13.1)
91-91: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In fla/models/deltaformer/configuration_deltaformer.py around lines 89 to 91,
the isinstance check uses typing.Dict which can raise TypeError; change the
check to use a concrete runtime check like isinstance(attn, Mapping) (from
collections.abc) or isinstance(attn, dict) and add the necessary import (from
collections.abc import Mapping) if using Mapping, so the validation works at
runtime without errors.
| @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") | ||
| def prepare_inputs_for_generation( | ||
| self, | ||
| input_ids: torch.LongTensor = None, | ||
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| inputs_embeds: Optional[torch.Tensor] = None, | ||
| use_cache: bool = True, | ||
| logits_to_keep: Optional[int] = None, | ||
| **kwargs | ||
| ): | ||
| # only last token for `inputs_ids` if the `past_key_values` is not empty. | ||
| if past_key_values is not None and len(past_key_values) > 0: | ||
| input_ids = input_ids[:, -1:] | ||
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | ||
| if inputs_embeds is not None and len(past_key_values) == 0: | ||
| model_inputs = {'inputs_embeds': inputs_embeds} | ||
| else: | ||
| # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise | ||
| # recompiles graphs as the stride of the inputs is a guard. | ||
| model_inputs = {'input_ids': input_ids.contiguous()} | ||
|  | ||
| if logits_to_keep is not None: | ||
| model_inputs['logits_to_keep'] = logits_to_keep | ||
|  | ||
| model_inputs.update({ | ||
| 'past_key_values': past_key_values, | ||
| 'use_cache': use_cache, | ||
| 'attention_mask': attention_mask, | ||
| }) | ||
| return model_inputs | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix None-handling bug in prepare_inputs_for_generation.
len(past_key_values) is called when it can be None, causing a TypeError.
-        if inputs_embeds is not None and len(past_key_values) == 0:
+        if inputs_embeds is not None and (past_key_values is None or len(past_key_values) == 0):
             model_inputs = {'inputs_embeds': inputs_embeds}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| use_cache: bool = True, | |
| logits_to_keep: Optional[int] = None, | |
| **kwargs | |
| ): | |
| # only last token for `inputs_ids` if the `past_key_values` is not empty. | |
| if past_key_values is not None and len(past_key_values) > 0: | |
| input_ids = input_ids[:, -1:] | |
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
| if inputs_embeds is not None and len(past_key_values) == 0: | |
| model_inputs = {'inputs_embeds': inputs_embeds} | |
| else: | |
| # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise | |
| # recompiles graphs as the stride of the inputs is a guard. | |
| model_inputs = {'input_ids': input_ids.contiguous()} | |
| if logits_to_keep is not None: | |
| model_inputs['logits_to_keep'] = logits_to_keep | |
| model_inputs.update({ | |
| 'past_key_values': past_key_values, | |
| 'use_cache': use_cache, | |
| 'attention_mask': attention_mask, | |
| }) | |
| return model_inputs | |
| @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| use_cache: bool = True, | |
| logits_to_keep: Optional[int] = None, | |
| **kwargs | |
| ): | |
| # only last token for `inputs_ids` if the `past_key_values` is not empty. | |
| if past_key_values is not None and len(past_key_values) > 0: | |
| input_ids = input_ids[:, -1:] | |
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
| if inputs_embeds is not None and (past_key_values is None or len(past_key_values) == 0): | |
| model_inputs = {'inputs_embeds': inputs_embeds} | |
| else: | |
| # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise | |
| # recompiles graphs as the stride of the inputs is a guard. | |
| model_inputs = {'input_ids': input_ids.contiguous()} | |
| if logits_to_keep is not None: | |
| model_inputs['logits_to_keep'] = logits_to_keep | |
| model_inputs.update({ | |
| 'past_key_values': past_key_values, | |
| 'use_cache': use_cache, | |
| 'attention_mask': attention_mask, | |
| }) | |
| return model_inputs | 
🧰 Tools
🪛 Ruff (0.13.1)
267-267: Unused method argument: kwargs
(ARG002)
🤖 Prompt for AI Agents
In fla/models/deltaformer/modeling_deltaformer.py around lines 258 to 289, the
code calls len(past_key_values) when past_key_values can be None which raises a
TypeError; change the inputs_embeds check to verify past_key_values is None or
empty (e.g., replace "len(past_key_values) == 0" with "past_key_values is None
or len(past_key_values) == 0") and also guard input_ids.contiguous() by ensuring
input_ids is not None before calling contiguous() so you don't call methods on
None.
| hidden_states = outputs[0] | ||
| # For fused linear cross-entropy we do not materialize logits for the full sequence | ||
| logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:]) | ||
|  | ||
| loss = None | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don’t drop logits in eval when fuse_linear_cross_entropy is enabled.
Currently logits is None whenever fuse_linear_cross_entropy is True, even if labels is None. Compute logits for eval/inference.
-        # For fused linear cross-entropy we do not materialize logits for the full sequence
-        logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
+        # Only skip logits materialization during training with fused linear CE
+        materialize_logits = (labels is None) or (not self.config.fuse_linear_cross_entropy)
+        logits = self.lm_head(hidden_states[:, -logits_to_keep:]) if materialize_logits else None📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| hidden_states = outputs[0] | |
| # For fused linear cross-entropy we do not materialize logits for the full sequence | |
| logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:]) | |
| loss = None | |
| hidden_states = outputs[0] | |
| # Only skip logits materialization during training with fused linear CE | |
| materialize_logits = (labels is None) or (not self.config.fuse_linear_cross_entropy) | |
| logits = self.lm_head(hidden_states[:, -logits_to_keep:]) if materialize_logits else None | |
| loss = None | 
🤖 Prompt for AI Agents
In fla/models/deltaformer/modeling_deltaformer.py around lines 319-323, the code
sets logits to None whenever config.fuse_linear_cross_entropy is True which
incorrectly drops logits during eval/inference; change the logic so that when
fuse_linear_cross_entropy is True you still compute logits if labels is None
(i.e., in eval/inference) but keep logits None during training when labels are
present so the fused loss can run; implement this by branching on both
self.config.fuse_linear_cross_entropy and whether labels is None and compute
logits = self.lm_head(hidden_states[:, -logits_to_keep:]) for the eval path.
| GENERATION_UNSUPPORTED = [ | ||
| "ABCConfig", "LinearAttentionConfig", "LightNetConfig", | ||
| "Mamba2Config", "MambaConfig", "NSAConfig", "SambaConfig", "RWKV6Config", "RWKV7Config", | ||
| "DeltaFormerConfig", | ||
| ] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DeltaFormer added to GENERATION_UNSUPPORTED – track as temporary.
This aligns with missing KV cache, but please add a TODO and issue to re-enable once caching lands.
I can open a follow-up ticket and patch tests to auto-enable when DeltaFormerAttention implements cache.
🤖 Prompt for AI Agents
In tests/models/test_modeling_utils.py around lines 24 to 28, DeltaFormerConfig
was added to GENERATION_UNSUPPORTED due to missing KV cache but lacks a TODO and
tracking issue; add a concise TODO comment immediately next to the
DeltaFormerConfig entry stating this is temporary due to missing KV cache,
create a follow-up issue (or use the provided issue number) to re-enable
generation support once DeltaFormerAttention implements KV caching, and include
the issue number or URL in the TODO so the test can be easily auto-enabled
later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
fla/layers/deltaformer.py (1)
157-181: Consider extracting common deltaformer_attn call to reduce duplication.All three branches call
deltaformer_attnwith identical arguments except forcu_seqlens. This duplication can be simplified.Apply this diff:
-if attention_mask is not None: - u = deltaformer_attn( - # kk similarity - rearrange(k_full, 'b t h d -> b h t d'), - rearrange(k_full, 'b t h d -> b h t d'), - rearrange(v_full, 'b t h d -> b h t d'), - beta_full, - cu_seqlens=cu_seqlens, - ) -elif cu_seqlens is not None: - u = deltaformer_attn( - rearrange(k_full, 'b t h d -> b h t d'), - rearrange(k_full, 'b t h d -> b h t d'), - rearrange(v_full, 'b t h d -> b h t d'), - beta_full, - cu_seqlens=cu_seqlens, - ) -else: - u = deltaformer_attn( - rearrange(k_full, 'b t h d -> b h t d'), - rearrange(k_full, 'b t h d -> b h t d'), - rearrange(v_full, 'b t h d -> b h t d'), - beta_full, - ) +# Use KK similarity for pre-attention +u = deltaformer_attn( + rearrange(k_full, 'b t h d -> b h t d'), + rearrange(k_full, 'b t h d -> b h t d'), + rearrange(v_full, 'b t h d -> b h t d'), + beta_full, + cu_seqlens=cu_seqlens, # None when not using varlen +)fla/ops/deltaformer/naive.py (1)
59-70: Replace asserts with explicit exceptions for runtime validation.Runtime asserts can be stripped with Python optimization and shouldn't be used for input validation.
Apply this diff:
-assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "q,k,v must be [B,H,T,D]" +if q.dim() != 4 or k.dim() != 4 or v.dim() != 4: + raise ValueError("q, k, v must have 4 dimensions [B, H, T, D]") B, H, T, D = q.shape -assert k.shape == (B, H, T, D) and v.shape == (B, H, T, D) +if k.shape != (B, H, T, D) or v.shape != (B, H, T, D): + raise ValueError(f"Shape mismatch: q={q.shape}, k={k.shape}, v={v.shape}. All must be [{B}, {H}, {T}, {D}]") orig_dtype = q.dtype qf = q.float() kf = k.float() vf = v.float() if beta is None: betaf = torch.ones((B, H, T), dtype=torch.float32, device=q.device) else: - assert beta.shape == (B, H, T) + if beta.shape != (B, H, T): + raise ValueError(f"beta shape {beta.shape} does not match expected [{B}, {H}, {T}]") betaf = beta.float()fla/ops/deltaformer/parallel.py (1)
162-169: Consider expanding autotune search space for better performance.The autotune configuration uses a limited search space. Per PR comments, autotune was added to address performance issues with default block sizes being too large.
Consider adding more block size options to the search space:
def _config_delta_flash_attn(): return [ triton.Config({'BLOCK_C': BC, 'BLOCK_T': BT}, num_stages=ns, num_warps=nw) - for BC in [128, 64] - for BT in [64, 32] + for BC in [256, 128, 64, 32] + for BT in [128, 64, 32, 16] for ns in [3, 2] for nw in [8, 4] ]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
- fla/layers/deltaformer.py(1 hunks)
- fla/ops/deltaformer/__init__.py(1 hunks)
- fla/ops/deltaformer/naive.py(1 hunks)
- fla/ops/deltaformer/parallel.py(1 hunks)
- tests/ops/test_deltaformer.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/ops/test_deltaformer.py
🧰 Additional context used
🧬 Code graph analysis (3)
fla/layers/deltaformer.py (4)
fla/layers/utils.py (3)
get_unpad_data(75-98)
pad_input(176-197)
unpad_input(101-173)fla/modules/layernorm.py (1)
RMSNorm(1060-1107)fla/ops/deltaformer/parallel.py (3)
deltaformer_attn(1073-1099)
forward(16-31)
forward(802-832)fla/ops/utils/index.py (1)
prepare_lens_from_mask(43-44)
fla/ops/deltaformer/parallel.py (2)
fla/layers/deltaformer.py (1)
forward(105-214)fla/ops/deltaformer/invcum.py (2)
backward_x(20-26)
forward_inplace(16-17)
fla/ops/deltaformer/__init__.py (2)
fla/ops/deltaformer/naive.py (1)
naive_deltaformer_attn(39-87)fla/ops/deltaformer/parallel.py (1)
deltaformer_attn(1073-1099)
🪛 Ruff (0.13.1)
fla/layers/deltaformer.py
91-91: Avoid specifying long messages outside the exception class
(TRY003)
110-110: Unused method argument: output_attentions
(ARG002)
111-111: Unused method argument: use_cache
(ARG002)
fla/ops/deltaformer/parallel.py
25-25: Unpacked variable T is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
198-198: Unused function argument: B
(ARG001)
364-364: Unused function argument: B
(ARG001)
477-477: Unused function argument: B
(ARG001)
604-604: Unused function argument: B
(ARG001)
842-842: Unpacked variable T_max is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🔇 Additional comments (12)
fla/layers/deltaformer.py (7)
115-120: Replace assert with explicit ValueError.Using
assertfor runtime validation can be stripped by Python optimization (-Oflag).Apply this diff:
if attention_mask is not None: - assert len(attention_mask.shape) == 2, ( - "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " - "for padding purposes (0 indicating padding). " - "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." - ) + if attention_mask.dim() != 2: + raise ValueError( + "attention_mask must be [batch, seq_len] with 1=valid and 0=pad; 2D pairwise masks are unsupported." + )
21-25: Set both to None to prevent NameError.If the import fails, only
flash_attn_funcis set toNone, leavingflash_attn_varlen_funcundefined. This will cause aNameErrorwhen checking it later.Apply this diff:
try: from flash_attn import flash_attn_func, flash_attn_varlen_func except ImportError: flash_attn_func = None + flash_attn_varlen_func = None
78-84: Add validation for head dimension divisibility.The code assumes
hidden_sizeis divisible bynum_headsandnum_headsbynum_kv_heads, but doesn't validate these assumptions. This can lead to silent errors.Apply this diff:
self.hidden_size = hidden_size self.num_heads = num_heads self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads +if self.hidden_size % self.num_heads != 0: + raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})") +if self.num_heads % self.num_kv_heads != 0: + raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})") self.num_kv_groups = num_heads // self.num_kv_heads self.head_dim = self.hidden_size // self.num_heads self.kv_dim = self.num_kv_heads * self.head_dim
90-91: Check both flash_attn imports to avoid NameError later.The code only checks
flash_attn_funcbut callsflash_attn_varlen_funcunconditionally in the forward pass, which will cause a NameError if varlen import failed.Apply this diff:
-if flash_attn_func is None: - raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first") +if flash_attn_func is None or flash_attn_varlen_func is None: + raise ImportError( + "DeltaFormer requires FlashAttention (both flash_attn_func and flash_attn_varlen_func). " + "Install via: pip install flash-attn --no-build-isolation" + )
105-114: KV cache not implemented - fail explicitly.The method accepts
past_key_valuesanduse_cachebut doesn't implement caching, which will silently break generation.Apply this diff:
def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if use_cache: + raise NotImplementedError("DeltaFormerAttention KV cache is not implemented yet.") + if output_attentions: + raise NotImplementedError("DeltaFormer does not support outputting attention weights") attentions = None
138-141: Fix: Using Python max() on tensor will fail.Python's built-in
max()on a PyTorch tensor will raise an error. Use tensor's.max()method and convert to Python int.Apply this diff:
if attention_mask is not None: seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1] - max_seqlen = q_len + max(seqlen_offset) + max_seqlen = q_len + int(seqlen_offset.max().item())
145-156: ROPE cu_seqlens mismatch for attention_mask path.ROPE receives
cu_seqlens_kwbut should use the mask-derivedcu_seqlenswhenattention_maskis present for correct position indexing.Apply this diff to compute cu_seqlens first:
-q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens_kw) if attention_mask is not None: _, cu_seqlens, _ = get_unpad_data(attention_mask) + q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens) # q_full = q k_full = k v_full = v beta_full = beta else: cu_seqlens = cu_seqlens_kw + q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens_kw) # q_full, k_full, v_full, beta_full = q, k, v, beta k_full, v_full, beta_full = k, v, betafla/ops/deltaformer/__init__.py (1)
1-9: LGTM! Clean module exports.The module properly exports both the optimized
deltaformer_attnand the referencenaive_deltaformer_attnimplementation with clear all definition.fla/ops/deltaformer/naive.py (2)
10-36: LGTM! Well-implemented causal softmax utility.The
tril_softmaxfunction correctly implements row-wise causal softmax with proper numerical stability through max subtraction and handles the strict vs non-strict causal masking appropriately.
76-87: LGTM! Correct iterative DeltaFormer computation.The iterative computation correctly implements the DeltaFormer formula
u[i] = v[i] - beta[i] * sum_{j<i} softmax(q[i] @ k[:i]^T) @ u[:i]with proper handling of the base case and accumulation of previous u values.fla/ops/deltaformer/parallel.py (2)
1073-1099: LGTM! Clean public API with varlen support.The
deltaformer_attnfunction provides a clean interface supporting both fixed-length and variable-length sequences with proper gradient handling through the custom autograd function.
839-905: Varlen backward assumes BS==1 - add explicit guard.The backward pass indexes tensors with
[0, ...]which assumes batch size 1. This will silently produce incorrect gradients for BS>1.Apply this diff:
if getattr(ctx, 'cu_seqlens', None) is not None: cu = ctx.cu_seqlens qo, ko, vo, u_full, ws, lses, betao = ctx.saved_tensors BS, NH, T_max, D = ko.size() + if BS != 1: + raise NotImplementedError( + "Varlen deltaformer backward currently requires packed inputs with BS==1. " + "Got BS={BS}. Please pack sequences into a single batch dimension." + )
…n func integration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
fla/ops/deltaformer/parallel.py (1)
17-20: Add stacklevel to warning for better diagnostics.Static analyzer suggests adding
stacklevel=2to point the warning to the calling code rather than this module.warnings.warn( "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`", - category=ImportWarning + category=ImportWarning, + stacklevel=2 )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
- fla/layers/deltaformer.py(1 hunks)
- fla/ops/deltaformer/parallel.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
fla/layers/deltaformer.py (5)
fla/modules/layernorm.py (1)
RMSNorm(1060-1107)fla/modules/rotary.py (1)
RotaryEmbedding(306-500)fla/ops/deltaformer/parallel.py (2)
deltaformer_attn(941-986)
forward(664-694)fla/ops/utils/index.py (1)
prepare_lens_from_mask(43-44)fla/models/deltaformer/modeling_deltaformer.py (3)
forward(63-91)
forward(142-207)
forward(291-351)
fla/ops/deltaformer/parallel.py (3)
fla/layers/utils.py (2)
pad_input(176-197)
unpad_input(101-173)fla/layers/deltaformer.py (1)
forward(96-153)fla/ops/deltaformer/invcum.py (2)
backward_x(20-26)
forward_inplace(16-17)
🪛 Ruff (0.13.1)
fla/layers/deltaformer.py
102-102: Unused method argument: use_cache
(ARG002)
fla/ops/deltaformer/parallel.py
17-17: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
37-37: Unpacked variable T is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
673-673: Unpacked variable D is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
704-704: Unpacked variable T_max is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
951-951: Avoid specifying long messages outside the exception class
(TRY003)
953-953: Unpacked variable H is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
953-953: Unpacked variable D is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🔇 Additional comments (6)
fla/layers/deltaformer.py (4)
1-154: LGTM! Clean and comprehensive DeltaFormer attention layer implementation.This implementation provides a solid foundation for the DeltaFormer attention mechanism with proper module structure, comprehensive docstrings, and appropriate parameter handling. The forward method correctly integrates with the deltaformer ops and handles both fixed-length and variable-length sequences appropriately.
101-103: Add explicit guard for unused use_cache parameter.The static analyzer correctly identifies unused
use_cacheparameter. Based on past review feedback, this should be guarded explicitly to prevent misleading the API consumer.Apply this guard to make the API behavior explicit:
def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if use_cache: + raise NotImplementedError("DeltaFormerAttention KV cache is not implemented yet.") + if output_attentions: + raise NotImplementedError("DeltaFormerAttention does not support outputting attention weights.") attentions = None
106-111: Replace assertion with explicit exception.As noted in past reviews, runtime assertions can be stripped with
-Oand should be replaced with explicit exceptions.- if attention_mask is not None: - assert len(attention_mask.shape) == 2, ( - "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " - "for padding purposes (0 indicating padding). " - "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." - ) + if attention_mask is not None: + if attention_mask.dim() != 2: + raise ValueError( + "attention_mask must be [batch, seq_len] with 1=valid and 0=pad; 2D pairwise masks are unsupported." + )
129-131: Fix tensor max() usage for seqlen_offset calculation.The code uses Python's built-in
max()on a tensor which will raise an error. This needs to be fixed for proper cache sizing.if attention_mask is not None: seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1] - max_seqlen = q_len + max(seqlen_offset) + max_seqlen = q_len + int(seqlen_offset.max().item())fla/ops/deltaformer/parallel.py (2)
950-951: Check for both flash attention functions in error guard.The function check only validates
flash_attn_funcbut the code later unconditionally callsflash_attn_varlen_func, which could cause NameError if varlen function is missing.- if flash_attn_func is None: - raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first") + if flash_attn_func is None or flash_attn_varlen_func is None: + raise ImportError( + "DeltaFormer requires FlashAttention (both flash_attn_func and flash_attn_varlen_func). " + "Install via: pip install flash-attn --no-build-isolation" + )
1-991: LGTM! Comprehensive Triton-accelerated DeltaFormer implementation.This is a well-structured implementation of the DeltaFormer parallel attention with proper autotuning, autograd support, and both fixed-length and variable-length sequence handling. The Triton kernels are appropriately organized and the public API provides clean abstraction over the complex internals.
| try: | ||
| from flash_attn import flash_attn_func, flash_attn_varlen_func | ||
| except ImportError: | ||
| warnings.warn( | ||
| "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`", | ||
| category=ImportWarning | ||
| ) | ||
| flash_attn_func = None | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing flash_attn_varlen_func in except block.
The search results confirm that flash_attn import errors are common and both functions need to be handled. The except block only sets flash_attn_func = None but leaves flash_attn_varlen_func undefined, which will cause NameError when the varlen path is used.
Apply this fix to define both sentinel values:
 try:
     from flash_attn import flash_attn_func, flash_attn_varlen_func
 except ImportError:
     warnings.warn(
         "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
         category=ImportWarning
     )
     flash_attn_func = None
+    flash_attn_varlen_func = None📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| try: | |
| from flash_attn import flash_attn_func, flash_attn_varlen_func | |
| except ImportError: | |
| warnings.warn( | |
| "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`", | |
| category=ImportWarning | |
| ) | |
| flash_attn_func = None | |
| try: | |
| from flash_attn import flash_attn_func, flash_attn_varlen_func | |
| except ImportError: | |
| warnings.warn( | |
| "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`", | |
| category=ImportWarning | |
| ) | |
| flash_attn_func = None | |
| flash_attn_varlen_func = None | 
🧰 Tools
🪛 Ruff (0.13.1)
17-17: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
🤖 Prompt for AI Agents
In fla/ops/deltaformer/parallel.py around lines 14 to 22, the except ImportError
only sets flash_attn_func = None but leaves flash_attn_varlen_func undefined;
update the except block to also assign flash_attn_varlen_func = None so both
sentinels are defined when the import fails, and keep the existing warning
message intact.
fla/ops/deltaformer/parallel.py,invcum.py, andnaive.py. The pre-attention calculates the u vector, which is a substitute for v that can be directly integrated into flash attention func.betadefaults to all ones when not provided.Summary by CodeRabbit
New Features
Documentation
Tests
Chores