-
Notifications
You must be signed in to change notification settings - Fork 29.9k
Description
System Info
transformers version: 4.52.4
pytorch version: 2.6
Who can help?
transformers version: 4.52.4
pytorch version: 2.6
When running Llama4 with tensor parallel, torch.nn.Unfold used in llama4 isn't compatible with DTensor. So I got this error:
NotImplementedError: Operator aten.im2col.default does not have a sharding strategy registered.
Looks like it is because the latest transformers use replicate DTensor for layers without tp_plan but Unfold
isn't compatible with DTensor.
To workaround this error, I manually changed the input tensor to regular Tensor.
device_mesh = hidden_states.device_mesh if isinstance(hidden_states, DTensor) else None
placements = hidden_states.placements if isinstance(hidden_states, DTensor) else None
hidden_states = hidden_states.to_local()
hidden_states = self.unfold(hidden_states)
hidden_states = DTensor.from_local(hidden_states, device_mesh, placements)
After the change, I got AttributeError: 'BaseModelOutput' object has no attribute 'to_local'
when running vision_model
https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/llama4/modeling_llama4.py#L1543.
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Minimal script to reproduce the error
# test.py
import torch
from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
if __name__ == '__main__':
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map="auto",
)
B = 1
S = 128
input_ids = torch.randint(0, 1000, (B, S))
attention_mask = torch.ones((B, S))
pixel_values = torch.randn((5, 3, 336, 336)).to(torch.bfloat16)
model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
torchrun --nproc-per-node=8 test.py
Expected behavior
Expect the program to finish successfully.