Skip to content

Error when create ModernBert model with flash attention TypeError: RotaryEmbedding.__init__() got an unexpected keyword argument 'pos_idx_in_fp32' #38843

Open
@KabaevAnton

Description

@KabaevAnton

System Info

linux ubuntu 22.04
Python 3.12.4
transformers 4.52.4
flash-attn 2.8.0.post2

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. pip install flash-attn
  2. config = AutoConfig.from_pretrained("answerdotai/ModernBERT-base")
  3. model = AutoModelForMaskedLM.from_config(config)
  4. TypeError: RotaryEmbedding.init() got an unexpected keyword argument 'pos_idx_in_fp32'

The issue are as folows: when you try to create a ModernBert model with flash attention it uses

` class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding):
"""
The rotary position embeddings applied directly to unpadded sequences.
"""

def __init__(
    self,
    dim: int,
    base: float = 10000.0,
    max_seqlen: Optional[int] = None,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
):
    """
    max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
        up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
        the cos_sin_cache will be recomputed during the forward pass.
    """
    super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False)
    self.max_seqlen = max_seqlen

    if max_seqlen is not None and device is not None and dtype is not None:
        self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)

`
ModernBertUnpaddedRotaryEmbedding set pos_idx_in_fp32 parameter into its super class witch is RotaryEmbedding but the only parameter it have is dim

`class RotaryEmbedding(torch.nn.Module):
"""
Rotary position embeddings based on those in
RoFormer. Query and keys are transformed by rotation
matrices which depend on their relative positions.
"""

def __init__(self, dim: int):
    super().__init__()
    # Generate and save the inverse frequency buffer (non trainable)
    inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
    inv_freq = inv_freq
    self.register_buffer("inv_freq", inv_freq)

    self._seq_len_cached = None
    self._cos_cached = None
    self._sin_cached = None

`

Expected behavior

Works without error

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions