Skip to content

[gpt-oss] MoE routing bug in the mxfp4 implementation (in distributed setting) #40031

@kitft

Description

@kitft

System Info

- `transformers` version: 4.55.0
- Platform: Linux-6.11.11+-x86_64-with-glibc2.35
- Python version: 3.11.13
- Huggingface_hub version: 0.34.3
- Safetensors version: 0.6.1
- Accelerate version: 1.10.0
- Accelerate config: 	not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.8.0+cu128 (CUDA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>
- Using GPU in script?: <fill in>
- GPU type: NVIDIA B200


Who can help?

@SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Running data-parallel on a multi-gpu node, and distributing with torchrun, I get nonsense generations when trying to run the new gpt-oss model with mxfp4, with the recommended triton kernels + latest version of PyTorch. This happens in an 8x node of H100s or B200s, I have not tested other settings.
 I have no such issues if I run the model with kernels=true (which dequantises the model to bf16).

torchrun --standalone --nproc_per_node=2 test_gpt.py

from transformers import AutoModelForCausalLM, AutoTokenizer
import os
from transformers.distributed import DistributedConfig

rank=int(os.environ.get("LOCAL_RANK"))

import torch
import torch.distributed as dist
# Initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=2)

model_id = "openai/gpt-oss-20b"
torch.cuda.set_device(rank)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map={"": rank}, #”auto” # either of these two choices leads to the issue!
)

messages = [
    {"role": "user", "content": "How many rs are in the word 'strawberry'?"},
]

inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt",
    return_dict=True,
).to(model.device)

generated = model.generate(**inputs, max_new_tokens=100, do_sample=True)
print(f'rank {rank} ------ {tokenizer.decode(generated[0][inputs["input_ids"].shape[-1]:])}')

———————————————————————————————————————


The output is:


rank 1 ------ <|message|> 1 that there can. that no you there, if for many or, no for many: (no). with, it are not, or.

 not. and. etc..

 or. or; The answer. not, and, etc. etc. or; ChatGPT,. A of Open AI, 3.  3. The Answer. not, and, do, no, where, to, no, and? This language...

. For many, and
rank 0 ------ <|channel|>analysis<|message|>We need to determine number of 'r's in word 'strawberry'. The word 'strawberry? Actually the word 'strawberry. The letters: s t r a w b e e? Wait no: 's'. Let's write: s t r a w b e? I'm going to check.

Word: 'strawberry': letters 's' 't' 'r' 'a' 'w' 'b' 'e' 'r'





I have localised the bug to https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/mxfp4.py, where the routing logic appears to be incorrectly attempting to do expert parallel if dist.is_initialized(). In my setting, forcibly disabling this branch resolves the issue:


if dist.is_available() and dist.is_initialized():
        routing = routing_torch_dist
    else:
        routing = triton_kernels_hub.routing.routing
```



### Expected behavior

The generations should make sense, and be of equal quality across ranks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions