Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit c3f5c9a

Browse files
committed
add test to fsdp
1 parent 177173a commit c3f5c9a

File tree

3 files changed

+72
-20
lines changed

3 files changed

+72
-20
lines changed

float8_experimental/float8_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def forward(
189189
emulate=emulate,
190190
)
191191
if recompute_float8_weight:
192-
# This should be set to True when using traditional fsdp to avoid saving
192+
# This should be set to True when using traditional fsdp to avoid
193193
# saving the unsharded weight for backwards
194194
ctx.save_for_backward(
195195
x_fp8, original_weight, weight_scale, weight_amax_buffer

test/test_fsdp.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,22 @@ def cleanup():
6161
dist.destroy_process_group()
6262

6363

64-
def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
64+
def get_model(
65+
K, N, is_fp8, emulate, base_dtype=torch.float32, recompute_weight_cast: bool = False
66+
):
6567
m = nn.Sequential(
6668
nn.Linear(K, N, dtype=base_dtype),
6769
nn.ReLU(),
6870
nn.Linear(N, N, dtype=base_dtype),
6971
nn.ReLU(),
7072
)
7173
if is_fp8:
72-
swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate)
74+
swap_linear_with_float8_linear(
75+
m,
76+
Float8Linear,
77+
emulate=emulate,
78+
recompute_weight_cast=recompute_weight_cast,
79+
)
7380
return m
7481

7582

@@ -81,10 +88,15 @@ def fsdp_main(rank, world_size, args):
8188

8289
# TODO: We set fullgraph as an option. However, it currently doesn't work for fullgraph compile.
8390
# We can investigate and fix it later.
84-
is_fp8, emulate, base_dtype, compile, fullgraph = args
85-
model = get_model(K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype).to(
86-
rank
87-
)
91+
is_fp8, emulate, base_dtype, compile, fullgraph, recompute_weight_cast = args
92+
model = get_model(
93+
K,
94+
N,
95+
is_fp8=is_fp8,
96+
emulate=emulate,
97+
base_dtype=base_dtype,
98+
recompute_weight_cast=recompute_weight_cast,
99+
).to(rank)
88100
model.load_state_dict(torch.load(sd_in_fname))
89101
# To compile FSDP, we need use_orig_params to True
90102
model = FSDP(model, use_orig_params=True)
@@ -148,7 +160,13 @@ def forward_backward(model):
148160
cleanup()
149161

150162

151-
def run(mode: str, is_fp8: bool, compile_fsdp: bool = False, fullgraph: bool = False):
163+
def run(
164+
mode: str,
165+
is_fp8: bool,
166+
compile_fsdp: bool = False,
167+
fullgraph: bool = False,
168+
recompute_weight_cast: bool = False,
169+
):
152170
print(f"Mode: {mode}".center(100, "-"))
153171
base_dtype = torch.bfloat16
154172
if not os.path.exists(data_dir):
@@ -169,15 +187,25 @@ def run(mode: str, is_fp8: bool, compile_fsdp: bool = False, fullgraph: bool = F
169187
# generate reference input
170188
ref_input = torch.randn(B, M, K).cuda().to(base_dtype)
171189
model = get_model(
172-
K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype
190+
K,
191+
N,
192+
is_fp8=is_fp8,
193+
emulate=emulate,
194+
base_dtype=base_dtype,
195+
recompute_weight_cast=recompute_weight_cast,
173196
).cuda()
174197
torch.save(ref_input, input_fname)
175198
torch.save(model.state_dict(), sd_in_fname)
176199

177200
elif mode == "single_gpu":
178201
ref_input = torch.load(input_fname).to(base_dtype)
179202
model = get_model(
180-
K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype
203+
K,
204+
N,
205+
is_fp8=is_fp8,
206+
emulate=emulate,
207+
base_dtype=base_dtype,
208+
recompute_weight_cast=recompute_weight_cast,
181209
).cuda()
182210
model.load_state_dict(torch.load(sd_in_fname))
183211
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
@@ -199,7 +227,14 @@ def forward_backward():
199227
elif mode == "fsdp":
200228
WORLD_SIZE = torch.cuda.device_count()
201229
# We only compile for fsdp, and compare the numerics with signle-gpu no-compile
202-
args = (is_fp8, emulate, base_dtype, compile_fsdp, fullgraph)
230+
args = (
231+
is_fp8,
232+
emulate,
233+
base_dtype,
234+
compile_fsdp,
235+
fullgraph,
236+
recompute_weight_cast,
237+
)
203238
mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)
204239

205240
elif mode == "analyze":

test/test_fsdp.sh

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
set -e
55

66
launch() {
7-
echo "launching IS_FP8 $IS_FP8, compile_fsdp $COMPILE, fullgraph $FULLGRAPH"
7+
echo "Launching test with the following configuration:"
8+
echo "IS_FP8: $IS_FP8"
9+
echo "compile_fsdp: $COMPILE"
10+
echo "fullgraph: $FULLGRAPH"
11+
echo "recompute_weight_cast: $RECOMPUTE"
812

913
# generate the test data
10-
python test/test_fsdp.py --mode generate --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
14+
python test/test_fsdp.py --mode generate --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH --recompute_weight_cast $RECOMPUTE
1115
echo "Success: ✅"
1216

1317
# generate single GPU model output and updated state dict
14-
python test/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
18+
python test/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH --recompute_weight_cast $RECOMPUTE
1519
echo "Success: ✅"
1620

1721
# generate FSDP model output and updated state dict
@@ -20,19 +24,32 @@ launch() {
2024
# the NCCL_NET setting is to work around transient issues on a
2125
# specific host (`devgpu001.nha2`)
2226
NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 NCCL_NET=SOCKET python test/test_fsdp.py \
23-
--mode fsdp --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
27+
--mode fsdp --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH --recompute_weight_cast $RECOMPUTE
2428

2529
# compare the outputs and state dicts and verify equivalence
26-
python test/test_fsdp.py --mode analyze --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
30+
python test/test_fsdp.py --mode analyze --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH --recompute_weight_cast $RECOMPUTE
2731
echo "Success: ✅"
2832

2933
echo "✅ All Tests Passed ✅"
3034
}
3135

32-
# IS_FP8, COMPILE, FULLGRAPH
33-
for i in False,False,False True,False,False True,True,False
36+
# Loop over different combinations of settings
37+
for i in False,False,False,False \
38+
True,False,False,False \
39+
True,True,False,False \
40+
True,False,False,True \
41+
True,True,False,True
3442
do
35-
IFS=","; set -- $i;
36-
IS_FP8=$1; COMPILE=$2; FULLGRAPH=$3
43+
# Split the string into variables
44+
IFS=","
45+
set -- $i
46+
47+
# Assign each variable to a more descriptive name
48+
IS_FP8=$1
49+
COMPILE=$2
50+
FULLGRAPH=$3
51+
RECOMPUTE=$4
52+
53+
# Launch the test with the current settings
3754
launch
3855
done

0 commit comments

Comments
 (0)