|  | 
|  | 1 | +# -*- coding: utf-8 -*- | 
|  | 2 | +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang | 
|  | 3 | + | 
|  | 4 | +from __future__ import annotations | 
|  | 5 | + | 
|  | 6 | +from typing import TYPE_CHECKING, Optional, Tuple | 
|  | 7 | + | 
|  | 8 | +import torch | 
|  | 9 | +import torch.nn as nn | 
|  | 10 | +from einops import rearrange | 
|  | 11 | +from transformers.utils import logging | 
|  | 12 | + | 
|  | 13 | +from fla.modules import RMSNorm, RotaryEmbedding | 
|  | 14 | +from fla.ops.deltaformer import deltaformer_attn | 
|  | 15 | +from fla.ops.utils.index import prepare_lens_from_mask | 
|  | 16 | + | 
|  | 17 | +if TYPE_CHECKING: | 
|  | 18 | +    from fla.models.utils import Cache | 
|  | 19 | + | 
|  | 20 | +logger = logging.get_logger(__name__) | 
|  | 21 | + | 
|  | 22 | + | 
|  | 23 | +class DeltaFormerAttention(nn.Module): | 
|  | 24 | + | 
|  | 25 | +    r""" | 
|  | 26 | +    The layer implementation for DeltaFormer, | 
|  | 27 | +    [Understanding Transformer from the Perspective of Associative Memory] | 
|  | 28 | +    (https://arxiv.org/pdf/2505.19488). | 
|  | 29 | +
 | 
|  | 30 | +    Notes | 
|  | 31 | +        - DeltaFormer attention is implemented with Triton kernels in `fla.ops.deltaformer` and is tuned | 
|  | 32 | +          for typical head dimensions (e.g., 64/128). It currently supports fixed-length inputs. | 
|  | 33 | +        - For variable-length inputs (padding masks), the deltaformer computation falls back to using the | 
|  | 34 | +          fixed-length path, while the second stage (softmax attention over U) uses FlashAttention's | 
|  | 35 | +          varlen path when an attention mask is provided. | 
|  | 36 | +        - K/V grouping (GQA) is supported natively by FlashAttention via `num_kv_heads`. | 
|  | 37 | +        - Uses K-K similarity in deltaformer computation instead of Q-K similarity for better performance. | 
|  | 38 | +
 | 
|  | 39 | +    Args: | 
|  | 40 | +        hidden_size (int, Optional): | 
|  | 41 | +            The hidden size of the input. Default: 2048. | 
|  | 42 | +        num_heads (int, Optional): | 
|  | 43 | +            The number of attention heads. Default: 32. | 
|  | 44 | +        num_kv_heads (int, Optional): | 
|  | 45 | +            The number of key/value heads for grouped-query attention. If None, equals `num_heads`. | 
|  | 46 | +            Default: None. | 
|  | 47 | +        qkv_bias (bool, Optional): | 
|  | 48 | +            Whether to use bias for Q/K/V projections. Default: False. | 
|  | 49 | +        qk_norm (bool, Optional): | 
|  | 50 | +            Whether to apply per-head RMSNorm to Q and K before attention. Default: False. | 
|  | 51 | +        rope_theta (float, Optional): | 
|  | 52 | +            The base frequency for rotary position embedding. Default: 10000. | 
|  | 53 | +        max_position_embeddings (int, Optional): | 
|  | 54 | +            The maximum position embeddings. Default: None. | 
|  | 55 | +        layer_idx (int, Optional): | 
|  | 56 | +            The index of the layer (used for cache compatibility). Default: None. | 
|  | 57 | +    """ | 
|  | 58 | + | 
|  | 59 | +    def __init__( | 
|  | 60 | +        self, | 
|  | 61 | +        hidden_size: int = 2048, | 
|  | 62 | +        num_heads: int = 32, | 
|  | 63 | +        num_kv_heads: Optional[int] = None, | 
|  | 64 | +        qkv_bias: bool = False, | 
|  | 65 | +        qk_norm: bool = False, | 
|  | 66 | +        rope_theta: float = 10000., | 
|  | 67 | +        max_position_embeddings: Optional[int] = None, | 
|  | 68 | +        layer_idx: int | None = None, | 
|  | 69 | +    ): | 
|  | 70 | +        super().__init__() | 
|  | 71 | + | 
|  | 72 | +        self.hidden_size = hidden_size | 
|  | 73 | +        self.num_heads = num_heads | 
|  | 74 | +        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads | 
|  | 75 | +        self.num_kv_groups = num_heads // self.num_kv_heads | 
|  | 76 | +        self.head_dim = self.hidden_size // self.num_heads | 
|  | 77 | +        self.kv_dim = self.num_kv_heads * self.head_dim | 
|  | 78 | +        self.qkv_bias = qkv_bias | 
|  | 79 | +        self.qk_norm = qk_norm | 
|  | 80 | +        self.rope_theta = rope_theta | 
|  | 81 | +        self.max_position_embeddings = max_position_embeddings | 
|  | 82 | +        self.layer_idx = layer_idx | 
|  | 83 | + | 
|  | 84 | +        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias) | 
|  | 85 | +        self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) | 
|  | 86 | +        self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) | 
|  | 87 | +        self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True) | 
|  | 88 | +        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) | 
|  | 89 | + | 
|  | 90 | +        if qk_norm: | 
|  | 91 | +            self.q_norm = RMSNorm(self.head_dim) | 
|  | 92 | +            self.k_norm = RMSNorm(self.head_dim) | 
|  | 93 | + | 
|  | 94 | +        self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta) | 
|  | 95 | + | 
|  | 96 | +    def forward( | 
|  | 97 | +        self, | 
|  | 98 | +        hidden_states: torch.Tensor, | 
|  | 99 | +        attention_mask: Optional[torch.LongTensor] = None, | 
|  | 100 | +        past_key_values: Optional[Cache] = None, | 
|  | 101 | +        output_attentions: bool = False, | 
|  | 102 | +        use_cache: bool = False, | 
|  | 103 | +        **kwargs, | 
|  | 104 | +    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | 
|  | 105 | +        attentions = None | 
|  | 106 | +        if attention_mask is not None: | 
|  | 107 | +            assert len(attention_mask.shape) == 2, ( | 
|  | 108 | +                "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " | 
|  | 109 | +                "for padding purposes (0 indicating padding). " | 
|  | 110 | +                "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." | 
|  | 111 | +            ) | 
|  | 112 | + | 
|  | 113 | +        batch_size, q_len, _ = hidden_states.size() | 
|  | 114 | + | 
|  | 115 | +        q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) | 
|  | 116 | +        k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) | 
|  | 117 | +        v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) | 
|  | 118 | +        beta = self.b_proj(hidden_states) | 
|  | 119 | + | 
|  | 120 | +        if self.qk_norm: | 
|  | 121 | +            q, k = self.q_norm(q), self.k_norm(k) | 
|  | 122 | + | 
|  | 123 | +        cu_seqlens_kw = kwargs.get('cu_seqlens', None) | 
|  | 124 | +        seqlen_offset, max_seqlen = 0, q_len | 
|  | 125 | +        if past_key_values is not None: | 
|  | 126 | +            seqlen_offset = past_key_values.get_seq_length(self.layer_idx) | 
|  | 127 | +            max_seqlen = q_len + seqlen_offset | 
|  | 128 | + | 
|  | 129 | +            if attention_mask is not None: | 
|  | 130 | +                seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1] | 
|  | 131 | +                max_seqlen = q_len + max(seqlen_offset) | 
|  | 132 | + | 
|  | 133 | +        if self.max_position_embeddings is not None: | 
|  | 134 | +            max_seqlen = max(max_seqlen, self.max_position_embeddings) | 
|  | 135 | + | 
|  | 136 | +        q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens_kw) | 
|  | 137 | + | 
|  | 138 | +        o = deltaformer_attn( | 
|  | 139 | +            q=q, | 
|  | 140 | +            k=k, | 
|  | 141 | +            v=v, | 
|  | 142 | +            beta=beta, | 
|  | 143 | +            attention_mask=attention_mask, | 
|  | 144 | +            cu_seqlens=cu_seqlens_kw | 
|  | 145 | +        ) | 
|  | 146 | + | 
|  | 147 | +        o = o.reshape(batch_size, q_len, -1) | 
|  | 148 | +        o = self.o_proj(o) | 
|  | 149 | + | 
|  | 150 | +        if not output_attentions: | 
|  | 151 | +            attentions = None | 
|  | 152 | + | 
|  | 153 | +        return o, attentions, past_key_values | 
0 commit comments