Skip to content

[Bug] Potential bugs in "_grouped_mm" in Llama4 MoE codes #1237

Open
@raymin0223

Description

@raymin0223

Bug description

Descriptions for Bugs.

I encountered NaN loss values when running Llama 4 MoE experimental codes.
The errors come from here.

Afaik offsets are defined as torch.cumsum(num_local_tokens_per_expert) and x (routed_input) is permuted with the shape of original_shape + num_experts * ALIGN_SIZE_M.
Thus, there was a difference between x.shape[0] and offsets[-1].

I'm not sure which expert will be allocated for those redundant tensors in x in grouped_mm.
I believe the expected behavior would be the outputs from them should always be 0, because they are filled with 0 values.
But _grouped_mm sometimes results in large values, which first index of outputs gets inf elements (here).

How to Reproduce?

  1. I used Llama-3.2-1B tokenizer.
  2. I used debug_model.toml, but with different batch size and seq_len in 1 H200 GPU. Here is the running script:
torchrun --nnodes 1 --nproc_per_node 1  ./torchtitan/train.py  \
--job.config_file ./torchtitan/experiments/llama4/train_configs/debug_model.toml --job.dump_folder ./outputs/250528_grouped_mm_debug  \
--profiling.save_traces_folder profile_trace --comm.trace_buf_size 0  --checkpoint.folder ./checkpoints/250528_grouped_mm_debug --checkpoint.interval 13000   \
--training.steps 114440 --training.batch_size 1 --training.seq_len 2048   \
--metrics.log_freq 100 --lr_scheduler.warmup_steps 1000 --optimizer.lr 6e-4  \
--parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1
  1. Add x = x.to(torch.bfloat16) and ..., dtype=torch.bfloat16) for self.w1, self.w2, and self.w3, since 1 GPU will automatically use torch.float32 in the code and _grouped_mm requires tensors are in GPU.
  2. I used pdb to get intermediate outputs one by one.

Results and Expected Behaviors.

Routed outputs sometimes show the following results (at the first step or a few steps later):

offsets : tensor([ 176,  416,  736,  992, 1296, 1584, 1840, 2096], device='cuda:0', dtype=torch.int32)

x.shape : torch.Size([2176, 256])

h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets)) :
tensor([[ 3.7598e-02, -9.3262e-02,  1.3965e-01,  ..., -1.7822e-02,
         -2.2949e-02,  2.0020e-02],
        [ 1.1572e-01,  2.2461e-01,  3.1641e-01,  ...,  8.6060e-03,
         -5.3711e-02, -2.7100e-02],
        [ 1.4551e-01,  2.1973e-02,  1.3086e-01,  ..., -2.5269e-02,
          3.7354e-02, -1.5503e-02],
        ...,
        [-0.0000e+00,  2.9297e-02, -0.0000e+00,  ...,  5.2246e-02,
          7.7462e+18, -1.8066e-02],
        [ 2.8531e+26,  5.1025e-02, -0.0000e+00,  ...,  1.1670e-01,
          3.2028e-28,  1.5076e-02],
        [ 6.3348e+26,  3.8818e-02,  4.0250e+01,  ..., -2.8229e-03,
          2.4844e-32, -8.6670e-03]], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SiluBackward0>)

h = h * torch._grouped_mm(x, self.w3, offs=offsets)
tensor([[-1.8692e-03, -2.8992e-03,  1.6327e-03,  ..., -1.5564e-03,
         -1.0681e-02,  5.1022e-05],
        [-5.5237e-03,  6.0425e-03,  1.0864e-02,  ...,  9.8419e-04,
          3.0396e-02, -4.2152e-04],
        [-1.6785e-03, -4.5776e-04, -2.0142e-03,  ...,  1.0193e-02,
         -4.6082e-03, -1.3733e-04],
        ...,
        [ 0.0000e+00,  1.2054e-03, -0.0000e+00,  ..., -2.5177e-03,
          3.5863e+11, -1.7548e-03],
        [       -inf,  6.3705e-04,  0.0000e+00,  ...,  9.5825e-03,
         -2.9000e+02,  3.2234e-04],
        [ 8.4410e+07,  4.0588e-03, -1.0379e+31,  ...,  3.7432e-05,
          1.2387e-07, -1.3733e-03]], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)

out = torch._grouped_mm(h, self.w2, offs=offsets)
tensor([[ 6.3782e-03,  4.0894e-03, -1.3672e-02,  ..., -8.4839e-03,
         -2.8229e-03, -3.9978e-03],
        [-1.9379e-03, -4.6387e-03,  8.5449e-03,  ..., -4.8523e-03,
         -4.4861e-03, -1.4114e-03],
        [-3.1128e-03, -2.5177e-03, -3.4332e-03,  ...,  1.3062e-02,
         -6.7139e-03, -7.6904e-03],
        ...,
        [-1.6251e-03, -1.3279e-10, -7.3787e+19,  ..., -5.1659e-10,
         -3.8780e+34, -3.5834e-10],
        [ 4.7055e+34, -1.6735e-09,  6.0889e+18,  ..., -1.1205e-09,
          7.1024e+24,  3.1287e-10],
        [-2.4087e-21, -2.1682e-09,  3.0898e+20,  ...,  2.9831e-09,
          2.4898e-30,  5.5297e-10]], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<GroupedMmBackward0>)

We expect that tensors, where the sequence positions are from 2096 to 2176, should be always zero.
This causes to hidden states to have nan values, and nan values of loss eventually.

Versions

Python 3.13 with the following packages:

absl-py==2.2.2
aiohappyeyeballs==2.6.1
aiohttp==3.11.18
aiosignal==1.3.2
annotated-types==0.7.0
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
attrs==25.3.0
beautifulsoup4==4.13.4
bleach==6.2.0
blessed==1.21.0
blobfile==3.0.0
certifi==2025.4.26
charset-normalizer==3.4.2
click==8.2.0
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work
contourpy==1.3.2
cycler==0.12.1
datasets==3.6.0
debugpy @ file:///croot/debugpy_1736267418885/work
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
defusedxml==0.7.1
dill==0.3.8
docker-pycreds==0.4.0
docstring_parser==0.16
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1746947292760/work
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work
fastjsonschema==2.21.1
filelock==3.16.1
fonttools==4.58.0
frozenlist==1.6.0
fsspec==2024.10.0
gitdb==4.0.12
GitPython==3.1.44
gpustat==1.1.1
grpcio==1.71.0
huggingface-hub==0.31.4
idna==3.10
importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_importlib-metadata_1747934053/work
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
ipython @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_ipython_1745672166/work
ipython_pygments_lexers @ file:///home/conda/feedstock_root/build_artifacts/ipython_pygments_lexers_1737123620466/work
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
Jinja2==3.1.4
jsonschema==4.24.0
jsonschema-specifications==2025.4.1
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1727163409502/work
jupyterlab_pygments==0.3.0
kiwisolver==1.4.8
lxml==5.4.0
Markdown==3.8
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.10.3
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
mdurl==0.1.2
mistune==3.1.3
mpmath==1.3.0
multidict==6.4.4
multiprocess==0.70.16
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
networkx==3.4.2
numpy==2.2.6
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu12==9.5.1.17
nvidia-cufft-cu12==11.3.0.4
nvidia-cufile-cu12==1.11.1.6
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparselt-cu12==0.6.3
nvidia-ml-py==12.575.51
nvidia-nccl-cu12==2.26.5
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu12==12.6.77
packaging==25.0
pandas==2.2.3
pandocfilters==1.5.1
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
pillow==11.2.1
platformdirs==4.3.8
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work
propcache==0.3.1
protobuf==6.31.0
psutil==7.0.0
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f
pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
pyarrow==20.0.0
pycryptodomex==3.23.0
pydantic==2.11.4
pydantic_core==2.33.2
Pygments==2.19.1
pyparsing==3.2.3
python-dateutil==2.9.0.post0
pytorch-triton==3.3.0+git96316ce5
pytz==2025.2
PyYAML==6.0.2
pyzmq @ file:///croot/pyzmq_1734687138743/work
referencing==0.36.2
regex==2024.11.6
requests==2.32.3
rich==14.0.0
rpds-py==0.25.1
safetensors==0.5.3
sentry-sdk==2.29.1
setproctitle==1.3.6
setuptools==70.2.0
shtab==1.7.2
six==1.17.0
smmap==5.0.2
soupsieve==2.7
stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
sympy==1.13.3
tabulate==0.9.0
tensorboard==2.19.0
tensorboard-data-server==0.7.2
tiktoken==0.9.0
tinycss2==1.4.0
tokenizers==0.21.1
torch==2.8.0.dev20250519+cu126
torchdata==0.11.0
tornado @ file:///croot/tornado_1747918059467/work
tqdm==4.67.1
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
transformers==4.52.1
triton==3.3.0
typeguard==4.4.2
typing-inspection==0.4.0
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1744302253/work
tyro==0.9.20
tzdata==2025.2
urllib3==2.4.0
wandb==0.19.11
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work
webencodings==0.5.1
Werkzeug==3.1.3
wheel==0.45.1
xxhash==3.5.0
yarl==1.20.0
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1732827521216/work

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions