Skip to content

DTensor issues when running Llama4ForConditionalGeneration with tensor parallel. #38803

@czkkkkkk

Description

@czkkkkkk

System Info

transformers version: 4.52.4
pytorch version: 2.6

Who can help?

transformers version: 4.52.4
pytorch version: 2.6

When running Llama4 with tensor parallel, torch.nn.Unfold used in llama4 isn't compatible with DTensor. So I got this error:
NotImplementedError: Operator aten.im2col.default does not have a sharding strategy registered. Looks like it is because the latest transformers use replicate DTensor for layers without tp_plan but Unfold isn't compatible with DTensor.

To workaround this error, I manually changed the input tensor to regular Tensor.

device_mesh = hidden_states.device_mesh if isinstance(hidden_states, DTensor) else None
placements = hidden_states.placements if isinstance(hidden_states, DTensor) else None
hidden_states = hidden_states.to_local()
hidden_states = self.unfold(hidden_states)
hidden_states = DTensor.from_local(hidden_states, device_mesh, placements)

After the change, I got AttributeError: 'BaseModelOutput' object has no attribute 'to_local' when running vision_model https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/llama4/modeling_llama4.py#L1543.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Minimal script to reproduce the error

# test.py
import torch
from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration

if __name__ == '__main__':
    model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
    model = Llama4ForConditionalGeneration.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        device_map="auto",
    )
    B = 1
    S = 128
    input_ids = torch.randint(0, 1000, (B, S))
    attention_mask = torch.ones((B, S))
    pixel_values = torch.randn((5, 3, 336, 336)).to(torch.bfloat16)
    model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
torchrun --nproc-per-node=8 test.py

Expected behavior

Expect the program to finish successfully.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions