Skip to content

Flexattention Implementation in Gemma Failing on CPU #40345

@amd-lalithnc

Description

@amd-lalithnc

System Info

  • transformers version: 4.55.3
  • Platform: Linux-5.15.0-119-generic-x86_64-with-glibc2.35
  • Python version: 3.10.18
  • Huggingface_hub version: 0.34.4
  • Safetensors version: 0.6.2
  • Accelerate version: not installed
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.8.0+cpu (NA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Loading a flex-attention supported model such as Gemma fails - returning the error - detailed error logs below

torch._inductor.exc.InductorError: LoweringException: NotImplementedError: torch.compile on CPU only supports inference and `return_lse` is not supported yet.
  target: flex_attention
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def main():
    model_id = "google/gemma-2-2b"  # any flex-attn supported model is fine
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    text = "The quick brown fox jumps over the lazy dog. " * 10
    inputs = tokenizer(
        text,
        return_tensors="pt",
        max_length=256,
        truncation=True,
        padding=True,
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,          # matches failing environment; use float32 if bf16 unsupported
        attn_implementation="flex_attention",
        output_attentions=False,
    )

    with torch.no_grad():
        _ = model(**inputs)

if __name__ == "__main__":
    main()

Expected behavior

Traceback (most recent call last):
  File "/../../../../sample.py", line 28, in <module>
    main()
  File "/../../../../sample.py", line 25, in main
    _ = model(**inputs)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/utils/generic.py", line 959, in wrapper
    output = func(self, *args, **kwargs)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 548, in forward
    outputs: BaseModelOutputWithPast = self.model(
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/utils/generic.py", line 1083, in wrapper
    outputs = func(self, *args, **kwargs)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 452, in forward
    layer_outputs = decoder_layer(
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/modeling_layers.py", line 93, in __call__
    return super().__call__(*args, **kwargs)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 271, in forward
    hidden_states, self_attn_weights = self.self_attn(
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 222, in forward
    attn_output, attn_weights = attention_interface(
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/integrations/flex_attention.py", line 293, in flex_attention_forward
    attn_output, attention_weights = compile_friendly_flex_attention(
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/integrations/flex_attention.py", line 97, in compile_friendly_flex_attention
    return flex_attention_compiled(
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 749, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 923, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 907, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1578, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1377, in codegen_and_compile
    graph.run(*example_inputs)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/graph.py", line 921, in run
    return super().run(*args)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/fx/interpreter.py", line 173, in run
    self.env[node] = self.run_node(node)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1599, in run_node
    result = super().run_node(n)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/fx/interpreter.py", line 242, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1268, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1258, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 446, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/kernel/flex_attention.py", line 1294, in flex_attention
    return lower_cpu(
  File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/kernel/flex_attention.py", line 967, in lower_cpu
    raise NotImplementedError(
torch._inductor.exc.InductorError: LoweringException: NotImplementedError: torch.compile on CPU only supports inference and `return_lse` is not supported yet.
  target: flex_attention
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cpu', torch.bfloat16, size=[1, 8, 102, 256], stride=[208896, 256, 2048, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cpu', torch.bfloat16, size=[1, 4, 102, 256], stride=[104448, 256, 1024, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='arg2_1', layout=FixedLayout('cpu', torch.bfloat16, size=[1, 4, 102, 256], stride=[104448, 256, 1024, 1]))
  ))
  args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
  args[4]: (102, 102, TensorBox(StorageBox(
    InputBuffer(name='arg4_1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg3_1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg7_1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg8_1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg9_1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg10_1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg11_1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg12_1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]))
  )), 128, 128, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
  args[5]: 0.0625
  args[6]: {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True}
  args[7]: ()
  args[8]: (TensorBox(StorageBox(
    InputBuffer(name='arg5_1', layout=FixedLayout('cpu', torch.int64, size=[], stride=[]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg6_1', layout=FixedLayout('cpu', torch.bool, size=[1, 102], stride=[102, 1]))
  )))

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions