Description
Bug description
Descriptions for Bugs.
I encountered NaN loss values when running Llama 4 MoE experimental codes.
The errors come from here.
Afaik offsets
are defined as torch.cumsum(num_local_tokens_per_expert)
and x
(routed_input
) is permuted with the shape of original_shape + num_experts * ALIGN_SIZE_M
.
Thus, there was a difference between x.shape[0]
and offsets[-1]
.
I'm not sure which expert will be allocated for those redundant tensors in x in grouped_mm
.
I believe the expected behavior would be the outputs from them should always be 0, because they are filled with 0 values.
But _grouped_mm
sometimes results in large values, which first index of outputs gets inf
elements (here).
How to Reproduce?
- I used Llama-3.2-1B tokenizer.
- I used
debug_model.toml
, but with different batch size and seq_len in 1 H200 GPU. Here is the running script:
torchrun --nnodes 1 --nproc_per_node 1 ./torchtitan/train.py \
--job.config_file ./torchtitan/experiments/llama4/train_configs/debug_model.toml --job.dump_folder ./outputs/250528_grouped_mm_debug \
--profiling.save_traces_folder profile_trace --comm.trace_buf_size 0 --checkpoint.folder ./checkpoints/250528_grouped_mm_debug --checkpoint.interval 13000 \
--training.steps 114440 --training.batch_size 1 --training.seq_len 2048 \
--metrics.log_freq 100 --lr_scheduler.warmup_steps 1000 --optimizer.lr 6e-4 \
--parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1
- Add
x = x.to(torch.bfloat16)
and..., dtype=torch.bfloat16)
forself.w1
,self.w2
, andself.w3
, since 1 GPU will automatically use torch.float32 in the code and_grouped_mm
requires tensors are in GPU. - I used
pdb
to get intermediate outputs one by one.
Results and Expected Behaviors.
Routed outputs sometimes show the following results (at the first step or a few steps later):
offsets : tensor([ 176, 416, 736, 992, 1296, 1584, 1840, 2096], device='cuda:0', dtype=torch.int32)
x.shape : torch.Size([2176, 256])
h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets)) :
tensor([[ 3.7598e-02, -9.3262e-02, 1.3965e-01, ..., -1.7822e-02,
-2.2949e-02, 2.0020e-02],
[ 1.1572e-01, 2.2461e-01, 3.1641e-01, ..., 8.6060e-03,
-5.3711e-02, -2.7100e-02],
[ 1.4551e-01, 2.1973e-02, 1.3086e-01, ..., -2.5269e-02,
3.7354e-02, -1.5503e-02],
...,
[-0.0000e+00, 2.9297e-02, -0.0000e+00, ..., 5.2246e-02,
7.7462e+18, -1.8066e-02],
[ 2.8531e+26, 5.1025e-02, -0.0000e+00, ..., 1.1670e-01,
3.2028e-28, 1.5076e-02],
[ 6.3348e+26, 3.8818e-02, 4.0250e+01, ..., -2.8229e-03,
2.4844e-32, -8.6670e-03]], device='cuda:0', dtype=torch.bfloat16,
grad_fn=<SiluBackward0>)
h = h * torch._grouped_mm(x, self.w3, offs=offsets)
tensor([[-1.8692e-03, -2.8992e-03, 1.6327e-03, ..., -1.5564e-03,
-1.0681e-02, 5.1022e-05],
[-5.5237e-03, 6.0425e-03, 1.0864e-02, ..., 9.8419e-04,
3.0396e-02, -4.2152e-04],
[-1.6785e-03, -4.5776e-04, -2.0142e-03, ..., 1.0193e-02,
-4.6082e-03, -1.3733e-04],
...,
[ 0.0000e+00, 1.2054e-03, -0.0000e+00, ..., -2.5177e-03,
3.5863e+11, -1.7548e-03],
[ -inf, 6.3705e-04, 0.0000e+00, ..., 9.5825e-03,
-2.9000e+02, 3.2234e-04],
[ 8.4410e+07, 4.0588e-03, -1.0379e+31, ..., 3.7432e-05,
1.2387e-07, -1.3733e-03]], device='cuda:0', dtype=torch.bfloat16,
grad_fn=<MulBackward0>)
out = torch._grouped_mm(h, self.w2, offs=offsets)
tensor([[ 6.3782e-03, 4.0894e-03, -1.3672e-02, ..., -8.4839e-03,
-2.8229e-03, -3.9978e-03],
[-1.9379e-03, -4.6387e-03, 8.5449e-03, ..., -4.8523e-03,
-4.4861e-03, -1.4114e-03],
[-3.1128e-03, -2.5177e-03, -3.4332e-03, ..., 1.3062e-02,
-6.7139e-03, -7.6904e-03],
...,
[-1.6251e-03, -1.3279e-10, -7.3787e+19, ..., -5.1659e-10,
-3.8780e+34, -3.5834e-10],
[ 4.7055e+34, -1.6735e-09, 6.0889e+18, ..., -1.1205e-09,
7.1024e+24, 3.1287e-10],
[-2.4087e-21, -2.1682e-09, 3.0898e+20, ..., 2.9831e-09,
2.4898e-30, 5.5297e-10]], device='cuda:0', dtype=torch.bfloat16,
grad_fn=<GroupedMmBackward0>)
We expect that tensors, where the sequence positions are from 2096 to 2176, should be always zero.
This causes to hidden states to have nan values, and nan values of loss eventually.
Versions
Python 3.13 with the following packages:
absl-py==2.2.2
aiohappyeyeballs==2.6.1
aiohttp==3.11.18
aiosignal==1.3.2
annotated-types==0.7.0
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
attrs==25.3.0
beautifulsoup4==4.13.4
bleach==6.2.0
blessed==1.21.0
blobfile==3.0.0
certifi==2025.4.26
charset-normalizer==3.4.2
click==8.2.0
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work
contourpy==1.3.2
cycler==0.12.1
datasets==3.6.0
debugpy @ file:///croot/debugpy_1736267418885/work
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
defusedxml==0.7.1
dill==0.3.8
docker-pycreds==0.4.0
docstring_parser==0.16
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1746947292760/work
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work
fastjsonschema==2.21.1
filelock==3.16.1
fonttools==4.58.0
frozenlist==1.6.0
fsspec==2024.10.0
gitdb==4.0.12
GitPython==3.1.44
gpustat==1.1.1
grpcio==1.71.0
huggingface-hub==0.31.4
idna==3.10
importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_importlib-metadata_1747934053/work
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
ipython @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_ipython_1745672166/work
ipython_pygments_lexers @ file:///home/conda/feedstock_root/build_artifacts/ipython_pygments_lexers_1737123620466/work
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
Jinja2==3.1.4
jsonschema==4.24.0
jsonschema-specifications==2025.4.1
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1727163409502/work
jupyterlab_pygments==0.3.0
kiwisolver==1.4.8
lxml==5.4.0
Markdown==3.8
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.10.3
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
mdurl==0.1.2
mistune==3.1.3
mpmath==1.3.0
multidict==6.4.4
multiprocess==0.70.16
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
networkx==3.4.2
numpy==2.2.6
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu12==9.5.1.17
nvidia-cufft-cu12==11.3.0.4
nvidia-cufile-cu12==1.11.1.6
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparselt-cu12==0.6.3
nvidia-ml-py==12.575.51
nvidia-nccl-cu12==2.26.5
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu12==12.6.77
packaging==25.0
pandas==2.2.3
pandocfilters==1.5.1
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
pillow==11.2.1
platformdirs==4.3.8
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work
propcache==0.3.1
protobuf==6.31.0
psutil==7.0.0
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f
pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
pyarrow==20.0.0
pycryptodomex==3.23.0
pydantic==2.11.4
pydantic_core==2.33.2
Pygments==2.19.1
pyparsing==3.2.3
python-dateutil==2.9.0.post0
pytorch-triton==3.3.0+git96316ce5
pytz==2025.2
PyYAML==6.0.2
pyzmq @ file:///croot/pyzmq_1734687138743/work
referencing==0.36.2
regex==2024.11.6
requests==2.32.3
rich==14.0.0
rpds-py==0.25.1
safetensors==0.5.3
sentry-sdk==2.29.1
setproctitle==1.3.6
setuptools==70.2.0
shtab==1.7.2
six==1.17.0
smmap==5.0.2
soupsieve==2.7
stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
sympy==1.13.3
tabulate==0.9.0
tensorboard==2.19.0
tensorboard-data-server==0.7.2
tiktoken==0.9.0
tinycss2==1.4.0
tokenizers==0.21.1
torch==2.8.0.dev20250519+cu126
torchdata==0.11.0
tornado @ file:///croot/tornado_1747918059467/work
tqdm==4.67.1
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
transformers==4.52.1
triton==3.3.0
typeguard==4.4.2
typing-inspection==0.4.0
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1744302253/work
tyro==0.9.20
tzdata==2025.2
urllib3==2.4.0
wandb==0.19.11
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work
webencodings==0.5.1
Werkzeug==3.1.3
wheel==0.45.1
xxhash==3.5.0
yarl==1.20.0
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1732827521216/work