This repository was archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 19
enable autocast + compile + FSDP + Float8Linear #172
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
Test autocast + torch.compile + FSDP + Float8Linear | ||
""" | ||
|
||
import os | ||
import warnings | ||
|
||
import fire | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.multiprocessing as mp | ||
import torch.nn as nn | ||
from float8_experimental import config | ||
from float8_experimental.float8_linear import Float8Linear | ||
from float8_experimental.float8_linear_utils import ( | ||
swap_linear_with_float8_linear, | ||
sync_float8_amax_and_scale_history, | ||
) | ||
from torch.distributed.fsdp import ( | ||
FullStateDictConfig, | ||
FullyShardedDataParallel as FSDP, | ||
StateDictType, | ||
) | ||
|
||
torch.manual_seed(0) | ||
|
||
B, M, K, N = 8, 8, 32, 32 | ||
lr = 0.01 | ||
N_ITER = 1 | ||
|
||
|
||
def setup(rank, world_size): | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
os.environ["MASTER_PORT"] = "12355" | ||
|
||
# initialize the process group | ||
dist.init_process_group("nccl", rank=rank, world_size=world_size) | ||
|
||
|
||
def cleanup(): | ||
dist.destroy_process_group() | ||
|
||
|
||
def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): | ||
m = nn.Sequential( | ||
nn.Linear(K, N, dtype=base_dtype), | ||
nn.ReLU(), | ||
) | ||
swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate) | ||
return m | ||
|
||
|
||
# taken from https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html | ||
# and modified | ||
def fsdp_main(rank, world_size, args): | ||
setup(rank, world_size) | ||
torch.cuda.set_device(rank) | ||
|
||
(emulate,) = args | ||
|
||
# composability of torch.compile + FSDP + autocast + Float8Linear | ||
# as fo 2023-12-30 | ||
|
||
# without any changes to the Float8Linear, we get this error: | ||
# https://gist.github.com/vkuzo/3bcb81806cc92f99ac0b9c5fdf287730 | ||
|
||
# if we initialize Float8Linear with is_amax_initialized=True and | ||
# amax_and_scale_synced=True, we get | ||
# https://gist.github.com/vkuzo/ed8e168fd9f7463f1fce34301334ab55 | ||
# to get around this, we can disable amax init | ||
config.enable_amax_init = False | ||
|
||
# finally, if we remove the usage of self.bias_dtype, then | ||
# things work e2e. Note that FSDP does not support full-graph compile | ||
# regardless of float8. | ||
|
||
model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=torch.bfloat16).to( | ||
rank | ||
) | ||
|
||
# To compile FSDP, we need use_orig_params to True | ||
model = FSDP(model, use_orig_params=True) | ||
|
||
optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size) | ||
input_local = torch.randn(B, M, K, N, device="cuda") | ||
sync_float8_func = torch.compile(sync_float8_amax_and_scale_history) | ||
|
||
model = torch.compile(model) | ||
|
||
for _iter in range(N_ITER): | ||
optimizer.zero_grad() | ||
with torch.autocast("cuda"): | ||
y_local = model(input_local) | ||
y_local.sum().backward() | ||
sync_float8_func(model) | ||
optimizer.step() | ||
|
||
print("done!") | ||
cleanup() | ||
|
||
|
||
def run(): | ||
emulate = False | ||
if not torch.cuda.is_available(): | ||
warnings.warn("CUDA not available, running in emulation_mode", stacklevel=2) | ||
emulate = True | ||
elif torch.cuda.get_device_capability() < (9, 0): | ||
warnings.warn( | ||
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode", | ||
stacklevel=2, | ||
) | ||
emulate = True | ||
|
||
WORLD_SIZE = torch.cuda.device_count() | ||
args = (emulate,) | ||
mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(run) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/bin/bash | ||
|
||
# terminate script on first error | ||
set -e | ||
|
||
NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/test_fsdp_compile.py |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this change is not called out in the configs, but this just removes the need to store a module attribute, which also makes this code more full-graph compile friendly. Modifying module attributes seems to graph break if autocast and FSDP are both enabled.