Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[wip] make Float8Linear amax init more FSDP+compile friendly #171

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as _functional_collectives

# Helpful visualizer for debugging (only supports fp32):
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
Expand Down Expand Up @@ -60,7 +61,12 @@ def tensor_to_amax(x, distributed_reduction=False):
# If the user did not ask for it, assume that it will
# happen elsewhere.
if distributed_reduction and dist.is_initialized():
dist.all_reduce(amax, op=dist.ReduceOp.MAX)
# TODO(future): support process groups
ranks = list(range(dist.get_world_size()))
# print('ranks', ranks)
# print('old amax', amax)
amax = _functional_collectives.all_reduce(amax, "max", group=ranks)
# print('new amax', amax)

return amax

Expand Down
22 changes: 15 additions & 7 deletions test/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,13 @@ def fsdp_main(rank, world_size, args):
ref_input_local = ref_input_global[bsz_local_start:bsz_local_end].to(rank)

sync_float8_func = sync_float8_amax_and_scale_history
if compile:
if compile and False:
sync_float8_func = torch.compile(
sync_float8_amax_and_scale_history, fullgraph=fullgraph
)
model = torch.compile(model)

print(model)

def forward_backward(model):
optimizer.zero_grad()
Expand All @@ -118,14 +121,19 @@ def forward_backward(model):
return y_local

for iter in range(N_ITER):
# We first run one iteration without compile, as a workaround to compile float8 layer.
# In the first iter, float8 layers go to the branches of "self.is_amax_initialized == False"
# After that, float8 layers go the the branches of "self.is_amax_initialized == True"
# TODO: Need to fix compile to run wihtout this workaround.
if iter == 1 and compile:
model = torch.compile(model, fullgraph=fullgraph)
y_local = forward_backward(model)

if compile and False:
base = model._orig_mod._fsdp_wrapped_module
else:
base = model._fsdp_wrapped_module
print('0_x', base[0].fp8_amax_history_x)
print('2_x', base[2].fp8_amax_history_x)
print('0_w', base[0].fp8_amax_history_w)
print('2_w', base[2].fp8_amax_history_w)
print('0_g', base[0].fp8_amax_history_dL_dY)
print('2_g', base[2].fp8_amax_history_dL_dY)

# get global y
y_global = [
torch.zeros(*y_local.shape, dtype=base_dtype).to(rank)
Expand Down
11 changes: 6 additions & 5 deletions test/test_fsdp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ launch() {
echo "launching IS_FP8 $IS_FP8, compile_fsdp $COMPILE, fullgraph $FULLGRAPH"

# generate the test data
python test/test_fsdp.py --mode generate --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
echo "Success: ✅"
# python test/test_fsdp.py --mode generate --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
# echo "Success: ✅"

# generate single GPU model output and updated state dict
python test/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
echo "Success: ✅"
# python test/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
# echo "Success: ✅"

# generate FSDP model output and updated state dict
# the NCCL_DEBUG setting is to avoid log spew
Expand All @@ -30,7 +30,8 @@ launch() {
}

# IS_FP8, COMPILE, FULLGRAPH
for i in False,False,False True,False,False True,True,False
# for i in False,False,False True,False,False True,True,False
for i in True,True,False
do
IFS=","; set -- $i;
IS_FP8=$1; COMPILE=$2; FULLGRAPH=$3
Expand Down