4
4
# - lpmm (4-bit optim): pip install yacs git+https://github.com/thu-ml/low-bit-optimizers.git
5
5
# - DeepSpeed (ZeRO-Offload):
6
6
# sudo apt install libopenmpi-dev
7
- # LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4p
7
+ # LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4py
8
8
# DS_BUILD_CPU_ADAM=1 pip install deepspeed --no-cache-dir
9
9
#
10
10
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default AdamW optimizer from PyTorch core
@@ -98,6 +98,7 @@ def get_parser():
98
98
parser .add_argument ("--run_name" , default = "debug" )
99
99
parser .add_argument ("--profile" , action = "store_true" )
100
100
parser .add_argument ("--seed" , type = int )
101
+ parser .add_argument ("--device" , type = str , choices = ["cuda" , "xpu" ], default = "cuda" )
101
102
return parser
102
103
103
104
@@ -128,9 +129,9 @@ def get_dloader(args, training: bool):
128
129
)
129
130
130
131
131
- def get_amp_ctx (amp ):
132
+ def get_amp_ctx (amp , device ):
132
133
dtype = dict (bf16 = torch .bfloat16 , fp16 = torch .float16 , none = None )[amp ]
133
- return torch .autocast ("cuda" , dtype = dtype , enabled = amp != "none" )
134
+ return torch .autocast (device , dtype = dtype , enabled = amp != "none" )
134
135
135
136
136
137
@torch .no_grad ()
@@ -148,8 +149,8 @@ def evaluate_model(model, args):
148
149
if args .channels_last :
149
150
batch ["image" ] = batch ["image" ].to (memory_format = torch .channels_last )
150
151
151
- with get_amp_ctx (args .amp ):
152
- all_preds .append (model (batch ["image" ].cuda ( )).argmax (1 ).cpu ())
152
+ with get_amp_ctx (args .amp , args . device ):
153
+ all_preds .append (model (batch ["image" ].to ( args . device )).argmax (1 ).cpu ())
153
154
154
155
all_labels = torch .cat (all_labels , dim = 0 )
155
156
all_preds = torch .cat (all_preds , dim = 0 )
@@ -192,7 +193,7 @@ def evaluate_model(model, args):
192
193
model .bfloat16 ()
193
194
if args .channels_last :
194
195
model .to (memory_format = torch .channels_last )
195
- model .cuda ( ) # move model to CUDA after optionally convert it to BF16
196
+ model .to ( args . device ) # move model to DEVICE after optionally convert it to BF16
196
197
if args .compile :
197
198
model .compile (fullgraph = True )
198
199
print (f"Model parameters: { sum (p .numel () for p in model .parameters ()):,} " )
@@ -227,9 +228,9 @@ def evaluate_model(model, args):
227
228
optim_cls = OPTIM_MAP [args .optim ]
228
229
229
230
if args .optim_cpu_offload == "ao" :
230
- optim_cls = partial (low_bit_optim .CPUOffloadOptimizer , optimizer_class = optim_cls )
231
+ optim_cls = partial (low_bit_optim .CPUOffloadOptimizer , optimizer_class = optim_cls , device = args . device )
231
232
elif args .optim_cpu_offload == "ao_offload_grads" :
232
- optim_cls = partial (low_bit_optim .CPUOffloadOptimizer , optimizer_class = optim_cls , offload_gradients = True )
233
+ optim_cls = partial (low_bit_optim .CPUOffloadOptimizer , optimizer_class = optim_cls , offload_gradients = True , device = args . device )
233
234
234
235
optim = optim_cls (
235
236
model .parameters (),
@@ -239,7 +240,7 @@ def evaluate_model(model, args):
239
240
)
240
241
241
242
lr_schedule = CosineSchedule (args .lr , len (dloader ) * args .n_epochs )
242
- grad_scaler = torch .amp .GradScaler ("cuda" , enabled = args .amp == "fp16" )
243
+ grad_scaler = torch .amp .GradScaler (args . device , enabled = args .amp == "fp16" )
243
244
log_interval = 10
244
245
t0 = time .perf_counter ()
245
246
@@ -248,15 +249,26 @@ def evaluate_model(model, args):
248
249
model .train ()
249
250
pbar = tqdm (dloader , dynamic_ncols = True , desc = f"Epoch { epoch_idx + 1 } /{ args .n_epochs } " )
250
251
251
- with torch .profiler .profile () if args .profile else nullcontext () as prof :
252
+ if args .profile :
253
+ activities = [torch .profiler .ProfilerActivity .CPU ]
254
+ if args .device == "cuda" :
255
+ activities .append (torch .profiler .ProfilerActivity .CUDA )
256
+ elif args .device == "xpu" :
257
+ activities .append (torch .profiler .ProfilerActivity .XPU )
258
+
259
+ prof = torch .profiler .profile (activities = activities )
260
+ else :
261
+ prof = nullcontext ()
262
+
263
+ with prof :
252
264
for batch in pbar :
253
265
if args .full_bf16 :
254
266
batch ["image" ] = batch ["image" ].bfloat16 ()
255
267
if args .channels_last :
256
268
batch ["image" ] = batch ["image" ].to (memory_format = torch .channels_last )
257
269
258
- with get_amp_ctx (args .amp ):
259
- loss = F .cross_entropy (model (batch ["image" ].cuda ( )), batch ["label" ].cuda ( ))
270
+ with get_amp_ctx (args .amp , args . device ):
271
+ loss = F .cross_entropy (model (batch ["image" ].to ( args . device )), batch ["label" ].to ( args . device ))
260
272
261
273
if args .optim_cpu_offload == "deepspeed" :
262
274
model .backward (loss )
@@ -299,6 +311,9 @@ def evaluate_model(model, args):
299
311
print (f"Epoch { epoch_idx + 1 } /{ args .n_epochs } : val_acc={ val_acc .item () * 100 :.2f} " )
300
312
logger .log (dict (val_acc = val_acc ), step = step )
301
313
302
- peak_mem = torch .cuda .max_memory_allocated () / 1e9
314
+ if args .device == "cuda" :
315
+ peak_mem = torch .cuda .max_memory_allocated () / 1e9
316
+ elif args .device == "xpu" :
317
+ peak_mem = torch .xpu .max_memory_allocated () / 1e9
303
318
print (f"Max memory used: { peak_mem :.02f} GB" )
304
319
logger .log (dict (max_memory_allocated = peak_mem ))
0 commit comments