- 
                Notifications
    You must be signed in to change notification settings 
- Fork 280
[DeltaFormer] Add Model #585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e6a82e9
              90f6bcb
              059105d
              566b28b
              6f5bf10
              8f99922
              c5fceb2
              c17dfec
              cab4361
              1744163
              0ea272d
              12ae828
              9805fa0
              b9c3be0
              dfb54ec
              0a23799
              e6f09c3
              edb3093
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,153 @@ | ||||||||||||||||||||||||||||||||||
| # -*- coding: utf-8 -*- | ||||||||||||||||||||||||||||||||||
| # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| from typing import TYPE_CHECKING, Optional, Tuple | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||||||||||||||
| from einops import rearrange | ||||||||||||||||||||||||||||||||||
| from transformers.utils import logging | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| from fla.modules import RMSNorm, RotaryEmbedding | ||||||||||||||||||||||||||||||||||
| from fla.ops.deltaformer import deltaformer_attn | ||||||||||||||||||||||||||||||||||
| from fla.ops.utils.index import prepare_lens_from_mask | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| if TYPE_CHECKING: | ||||||||||||||||||||||||||||||||||
| from fla.models.utils import Cache | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| logger = logging.get_logger(__name__) | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| class DeltaFormerAttention(nn.Module): | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| r""" | ||||||||||||||||||||||||||||||||||
| The layer implementation for DeltaFormer, | ||||||||||||||||||||||||||||||||||
| [Understanding Transformer from the Perspective of Associative Memory] | ||||||||||||||||||||||||||||||||||
| (https://arxiv.org/pdf/2505.19488). | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| Notes | ||||||||||||||||||||||||||||||||||
| - DeltaFormer attention is implemented with Triton kernels in `fla.ops.deltaformer` and is tuned | ||||||||||||||||||||||||||||||||||
| for typical head dimensions (e.g., 64/128). It currently supports fixed-length inputs. | ||||||||||||||||||||||||||||||||||
| - For variable-length inputs (padding masks), the deltaformer computation falls back to using the | ||||||||||||||||||||||||||||||||||
| fixed-length path, while the second stage (softmax attention over U) uses FlashAttention's | ||||||||||||||||||||||||||||||||||
| varlen path when an attention mask is provided. | ||||||||||||||||||||||||||||||||||
| - K/V grouping (GQA) is supported natively by FlashAttention via `num_kv_heads`. | ||||||||||||||||||||||||||||||||||
| - Uses K-K similarity in deltaformer computation instead of Q-K similarity for better performance. | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||
| hidden_size (int, Optional): | ||||||||||||||||||||||||||||||||||
| The hidden size of the input. Default: 2048. | ||||||||||||||||||||||||||||||||||
| num_heads (int, Optional): | ||||||||||||||||||||||||||||||||||
| The number of attention heads. Default: 32. | ||||||||||||||||||||||||||||||||||
| num_kv_heads (int, Optional): | ||||||||||||||||||||||||||||||||||
| The number of key/value heads for grouped-query attention. If None, equals `num_heads`. | ||||||||||||||||||||||||||||||||||
| Default: None. | ||||||||||||||||||||||||||||||||||
| qkv_bias (bool, Optional): | ||||||||||||||||||||||||||||||||||
| Whether to use bias for Q/K/V projections. Default: False. | ||||||||||||||||||||||||||||||||||
| qk_norm (bool, Optional): | ||||||||||||||||||||||||||||||||||
| Whether to apply per-head RMSNorm to Q and K before attention. Default: False. | ||||||||||||||||||||||||||||||||||
| rope_theta (float, Optional): | ||||||||||||||||||||||||||||||||||
| The base frequency for rotary position embedding. Default: 10000. | ||||||||||||||||||||||||||||||||||
| max_position_embeddings (int, Optional): | ||||||||||||||||||||||||||||||||||
| The maximum position embeddings. Default: None. | ||||||||||||||||||||||||||||||||||
| layer_idx (int, Optional): | ||||||||||||||||||||||||||||||||||
| The index of the layer (used for cache compatibility). Default: None. | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||
| hidden_size: int = 2048, | ||||||||||||||||||||||||||||||||||
| num_heads: int = 32, | ||||||||||||||||||||||||||||||||||
| num_kv_heads: Optional[int] = None, | ||||||||||||||||||||||||||||||||||
| qkv_bias: bool = False, | ||||||||||||||||||||||||||||||||||
| qk_norm: bool = False, | ||||||||||||||||||||||||||||||||||
| rope_theta: float = 10000., | ||||||||||||||||||||||||||||||||||
| max_position_embeddings: Optional[int] = None, | ||||||||||||||||||||||||||||||||||
| layer_idx: int | None = None, | ||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| self.hidden_size = hidden_size | ||||||||||||||||||||||||||||||||||
| self.num_heads = num_heads | ||||||||||||||||||||||||||||||||||
| self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads | ||||||||||||||||||||||||||||||||||
| self.num_kv_groups = num_heads // self.num_kv_heads | ||||||||||||||||||||||||||||||||||
| self.head_dim = self.hidden_size // self.num_heads | ||||||||||||||||||||||||||||||||||
| self.kv_dim = self.num_kv_heads * self.head_dim | ||||||||||||||||||||||||||||||||||
| self.qkv_bias = qkv_bias | ||||||||||||||||||||||||||||||||||
| self.qk_norm = qk_norm | ||||||||||||||||||||||||||||||||||
| self.rope_theta = rope_theta | ||||||||||||||||||||||||||||||||||
| self.max_position_embeddings = max_position_embeddings | ||||||||||||||||||||||||||||||||||
| self.layer_idx = layer_idx | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias) | ||||||||||||||||||||||||||||||||||
| self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) | ||||||||||||||||||||||||||||||||||
| self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) | ||||||||||||||||||||||||||||||||||
| self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True) | ||||||||||||||||||||||||||||||||||
| self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| if qk_norm: | ||||||||||||||||||||||||||||||||||
| self.q_norm = RMSNorm(self.head_dim) | ||||||||||||||||||||||||||||||||||
| self.k_norm = RMSNorm(self.head_dim) | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta) | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| def forward( | ||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||
| hidden_states: torch.Tensor, | ||||||||||||||||||||||||||||||||||
| attention_mask: Optional[torch.LongTensor] = None, | ||||||||||||||||||||||||||||||||||
| past_key_values: Optional[Cache] = None, | ||||||||||||||||||||||||||||||||||
| output_attentions: bool = False, | ||||||||||||||||||||||||||||||||||
| use_cache: bool = False, | ||||||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +101
     to 
      +103
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Remove unused parameters or implement functionality. The static analyzer correctly identifies unused parameters. Either implement the missing functionality or remove these parameters to avoid misleading the API consumer. For  -def forward(
-    self,
-    hidden_states: torch.Tensor,
-    attention_mask: Optional[torch.LongTensor] = None,
-    past_key_values: Optional[Cache] = None,
-    output_attentions: bool = False,
-    use_cache: bool = False,
-    **kwargs,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-    attentions = None
+def forward(
+    self,
+    hidden_states: torch.Tensor,
+    attention_mask: Optional[torch.LongTensor] = None,
+    past_key_values: Optional[Cache] = None,
+    output_attentions: bool = False,
+    use_cache: bool = False,
+    **kwargs,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+    if output_attentions:
+        raise NotImplementedError("DeltaFormer does not support outputting attention weights")
+    attentions = None📝 Committable suggestion
 
        Suggested change
       
 🧰 Tools🪛 Ruff (0.12.2)98-98: Unused method argument:  (ARG002) 100-100: Unused method argument:  (ARG002) 🤖 Prompt for AI Agents | ||||||||||||||||||||||||||||||||||
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +96
     to 
      +104
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The  | ||||||||||||||||||||||||||||||||||
| attentions = None | ||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +96
     to 
      +105
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. KV cache not implemented; forward signature is misleading. 
 Apply a minimal guard now; implement caching next:      def forward(
         self,
         hidden_states: torch.Tensor,
         attention_mask: Optional[torch.LongTensor] = None,
         past_key_values: Optional[Cache] = None,
         output_attentions: bool = False,
         use_cache: bool = False,
         **kwargs,
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        if use_cache:
+            raise NotImplementedError("DeltaFormerAttention KV cache is not implemented yet.")
@@
-        return o, attentions, past_key_values
+        return o, attentions, past_key_valuesAlso consider removing  Also applies to: 214-214 🧰 Tools🪛 Ruff (0.13.1)110-110: Unused method argument:  (ARG002) 111-111: Unused method argument:  (ARG002) 🤖 Prompt for AI Agents | ||||||||||||||||||||||||||||||||||
| if attention_mask is not None: | ||||||||||||||||||||||||||||||||||
| assert len(attention_mask.shape) == 2, ( | ||||||||||||||||||||||||||||||||||
| "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " | ||||||||||||||||||||||||||||||||||
| "for padding purposes (0 indicating padding). " | ||||||||||||||||||||||||||||||||||
| "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
|         
                  Nathancgy marked this conversation as resolved.
              Show resolved
            Hide resolved | ||||||||||||||||||||||||||||||||||
| batch_size, q_len, _ = hidden_states.size() | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) | ||||||||||||||||||||||||||||||||||
| k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) | ||||||||||||||||||||||||||||||||||
| v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) | ||||||||||||||||||||||||||||||||||
| beta = self.b_proj(hidden_states) | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| if self.qk_norm: | ||||||||||||||||||||||||||||||||||
| q, k = self.q_norm(q), self.k_norm(k) | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| cu_seqlens_kw = kwargs.get('cu_seqlens', None) | ||||||||||||||||||||||||||||||||||
| seqlen_offset, max_seqlen = 0, q_len | ||||||||||||||||||||||||||||||||||
| if past_key_values is not None: | ||||||||||||||||||||||||||||||||||
| seqlen_offset = past_key_values.get_seq_length(self.layer_idx) | ||||||||||||||||||||||||||||||||||
| max_seqlen = q_len + seqlen_offset | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| if attention_mask is not None: | ||||||||||||||||||||||||||||||||||
| seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1] | ||||||||||||||||||||||||||||||||||
| max_seqlen = q_len + max(seqlen_offset) | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +129
     to 
      +132
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: using Python  
 Apply: -            if attention_mask is not None:
-                seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
-                max_seqlen = q_len + max(seqlen_offset)
+            if attention_mask is not None:
+                seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
+                max_seqlen = q_len + int(seqlen_offset.max().item())📝 Committable suggestion
 
        Suggested change
       
 🤖 Prompt for AI Agents | ||||||||||||||||||||||||||||||||||
| if self.max_position_embeddings is not None: | ||||||||||||||||||||||||||||||||||
| max_seqlen = max(max_seqlen, self.max_position_embeddings) | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens_kw) | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| o = deltaformer_attn( | ||||||||||||||||||||||||||||||||||
| q=q, | ||||||||||||||||||||||||||||||||||
| k=k, | ||||||||||||||||||||||||||||||||||
| v=v, | ||||||||||||||||||||||||||||||||||
| beta=beta, | ||||||||||||||||||||||||||||||||||
| attention_mask=attention_mask, | ||||||||||||||||||||||||||||||||||
| cu_seqlens=cu_seqlens_kw | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| o = o.reshape(batch_size, q_len, -1) | ||||||||||||||||||||||||||||||||||
| o = self.o_proj(o) | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| if not output_attentions: | ||||||||||||||||||||||||||||||||||
| attentions = None | ||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||
| return o, attentions, past_key_values | ||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| # -*- coding: utf-8 -*- | ||
|  | ||
| from transformers import AutoConfig, AutoModel, AutoModelForCausalLM | ||
|  | ||
| from fla.models.deltaformer.configuration_deltaformer import DeltaFormerConfig | ||
| from fla.models.deltaformer.modeling_deltaformer import DeltaFormerForCausalLM, DeltaFormerModel | ||
|  | ||
| AutoConfig.register(DeltaFormerConfig.model_type, DeltaFormerConfig, exist_ok=True) | ||
| AutoModel.register(DeltaFormerConfig, DeltaFormerModel, exist_ok=True) | ||
| AutoModelForCausalLM.register(DeltaFormerConfig, DeltaFormerForCausalLM, exist_ok=True) | ||
|  | ||
| __all__ = ['DeltaFormerConfig', 'DeltaFormerForCausalLM', 'DeltaFormerModel'] | 
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,107 @@ | ||||||||||||||||||||
| # -*- coding: utf-8 -*- | ||||||||||||||||||||
|  | ||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||
|  | ||||||||||||||||||||
| import warnings | ||||||||||||||||||||
| from typing import Dict, Optional | ||||||||||||||||||||
|  | ||||||||||||||||||||
| from transformers.configuration_utils import PretrainedConfig | ||||||||||||||||||||
|  | ||||||||||||||||||||
|  | ||||||||||||||||||||
| class DeltaFormerConfig(PretrainedConfig): | ||||||||||||||||||||
| model_type = 'deltaformer' | ||||||||||||||||||||
| keys_to_ignore_at_inference = ['past_key_values'] | ||||||||||||||||||||
|  | ||||||||||||||||||||
| def __init__( | ||||||||||||||||||||
| self, | ||||||||||||||||||||
| hidden_size: int = 2048, | ||||||||||||||||||||
| hidden_ratio: Optional[int] = 4, | ||||||||||||||||||||
| intermediate_size: Optional[int] = None, | ||||||||||||||||||||
| num_hidden_layers: int = 24, | ||||||||||||||||||||
| num_heads: int = 8, | ||||||||||||||||||||
| num_kv_heads: Optional[int] = None, | ||||||||||||||||||||
| attn_mode: str = "chunk", | ||||||||||||||||||||
| hidden_act: str = "swish", | ||||||||||||||||||||
| max_position_embeddings: int = 2048, | ||||||||||||||||||||
| elementwise_affine: Optional[bool] = True, | ||||||||||||||||||||
| norm_eps: float = 1e-6, | ||||||||||||||||||||
| qkv_bias: bool = False, | ||||||||||||||||||||
| qk_norm: bool = False, | ||||||||||||||||||||
| rope_theta: float = 10000., | ||||||||||||||||||||
| rope_max_position_embeddings: Optional[int] = None, | ||||||||||||||||||||
| attn: Optional[Dict] = None, | ||||||||||||||||||||
| use_cache: bool = True, | ||||||||||||||||||||
| pad_token_id: Optional[int] = None, | ||||||||||||||||||||
| bos_token_id: int = 1, | ||||||||||||||||||||
| eos_token_id: int = 2, | ||||||||||||||||||||
| tie_word_embeddings: bool = False, | ||||||||||||||||||||
| initializer_range: float = 0.02, | ||||||||||||||||||||
| fuse_norm: bool = True, | ||||||||||||||||||||
| fuse_swiglu: bool = True, | ||||||||||||||||||||
| fuse_cross_entropy: bool = True, | ||||||||||||||||||||
| fuse_linear_cross_entropy: bool = False, | ||||||||||||||||||||
| use_l2warp: bool = False, | ||||||||||||||||||||
| vocab_size: int = 32000, | ||||||||||||||||||||
| output_attentions: bool = False, | ||||||||||||||||||||
| output_hidden_states: bool = False, | ||||||||||||||||||||
| **kwargs | ||||||||||||||||||||
| ): | ||||||||||||||||||||
| self.hidden_size = hidden_size | ||||||||||||||||||||
| self.hidden_ratio = hidden_ratio | ||||||||||||||||||||
| self.intermediate_size = intermediate_size | ||||||||||||||||||||
| self.num_hidden_layers = num_hidden_layers | ||||||||||||||||||||
| self.num_heads = num_heads | ||||||||||||||||||||
| self.num_kv_heads = num_kv_heads | ||||||||||||||||||||
| self.attn_mode = attn_mode | ||||||||||||||||||||
| self.hidden_act = hidden_act | ||||||||||||||||||||
| self.max_position_embeddings = max_position_embeddings | ||||||||||||||||||||
| self.elementwise_affine = elementwise_affine | ||||||||||||||||||||
| self.norm_eps = norm_eps | ||||||||||||||||||||
| self.qkv_bias = qkv_bias | ||||||||||||||||||||
| self.qk_norm = qk_norm | ||||||||||||||||||||
| self.rope_theta = rope_theta | ||||||||||||||||||||
| self.rope_max_position_embeddings = rope_max_position_embeddings | ||||||||||||||||||||
| self.attn = attn | ||||||||||||||||||||
| self.use_cache = use_cache | ||||||||||||||||||||
| self.initializer_range = initializer_range | ||||||||||||||||||||
|  | ||||||||||||||||||||
| self.fuse_norm = fuse_norm | ||||||||||||||||||||
| self.fuse_swiglu = fuse_swiglu | ||||||||||||||||||||
| self.fuse_cross_entropy = fuse_cross_entropy | ||||||||||||||||||||
| self.fuse_linear_cross_entropy = fuse_linear_cross_entropy | ||||||||||||||||||||
| self.use_l2warp = use_l2warp | ||||||||||||||||||||
| self.vocab_size = vocab_size | ||||||||||||||||||||
|  | ||||||||||||||||||||
| self.output_attentions = output_attentions | ||||||||||||||||||||
| self.output_hidden_states = output_hidden_states | ||||||||||||||||||||
|  | ||||||||||||||||||||
| if fuse_cross_entropy and fuse_linear_cross_entropy: | ||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||
| "`fuse_cross_entropy` and `fuse_linear_cross_entropy` cannot be True at the same time." | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| if fuse_linear_cross_entropy: | ||||||||||||||||||||
| warnings.warn( | ||||||||||||||||||||
| "`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency " | ||||||||||||||||||||
| "at the potential cost of reduced precision. " | ||||||||||||||||||||
| "If you observe issues like loss divergence, consider disabling this setting." | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|  | ||||||||||||||||||||
| if attn is not None: | ||||||||||||||||||||
| if not isinstance(attn, Dict): | ||||||||||||||||||||
| raise ValueError("attn must be a dictionary") | ||||||||||||||||||||
| 
      Comment on lines
    
      +89
     to 
      +91
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: isinstance check against typing.Dict will raise or misbehave. Use  -from typing import Dict, Optional
+from typing import Dict, Optional
+from collections.abc import Mapping
@@
-            if not isinstance(attn, Dict):
+            if not isinstance(attn, Mapping):
                 raise ValueError("attn must be a dictionary")📝 Committable suggestion
 
        Suggested change
       
 🧰 Tools🪛 Ruff (0.13.1)91-91: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents | ||||||||||||||||||||
| if 'layers' not in attn: | ||||||||||||||||||||
| raise ValueError("Layer indices must be provided to initialize hybrid attention layers") | ||||||||||||||||||||
| if 'num_heads' not in attn: | ||||||||||||||||||||
| raise ValueError("Number of heads must be provided to initialize hybrid attention layers") | ||||||||||||||||||||
| attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) | ||||||||||||||||||||
| attn['qkv_bias'] = attn.get('qkv_bias', False) | ||||||||||||||||||||
| attn['window_size'] = attn.get('window_size', None) | ||||||||||||||||||||
| attn['rope_theta'] = attn.get('rope_theta', 10000.) | ||||||||||||||||||||
|  | ||||||||||||||||||||
| super().__init__( | ||||||||||||||||||||
| pad_token_id=pad_token_id, | ||||||||||||||||||||
| bos_token_id=bos_token_id, | ||||||||||||||||||||
| eos_token_id=eos_token_id, | ||||||||||||||||||||
| tie_word_embeddings=tie_word_embeddings, | ||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate head divisibility and GQA grouping early.
Prevent silent shape mismatches when
hidden_size % num_heads != 0ornum_heads % num_kv_heads != 0.📝 Committable suggestion
🤖 Prompt for AI Agents