Skip to content

Gradient Checkpoint makes FP8 Training Slow #2445

Open
@Yzx835

Description

@Yzx835
Iter baseline torch compile fp8 torch compile + checkpoint fp8 + checkpoint
0 8.8059 15.4885 16.1189 1.8728 1.8823
1 0.8304 1.2415 0.7548 1.2136 1.5298
2 3.0760 0.8356 0.3407 1.0668 1.2351
3 0.8182 0.7047 0.3387 0.7833 1.3250
4 3.3012 2.1133 0.3383 0.7250 1.1671
5 0.7889 1.1844 0.3384 0.7254 1.1676
6 2.9625 0.8600 0.3391 0.7260 1.1666
7 0.7782 0.7020 0.3384 0.7251 1.1650
8 3.2161 1.9663 0.3387 0.7258 1.1652
9 0.7957 1.0336 0.6325 0.7241 1.1706

note:fp8 is with torch compile

This is my env:

cuda_12.4
torch 2.6.0
torchao 0.11.0

This is my test code:

import copy
import os

import pytest
import torch
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
    pytest.skip("Unsupported PyTorch version", allow_module_level=True)

import random
import time

import torch.distributed as dist

# from torch._dynamo.eval_frame import is_compiled_module
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard
from torch.distributed.tensor.parallel import parallelize_module
from torchao.float8 import Float8LinearConfig
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_tensor_parallel import Float8ColwiseParallel, Float8RowwiseParallel
from torchao.testing.float8.dtensor_utils import ToyModel
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer


def setup_distributed():
    world_size = int(os.environ.get("WORLD_SIZE", -1))

    # https://pytorch.org/tutorials/recipes/distributed_device_mesh.html
    device_mesh = init_device_mesh(
        "cuda",
        (world_size, ),
        mesh_dim_names=("fsdp", ),
    )
    # seed must be the same in all processes
    torch.manual_seed(1)
    return device_mesh


def get_qwen_model(model_name="Qwen/Qwen3-8B", dtype=torch.bfloat16):
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=dtype,
        trust_remote_code=True,
        attn_implementation="flash_attention_2",
    )
    model.config.use_cache = False
    return model, tokenizer


def create_fake_inputs(tokenizer, batch_size=1, seq_len=512, device="cuda"):
    vocab_size = tokenizer.vocab_size
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long, device=device)
    labels = input_ids.clone()
    return {"input_ids": input_ids, "labels": labels}


def _test_fp8_backend(mesh: DeviceMesh, use_fp8=True, compile=True, enable_gradient_checkpointing=False):
    device = mesh.device_type

    model_name = "Qwen/Qwen3-8B"
    print("Use FP8:", use_fp8)
    model, tokenizer = get_qwen_model(model_name=model_name)
    # model = model.to(device)

    print(model.device)

    if use_fp8:
        fp8_config = Float8LinearConfig(
            enable_fsdp_float8_all_gather=True,
            force_recompute_fp8_weight_in_bwd=True,
            pad_inner_dim=True,
            round_scales_to_power_of_2=True,
        )
        convert_to_float8_training(
            model,
            config=fp8_config,
            module_filter_fn=lambda mod, fqn: fqn != "lm_head",
        )

    print(model.device)

    if enable_gradient_checkpointing:

        gradient_checkpointing_kwargs = {"use_reentrant": True}
        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)

    # apply FSDP
    fsdp_config = {
        "mesh": mesh,
        "offload_policy": False,
        "reshard_after_forward": True,
    }

    fully_shard(model, **fsdp_config)

    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

    if use_fp8:
        from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
        optimizer.register_step_post_hook(lambda *args, **kwargs: precompute_float8_dynamic_scale_for_fsdp(model))

    print(type(model))

    if compile:
        model = torch.compile(model)

    print(type(model))
    print(model.device)

    model.train()

    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        for i in range(10):
            model.train()
            start_time = time.time()
            optimizer.zero_grad()
            inputs = create_fake_inputs(tokenizer, batch_size=1, seq_len=4096, device=device)

            loss = model(**inputs).loss
            loss.backward()
            optimizer.step()

            end_time = time.time()

            print(f"Iter {i} | Loss: {loss.item():.4f}, Time: {end_time - start_time:.4f}s")


def _test_baseline(mesh: DeviceMesh):
    _test_fp8_backend(mesh, use_fp8=False, compile=False, enable_gradient_checkpointing=False)


def _test_torchcompile(mesh: DeviceMesh):
    _test_fp8_backend(mesh, use_fp8=False, compile=True, enable_gradient_checkpointing=False)


def _test_fp8(mesh: DeviceMesh):
    _test_fp8_backend(mesh, use_fp8=True, compile=True, enable_gradient_checkpointing=False)


def _test_torchcompile_checkpoint(mesh: DeviceMesh):
    _test_fp8_backend(mesh, use_fp8=False, compile=True, enable_gradient_checkpointing=True)


def _test_fp8_checkpoint(mesh: DeviceMesh):
    _test_fp8_backend(mesh, use_fp8=True, compile=True, enable_gradient_checkpointing=True)


if __name__ == "__main__":
    # float8 only works on CUDA H100 so we only test cuda and we follow
    # other test files to not use TestCase but instead just add the test
    # cases in the main func.
    device_mesh = setup_distributed()

    tests = [
        _test_baseline,
        _test_torchcompile,
        _test_fp8,
        _test_torchcompile_checkpoint,
        _test_fp8_checkpoint,
    ]

    for test in tqdm(tests, desc="Running tests"):
        try:
            test(device_mesh)
        except Exception as e:
            print(f"Test {test.__name__} failed with error: {e}")
            raise e

    # torch.distributed.destroy_process_group()

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions