Skip to content

Commit 6ff3904

Browse files
authored
Enable CPU Offload for Intel GPU (#1324)
* feat(cpu-offload): enable CPU Offload for XPU Signed-off-by: dbyoung18 <yang5.yang@intel.com> * test(cpu-offload): enable benchmark_low_bit_adam for XPU Signed-off-by: dbyoung18 <yang5.yang@intel.com> * fix(cpu-offload): auto-detect ProfilerActivity Signed-off-by: dbyoung18 <yang5.yang@intel.com> * fix(cpu-offload): replace if-else w/ getattr for device API calls Signed-off-by: dbyoung18 <yang5.yang@intel.com> * fix(cpu-offload): add auto-detect available devices to utils Signed-off-by: dbyoung18 <yang5.yang@intel.com> * fix(cpu-offload): improve auto-detect ProfilerActivity Signed-off-by: dbyoung18 <yang5.yang@intel.com> * fix(cpu-offload): improve device assert Signed-off-by: dbyoung18 <yang5.yang@intel.com> * fix(cpu-offload): fix auto-detect mps Signed-off-by: dbyoung18 <yang5.yang@intel.com> * fix(cpu-offload): fix import order Signed-off-by: dbyoung18 <yang5.yang@intel.com> * refactor(cpu-offload): use ruff format Signed-off-by: dbyoung18 <yang5.yang@intel.com> * doc(cpu-offload): modify README to cover XPU Signed-off-by: dbyoung18 <yang5.yang@intel.com> --------- Signed-off-by: dbyoung18 <yang5.yang@intel.com>
1 parent 6312329 commit 6ff3904

File tree

5 files changed

+151
-72
lines changed

5 files changed

+151
-72
lines changed

benchmarks/benchmark_low_bit_adam.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# - lpmm (4-bit optim): pip install yacs git+https://github.com/thu-ml/low-bit-optimizers.git
55
# - DeepSpeed (ZeRO-Offload):
66
# sudo apt install libopenmpi-dev
7-
# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4p
7+
# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4py
88
# DS_BUILD_CPU_ADAM=1 pip install deepspeed --no-cache-dir
99
#
1010
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default AdamW optimizer from PyTorch core
@@ -31,11 +31,15 @@
3131
import torch.nn.functional as F
3232
import wandb
3333
from torch.utils.data import DataLoader
34+
from torchao.utils import get_available_devices
3435
from torchvision.transforms import v2
3536
from tqdm import tqdm
3637

3738
from torchao.prototype import low_bit_optim
3839

40+
_DEVICE = get_available_devices()[-1]
41+
assert _DEVICE in ["cuda", "xpu"], "Benchmark currently only supports CUDA & XPU(BF16)"
42+
3943
OPTIM_MAP = dict(
4044
AdamW=partial(torch.optim.AdamW, fused=True),
4145
AdamW8bitBnb=bnb.optim.AdamW8bit,
@@ -49,7 +53,9 @@
4953

5054
OPTIM_MAP.update(
5155
AdamW4bitLpmm=partial(lpmm.optim.AdamW, fused=True),
52-
AdamW4bitRank1Lpmm=partial(lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")),
56+
AdamW4bitRank1Lpmm=partial(
57+
lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")
58+
),
5359
)
5460

5561
except ImportError:
@@ -67,8 +73,12 @@ def get_lr(self, step: int) -> float:
6773
if step < self.warmup_steps:
6874
return self.lr * step / self.warmup_steps
6975
if step < self.total_steps:
70-
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
71-
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (1 + math.cos(progress * math.pi))
76+
progress = (step - self.warmup_steps) / (
77+
self.total_steps - self.warmup_steps
78+
)
79+
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (
80+
1 + math.cos(progress * math.pi)
81+
)
7282
return self.final_lr
7383

7484

@@ -92,7 +102,9 @@ def get_parser():
92102
parser.add_argument("--weight_decay", type=float, default=0)
93103
parser.add_argument("--optim_kwargs", type=json.loads, default=dict())
94104
parser.add_argument("--cosine_lr_scheduler", action="store_true")
95-
parser.add_argument("--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"])
105+
parser.add_argument(
106+
"--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"]
107+
)
96108

97109
parser.add_argument("--project")
98110
parser.add_argument("--run_name", default="debug")
@@ -110,11 +122,15 @@ def get_dloader(args, training: bool):
110122
transforms.extend([v2.Resize(256), v2.CenterCrop(224)])
111123

112124
transforms.append(v2.ToDtype(torch.float32, scale=True))
113-
transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
125+
transforms.append(
126+
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
127+
)
114128
transforms = v2.Compose(transforms)
115129

116130
# use dataset from HF so download is fast
117-
ds = datasets.load_dataset("timm/resisc45", split="train" if training else "validation")
131+
ds = datasets.load_dataset(
132+
"timm/resisc45", split="train" if training else "validation"
133+
)
118134
ds = ds.select_columns(["image", "label"])
119135
ds.set_transform(lambda x: dict(image=transforms(x["image"]), label=x["label"]))
120136

@@ -128,9 +144,9 @@ def get_dloader(args, training: bool):
128144
)
129145

130146

131-
def get_amp_ctx(amp):
147+
def get_amp_ctx(amp, device):
132148
dtype = dict(bf16=torch.bfloat16, fp16=torch.float16, none=None)[amp]
133-
return torch.autocast("cuda", dtype=dtype, enabled=amp != "none")
149+
return torch.autocast(device, dtype=dtype, enabled=amp != "none")
134150

135151

136152
@torch.no_grad()
@@ -148,8 +164,8 @@ def evaluate_model(model, args):
148164
if args.channels_last:
149165
batch["image"] = batch["image"].to(memory_format=torch.channels_last)
150166

151-
with get_amp_ctx(args.amp):
152-
all_preds.append(model(batch["image"].cuda()).argmax(1).cpu())
167+
with get_amp_ctx(args.amp, _DEVICE):
168+
all_preds.append(model(batch["image"].to(_DEVICE)).argmax(1).cpu())
153169

154170
all_labels = torch.cat(all_labels, dim=0)
155171
all_preds = torch.cat(all_preds, dim=0)
@@ -164,8 +180,12 @@ def evaluate_model(model, args):
164180
if args.full_bf16:
165181
assert args.amp == "none", "When --full_bf16 is set, --amp must be none"
166182
if args.optim_cpu_offload == "deepspeed":
167-
assert args.amp == "none", "When using DeepSpeed ZeRO-Offload, --amp must be none"
168-
assert args.optim == "AdamW", "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
183+
assert (
184+
args.amp == "none"
185+
), "When using DeepSpeed ZeRO-Offload, --amp must be none"
186+
assert (
187+
args.optim == "AdamW"
188+
), "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
169189
if args.profile:
170190
args.n_epochs = 1
171191
if args.seed is not None:
@@ -185,14 +205,16 @@ def evaluate_model(model, args):
185205
dloader = get_dloader(args, True)
186206
print(f"Train dataset: {len(dloader.dataset):,} images")
187207

188-
model = timm.create_model(args.model, pretrained=True, num_classes=45, **args.model_kwargs)
208+
model = timm.create_model(
209+
args.model, pretrained=True, num_classes=45, **args.model_kwargs
210+
)
189211
if args.checkpoint_activations:
190212
model.set_grad_checkpointing()
191213
if args.full_bf16:
192214
model.bfloat16()
193215
if args.channels_last:
194216
model.to(memory_format=torch.channels_last)
195-
model.cuda() # move model to CUDA after optionally convert it to BF16
217+
model.to(_DEVICE) # move model to DEVICE after optionally convert it to BF16
196218
if args.compile:
197219
model.compile(fullgraph=True)
198220
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
@@ -227,9 +249,15 @@ def evaluate_model(model, args):
227249
optim_cls = OPTIM_MAP[args.optim]
228250

229251
if args.optim_cpu_offload == "ao":
230-
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls)
252+
optim_cls = partial(
253+
low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls
254+
)
231255
elif args.optim_cpu_offload == "ao_offload_grads":
232-
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True)
256+
optim_cls = partial(
257+
low_bit_optim.CPUOffloadOptimizer,
258+
optimizer_class=optim_cls,
259+
offload_gradients=True,
260+
)
233261

234262
optim = optim_cls(
235263
model.parameters(),
@@ -239,24 +267,30 @@ def evaluate_model(model, args):
239267
)
240268

241269
lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
242-
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")
270+
grad_scaler = torch.amp.GradScaler(_DEVICE, enabled=args.amp == "fp16")
243271
log_interval = 10
244272
t0 = time.perf_counter()
245273

246274
step = 0
247275
for epoch_idx in range(args.n_epochs):
248276
model.train()
249-
pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}")
277+
pbar = tqdm(
278+
dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"
279+
)
250280

251281
with torch.profiler.profile() if args.profile else nullcontext() as prof:
252282
for batch in pbar:
253283
if args.full_bf16:
254284
batch["image"] = batch["image"].bfloat16()
255285
if args.channels_last:
256-
batch["image"] = batch["image"].to(memory_format=torch.channels_last)
286+
batch["image"] = batch["image"].to(
287+
memory_format=torch.channels_last
288+
)
257289

258-
with get_amp_ctx(args.amp):
259-
loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda())
290+
with get_amp_ctx(args.amp, _DEVICE):
291+
loss = F.cross_entropy(
292+
model(batch["image"].to(_DEVICE)), batch["label"].to(_DEVICE)
293+
)
260294

261295
if args.optim_cpu_offload == "deepspeed":
262296
model.backward(loss)
@@ -275,7 +309,9 @@ def evaluate_model(model, args):
275309
log_dict = dict(loss=loss.item(), lr=optim.param_groups[0]["lr"])
276310
if step > 0:
277311
t1 = time.perf_counter()
278-
log_dict["imgs_per_second"] = args.batch_size * log_interval / (t1 - t0)
312+
log_dict["imgs_per_second"] = (
313+
args.batch_size * log_interval / (t1 - t0)
314+
)
279315
t0 = t1
280316
logger.log(log_dict, step=step)
281317

@@ -296,9 +332,11 @@ def evaluate_model(model, args):
296332

297333
else:
298334
val_acc = evaluate_model(model, args)
299-
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
335+
print(
336+
f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}"
337+
)
300338
logger.log(dict(val_acc=val_acc), step=step)
301339

302-
peak_mem = torch.cuda.max_memory_allocated() / 1e9
340+
peak_mem = getattr(torch, _DEVICE).max_memory_allocated() / 1e9
303341
print(f"Max memory used: {peak_mem:.02f} GB")
304342
logger.log(dict(max_memory_allocated=peak_mem))

test/prototype/test_low_bit_optim.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
2727
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8
2828
from torchao.utils import (
29+
get_available_devices,
2930
TORCH_VERSION_AT_LEAST_2_4,
3031
TORCH_VERSION_AT_LEAST_2_5,
3132
TORCH_VERSION_AT_LEAST_2_6,
@@ -42,7 +43,7 @@
4243
lpmm = None
4344

4445

45-
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
46+
_DEVICES = get_available_devices()
4647

4748

4849
class TestQuantize(TestCase):
@@ -94,7 +95,9 @@ def test_bf16_stochastic_round(self, device, compile):
9495
x = torch.rand(32, device=device) * 100
9596
x_rep = x.view(-1, 1).repeat(1, 100_000)
9697

97-
func = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile)
98+
func = torch.compile(
99+
_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile
100+
)
98101
x_rep_bf16 = func(x_rep)
99102
assert x_rep_bf16.dtype is torch.bfloat16
100103

@@ -169,8 +172,13 @@ def test_subclass_slice(self, subclass, shape, device):
169172
tensor = subclass.zeros(shape, device=device)
170173
offset = shape[0] // 2
171174

172-
torch.testing.assert_close(tensor.dequantize()[:offset], tensor[:offset].dequantize())
173-
torch.testing.assert_close(tensor.dequantize()[offset:offset*2], tensor[offset:offset*2].dequantize())
175+
torch.testing.assert_close(
176+
tensor.dequantize()[:offset], tensor[:offset].dequantize()
177+
)
178+
torch.testing.assert_close(
179+
tensor.dequantize()[offset : offset * 2],
180+
tensor[offset : offset * 2].dequantize(),
181+
)
174182

175183
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available")
176184
@pytest.mark.skipif(
@@ -188,7 +196,9 @@ def test_optim_8bit_correctness(self, optim_name):
188196
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048
189197

190198
optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
191-
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)
199+
optim2 = getattr(low_bit_optim, optim_name)(
200+
model2.parameters(), block_size=block_size
201+
)
192202

193203
for _ in range(2):
194204
x = torch.randn(4, 32, device=device)
@@ -244,11 +254,12 @@ def test_optim_4bit_correctness(self, optim_name):
244254
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)
245255

246256
@pytest.mark.skipif(
247-
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
257+
not torch.cuda.is_available() and not torch.xpu.is_available(),
258+
reason="optim CPU offload requires CUDA or XPU",
248259
)
249260
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
250261
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
251-
device = "cuda"
262+
device = _DEVICES[-1]
252263
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
253264
model1.to(device)
254265

@@ -279,13 +290,16 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
279290
torch.testing.assert_close(p2, p1)
280291

281292
@pytest.mark.skipif(
282-
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
293+
not torch.cuda.is_available() and not torch.xpu.is_available(),
294+
reason="optim CPU offload requires CUDA or XPU",
283295
)
284296
def test_optim_cpu_offload_save_load(self):
285-
device = "cuda"
297+
device = _DEVICES[-1]
286298
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
287299
model1.to(device)
288-
optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)
300+
optim1 = low_bit_optim.CPUOffloadOptimizer(
301+
model1.parameters(), torch.optim.AdamW
302+
)
289303

290304
for _ in range(2):
291305
x = torch.randn(4, 32, device=device)
@@ -300,7 +314,9 @@ def test_optim_cpu_offload_save_load(self):
300314

301315
# resume training
302316
model2 = copy.deepcopy(model1)
303-
optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
317+
optim2 = low_bit_optim.CPUOffloadOptimizer(
318+
model2.parameters(), torch.optim.AdamW
319+
)
304320
optim2.load_state_dict(state_dict)
305321

306322
for _ in range(2):
@@ -384,7 +400,11 @@ def _test_fsdp2(self, optim_cls):
384400
import torch.utils._pytree as pytree
385401
from torch.distributed._composable.fsdp import fully_shard
386402
from torch.distributed.tensor import DTensor
387-
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer, TransformerBlock
403+
from torch.testing._internal.distributed._tensor.common_dtensor import (
404+
ModelArgs,
405+
Transformer,
406+
TransformerBlock,
407+
)
388408

389409
batch_size = 3
390410
vocab_size = 1024
@@ -457,7 +477,10 @@ def _test_fsdp2(self, optim_cls):
457477

458478
subclasses = (OptimState4bit, OptimState8bit, OptimStateFp8)
459479

460-
for v1, v2 in zip(pytree.tree_iter(resumed_fsdp_optim.state_dict()), pytree.tree_iter(fsdp_optim.state_dict())):
480+
for v1, v2 in zip(
481+
pytree.tree_iter(resumed_fsdp_optim.state_dict()),
482+
pytree.tree_iter(fsdp_optim.state_dict()),
483+
):
461484
assert v1.__class__ == v2.__class__, (v1.__class__, v2.__class__)
462485
if isinstance(v1, DTensor):
463486
v1 = v1.to_local()

torchao/prototype/low_bit_optim/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ All of our low-bit optimizers mentioned above also support `bf16_stochastic_roun
8080

8181
## Optimizer CPU offload
8282

83-
This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. Only CUDA is supported. For multi-GPU training, you can use FSDP's built-in CPU offload.
83+
This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. Only CUDA and XPU is supported. For multi-GPU training, you can use FSDP's built-in CPU offload.
8484

8585
```python
8686
import torch
@@ -97,7 +97,7 @@ optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, offload_gradi
9797

9898
This will reduce GPU memory usage by optimizer state size, and additionally gradient size if `offload_gradients=True`. `CPUOffloadOptimizer` can wrap any base optimizer.
9999

100-
For saving and loading `CPUOffloadOptimizer`, it is important that you load model's weights BEFORE creating the optimizer, since we create a CPU copy of the parameters inside `CPUOffloadOptimizer.__init__()`. (TODO: we might want to have a method to synchronize CUDA and CPU params in either direction CPU->CUDA and CUDA->CPU, in case they are out of sync.)
100+
For saving and loading `CPUOffloadOptimizer`, it is important that you load model's weights BEFORE creating the optimizer, since we create a CPU copy of the parameters inside `CPUOffloadOptimizer.__init__()`. (TODO: we might want to have a method to synchronize GPU and CPU params in either direction CPU->GPU and GPU->CPU, in case they are out of sync.)
101101

102102
```python
103103
ckpt = torch.load("checkpoint.pth")

0 commit comments

Comments
 (0)