Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

test FSDP with fp8 linear toy model #11

Merged
merged 1 commit into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions float8_playground/float8_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.distributed as dist

# Helpful visualizer for debugging (only supports fp32):
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
Expand All @@ -11,15 +12,22 @@
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
EPS = 1e-12

@torch.no_grad()
def amax_to_scale(amax, dtype):
amax = amax.detach()
if dtype == torch.float8_e4m3fn:
return E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else: # e5m2
return E5M2_MAX_POS / torch.clamp(amax, min=EPS)

@torch.no_grad()
def tensor_to_scale(x, dtype):
amax = torch.max(torch.abs(x.detach()))
amax = torch.max(torch.abs(x))
amax_copy = amax.detach().clone()
# Hack: inline the distributed logic, just for testing numerics
# with FSDP.
# TODO(future): better composability with distributed
if dist.is_initialized():
dist.all_reduce(amax, op=dist.ReduceOp.MAX)
return amax_to_scale(amax, dtype)

def compute_error(x, y):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_everything.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

# terminate script on first error
set -e

python tests/test.py
python tests/test_sam.py
./tests/test_fsdp.sh

echo "all tests successful"
54 changes: 45 additions & 9 deletions tests/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
StateDictType,
)

# set up float8 path
import context

from float8_linear import swap_linear_with_float8_linear

torch.manual_seed(0)

# assumes user is running the script from /data/users/{user}/float8_playground
Expand All @@ -47,19 +52,26 @@ def setup(rank, world_size):
def cleanup():
dist.destroy_process_group()

def get_model(K, N):
return nn.Sequential(
def get_model(K, N, is_fp8):
m = nn.Sequential(
nn.Linear(K, N),
# torch.float8_e4m3fn is not serializeable yet, for now
# force model output to be float so we can serialize it
# TODO revert this once serialization works
nn.ReLU(),
)
if is_fp8:
swap_linear_with_float8_linear(m)
return m

# taken from https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
# and modified
def fsdp_main(rank, world_size, args):
setup(rank, world_size)

torch.cuda.set_device(rank)

model = get_model(K, N).to(rank)
is_fp8, = args
model = get_model(K, N, is_fp8=is_fp8).to(rank)
model.load_state_dict(torch.load(sd_in_fname))
model = FSDP(model)
# Note: we need to multiply by world_size here to match single GPU
Expand Down Expand Up @@ -100,18 +112,18 @@ def fsdp_main(rank, world_size, args):

cleanup()

def run(mode):
def run(mode: str, is_fp8: bool):

if mode == 'generate':
# generate reference input
ref_input = torch.randn(B, M, K)
model = get_model(K, N)
model = get_model(K, N, is_fp8=is_fp8)
torch.save(ref_input, input_fname)
torch.save(model.state_dict(), sd_in_fname)

elif mode == 'single_gpu':
ref_input = torch.load(input_fname)
model = get_model(K, N)
model = get_model(K, N, is_fp8=is_fp8)
model.load_state_dict(torch.load(sd_in_fname))
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
optimizer.zero_grad()
Expand All @@ -124,7 +136,7 @@ def run(mode):

elif mode == 'fsdp':
WORLD_SIZE = torch.cuda.device_count()
args = ()
args = (is_fp8,)
mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)

elif mode == 'analyze':
Expand All @@ -137,8 +149,32 @@ def run(mode):
sd_out_single_gpu = torch.load(sd_out_single_gpu_fname)
sd_out_fsdp = torch.load(sd_out_fsdp_fname)
for k, v1 in sd_out_single_gpu.items():

v2 = sd_out_fsdp[k]
torch.testing.assert_close(v1, v2)
if is_fp8:
# Note: for fp8 single-node vs FSDP, we are not expected
# to match the scale of the weight gradient. Because of this,
# we also cannot match the weight gradient, and therefore
# after the first optimizer update we won't be able to match
# anything.
#
# Reasoning is the order of operations of calculating dL/dW:
# a. single node:
# 1. calculate dL_dW and s_dL_dW
# 2. you're done
# b. FSDP:
# 1. calculate dL_dW and s_dL_dW of each slice
# 2. reduce using summation
#
# a and b cannot always match because calculating the scale
# involves taking max(dL_dW), FSDP reduces the gradients, and
# max(abs(a), abs(b)) != max(abs(a + b))
if ('weight' in k) or ('dL_dW' in k):
pass
else:
torch.testing.assert_close(v1, v2)
else:
torch.testing.assert_close(v1, v2)
print('state dict testing single_gpu vs FSDP success')


Expand Down
35 changes: 25 additions & 10 deletions tests/test_fsdp.sh
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
#!/bin/bash

# generate the test data
python tests/test_fsdp.py --mode generate
# terminate script on first error
set -e

# generate single GPU model output and updated state dict
python tests/test_fsdp.py --mode single_gpu
launch() {
echo "launching IS_FP8 $IS_FP8"

# generate FSDP model output and updated state dict
# the NCCL_DEBUG setting is to avoid log spew
# the CUDA_VISIBLE_DEVICES setting is for easy debugging
NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python tests/test_fsdp.py --mode fsdp
# generate the test data
python tests/test_fsdp.py --mode generate --is_fp8 $IS_FP8

# compare the outputs and state dicts and verify equivalence
python tests/test_fsdp.py --mode analyze
# generate single GPU model output and updated state dict
python tests/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8

# generate FSDP model output and updated state dict
# the NCCL_DEBUG setting is to avoid log spew
# the CUDA_VISIBLE_DEVICES setting is for easy debugging
NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python tests/test_fsdp.py \
--mode fsdp --is_fp8 $IS_FP8

# compare the outputs and state dicts and verify equivalence
python tests/test_fsdp.py --mode analyze --is_fp8 $IS_FP8

echo "done"
}

for IS_FP8 in False True
do
launch
done