Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

enable autocast + compile + FSDP + Float8Linear #172

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,18 @@
# according to their microbatching/pipeline parallel setup.
# Note: this is currently a global flag for simplicity and dynamo performance.
weight_cache_enabled = False

#
# Other
#

# If True, on the first iteration of Float8Linear the amaxes will be
# initialized with the incoming data. As of 2023-12-30, this doesn't work
# with autocast + torch.compile + FSDP. Enabling this option is nice for
# testing, but this is not necessary for real training jobs.
enable_amax_init = True

# If True, pre-forward and post-forward functions are run. As of 2023-12-30,
# this doesn't work with autocast + torch.compile + FSDP. Enabling this
# option is useful for safety, but not strictly necessary.
enable_pre_and_post_forward = True
19 changes: 12 additions & 7 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,14 @@ def __init__(self, *args, **kwargs):
# Note: is_amax_initialized is not a buffer to avoid data dependent
# control flow visible to dynamo
# TODO(future PR): add serialization for this flag
self.is_amax_initialized = False
self.is_amax_initialized = not config.enable_amax_init

# Syncing of amaxes and scales happens outside of this function. This
# flag is here to enforce that the user does not forget to do this.
self.amax_and_scale_synced = False
self.amax_and_scale_synced = not config.enable_amax_init

# This is needed to properly handle autocast in the amax/scale
# update function
# update function for torch.float16
self.last_seen_input_dtype = None

# If true, this enables TP+SP style distributed comms in TP primitives
Expand All @@ -177,6 +177,10 @@ def __init__(self, *args, **kwargs):
# will access the scale when it has ensured that it is on GPU.
self._float8_tensor_ctor = lambda *args, **kwargs: Float8Tensor(*args, **kwargs)

# pre_forward and post_forward are currently broken with FSDP
# and torch.compile, this option can disable them
self.enable_pre_and_post_forward = config.enable_pre_and_post_forward

if config.allocate_float8_weight_cache_buffers:
# this is a buffer to get `to(dtype)` for free
# TODO(future): hide this from serialization
Expand Down Expand Up @@ -212,7 +216,6 @@ def cast_x_to_float8(
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
x = x.to(autocast_dtype)
self.bias_dtype = autocast_dtype

scale_fn_name = self.recipe.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
Expand Down Expand Up @@ -288,6 +291,8 @@ def cast_y_to_float8_in_bw(
return y

def float8_pre_forward(self, x):
if not self.enable_pre_and_post_forward:
return
if (
self.is_amax_initialized
and (not self.amax_and_scale_synced)
Expand All @@ -299,6 +304,8 @@ def float8_pre_forward(self, x):
self.last_seen_input_dtype = x.dtype

def float8_post_forward(self):
if not self.enable_pre_and_post_forward:
return
# Ensure that calling forward again will fail until the user syncs
# amaxes and scales
self.is_amax_initialized = True
Expand Down Expand Up @@ -335,7 +342,7 @@ def forward(self, x):
y = self.cast_y_to_float8_in_bw(y, self.emulate)

if self.bias is not None:
y = y + self.bias.to(self.bias_dtype)
y = y + self.bias.to(y.dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is not called out in the configs, but this just removes the need to store a module attribute, which also makes this code more full-graph compile friendly. Modifying module attributes seems to graph break if autocast and FSDP are both enabled.


self.float8_post_forward()
return y
Expand All @@ -356,8 +363,6 @@ def from_float(cls, mod, emulate: bool = False):
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.emulate = emulate
if mod.bias is not None:
new_mod.bias_dtype = mod.bias.dtype
# I think its okay to send all params and buffers to device
new_mod.to(mod.weight.device)
new_mod.add_weight_tag()
Expand Down
1 change: 1 addition & 0 deletions test/test_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pytest test/test_base.py
pytest test/test_sam.py
pytest test/test_compile.py
./test/test_fsdp.sh
./test/test_fsdp_compile.sh
./test/test_tp.sh

echo "all tests successful"
127 changes: 127 additions & 0 deletions test/test_fsdp_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
Test autocast + torch.compile + FSDP + Float8Linear
"""

import os
import warnings

import fire

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from float8_experimental import config
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
from torch.distributed.fsdp import (
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
StateDictType,
)

torch.manual_seed(0)

B, M, K, N = 8, 8, 32, 32
lr = 0.01
N_ITER = 1


def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"

# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup():
dist.destroy_process_group()


def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
m = nn.Sequential(
nn.Linear(K, N, dtype=base_dtype),
nn.ReLU(),
)
swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate)
return m


# taken from https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
# and modified
def fsdp_main(rank, world_size, args):
setup(rank, world_size)
torch.cuda.set_device(rank)

(emulate,) = args

# composability of torch.compile + FSDP + autocast + Float8Linear
# as fo 2023-12-30

# without any changes to the Float8Linear, we get this error:
# https://gist.github.com/vkuzo/3bcb81806cc92f99ac0b9c5fdf287730

# if we initialize Float8Linear with is_amax_initialized=True and
# amax_and_scale_synced=True, we get
# https://gist.github.com/vkuzo/ed8e168fd9f7463f1fce34301334ab55
# to get around this, we can disable amax init
config.enable_amax_init = False

# finally, if we remove the usage of self.bias_dtype, then
# things work e2e. Note that FSDP does not support full-graph compile
# regardless of float8.

model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=torch.bfloat16).to(
rank
)

# To compile FSDP, we need use_orig_params to True
model = FSDP(model, use_orig_params=True)

optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size)
input_local = torch.randn(B, M, K, N, device="cuda")
sync_float8_func = torch.compile(sync_float8_amax_and_scale_history)

model = torch.compile(model)

for _iter in range(N_ITER):
optimizer.zero_grad()
with torch.autocast("cuda"):
y_local = model(input_local)
y_local.sum().backward()
sync_float8_func(model)
optimizer.step()

print("done!")
cleanup()


def run():
emulate = False
if not torch.cuda.is_available():
warnings.warn("CUDA not available, running in emulation_mode", stacklevel=2)
emulate = True
elif torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode",
stacklevel=2,
)
emulate = True

WORLD_SIZE = torch.cuda.device_count()
args = (emulate,)
mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)


if __name__ == "__main__":
fire.Fire(run)
6 changes: 6 additions & 0 deletions test/test_fsdp_compile.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash

# terminate script on first error
set -e

NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/test_fsdp_compile.py