Skip to content

Commit

Permalink
Enable torch.cuda.amp typechecking (#45480)
Browse files Browse the repository at this point in the history
Summary:
Fix `torch._C._autocast_*_nesting` declarations in __init__.pyi

Fix iterable constructor logic: not every iterable can be constructed using `type(val)(val)` trick, for example it would not work for `val=range(10)` although `isinstance(val, Iterable)` is True
Change optional resolution logic to meet mypy expectations

Fixes #45436

Pull Request resolved: #45480

Reviewed By: walterddr

Differential Revision: D23982822

Pulled By: malfet

fbshipit-source-id: 6418a28d04ece1b2427dcde4b71effb67856a872
  • Loading branch information
malfet authored and facebook-github-bot committed Sep 29, 2020
1 parent df0de78 commit b3135c2
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 43 deletions.
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,6 @@ ignore_errors = True
[mypy-torch.cuda]
ignore_errors = True

[mypy-torch.cuda.amp.*]
ignore_errors = True

[mypy-torch._lobpcg]
ignore_errors = True

Expand Down
4 changes: 2 additions & 2 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@ def is_grad_enabled() -> _bool: ...
def set_autocast_enabled(enabled: _bool) -> None: ...
def is_autocast_enabled() -> _bool: ...
def clear_autocast_cache() -> None: ...
def autocast_increment_nesting() -> None: ...
def autocast_decrement_nesting() -> None: ...
def autocast_increment_nesting() -> _int: ...
def autocast_decrement_nesting() -> _int: ...
def set_anomaly_enabled(enabled: _bool) -> None: ...
def is_anomaly_enabled() -> _bool: ...

Expand Down
6 changes: 5 additions & 1 deletion torch/cuda/amp/autocast_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,11 @@ def _cast(value, dtype):
elif isinstance(value, container_abcs.Mapping):
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(_cast(v, dtype) for v in value)
iterable = map(lambda v: _cast(v, dtype), value)
if isinstance(value, list) or isinstance(value, tuple):
return type(value)(iterable)
else:
return iterable
else:
return value

Expand Down
86 changes: 49 additions & 37 deletions torch/cuda/amp/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
from torch._six import container_abcs
import warnings
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple


class _MultiDeviceReplicator(object):
"""
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
"""
def __init__(self, master_tensor):
def __init__(self, master_tensor: torch.Tensor) -> None:
assert master_tensor.is_cuda
self.master = master_tensor
self._per_device_tensors = {}
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}

def get(self, device):
def get(self, device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None)
if retval is None:
retval = self.master.to(device=device, non_blocking=True, copy=True)
Expand All @@ -38,6 +39,9 @@ def _refresh_per_optimizer_state():


class GradScaler(object):
_scale: Optional[torch.Tensor]
_grows_tracker: Optional[torch.Tensor]
_per_optimizer_states: Dict[int, Dict[str, Any]]
"""
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
conveniently.
Expand Down Expand Up @@ -128,10 +132,11 @@ def __init__(self,
self._growth_tracker = None
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)

def _check_scale_growth_tracker(self, funcname):
def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]:
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
return (self._scale, self._growth_tracker)

def _lazy_init_scale_growth_tracker(self, dev):
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
Expand All @@ -156,21 +161,27 @@ def scale(self, outputs):
assert outputs.is_cuda
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
return outputs * self._scale.to(device=outputs.device, non_blocking=True)

# Invoke the more complex machinery only if we're treating multiple outputs.
stash = [None] # trick to hold a reference that can be overwritten at any level of the recursion below.
stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale

def apply_scale(val):
if isinstance(val, torch.Tensor):
assert val.is_cuda
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
if stash[0] is None:
stash[0] = _MultiDeviceReplicator(self._scale)
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
assert self._scale is not None
stash.append(_MultiDeviceReplicator(self._scale))
return val * stash[0].get(val.device)
elif isinstance(val, container_abcs.Iterable):
return type(val)(apply_scale(v) for v in val)
iterable = map(apply_scale, val)
if isinstance(val, list) or isinstance(val, tuple):
return type(val)(iterable)
else:
return iterable
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")

Expand All @@ -182,25 +193,25 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):

for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is not None:
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
with torch.no_grad():
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
param.grad = param.grad.coalesce()
to_unscale = param.grad._values()
else:
with torch.no_grad():
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
param.grad = param.grad.coalesce()
to_unscale = param.grad._values()
else:
to_unscale = param.grad

torch._amp_non_finite_check_and_unscale_(to_unscale,
per_device_found_inf.get(param.grad.device),
per_device_inv_scale.get(param.grad.device))
to_unscale = param.grad

torch._amp_non_finite_check_and_unscale_(to_unscale,
per_device_found_inf.get(param.grad.device),
per_device_inv_scale.get(param.grad.device))

return per_device_found_inf._per_device_tensors

Expand Down Expand Up @@ -249,6 +260,7 @@ def unscale_(self, optimizer):
raise RuntimeError("unscale_() is being called after step().")

# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.double().reciprocal().float()
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)

Expand Down Expand Up @@ -332,22 +344,22 @@ def update(self, new_scale=None):
if not self._enabled:
return

self._check_scale_growth_tracker("update")
_scale, _growth_tracker = self._check_scale_growth_tracker("update")

if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale = torch.full((1,), new_scale, dtype=torch.float32, device=self._scale.device)
self._scale = torch.full((1,), new_scale, dtype=torch.float32, device=_scale.device)
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.cuda.FloatTensor), reason
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale = new_scale
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [found_inf.to(device=self._scale.device, non_blocking=True)
found_infs = [found_inf.to(device=_scale.device, non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()]

Expand All @@ -358,8 +370,8 @@ def update(self, new_scale=None):
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]

self._scale = torch._amp_update_scale(self._growth_tracker,
self._scale,
self._scale = torch._amp_update_scale(_growth_tracker,
_scale,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
Expand Down Expand Up @@ -498,10 +510,10 @@ def __setstate__(self, state):
self.__dict__.update(state)

def _check_inf_per_device(self, optimizer):
self._check_scale_growth_tracker("_check_inf_per_device")
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")

dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=self._scale.device)
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device)
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device)

self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
Expand Down

0 comments on commit b3135c2

Please sign in to comment.