Skip to content

Commit 042ca1f

Browse files
committed
feat(cpu-offload): enable CPU Offload for XPU
Signed-off-by: dbyoung18 <yang5.yang@intel.com>
1 parent 7489c7d commit 042ca1f

File tree

2 files changed

+59
-36
lines changed

2 files changed

+59
-36
lines changed

test/prototype/test_low_bit_optim.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,12 @@
4242
lpmm = None
4343

4444

45-
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
45+
if torch.cuda.is_available():
46+
_DEVICES = ["cpu", "cuda"]
47+
elif torch.xpu.is_available():
48+
_DEVICES = ["cpu", "xpu"]
49+
else:
50+
_DEVICES = ["cpu"]
4651

4752

4853
class TestQuantize(TestCase):
@@ -244,11 +249,12 @@ def test_optim_4bit_correctness(self, optim_name):
244249
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)
245250

246251
@pytest.mark.skipif(
247-
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
252+
not torch.cuda.is_available() and not torch.xpu.is_available(),
253+
reason="optim CPU offload requires CUDA or XPU"
248254
)
249255
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
250256
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
251-
device = "cuda"
257+
device = _DEVICES[-1]
252258
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
253259
model1.to(device)
254260

@@ -261,6 +267,7 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
261267
model2.parameters(),
262268
torch.optim.AdamW,
263269
offload_gradients=offload_grad,
270+
device=device,
264271
)
265272

266273
for _ in range(2):
@@ -279,10 +286,11 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
279286
torch.testing.assert_close(p2, p1)
280287

281288
@pytest.mark.skipif(
282-
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
289+
not torch.cuda.is_available() and not torch.xpu.is_available(),
290+
reason="optim CPU offload requires CUDA or XPU"
283291
)
284292
def test_optim_cpu_offload_save_load(self):
285-
device = "cuda"
293+
device = _DEVICES[-1]
286294
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
287295
model1.to(device)
288296
optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)

torchao/prototype/low_bit_optim/cpu_offload.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def __init__(
1313
optimizer_class: Type[Optimizer] = torch.optim.AdamW,
1414
*,
1515
offload_gradients: bool = False,
16+
device: str = "cuda",
1617
**kwargs,
1718
) -> None:
1819
"""Offload optimizer to CPU for single-GPU training. This will reduce GPU memory by the size of optimizer state.
@@ -22,6 +23,7 @@ def __init__(
2223
params: a list of parameters or parameter groups.
2324
optimizer_class: constructor of the base optimizer. Defaults to :class:`torch.optim.AdamW`.
2425
offload_gradients: free GPU gradients once they are moved to CPU. Not compatible with gradient accumulation.
26+
device: device type for GPU. Choose from "cuda" and "xpu". Defaults to "cuda".
2527
kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`.
2628
"""
2729
# default to fused CPU AdamW
@@ -38,51 +40,60 @@ def __init__(
3840
if not isinstance(param_groups[0], dict):
3941
param_groups = [{"params": param_groups}]
4042

41-
self.param_cuda2cpu_map = dict()
43+
self.param_d2h_map = dict()
4244
self.optim_dict = dict()
43-
self.stream = torch.cuda.Stream()
45+
self.device = device
46+
if self.device == "cuda":
47+
self.stream = torch.cuda.Stream()
48+
elif self.device == "xpu":
49+
self.stream = torch.xpu.Stream()
4450

4551
# the queue maintains the order which param we should do optim step on first.
4652
self.queue = dict()
4753

48-
def backward_hook(p_cuda):
49-
if p_cuda.grad is not None:
50-
p_cpu = self.param_cuda2cpu_map[p_cuda]
54+
def backward_hook(p_device):
55+
if p_device.grad is not None:
56+
p_host = self.param_d2h_map[p_device]
5157

5258
# make sure backward for this param finishes
53-
self.stream.wait_stream(torch.cuda.current_stream())
54-
with torch.cuda.stream(self.stream):
55-
p_cpu.grad.copy_(p_cuda.grad, non_blocking=True)
59+
if self.device == "cuda":
60+
self.stream.wait_stream(torch.cuda.current_stream())
61+
with torch.cuda.stream(self.stream):
62+
p_host.grad.copy_(p_device.grad, non_blocking=True)
63+
elif self.device == "xpu":
64+
self.stream.wait_stream(torch.xpu.current_stream())
65+
with torch.xpu.stream(self.stream):
66+
p_host.grad.copy_(p_device.grad, non_blocking=True)
5667

5768
# we rely on CPython implementation of dictionary, which preserves insertion order.
5869
# if a param is added again (e.g. due to gradient accumulation), it is moved to the
5970
# end of the queue by removing and inserting it again.
60-
if p_cuda in self.queue:
61-
del self.queue[p_cuda]
62-
self.queue[p_cuda] = self.stream.record_event()
71+
if p_device in self.queue:
72+
del self.queue[p_device]
73+
self.queue[p_device] = self.stream.record_event()
6374

64-
# deallocate CUDA gradients once D2H transfer finishes.
75+
# deallocate DEVICE gradients once D2H transfer finishes.
6576
if offload_gradients:
66-
p_cuda.grad.record_stream(self.stream)
67-
p_cuda.grad = None
77+
p_device.grad.record_stream(self.stream)
78+
p_device.grad = None
6879

6980
for param_group in param_groups:
7081
params = param_group.pop("params")
7182

72-
for p_cuda in params:
73-
if not p_cuda.requires_grad:
83+
for p_device in params:
84+
if not p_device.requires_grad:
7485
continue
7586

7687
# pre-allocate CPU params and grads
77-
p_cpu = torch.empty_like(p_cuda, device="cpu", pin_memory=True)
78-
p_cpu.grad = torch.empty_like(p_cpu, pin_memory=True)
88+
p_host = torch.empty_like(p_device, device="cpu", pin_memory=True)
89+
p_host.grad = torch.empty_like(p_host, pin_memory=True)
7990

80-
p_cpu.copy_(p_cuda.detach(), non_blocking=True)
81-
self.param_cuda2cpu_map[p_cuda] = p_cpu
91+
p_host.copy_(p_device.detach(), non_blocking=True)
92+
self.param_d2h_map[p_device] = p_host
8293

83-
p_cuda.register_post_accumulate_grad_hook(backward_hook)
84-
self.optim_dict[p_cuda] = optimizer_class(
85-
[{"params": p_cpu, **param_group}], **kwargs
94+
p_device.register_post_accumulate_grad_hook(backward_hook)
95+
self.optim_dict[p_device] = optimizer_class(
96+
[{"params": p_host, **param_group}], **kwargs
8697
)
8798

8899
@torch.no_grad()
@@ -91,26 +102,30 @@ def step(self, closure=None):
91102
if closure is not None:
92103
loss = closure()
93104

94-
for p_cuda, grad_d2h_event in self.queue.items():
105+
for p_device, grad_d2h_event in self.queue.items():
95106
grad_d2h_event.synchronize()
96-
self.optim_dict[p_cuda].step()
107+
self.optim_dict[p_device].step()
97108

98109
# submit more job to self.stream. it guarantees that we only start
99110
# moving param H2D once all backwards finish, since self.stream
100111
# will wait for current_stream when moving grad D2H.
101-
p_cpu = self.param_cuda2cpu_map[p_cuda]
102-
with torch.cuda.stream(self.stream):
103-
p_cuda.copy_(p_cpu, non_blocking=True)
112+
p_host = self.param_d2h_map[p_device]
113+
if self.device == "cuda":
114+
with torch.cuda.stream(self.stream):
115+
p_device.copy_(p_host, non_blocking=True)
116+
elif self.device == "xpu":
117+
with torch.xpu.stream(self.stream):
118+
p_device.copy_(p_host, non_blocking=True)
104119

105120
self.queue.clear()
106121
return loss
107122

108123
def zero_grad(self, set_to_none=True):
109124
assert set_to_none
110125

111-
# only clear CUDA grad. CPU grad will always be overwritten by CUDA grad.
112-
for p_cuda in self.param_cuda2cpu_map.keys():
113-
p_cuda.grad = None
126+
# only clear DEVICE grad. CPU grad will always be overwritten by DEVICE grad.
127+
for p_device in self.param_d2h_map.keys():
128+
p_device.grad = None
114129

115130
@property
116131
def param_groups(self):

0 commit comments

Comments
 (0)