Open
Description
System Info
linux ubuntu 22.04
Python 3.12.4
transformers 4.52.4
flash-attn 2.8.0.post2
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
- pip install flash-attn
- config = AutoConfig.from_pretrained("answerdotai/ModernBERT-base")
- model = AutoModelForMaskedLM.from_config(config)
- TypeError: RotaryEmbedding.init() got an unexpected keyword argument 'pos_idx_in_fp32'
The issue are as folows: when you try to create a ModernBert model with flash attention it uses
` class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding):
"""
The rotary position embeddings applied directly to unpadded sequences.
"""
def __init__(
self,
dim: int,
base: float = 10000.0,
max_seqlen: Optional[int] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
"""
max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
the cos_sin_cache will be recomputed during the forward pass.
"""
super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False)
self.max_seqlen = max_seqlen
if max_seqlen is not None and device is not None and dtype is not None:
self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
`
ModernBertUnpaddedRotaryEmbedding set pos_idx_in_fp32 parameter into its super class witch is RotaryEmbedding but the only parameter it have is dim
`class RotaryEmbedding(torch.nn.Module):
"""
Rotary position embeddings based on those in
RoFormer. Query and keys are transformed by rotation
matrices which depend on their relative positions.
"""
def __init__(self, dim: int):
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
inv_freq = inv_freq
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None
`
Expected behavior
Works without error