@@ -13,6 +13,7 @@ def __init__(
13
13
optimizer_class : Type [Optimizer ] = torch .optim .AdamW ,
14
14
* ,
15
15
offload_gradients : bool = False ,
16
+ device : str = "cuda" ,
16
17
** kwargs ,
17
18
) -> None :
18
19
"""Offload optimizer to CPU for single-GPU training. This will reduce GPU memory by the size of optimizer state.
@@ -22,6 +23,7 @@ def __init__(
22
23
params: a list of parameters or parameter groups.
23
24
optimizer_class: constructor of the base optimizer. Defaults to :class:`torch.optim.AdamW`.
24
25
offload_gradients: free GPU gradients once they are moved to CPU. Not compatible with gradient accumulation.
26
+ device: device type for GPU. Choose from "cuda" and "xpu". Defaults to "cuda".
25
27
kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`.
26
28
"""
27
29
# default to fused CPU AdamW
@@ -38,51 +40,60 @@ def __init__(
38
40
if not isinstance (param_groups [0 ], dict ):
39
41
param_groups = [{"params" : param_groups }]
40
42
41
- self .param_cuda2cpu_map = dict ()
43
+ self .param_d2h_map = dict ()
42
44
self .optim_dict = dict ()
43
- self .stream = torch .cuda .Stream ()
45
+ self .device = device
46
+ if self .device == "cuda" :
47
+ self .stream = torch .cuda .Stream ()
48
+ elif self .device == "xpu" :
49
+ self .stream = torch .xpu .Stream ()
44
50
45
51
# the queue maintains the order which param we should do optim step on first.
46
52
self .queue = dict ()
47
53
48
- def backward_hook (p_cuda ):
49
- if p_cuda .grad is not None :
50
- p_cpu = self .param_cuda2cpu_map [ p_cuda ]
54
+ def backward_hook (p_device ):
55
+ if p_device .grad is not None :
56
+ p_host = self .param_d2h_map [ p_device ]
51
57
52
58
# make sure backward for this param finishes
53
- self .stream .wait_stream (torch .cuda .current_stream ())
54
- with torch .cuda .stream (self .stream ):
55
- p_cpu .grad .copy_ (p_cuda .grad , non_blocking = True )
59
+ if self .device == "cuda" :
60
+ self .stream .wait_stream (torch .cuda .current_stream ())
61
+ with torch .cuda .stream (self .stream ):
62
+ p_host .grad .copy_ (p_device .grad , non_blocking = True )
63
+ elif self .device == "xpu" :
64
+ self .stream .wait_stream (torch .xpu .current_stream ())
65
+ with torch .xpu .stream (self .stream ):
66
+ p_host .grad .copy_ (p_device .grad , non_blocking = True )
56
67
57
68
# we rely on CPython implementation of dictionary, which preserves insertion order.
58
69
# if a param is added again (e.g. due to gradient accumulation), it is moved to the
59
70
# end of the queue by removing and inserting it again.
60
- if p_cuda in self .queue :
61
- del self .queue [p_cuda ]
62
- self .queue [p_cuda ] = self .stream .record_event ()
71
+ if p_device in self .queue :
72
+ del self .queue [p_device ]
73
+ self .queue [p_device ] = self .stream .record_event ()
63
74
64
- # deallocate CUDA gradients once D2H transfer finishes.
75
+ # deallocate DEVICE gradients once D2H transfer finishes.
65
76
if offload_gradients :
66
- p_cuda .grad .record_stream (self .stream )
67
- p_cuda .grad = None
77
+ p_device .grad .record_stream (self .stream )
78
+ p_device .grad = None
68
79
69
80
for param_group in param_groups :
70
81
params = param_group .pop ("params" )
71
82
72
- for p_cuda in params :
73
- if not p_cuda .requires_grad :
83
+ for p_device in params :
84
+ if not p_device .requires_grad :
74
85
continue
75
86
76
87
# pre-allocate CPU params and grads
77
- p_cpu = torch .empty_like (p_cuda , device = "cpu" , pin_memory = True )
78
- p_cpu .grad = torch .empty_like (p_cpu , pin_memory = True )
88
+ p_host = torch .empty_like (p_device , device = "cpu" , pin_memory = True )
89
+ p_host .grad = torch .empty_like (p_host , pin_memory = True )
79
90
80
- p_cpu .copy_ (p_cuda .detach (), non_blocking = True )
81
- self .param_cuda2cpu_map [ p_cuda ] = p_cpu
91
+ p_host .copy_ (p_device .detach (), non_blocking = True )
92
+ self .param_d2h_map [ p_device ] = p_host
82
93
83
- p_cuda .register_post_accumulate_grad_hook (backward_hook )
84
- self .optim_dict [p_cuda ] = optimizer_class (
85
- [{"params" : p_cpu , ** param_group }], ** kwargs
94
+ p_device .register_post_accumulate_grad_hook (backward_hook )
95
+ self .optim_dict [p_device ] = optimizer_class (
96
+ [{"params" : p_host , ** param_group }], ** kwargs
86
97
)
87
98
88
99
@torch .no_grad ()
@@ -91,26 +102,30 @@ def step(self, closure=None):
91
102
if closure is not None :
92
103
loss = closure ()
93
104
94
- for p_cuda , grad_d2h_event in self .queue .items ():
105
+ for p_device , grad_d2h_event in self .queue .items ():
95
106
grad_d2h_event .synchronize ()
96
- self .optim_dict [p_cuda ].step ()
107
+ self .optim_dict [p_device ].step ()
97
108
98
109
# submit more job to self.stream. it guarantees that we only start
99
110
# moving param H2D once all backwards finish, since self.stream
100
111
# will wait for current_stream when moving grad D2H.
101
- p_cpu = self .param_cuda2cpu_map [p_cuda ]
102
- with torch .cuda .stream (self .stream ):
103
- p_cuda .copy_ (p_cpu , non_blocking = True )
112
+ p_host = self .param_d2h_map [p_device ]
113
+ if self .device == "cuda" :
114
+ with torch .cuda .stream (self .stream ):
115
+ p_device .copy_ (p_host , non_blocking = True )
116
+ elif self .device == "xpu" :
117
+ with torch .xpu .stream (self .stream ):
118
+ p_device .copy_ (p_host , non_blocking = True )
104
119
105
120
self .queue .clear ()
106
121
return loss
107
122
108
123
def zero_grad (self , set_to_none = True ):
109
124
assert set_to_none
110
125
111
- # only clear CUDA grad. CPU grad will always be overwritten by CUDA grad.
112
- for p_cuda in self .param_cuda2cpu_map .keys ():
113
- p_cuda .grad = None
126
+ # only clear DEVICE grad. CPU grad will always be overwritten by DEVICE grad.
127
+ for p_device in self .param_d2h_map .keys ():
128
+ p_device .grad = None
114
129
115
130
@property
116
131
def param_groups (self ):
0 commit comments