Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 4fa2654

Browse files
committed
enable autocast + compile + FSDP + Float8Linear
Summary: This adds a couple of config options to unbreak autocast + compile + FSDP + Float8Linear. To enable these options, the user needs to do: ``` config.enable_amax_init = False config.enable_pre_and_post_forward = False ``` The `enable_amax_init` config adds the option to disable amax initialization. The reason this is currently broken is: 1. FSDP is not full-graph friendly (regardless of compile) 2. the amax init function has a graph break in distributed code because it uses inplace distributed collectives. I did try to use functional collectives, but that ran into numerical issues with compile, so for now just working around it. 3. graph breaks in Float8Linear code are not supported because of the issue documented in #166 4. so, as a workaround for all of the above, we just skip amax init for now. We do know from NVIDIA that this path is not needed for model convergence, and TE does not support this at all. It was nice for testing but not necessary for training jobs. The second config option disables pre-forward and post-forward. I don't have a repro in a unit test for now, but this does unbreak LLaMa 7B on 8 GPUs with FSDP + compile. Specifically, the thing which is broken in pre-forward/post-forward is assignment on module attributes. My hunch is that this graph breaks if autocast + FSDP are on, and graph breaks are not supported due to (3) above. Test Plan: ``` // unit / integration tests with-proxy test/test_everything.sh // run the LLaMa 7b trainer on 8 GPUs with autocast + compile + FSDP + Float8Linear, no compile errors ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 31fba04 commit 4fa2654

File tree

5 files changed

+160
-7
lines changed

5 files changed

+160
-7
lines changed

float8_experimental/config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,18 @@
1616
# according to their microbatching/pipeline parallel setup.
1717
# Note: this is currently a global flag for simplicity and dynamo performance.
1818
weight_cache_enabled = False
19+
20+
#
21+
# Other
22+
#
23+
24+
# If True, on the first iteration of Float8Linear the amaxes will be
25+
# initialized with the incoming data. As of 2023-12-30, this doesn't work
26+
# with autocast + torch.compile + FSDP. Enabling this option is nice for
27+
# testing, but this is not necessary for real training jobs.
28+
enable_amax_init = True
29+
30+
# If True, pre-forward and post-forward functions are run. As of 2023-12-30,
31+
# this doesn't work with autocast + torch.compile + FSDP. Enabling this
32+
# option is useful for safety, but not strictly necessary.
33+
enable_pre_and_post_forward = True

float8_experimental/float8_linear.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,14 @@ def __init__(self, *args, **kwargs):
156156
# Note: is_amax_initialized is not a buffer to avoid data dependent
157157
# control flow visible to dynamo
158158
# TODO(future PR): add serialization for this flag
159-
self.is_amax_initialized = False
159+
self.is_amax_initialized = not config.enable_amax_init
160160

161161
# Syncing of amaxes and scales happens outside of this function. This
162162
# flag is here to enforce that the user does not forget to do this.
163-
self.amax_and_scale_synced = False
163+
self.amax_and_scale_synced = not config.enable_amax_init
164164

165165
# This is needed to properly handle autocast in the amax/scale
166-
# update function
166+
# update function for torch.float16
167167
self.last_seen_input_dtype = None
168168

169169
# If true, this enables TP+SP style distributed comms in TP primitives
@@ -177,6 +177,10 @@ def __init__(self, *args, **kwargs):
177177
# will access the scale when it has ensured that it is on GPU.
178178
self._float8_tensor_ctor = lambda *args, **kwargs: Float8Tensor(*args, **kwargs)
179179

180+
# pre_forward and post_forward are currently broken with FSDP
181+
# and torch.compile, this option can disable them
182+
self.enable_pre_and_post_forward = config.enable_pre_and_post_forward
183+
180184
if config.allocate_float8_weight_cache_buffers:
181185
# this is a buffer to get `to(dtype)` for free
182186
# TODO(future): hide this from serialization
@@ -212,7 +216,6 @@ def cast_x_to_float8(
212216
# if we need CPU support in the future, we can add it
213217
autocast_dtype = torch.get_autocast_gpu_dtype()
214218
x = x.to(autocast_dtype)
215-
self.bias_dtype = autocast_dtype
216219

217220
scale_fn_name = self.recipe.scale_fn_name
218221
_maybe_initialize_amaxes_scales_for_float8_cast(
@@ -288,6 +291,8 @@ def cast_y_to_float8_in_bw(
288291
return y
289292

290293
def float8_pre_forward(self, x):
294+
if not self.enable_pre_and_post_forward:
295+
return
291296
if (
292297
self.is_amax_initialized
293298
and (not self.amax_and_scale_synced)
@@ -299,6 +304,8 @@ def float8_pre_forward(self, x):
299304
self.last_seen_input_dtype = x.dtype
300305

301306
def float8_post_forward(self):
307+
if not self.enable_pre_and_post_forward:
308+
return
302309
# Ensure that calling forward again will fail until the user syncs
303310
# amaxes and scales
304311
self.is_amax_initialized = True
@@ -335,7 +342,7 @@ def forward(self, x):
335342
y = self.cast_y_to_float8_in_bw(y, self.emulate)
336343

337344
if self.bias is not None:
338-
y = y + self.bias.to(self.bias_dtype)
345+
y = y + self.bias.to(y.dtype)
339346

340347
self.float8_post_forward()
341348
return y
@@ -356,8 +363,6 @@ def from_float(cls, mod, emulate: bool = False):
356363
new_mod.weight = mod.weight
357364
new_mod.bias = mod.bias
358365
new_mod.emulate = emulate
359-
if mod.bias is not None:
360-
new_mod.bias_dtype = mod.bias.dtype
361366
# I think its okay to send all params and buffers to device
362367
new_mod.to(mod.weight.device)
363368
new_mod.add_weight_tag()

test/test_everything.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pytest test/test_base.py
77
pytest test/test_sam.py
88
pytest test/test_compile.py
99
./test/test_fsdp.sh
10+
./test/test_fsdp_compile.sh
1011
./test/test_tp.sh
1112

1213
echo "all tests successful"

test/test_fsdp_compile.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Test autocast + torch.compile + FSDP + Float8Linear
9+
"""
10+
11+
import os
12+
import warnings
13+
14+
import fire
15+
16+
import torch
17+
import torch.distributed as dist
18+
import torch.multiprocessing as mp
19+
import torch.nn as nn
20+
from float8_experimental.float8_linear import Float8Linear
21+
from float8_experimental.float8_linear_utils import (
22+
swap_linear_with_float8_linear,
23+
sync_float8_amax_and_scale_history,
24+
)
25+
from float8_experimental import config
26+
from torch.distributed.fsdp import (
27+
FullStateDictConfig,
28+
FullyShardedDataParallel as FSDP,
29+
StateDictType,
30+
)
31+
32+
torch.manual_seed(0)
33+
34+
B, M, K, N = 8, 8, 32, 32
35+
lr = 0.01
36+
N_ITER = 1
37+
38+
39+
def setup(rank, world_size):
40+
os.environ["MASTER_ADDR"] = "localhost"
41+
os.environ["MASTER_PORT"] = "12355"
42+
43+
# initialize the process group
44+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
45+
46+
47+
def cleanup():
48+
dist.destroy_process_group()
49+
50+
51+
def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
52+
m = nn.Sequential(
53+
nn.Linear(K, N, dtype=base_dtype),
54+
nn.ReLU(),
55+
)
56+
swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate)
57+
return m
58+
59+
60+
# taken from https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
61+
# and modified
62+
def fsdp_main(rank, world_size, args):
63+
setup(rank, world_size)
64+
torch.cuda.set_device(rank)
65+
66+
emulate, = args
67+
68+
# composability of torch.compile + FSDP + autocast + Float8Linear
69+
# as fo 2023-12-30
70+
71+
# without any changes to the Float8Linear, we get this error:
72+
# https://gist.github.com/vkuzo/3bcb81806cc92f99ac0b9c5fdf287730
73+
74+
# if we initialize Float8Linear with is_amax_initialized=True and
75+
# amax_and_scale_synced=True, we get
76+
# https://gist.github.com/vkuzo/ed8e168fd9f7463f1fce34301334ab55
77+
# to get around this, we can disable amax init
78+
config.enable_amax_init = False
79+
80+
# finally, if we remove the usage of self.bias_dtype, then
81+
# things work e2e. Note that FSDP does not support full-graph compile
82+
# regardless of float8.
83+
84+
model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=torch.bfloat16).to(
85+
rank
86+
)
87+
88+
# To compile FSDP, we need use_orig_params to True
89+
model = FSDP(model, use_orig_params=True)
90+
91+
optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size)
92+
input_local = torch.randn(B, M, K, N, device='cuda')
93+
sync_float8_func = torch.compile(sync_float8_amax_and_scale_history)
94+
95+
model = torch.compile(model)
96+
97+
for iter in range(N_ITER):
98+
optimizer.zero_grad()
99+
with torch.autocast('cuda'):
100+
y_local = model(input_local)
101+
y_local.sum().backward()
102+
sync_float8_func(model)
103+
optimizer.step()
104+
105+
print('done!')
106+
cleanup()
107+
108+
109+
def run():
110+
emulate = False
111+
if not torch.cuda.is_available():
112+
warnings.warn("CUDA not available, running in emulation_mode")
113+
emulate = True
114+
elif torch.cuda.get_device_capability() < (9, 0):
115+
warnings.warn(
116+
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode"
117+
)
118+
emulate = True
119+
120+
WORLD_SIZE = torch.cuda.device_count()
121+
args = (emulate,)
122+
mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)
123+
124+
125+
if __name__ == "__main__":
126+
fire.Fire(run)

test/test_fsdp_compile.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
3+
# terminate script on first error
4+
set -e
5+
6+
NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/test_fsdp_compile.py

0 commit comments

Comments
 (0)