Skip to content

Commit b3135c2

Browse files
malfetfacebook-github-bot
authored andcommitted
Enable torch.cuda.amp typechecking (#45480)
Summary: Fix `torch._C._autocast_*_nesting` declarations in __init__.pyi Fix iterable constructor logic: not every iterable can be constructed using `type(val)(val)` trick, for example it would not work for `val=range(10)` although `isinstance(val, Iterable)` is True Change optional resolution logic to meet mypy expectations Fixes #45436 Pull Request resolved: #45480 Reviewed By: walterddr Differential Revision: D23982822 Pulled By: malfet fbshipit-source-id: 6418a28d04ece1b2427dcde4b71effb67856a872
1 parent df0de78 commit b3135c2

File tree

4 files changed

+56
-43
lines changed

4 files changed

+56
-43
lines changed

mypy.ini

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,6 @@ ignore_errors = True
174174
[mypy-torch.cuda]
175175
ignore_errors = True
176176

177-
[mypy-torch.cuda.amp.*]
178-
ignore_errors = True
179-
180177
[mypy-torch._lobpcg]
181178
ignore_errors = True
182179

torch/_C/__init__.pyi.in

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,8 @@ def is_grad_enabled() -> _bool: ...
389389
def set_autocast_enabled(enabled: _bool) -> None: ...
390390
def is_autocast_enabled() -> _bool: ...
391391
def clear_autocast_cache() -> None: ...
392-
def autocast_increment_nesting() -> None: ...
393-
def autocast_decrement_nesting() -> None: ...
392+
def autocast_increment_nesting() -> _int: ...
393+
def autocast_decrement_nesting() -> _int: ...
394394
def set_anomaly_enabled(enabled: _bool) -> None: ...
395395
def is_anomaly_enabled() -> _bool: ...
396396

torch/cuda/amp/autocast_mode.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,11 @@ def _cast(value, dtype):
149149
elif isinstance(value, container_abcs.Mapping):
150150
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
151151
elif isinstance(value, container_abcs.Iterable):
152-
return type(value)(_cast(v, dtype) for v in value)
152+
iterable = map(lambda v: _cast(v, dtype), value)
153+
if isinstance(value, list) or isinstance(value, tuple):
154+
return type(value)(iterable)
155+
else:
156+
return iterable
153157
else:
154158
return value
155159

torch/cuda/amp/grad_scaler.py

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,19 @@
33
from torch._six import container_abcs
44
import warnings
55
from enum import Enum
6+
from typing import Any, Dict, List, Optional, Tuple
67

78

89
class _MultiDeviceReplicator(object):
910
"""
1011
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
1112
"""
12-
def __init__(self, master_tensor):
13+
def __init__(self, master_tensor: torch.Tensor) -> None:
1314
assert master_tensor.is_cuda
1415
self.master = master_tensor
15-
self._per_device_tensors = {}
16+
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
1617

17-
def get(self, device):
18+
def get(self, device) -> torch.Tensor:
1819
retval = self._per_device_tensors.get(device, None)
1920
if retval is None:
2021
retval = self.master.to(device=device, non_blocking=True, copy=True)
@@ -38,6 +39,9 @@ def _refresh_per_optimizer_state():
3839

3940

4041
class GradScaler(object):
42+
_scale: Optional[torch.Tensor]
43+
_grows_tracker: Optional[torch.Tensor]
44+
_per_optimizer_states: Dict[int, Dict[str, Any]]
4145
"""
4246
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
4347
conveniently.
@@ -128,10 +132,11 @@ def __init__(self,
128132
self._growth_tracker = None
129133
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
130134

131-
def _check_scale_growth_tracker(self, funcname):
135+
def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]:
132136
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
133137
assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
134138
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
139+
return (self._scale, self._growth_tracker)
135140

136141
def _lazy_init_scale_growth_tracker(self, dev):
137142
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
@@ -156,21 +161,27 @@ def scale(self, outputs):
156161
assert outputs.is_cuda
157162
if self._scale is None:
158163
self._lazy_init_scale_growth_tracker(outputs.device)
164+
assert self._scale is not None
159165
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
160166

161167
# Invoke the more complex machinery only if we're treating multiple outputs.
162-
stash = [None] # trick to hold a reference that can be overwritten at any level of the recursion below.
168+
stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale
163169

164170
def apply_scale(val):
165171
if isinstance(val, torch.Tensor):
166172
assert val.is_cuda
167-
if self._scale is None:
168-
self._lazy_init_scale_growth_tracker(val.device)
169-
if stash[0] is None:
170-
stash[0] = _MultiDeviceReplicator(self._scale)
173+
if len(stash) == 0:
174+
if self._scale is None:
175+
self._lazy_init_scale_growth_tracker(val.device)
176+
assert self._scale is not None
177+
stash.append(_MultiDeviceReplicator(self._scale))
171178
return val * stash[0].get(val.device)
172179
elif isinstance(val, container_abcs.Iterable):
173-
return type(val)(apply_scale(v) for v in val)
180+
iterable = map(apply_scale, val)
181+
if isinstance(val, list) or isinstance(val, tuple):
182+
return type(val)(iterable)
183+
else:
184+
return iterable
174185
else:
175186
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
176187

@@ -182,25 +193,25 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
182193

183194
for group in optimizer.param_groups:
184195
for param in group["params"]:
185-
if param.grad is not None:
186-
if (not allow_fp16) and param.grad.dtype == torch.float16:
187-
raise ValueError("Attempting to unscale FP16 gradients.")
196+
if param.grad is None:
197+
continue
198+
if (not allow_fp16) and param.grad.dtype == torch.float16:
199+
raise ValueError("Attempting to unscale FP16 gradients.")
200+
with torch.no_grad():
201+
if param.grad.is_sparse:
202+
# is_coalesced() == False means the sparse grad has values with duplicate indices.
203+
# coalesce() deduplicates indices and adds all values that have the same index.
204+
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
205+
# so we should check the coalesced _values().
206+
if param.grad.dtype is torch.float16:
207+
param.grad = param.grad.coalesce()
208+
to_unscale = param.grad._values()
188209
else:
189-
with torch.no_grad():
190-
if param.grad.is_sparse:
191-
# is_coalesced() == False means the sparse grad has values with duplicate indices.
192-
# coalesce() deduplicates indices and adds all values that have the same index.
193-
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
194-
# so we should check the coalesced _values().
195-
if param.grad.dtype is torch.float16:
196-
param.grad = param.grad.coalesce()
197-
to_unscale = param.grad._values()
198-
else:
199-
to_unscale = param.grad
200-
201-
torch._amp_non_finite_check_and_unscale_(to_unscale,
202-
per_device_found_inf.get(param.grad.device),
203-
per_device_inv_scale.get(param.grad.device))
210+
to_unscale = param.grad
211+
212+
torch._amp_non_finite_check_and_unscale_(to_unscale,
213+
per_device_found_inf.get(param.grad.device),
214+
per_device_inv_scale.get(param.grad.device))
204215

205216
return per_device_found_inf._per_device_tensors
206217

@@ -249,6 +260,7 @@ def unscale_(self, optimizer):
249260
raise RuntimeError("unscale_() is being called after step().")
250261

251262
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
263+
assert self._scale is not None
252264
inv_scale = self._scale.double().reciprocal().float()
253265
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
254266

@@ -332,22 +344,22 @@ def update(self, new_scale=None):
332344
if not self._enabled:
333345
return
334346

335-
self._check_scale_growth_tracker("update")
347+
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
336348

337349
if new_scale is not None:
338350
# Accept a new user-defined scale.
339351
if isinstance(new_scale, float):
340-
self._scale = torch.full((1,), new_scale, dtype=torch.float32, device=self._scale.device)
352+
self._scale = torch.full((1,), new_scale, dtype=torch.float32, device=_scale.device)
341353
else:
342354
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
343-
assert isinstance(new_scale, torch.cuda.FloatTensor), reason
355+
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
344356
assert new_scale.numel() == 1, reason
345357
assert new_scale.requires_grad is False, reason
346358
self._scale = new_scale
347359
else:
348360
# Consume shared inf/nan data collected from optimizers to update the scale.
349361
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
350-
found_infs = [found_inf.to(device=self._scale.device, non_blocking=True)
362+
found_infs = [found_inf.to(device=_scale.device, non_blocking=True)
351363
for state in self._per_optimizer_states.values()
352364
for found_inf in state["found_inf_per_device"].values()]
353365

@@ -358,8 +370,8 @@ def update(self, new_scale=None):
358370
for i in range(1, len(found_infs)):
359371
found_inf_combined += found_infs[i]
360372

361-
self._scale = torch._amp_update_scale(self._growth_tracker,
362-
self._scale,
373+
self._scale = torch._amp_update_scale(_growth_tracker,
374+
_scale,
363375
found_inf_combined,
364376
self._growth_factor,
365377
self._backoff_factor,
@@ -498,10 +510,10 @@ def __setstate__(self, state):
498510
self.__dict__.update(state)
499511

500512
def _check_inf_per_device(self, optimizer):
501-
self._check_scale_growth_tracker("_check_inf_per_device")
513+
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
502514

503-
dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=self._scale.device)
504-
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
515+
dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device)
516+
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device)
505517

506518
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
507519
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)

0 commit comments

Comments
 (0)