Open
Description
Iter | baseline | torch compile | fp8 | torch compile + checkpoint | fp8 + checkpoint |
---|---|---|---|---|---|
0 | 8.8059 | 15.4885 | 16.1189 | 1.8728 | 1.8823 |
1 | 0.8304 | 1.2415 | 0.7548 | 1.2136 | 1.5298 |
2 | 3.0760 | 0.8356 | 0.3407 | 1.0668 | 1.2351 |
3 | 0.8182 | 0.7047 | 0.3387 | 0.7833 | 1.3250 |
4 | 3.3012 | 2.1133 | 0.3383 | 0.7250 | 1.1671 |
5 | 0.7889 | 1.1844 | 0.3384 | 0.7254 | 1.1676 |
6 | 2.9625 | 0.8600 | 0.3391 | 0.7260 | 1.1666 |
7 | 0.7782 | 0.7020 | 0.3384 | 0.7251 | 1.1650 |
8 | 3.2161 | 1.9663 | 0.3387 | 0.7258 | 1.1652 |
9 | 0.7957 | 1.0336 | 0.6325 | 0.7241 | 1.1706 |
note:fp8 is with torch compile
This is my env:
cuda_12.4
torch 2.6.0
torchao 0.11.0
This is my test code:
import copy
import os
import pytest
import torch
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
import random
import time
import torch.distributed as dist
# from torch._dynamo.eval_frame import is_compiled_module
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard
from torch.distributed.tensor.parallel import parallelize_module
from torchao.float8 import Float8LinearConfig
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_tensor_parallel import Float8ColwiseParallel, Float8RowwiseParallel
from torchao.testing.float8.dtensor_utils import ToyModel
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
def setup_distributed():
world_size = int(os.environ.get("WORLD_SIZE", -1))
# https://pytorch.org/tutorials/recipes/distributed_device_mesh.html
device_mesh = init_device_mesh(
"cuda",
(world_size, ),
mesh_dim_names=("fsdp", ),
)
# seed must be the same in all processes
torch.manual_seed(1)
return device_mesh
def get_qwen_model(model_name="Qwen/Qwen3-8B", dtype=torch.bfloat16):
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
trust_remote_code=True,
attn_implementation="flash_attention_2",
)
model.config.use_cache = False
return model, tokenizer
def create_fake_inputs(tokenizer, batch_size=1, seq_len=512, device="cuda"):
vocab_size = tokenizer.vocab_size
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long, device=device)
labels = input_ids.clone()
return {"input_ids": input_ids, "labels": labels}
def _test_fp8_backend(mesh: DeviceMesh, use_fp8=True, compile=True, enable_gradient_checkpointing=False):
device = mesh.device_type
model_name = "Qwen/Qwen3-8B"
print("Use FP8:", use_fp8)
model, tokenizer = get_qwen_model(model_name=model_name)
# model = model.to(device)
print(model.device)
if use_fp8:
fp8_config = Float8LinearConfig(
enable_fsdp_float8_all_gather=True,
force_recompute_fp8_weight_in_bwd=True,
pad_inner_dim=True,
round_scales_to_power_of_2=True,
)
convert_to_float8_training(
model,
config=fp8_config,
module_filter_fn=lambda mod, fqn: fqn != "lm_head",
)
print(model.device)
if enable_gradient_checkpointing:
gradient_checkpointing_kwargs = {"use_reentrant": True}
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
# apply FSDP
fsdp_config = {
"mesh": mesh,
"offload_policy": False,
"reshard_after_forward": True,
}
fully_shard(model, **fsdp_config)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
if use_fp8:
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
optimizer.register_step_post_hook(lambda *args, **kwargs: precompute_float8_dynamic_scale_for_fsdp(model))
print(type(model))
if compile:
model = torch.compile(model)
print(type(model))
print(model.device)
model.train()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
for i in range(10):
model.train()
start_time = time.time()
optimizer.zero_grad()
inputs = create_fake_inputs(tokenizer, batch_size=1, seq_len=4096, device=device)
loss = model(**inputs).loss
loss.backward()
optimizer.step()
end_time = time.time()
print(f"Iter {i} | Loss: {loss.item():.4f}, Time: {end_time - start_time:.4f}s")
def _test_baseline(mesh: DeviceMesh):
_test_fp8_backend(mesh, use_fp8=False, compile=False, enable_gradient_checkpointing=False)
def _test_torchcompile(mesh: DeviceMesh):
_test_fp8_backend(mesh, use_fp8=False, compile=True, enable_gradient_checkpointing=False)
def _test_fp8(mesh: DeviceMesh):
_test_fp8_backend(mesh, use_fp8=True, compile=True, enable_gradient_checkpointing=False)
def _test_torchcompile_checkpoint(mesh: DeviceMesh):
_test_fp8_backend(mesh, use_fp8=False, compile=True, enable_gradient_checkpointing=True)
def _test_fp8_checkpoint(mesh: DeviceMesh):
_test_fp8_backend(mesh, use_fp8=True, compile=True, enable_gradient_checkpointing=True)
if __name__ == "__main__":
# float8 only works on CUDA H100 so we only test cuda and we follow
# other test files to not use TestCase but instead just add the test
# cases in the main func.
device_mesh = setup_distributed()
tests = [
_test_baseline,
_test_torchcompile,
_test_fp8,
_test_torchcompile_checkpoint,
_test_fp8_checkpoint,
]
for test in tqdm(tests, desc="Running tests"):
try:
test(device_mesh)
except Exception as e:
print(f"Test {test.__name__} failed with error: {e}")
raise e
# torch.distributed.destroy_process_group()