53
53
54
54
OPTIM_MAP .update (
55
55
AdamW4bitLpmm = partial (lpmm .optim .AdamW , fused = True ),
56
- AdamW4bitRank1Lpmm = partial (lpmm .optim .AdamW , qconfig = argparse .Namespace (scale_type = "rank1" )),
56
+ AdamW4bitRank1Lpmm = partial (
57
+ lpmm .optim .AdamW , qconfig = argparse .Namespace (scale_type = "rank1" )
58
+ ),
57
59
)
58
60
59
61
except ImportError :
@@ -71,8 +73,12 @@ def get_lr(self, step: int) -> float:
71
73
if step < self .warmup_steps :
72
74
return self .lr * step / self .warmup_steps
73
75
if step < self .total_steps :
74
- progress = (step - self .warmup_steps ) / (self .total_steps - self .warmup_steps )
75
- return self .final_lr + 0.5 * (self .lr - self .final_lr ) * (1 + math .cos (progress * math .pi ))
76
+ progress = (step - self .warmup_steps ) / (
77
+ self .total_steps - self .warmup_steps
78
+ )
79
+ return self .final_lr + 0.5 * (self .lr - self .final_lr ) * (
80
+ 1 + math .cos (progress * math .pi )
81
+ )
76
82
return self .final_lr
77
83
78
84
@@ -96,7 +102,9 @@ def get_parser():
96
102
parser .add_argument ("--weight_decay" , type = float , default = 0 )
97
103
parser .add_argument ("--optim_kwargs" , type = json .loads , default = dict ())
98
104
parser .add_argument ("--cosine_lr_scheduler" , action = "store_true" )
99
- parser .add_argument ("--optim_cpu_offload" , choices = ["ao" , "ao_offload_grads" , "deepspeed" ])
105
+ parser .add_argument (
106
+ "--optim_cpu_offload" , choices = ["ao" , "ao_offload_grads" , "deepspeed" ]
107
+ )
100
108
101
109
parser .add_argument ("--project" )
102
110
parser .add_argument ("--run_name" , default = "debug" )
@@ -114,11 +122,15 @@ def get_dloader(args, training: bool):
114
122
transforms .extend ([v2 .Resize (256 ), v2 .CenterCrop (224 )])
115
123
116
124
transforms .append (v2 .ToDtype (torch .float32 , scale = True ))
117
- transforms .append (v2 .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]))
125
+ transforms .append (
126
+ v2 .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ])
127
+ )
118
128
transforms = v2 .Compose (transforms )
119
129
120
130
# use dataset from HF so download is fast
121
- ds = datasets .load_dataset ("timm/resisc45" , split = "train" if training else "validation" )
131
+ ds = datasets .load_dataset (
132
+ "timm/resisc45" , split = "train" if training else "validation"
133
+ )
122
134
ds = ds .select_columns (["image" , "label" ])
123
135
ds .set_transform (lambda x : dict (image = transforms (x ["image" ]), label = x ["label" ]))
124
136
@@ -168,8 +180,12 @@ def evaluate_model(model, args):
168
180
if args .full_bf16 :
169
181
assert args .amp == "none" , "When --full_bf16 is set, --amp must be none"
170
182
if args .optim_cpu_offload == "deepspeed" :
171
- assert args .amp == "none" , "When using DeepSpeed ZeRO-Offload, --amp must be none"
172
- assert args .optim == "AdamW" , "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
183
+ assert (
184
+ args .amp == "none"
185
+ ), "When using DeepSpeed ZeRO-Offload, --amp must be none"
186
+ assert (
187
+ args .optim == "AdamW"
188
+ ), "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
173
189
if args .profile :
174
190
args .n_epochs = 1
175
191
if args .seed is not None :
@@ -189,7 +205,9 @@ def evaluate_model(model, args):
189
205
dloader = get_dloader (args , True )
190
206
print (f"Train dataset: { len (dloader .dataset ):,} images" )
191
207
192
- model = timm .create_model (args .model , pretrained = True , num_classes = 45 , ** args .model_kwargs )
208
+ model = timm .create_model (
209
+ args .model , pretrained = True , num_classes = 45 , ** args .model_kwargs
210
+ )
193
211
if args .checkpoint_activations :
194
212
model .set_grad_checkpointing ()
195
213
if args .full_bf16 :
@@ -231,9 +249,15 @@ def evaluate_model(model, args):
231
249
optim_cls = OPTIM_MAP [args .optim ]
232
250
233
251
if args .optim_cpu_offload == "ao" :
234
- optim_cls = partial (low_bit_optim .CPUOffloadOptimizer , optimizer_class = optim_cls )
252
+ optim_cls = partial (
253
+ low_bit_optim .CPUOffloadOptimizer , optimizer_class = optim_cls
254
+ )
235
255
elif args .optim_cpu_offload == "ao_offload_grads" :
236
- optim_cls = partial (low_bit_optim .CPUOffloadOptimizer , optimizer_class = optim_cls , offload_gradients = True )
256
+ optim_cls = partial (
257
+ low_bit_optim .CPUOffloadOptimizer ,
258
+ optimizer_class = optim_cls ,
259
+ offload_gradients = True ,
260
+ )
237
261
238
262
optim = optim_cls (
239
263
model .parameters (),
@@ -250,17 +274,23 @@ def evaluate_model(model, args):
250
274
step = 0
251
275
for epoch_idx in range (args .n_epochs ):
252
276
model .train ()
253
- pbar = tqdm (dloader , dynamic_ncols = True , desc = f"Epoch { epoch_idx + 1 } /{ args .n_epochs } " )
277
+ pbar = tqdm (
278
+ dloader , dynamic_ncols = True , desc = f"Epoch { epoch_idx + 1 } /{ args .n_epochs } "
279
+ )
254
280
255
281
with torch .profiler .profile () if args .profile else nullcontext () as prof :
256
282
for batch in pbar :
257
283
if args .full_bf16 :
258
284
batch ["image" ] = batch ["image" ].bfloat16 ()
259
285
if args .channels_last :
260
- batch ["image" ] = batch ["image" ].to (memory_format = torch .channels_last )
286
+ batch ["image" ] = batch ["image" ].to (
287
+ memory_format = torch .channels_last
288
+ )
261
289
262
290
with get_amp_ctx (args .amp , _DEVICE ):
263
- loss = F .cross_entropy (model (batch ["image" ].to (_DEVICE )), batch ["label" ].to (_DEVICE ))
291
+ loss = F .cross_entropy (
292
+ model (batch ["image" ].to (_DEVICE )), batch ["label" ].to (_DEVICE )
293
+ )
264
294
265
295
if args .optim_cpu_offload == "deepspeed" :
266
296
model .backward (loss )
@@ -279,7 +309,9 @@ def evaluate_model(model, args):
279
309
log_dict = dict (loss = loss .item (), lr = optim .param_groups [0 ]["lr" ])
280
310
if step > 0 :
281
311
t1 = time .perf_counter ()
282
- log_dict ["imgs_per_second" ] = args .batch_size * log_interval / (t1 - t0 )
312
+ log_dict ["imgs_per_second" ] = (
313
+ args .batch_size * log_interval / (t1 - t0 )
314
+ )
283
315
t0 = t1
284
316
logger .log (log_dict , step = step )
285
317
@@ -300,7 +332,9 @@ def evaluate_model(model, args):
300
332
301
333
else :
302
334
val_acc = evaluate_model (model , args )
303
- print (f"Epoch { epoch_idx + 1 } /{ args .n_epochs } : val_acc={ val_acc .item () * 100 :.2f} " )
335
+ print (
336
+ f"Epoch { epoch_idx + 1 } /{ args .n_epochs } : val_acc={ val_acc .item () * 100 :.2f} "
337
+ )
304
338
logger .log (dict (val_acc = val_acc ), step = step )
305
339
306
340
peak_mem = getattr (torch , _DEVICE ).max_memory_allocated () / 1e9
0 commit comments