Skip to content

ThunderFX fails with FP8 and Activation Checkpointing #1424

Open
@mpatel31415

Description

@mpatel31415

🐛 Bug

When training models: 'vicuna-7b-v1.5-16k', 'longchat-13b-16k', 'Mistral-7B-v0.2', 'falcon-180B', 'Llama-3-70B', 'CodeLlama-34b-hf' with FSDP and FP8 we get KeyError: 'scaling_fwd'. This might be also issue with Transformer Engine,, so I'm happy to move this issue to TE if needed.

Full traceback:

[rank7]: Traceback (most recent call last):
7: [rank7]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 974, in
7: [rank7]: CLI(benchmark_main)
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/jsonargparse/_cli.py", line 96, in CLI
7: [rank7]: return _run_component(components, init)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/jsonargparse/_cli.py", line 204, in _run_component
7: [rank7]: return component(**cfg)
7: [rank7]: ^^^^^^^^^^^^^^^^
7: [rank7]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 871, in benchmark_main
7: [rank7]: benchmark.train()
7: [rank7]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 765, in train
7: [rank7]: loss.backward()
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 624, in backward
7: [rank7]: torch.autograd.backward(
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/init.py", line 347, in backward
7: [rank7]: _engine_run_backward(
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
7: [rank7]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 307, in apply
7: [rank7]: return user_fn(self, *args)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 600, in wrapper
7: [rank7]: outputs = fn(ctx, *args)
7: [rank7]: ^^^^^^^^^^^^^^
7: [rank7]: File "/opt/pytorch/lightning-thunder/thunder/executors/torch_autograd.py", line 115, in backward
7: [rank7]: grads = ctx.compiled_backward([saved_tensors_list, ctx.saved_other], args)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
7: [rank7]: return func(*args, **kwargs)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
7: [rank7]: return func(*args, **kwargs)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
7: [rank7]: return func(*args, **kwargs)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "thunder.backward_fn_13", line 28, in backward_fn
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
7: [rank7]: return self.call_impl(*args, **kwargs)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in call_impl
7: [rank7]: return forward_call(*args, **kwargs)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/opt/pytorch/lightning-thunder/thunder/executors/transformer_engineex.py", line 205, in forward
7: [rank7]: weight_fp8, weight_t_fp8 = self.get_fp8_weight_version_compat(
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/opt/pytorch/lightning-thunder/thunder/executors/transformer_engineex.py", line 273, in get_fp8_weight_version_compat
7: [rank7]: weight_fp8 = self.get_fp8_workspace(
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 1086, in get_fp8_workspace
7: [rank7]: out.quantize
(tensor, noop_flag=skip_update_flag)
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/tensor/float8_tensor.py", line 642, in quantize

7: [rank7]: fp8_meta = dst._fp8_meta[fp8_meta_key]
7: [rank7]: ~~~~~~~~~~~~~^^^^^^^^^^^^^^
7: [rank7]: KeyError: 'scaling_fwd'

To Reproduce

Please use:
1 node(s), each with 8 GPUs.
Image "INTERNAL_IMAGE:pjnl_20241107"
Training script:
python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py
--model_name Mistral-7B-v0.2
--distributed_mode fsdp
--shard_mode zero2
--compile dynamo_thunder
--checkpoint_activations True
--low_precision_mode fp8-delayed-te
--micro_batch_size 1

Environment

system.device_product_name DGXH100
system.gpu_driver_version 535.129.03
libraries.cuda 12.6.98.001
libraries.pip.lightning 2.4.0.dev20240728
libraries.pip.lightning-thunder 0.2.0.dev0
libraries.pip.lightning-utilities 0.11.8
libraries.pip.litgpt 0.4.11
libraries.pip.nvfuser 0.2.22+gitba4f7d4
libraries.pip.pytorch-lightning 2.4.0
libraries.pip.torch 2.6.0a0+gita9b4989
libraries.pip.torchao 0.6.1
libraries.pip.torchmetrics 1.5.1
libraries.pip.torchvision 0.19.0a0+d23a6e1

Metadata

Metadata

Labels

TransformerEnginemixologyIssues that the mixology team has surfacedthunderfxfor things that could be applicable to the dynamo+thunder frontend

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions