3
3
from torch ._six import container_abcs
4
4
import warnings
5
5
from enum import Enum
6
+ from typing import Any , Dict , List , Optional , Tuple
6
7
7
8
8
9
class _MultiDeviceReplicator (object ):
9
10
"""
10
11
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
11
12
"""
12
- def __init__ (self , master_tensor ) :
13
+ def __init__ (self , master_tensor : torch . Tensor ) -> None :
13
14
assert master_tensor .is_cuda
14
15
self .master = master_tensor
15
- self ._per_device_tensors = {}
16
+ self ._per_device_tensors : Dict [ torch . device , torch . Tensor ] = {}
16
17
17
- def get (self , device ):
18
+ def get (self , device ) -> torch . Tensor :
18
19
retval = self ._per_device_tensors .get (device , None )
19
20
if retval is None :
20
21
retval = self .master .to (device = device , non_blocking = True , copy = True )
@@ -38,6 +39,9 @@ def _refresh_per_optimizer_state():
38
39
39
40
40
41
class GradScaler (object ):
42
+ _scale : Optional [torch .Tensor ]
43
+ _grows_tracker : Optional [torch .Tensor ]
44
+ _per_optimizer_states : Dict [int , Dict [str , Any ]]
41
45
"""
42
46
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
43
47
conveniently.
@@ -128,10 +132,11 @@ def __init__(self,
128
132
self ._growth_tracker = None
129
133
self ._per_optimizer_states = defaultdict (_refresh_per_optimizer_state )
130
134
131
- def _check_scale_growth_tracker (self , funcname ):
135
+ def _check_scale_growth_tracker (self , funcname ) -> Tuple [ torch . Tensor , torch . Tensor ] :
132
136
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
133
137
assert self ._scale is not None , "Attempted {} but _scale is None. " .format (funcname ) + fix
134
138
assert self ._growth_tracker is not None , "Attempted {} but _growth_tracker is None. " .format (funcname ) + fix
139
+ return (self ._scale , self ._growth_tracker )
135
140
136
141
def _lazy_init_scale_growth_tracker (self , dev ):
137
142
assert self ._growth_tracker is None , "_growth_tracker initialized before _scale"
@@ -156,21 +161,27 @@ def scale(self, outputs):
156
161
assert outputs .is_cuda
157
162
if self ._scale is None :
158
163
self ._lazy_init_scale_growth_tracker (outputs .device )
164
+ assert self ._scale is not None
159
165
return outputs * self ._scale .to (device = outputs .device , non_blocking = True )
160
166
161
167
# Invoke the more complex machinery only if we're treating multiple outputs.
162
- stash = [None ] # trick to hold a reference that can be overwritten at any level of the recursion below.
168
+ stash : List [ _MultiDeviceReplicator ] = [] # holds a reference that can be overwritten by apply_scale
163
169
164
170
def apply_scale (val ):
165
171
if isinstance (val , torch .Tensor ):
166
172
assert val .is_cuda
167
- if self ._scale is None :
168
- self ._lazy_init_scale_growth_tracker (val .device )
169
- if stash [0 ] is None :
170
- stash [0 ] = _MultiDeviceReplicator (self ._scale )
173
+ if len (stash ) == 0 :
174
+ if self ._scale is None :
175
+ self ._lazy_init_scale_growth_tracker (val .device )
176
+ assert self ._scale is not None
177
+ stash .append (_MultiDeviceReplicator (self ._scale ))
171
178
return val * stash [0 ].get (val .device )
172
179
elif isinstance (val , container_abcs .Iterable ):
173
- return type (val )(apply_scale (v ) for v in val )
180
+ iterable = map (apply_scale , val )
181
+ if isinstance (val , list ) or isinstance (val , tuple ):
182
+ return type (val )(iterable )
183
+ else :
184
+ return iterable
174
185
else :
175
186
raise ValueError ("outputs must be a Tensor or an iterable of Tensors" )
176
187
@@ -182,25 +193,25 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
182
193
183
194
for group in optimizer .param_groups :
184
195
for param in group ["params" ]:
185
- if param .grad is not None :
186
- if (not allow_fp16 ) and param .grad .dtype == torch .float16 :
187
- raise ValueError ("Attempting to unscale FP16 gradients." )
196
+ if param .grad is None :
197
+ continue
198
+ if (not allow_fp16 ) and param .grad .dtype == torch .float16 :
199
+ raise ValueError ("Attempting to unscale FP16 gradients." )
200
+ with torch .no_grad ():
201
+ if param .grad .is_sparse :
202
+ # is_coalesced() == False means the sparse grad has values with duplicate indices.
203
+ # coalesce() deduplicates indices and adds all values that have the same index.
204
+ # For scaled fp16 values, there's a good chance coalescing will cause overflow,
205
+ # so we should check the coalesced _values().
206
+ if param .grad .dtype is torch .float16 :
207
+ param .grad = param .grad .coalesce ()
208
+ to_unscale = param .grad ._values ()
188
209
else :
189
- with torch .no_grad ():
190
- if param .grad .is_sparse :
191
- # is_coalesced() == False means the sparse grad has values with duplicate indices.
192
- # coalesce() deduplicates indices and adds all values that have the same index.
193
- # For scaled fp16 values, there's a good chance coalescing will cause overflow,
194
- # so we should check the coalesced _values().
195
- if param .grad .dtype is torch .float16 :
196
- param .grad = param .grad .coalesce ()
197
- to_unscale = param .grad ._values ()
198
- else :
199
- to_unscale = param .grad
200
-
201
- torch ._amp_non_finite_check_and_unscale_ (to_unscale ,
202
- per_device_found_inf .get (param .grad .device ),
203
- per_device_inv_scale .get (param .grad .device ))
210
+ to_unscale = param .grad
211
+
212
+ torch ._amp_non_finite_check_and_unscale_ (to_unscale ,
213
+ per_device_found_inf .get (param .grad .device ),
214
+ per_device_inv_scale .get (param .grad .device ))
204
215
205
216
return per_device_found_inf ._per_device_tensors
206
217
@@ -249,6 +260,7 @@ def unscale_(self, optimizer):
249
260
raise RuntimeError ("unscale_() is being called after step()." )
250
261
251
262
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
263
+ assert self ._scale is not None
252
264
inv_scale = self ._scale .double ().reciprocal ().float ()
253
265
found_inf = torch .full ((1 ,), 0.0 , dtype = torch .float32 , device = self ._scale .device )
254
266
@@ -332,22 +344,22 @@ def update(self, new_scale=None):
332
344
if not self ._enabled :
333
345
return
334
346
335
- self ._check_scale_growth_tracker ("update" )
347
+ _scale , _growth_tracker = self ._check_scale_growth_tracker ("update" )
336
348
337
349
if new_scale is not None :
338
350
# Accept a new user-defined scale.
339
351
if isinstance (new_scale , float ):
340
- self ._scale = torch .full ((1 ,), new_scale , dtype = torch .float32 , device = self . _scale .device )
352
+ self ._scale = torch .full ((1 ,), new_scale , dtype = torch .float32 , device = _scale .device )
341
353
else :
342
354
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
343
- assert isinstance (new_scale , torch .cuda .FloatTensor ), reason
355
+ assert isinstance (new_scale , torch .cuda .FloatTensor ), reason # type: ignore[attr-defined]
344
356
assert new_scale .numel () == 1 , reason
345
357
assert new_scale .requires_grad is False , reason
346
358
self ._scale = new_scale
347
359
else :
348
360
# Consume shared inf/nan data collected from optimizers to update the scale.
349
361
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
350
- found_infs = [found_inf .to (device = self . _scale .device , non_blocking = True )
362
+ found_infs = [found_inf .to (device = _scale .device , non_blocking = True )
351
363
for state in self ._per_optimizer_states .values ()
352
364
for found_inf in state ["found_inf_per_device" ].values ()]
353
365
@@ -358,8 +370,8 @@ def update(self, new_scale=None):
358
370
for i in range (1 , len (found_infs )):
359
371
found_inf_combined += found_infs [i ]
360
372
361
- self ._scale = torch ._amp_update_scale (self . _growth_tracker ,
362
- self . _scale ,
373
+ self ._scale = torch ._amp_update_scale (_growth_tracker ,
374
+ _scale ,
363
375
found_inf_combined ,
364
376
self ._growth_factor ,
365
377
self ._backoff_factor ,
@@ -498,10 +510,10 @@ def __setstate__(self, state):
498
510
self .__dict__ .update (state )
499
511
500
512
def _check_inf_per_device (self , optimizer ):
501
- self ._check_scale_growth_tracker ("_check_inf_per_device" )
513
+ _scale , _ = self ._check_scale_growth_tracker ("_check_inf_per_device" )
502
514
503
- dummy_inv_scale = torch .full ((1 ,), 1.0 , dtype = torch .float32 , device = self . _scale .device )
504
- found_inf = torch .full ((1 ,), 0.0 , dtype = torch .float32 , device = self . _scale .device )
515
+ dummy_inv_scale = torch .full ((1 ,), 1.0 , dtype = torch .float32 , device = _scale .device )
516
+ found_inf = torch .full ((1 ,), 0.0 , dtype = torch .float32 , device = _scale .device )
505
517
506
518
self ._per_optimizer_states [id (optimizer )]["found_inf_per_device" ] = \
507
519
self ._unscale_grads_ (optimizer , dummy_inv_scale , found_inf , True )
0 commit comments