@@ -61,15 +61,22 @@ def cleanup():
61
61
dist .destroy_process_group ()
62
62
63
63
64
- def get_model (K , N , is_fp8 , emulate , base_dtype = torch .float32 ):
64
+ def get_model (
65
+ K , N , is_fp8 , emulate , base_dtype = torch .float32 , recompute_weight_cast : bool = False
66
+ ):
65
67
m = nn .Sequential (
66
68
nn .Linear (K , N , dtype = base_dtype ),
67
69
nn .ReLU (),
68
70
nn .Linear (N , N , dtype = base_dtype ),
69
71
nn .ReLU (),
70
72
)
71
73
if is_fp8 :
72
- swap_linear_with_float8_linear (m , Float8Linear , emulate = emulate )
74
+ swap_linear_with_float8_linear (
75
+ m ,
76
+ Float8Linear ,
77
+ emulate = emulate ,
78
+ recompute_weight_cast = recompute_weight_cast ,
79
+ )
73
80
return m
74
81
75
82
@@ -81,10 +88,15 @@ def fsdp_main(rank, world_size, args):
81
88
82
89
# TODO: We set fullgraph as an option. However, it currently doesn't work for fullgraph compile.
83
90
# We can investigate and fix it later.
84
- is_fp8 , emulate , base_dtype , compile , fullgraph = args
85
- model = get_model (K , N , is_fp8 = is_fp8 , emulate = emulate , base_dtype = base_dtype ).to (
86
- rank
87
- )
91
+ is_fp8 , emulate , base_dtype , compile , fullgraph , recompute_weight_cast = args
92
+ model = get_model (
93
+ K ,
94
+ N ,
95
+ is_fp8 = is_fp8 ,
96
+ emulate = emulate ,
97
+ base_dtype = base_dtype ,
98
+ recompute_weight_cast = recompute_weight_cast ,
99
+ ).to (rank )
88
100
model .load_state_dict (torch .load (sd_in_fname ))
89
101
# To compile FSDP, we need use_orig_params to True
90
102
model = FSDP (model , use_orig_params = True )
@@ -148,7 +160,13 @@ def forward_backward(model):
148
160
cleanup ()
149
161
150
162
151
- def run (mode : str , is_fp8 : bool , compile_fsdp : bool = False , fullgraph : bool = False ):
163
+ def run (
164
+ mode : str ,
165
+ is_fp8 : bool ,
166
+ compile_fsdp : bool = False ,
167
+ fullgraph : bool = False ,
168
+ recompute_weight_cast : bool = False ,
169
+ ):
152
170
print (f"Mode: { mode } " .center (100 , "-" ))
153
171
base_dtype = torch .bfloat16
154
172
if not os .path .exists (data_dir ):
@@ -169,15 +187,25 @@ def run(mode: str, is_fp8: bool, compile_fsdp: bool = False, fullgraph: bool = F
169
187
# generate reference input
170
188
ref_input = torch .randn (B , M , K ).cuda ().to (base_dtype )
171
189
model = get_model (
172
- K , N , is_fp8 = is_fp8 , emulate = emulate , base_dtype = base_dtype
190
+ K ,
191
+ N ,
192
+ is_fp8 = is_fp8 ,
193
+ emulate = emulate ,
194
+ base_dtype = base_dtype ,
195
+ recompute_weight_cast = recompute_weight_cast ,
173
196
).cuda ()
174
197
torch .save (ref_input , input_fname )
175
198
torch .save (model .state_dict (), sd_in_fname )
176
199
177
200
elif mode == "single_gpu" :
178
201
ref_input = torch .load (input_fname ).to (base_dtype )
179
202
model = get_model (
180
- K , N , is_fp8 = is_fp8 , emulate = emulate , base_dtype = base_dtype
203
+ K ,
204
+ N ,
205
+ is_fp8 = is_fp8 ,
206
+ emulate = emulate ,
207
+ base_dtype = base_dtype ,
208
+ recompute_weight_cast = recompute_weight_cast ,
181
209
).cuda ()
182
210
model .load_state_dict (torch .load (sd_in_fname ))
183
211
optimizer = torch .optim .SGD (model .parameters (), lr = lr )
@@ -199,7 +227,14 @@ def forward_backward():
199
227
elif mode == "fsdp" :
200
228
WORLD_SIZE = torch .cuda .device_count ()
201
229
# We only compile for fsdp, and compare the numerics with signle-gpu no-compile
202
- args = (is_fp8 , emulate , base_dtype , compile_fsdp , fullgraph )
230
+ args = (
231
+ is_fp8 ,
232
+ emulate ,
233
+ base_dtype ,
234
+ compile_fsdp ,
235
+ fullgraph ,
236
+ recompute_weight_cast ,
237
+ )
203
238
mp .spawn (fsdp_main , args = (WORLD_SIZE , args ), nprocs = WORLD_SIZE , join = True )
204
239
205
240
elif mode == "analyze" :
0 commit comments