Skip to content

With torch.autocast, Phi-tiny-MoE-instruct raises an dtype mismatch error #43828

@vinnamkim

Description

@vinnamkim

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • transformers version: 5.1.0
  • Platform: Linux-5.15.0-168-generic-x86_64-with-glibc2.35
  • Python version: 3.11.14
  • Huggingface_hub version: 1.4.1
  • Safetensors version: 0.7.0
  • Accelerate version: 1.12.0
  • Accelerate config: not found
  • DeepSpeed version: 0.18.5
  • PyTorch version (accelerator?): 2.9.1+cu128 (CUDA)
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

@ArthurZucker @Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

This is a minimal Python script to reproduce

import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-tiny-MoE-instruct", device_map="cuda", dtype="bfloat16")

x = torch.randint(0, 100, [1, 16], device="cuda")

with torch.autocast("cuda", torch.bfloat16):
    y = model(x)

print(y)

Expected behavior

Expect no error but there is dtype mismatch in the grouped GEMM

Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 485/485 [00:02<00:00, 239.58it/s, Materializing param=model.norm.weight]
Traceback (most recent call last):
  File "/local/mnt/workspace/users/vinnkim/QAT-LlaMA/test.py", line 12, in <module>
    y = model(x)
        ^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/utils/generic.py", line 834, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/models/phimoe/modeling_phimoe.py", line 845, in forward
    outputs: MoeModelOutputWithPast = self.model(
                                      ^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/utils/generic.py", line 1001, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/models/phimoe/modeling_phimoe.py", line 682, in forward
    hidden_states = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/modeling_layers.py", line 93, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/models/phimoe/modeling_phimoe.py", line 584, in forward
    hidden_states = self.mlp(hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/models/phimoe/modeling_phimoe.py", line 541, in forward
    final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/integrations/moe.py", line 349, in forward
    return experts_forward(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/integrations/moe.py", line 251, in grouped_mm_experts_forward
    gate_up_out = _grouped_linear(
                  ^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/integrations/moe.py", line 192, in _grouped_linear
    out = torch._grouped_mm(input, weight.transpose(-2, -1), offs=offs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions