Skip to content

Commit 755414f

Browse files
committed
test(cpu-offload): enable benchmark_low_bit_adam for XPU
Signed-off-by: dbyoung18 <yang5.yang@intel.com>
1 parent 042ca1f commit 755414f

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed

benchmarks/benchmark_low_bit_adam.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# - lpmm (4-bit optim): pip install yacs git+https://github.com/thu-ml/low-bit-optimizers.git
55
# - DeepSpeed (ZeRO-Offload):
66
# sudo apt install libopenmpi-dev
7-
# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4p
7+
# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4py
88
# DS_BUILD_CPU_ADAM=1 pip install deepspeed --no-cache-dir
99
#
1010
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default AdamW optimizer from PyTorch core
@@ -98,6 +98,7 @@ def get_parser():
9898
parser.add_argument("--run_name", default="debug")
9999
parser.add_argument("--profile", action="store_true")
100100
parser.add_argument("--seed", type=int)
101+
parser.add_argument("--device", type=str, choices=["cuda", "xpu"], default="cuda")
101102
return parser
102103

103104

@@ -128,9 +129,9 @@ def get_dloader(args, training: bool):
128129
)
129130

130131

131-
def get_amp_ctx(amp):
132+
def get_amp_ctx(amp, device):
132133
dtype = dict(bf16=torch.bfloat16, fp16=torch.float16, none=None)[amp]
133-
return torch.autocast("cuda", dtype=dtype, enabled=amp != "none")
134+
return torch.autocast(device, dtype=dtype, enabled=amp != "none")
134135

135136

136137
@torch.no_grad()
@@ -148,8 +149,8 @@ def evaluate_model(model, args):
148149
if args.channels_last:
149150
batch["image"] = batch["image"].to(memory_format=torch.channels_last)
150151

151-
with get_amp_ctx(args.amp):
152-
all_preds.append(model(batch["image"].cuda()).argmax(1).cpu())
152+
with get_amp_ctx(args.amp, args.device):
153+
all_preds.append(model(batch["image"].to(args.device)).argmax(1).cpu())
153154

154155
all_labels = torch.cat(all_labels, dim=0)
155156
all_preds = torch.cat(all_preds, dim=0)
@@ -192,7 +193,7 @@ def evaluate_model(model, args):
192193
model.bfloat16()
193194
if args.channels_last:
194195
model.to(memory_format=torch.channels_last)
195-
model.cuda() # move model to CUDA after optionally convert it to BF16
196+
model.to(args.device) # move model to DEVICE after optionally convert it to BF16
196197
if args.compile:
197198
model.compile(fullgraph=True)
198199
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
@@ -227,9 +228,9 @@ def evaluate_model(model, args):
227228
optim_cls = OPTIM_MAP[args.optim]
228229

229230
if args.optim_cpu_offload == "ao":
230-
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls)
231+
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, device=args.device)
231232
elif args.optim_cpu_offload == "ao_offload_grads":
232-
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True)
233+
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True, device=args.device)
233234

234235
optim = optim_cls(
235236
model.parameters(),
@@ -239,7 +240,7 @@ def evaluate_model(model, args):
239240
)
240241

241242
lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
242-
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")
243+
grad_scaler = torch.amp.GradScaler(args.device, enabled=args.amp == "fp16")
243244
log_interval = 10
244245
t0 = time.perf_counter()
245246

@@ -248,15 +249,26 @@ def evaluate_model(model, args):
248249
model.train()
249250
pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}")
250251

251-
with torch.profiler.profile() if args.profile else nullcontext() as prof:
252+
if args.profile:
253+
activities = [torch.profiler.ProfilerActivity.CPU]
254+
if args.device == "cuda":
255+
activities.append(torch.profiler.ProfilerActivity.CUDA)
256+
elif args.device == "xpu":
257+
activities.append(torch.profiler.ProfilerActivity.XPU)
258+
259+
prof = torch.profiler.profile(activities=activities)
260+
else:
261+
prof = nullcontext()
262+
263+
with prof:
252264
for batch in pbar:
253265
if args.full_bf16:
254266
batch["image"] = batch["image"].bfloat16()
255267
if args.channels_last:
256268
batch["image"] = batch["image"].to(memory_format=torch.channels_last)
257269

258-
with get_amp_ctx(args.amp):
259-
loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda())
270+
with get_amp_ctx(args.amp, args.device):
271+
loss = F.cross_entropy(model(batch["image"].to(args.device)), batch["label"].to(args.device))
260272

261273
if args.optim_cpu_offload == "deepspeed":
262274
model.backward(loss)
@@ -299,6 +311,9 @@ def evaluate_model(model, args):
299311
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
300312
logger.log(dict(val_acc=val_acc), step=step)
301313

302-
peak_mem = torch.cuda.max_memory_allocated() / 1e9
314+
if args.device == "cuda":
315+
peak_mem = torch.cuda.max_memory_allocated() / 1e9
316+
elif args.device == "xpu":
317+
peak_mem = torch.xpu.max_memory_allocated() / 1e9
303318
print(f"Max memory used: {peak_mem:.02f} GB")
304319
logger.log(dict(max_memory_allocated=peak_mem))

0 commit comments

Comments
 (0)