-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Closed
Labels
Description
System Info
transformersversion: 4.55.3- Platform: Linux-5.15.0-119-generic-x86_64-with-glibc2.35
- Python version: 3.10.18
- Huggingface_hub version: 0.34.4
- Safetensors version: 0.6.2
- Accelerate version: not installed
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.8.0+cpu (NA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: No
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Loading a flex-attention supported model such as Gemma fails - returning the error - detailed error logs below
torch._inductor.exc.InductorError: LoweringException: NotImplementedError: torch.compile on CPU only supports inference and `return_lse` is not supported yet.
target: flex_attentionimport torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def main():
model_id = "google/gemma-2-2b" # any flex-attn supported model is fine
tokenizer = AutoTokenizer.from_pretrained(model_id)
text = "The quick brown fox jumps over the lazy dog. " * 10
inputs = tokenizer(
text,
return_tensors="pt",
max_length=256,
truncation=True,
padding=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16, # matches failing environment; use float32 if bf16 unsupported
attn_implementation="flex_attention",
output_attentions=False,
)
with torch.no_grad():
_ = model(**inputs)
if __name__ == "__main__":
main()Expected behavior
Traceback (most recent call last):
File "/../../../../sample.py", line 28, in <module>
main()
File "/../../../../sample.py", line 25, in main
_ = model(**inputs)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/utils/generic.py", line 959, in wrapper
output = func(self, *args, **kwargs)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 548, in forward
outputs: BaseModelOutputWithPast = self.model(
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/utils/generic.py", line 1083, in wrapper
outputs = func(self, *args, **kwargs)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 452, in forward
layer_outputs = decoder_layer(
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/modeling_layers.py", line 93, in __call__
return super().__call__(*args, **kwargs)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 271, in forward
hidden_states, self_attn_weights = self.self_attn(
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 222, in forward
attn_output, attn_weights = attention_interface(
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/integrations/flex_attention.py", line 293, in flex_attention_forward
attn_output, attention_weights = compile_friendly_flex_attention(
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/transformers/integrations/flex_attention.py", line 97, in compile_friendly_flex_attention
return flex_attention_compiled(
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 749, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 923, in _compile_fx_inner
raise InductorError(e, currentframe()).with_traceback(
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 907, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1578, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1377, in codegen_and_compile
graph.run(*example_inputs)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/graph.py", line 921, in run
return super().run(*args)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/fx/interpreter.py", line 173, in run
self.env[node] = self.run_node(node)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1599, in run_node
result = super().run_node(n)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/fx/interpreter.py", line 242, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1268, in call_function
raise LoweringException(e, target, args, kwargs).with_traceback(
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1258, in call_function
out = lowerings[target](*args, **kwargs) # type: ignore[index]
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 446, in wrapped
out = decomp_fn(*args, **kwargs)
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/kernel/flex_attention.py", line 1294, in flex_attention
return lower_cpu(
File "<home>/anaconda3/envs/flex-attn-hf/lib/python3.10/site-packages/torch/_inductor/kernel/flex_attention.py", line 967, in lower_cpu
raise NotImplementedError(
torch._inductor.exc.InductorError: LoweringException: NotImplementedError: torch.compile on CPU only supports inference and `return_lse` is not supported yet.
target: flex_attention
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cpu', torch.bfloat16, size=[1, 8, 102, 256], stride=[208896, 256, 2048, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cpu', torch.bfloat16, size=[1, 4, 102, 256], stride=[104448, 256, 1024, 1]))
))
args[2]: TensorBox(StorageBox(
InputBuffer(name='arg2_1', layout=FixedLayout('cpu', torch.bfloat16, size=[1, 4, 102, 256], stride=[104448, 256, 1024, 1]))
))
args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
args[4]: (102, 102, TensorBox(StorageBox(
InputBuffer(name='arg4_1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg3_1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg7_1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg8_1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg9_1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg10_1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg11_1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg12_1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]))
)), 128, 128, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
args[5]: 0.0625
args[6]: {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True}
args[7]: ()
args[8]: (TensorBox(StorageBox(
InputBuffer(name='arg5_1', layout=FixedLayout('cpu', torch.int64, size=[], stride=[]))
)), TensorBox(StorageBox(
InputBuffer(name='arg6_1', layout=FixedLayout('cpu', torch.bool, size=[1, 102], stride=[102, 1]))
)))
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"