Skip to content

Commit

Permalink
Add fp16/fp32 autocasting to JIT/TorchScript (pytorch#63939)
Browse files Browse the repository at this point in the history
Summary:
Adds mixed precision autocasting support between fp32/fp16 to torchscript/JIT. More in depth descriptoin can be found at [torch/csrc/jit/JIT-AUTOCAST.md](https://github.com/pytorch/pytorch/pull/63939/files#diff-1f1772aaa508841c5bb58b74ab98f49a1e577612cd9ea5c386c8714a75db830b)

This PR implemented an autocast optimization pass that inserts casting ops per AMP rule (torch/csrc/jit/passes/autocast.cpp), that mimics the behavior of eager autocast. The pass also takes into consideration the context of `torch.cuda.amp.autocast` and only inserts casting ops within the enabled context manager, giving feature parity as with eager amp autocast.

We currently provide JIT AMP autocast as a prototyping feature, so it is default off and could be turned on via `torch._C._jit_set_autocast_mode(True)`

The JIT support for autocast is subject to different constraints compared to the eager mode implementation (mostly related to the fact that TorchScript is statically typed), restriction on the user facing python code is described in doc torch/csrc/jit/JIT-AUTOCAST.md

This is a prototype, there are also implementation limitation that's necessary to keep this PR small and get something functioning quickly on upstream, so we can iterate on designs.

Few limitation/challenge that is not properly resolved in this PR:
1. Autocast inserts cast operation, which would have impact on scalar type of output tensor feeding downstream operations. We are not currently propagating the updated scalar types, this would give issues/wrong results on operations in promotion rules.

2. Backward for autodiff in JIT misses the casting of dgrad to input scalar type, as what autograd does in eager. This forces us to explicitly mark the casting operation for certain operations (e.g. binary ops), otherwise, we might be feeding dgrad with mismatch scalar type to input. This could potentially break gradient function consuming dgrad. (e.g. gemm backwards, which assumes grad_output to be of same scalar type as input')

3. `torch.autocast` api has an optional argument `dtype` which is not currently supported in the JIT autocast and we require a static value.

Credit goes mostly to:
tlemo
kevinstephano

Pull Request resolved: pytorch#63939

Reviewed By: navahgar

Differential Revision: D31093381

Pulled By: eellison

fbshipit-source-id: da6e26c668c38b01e296f304507048d6c1794314
  • Loading branch information
jjsjann123 authored and facebook-github-bot committed Oct 27, 2021
1 parent 0101b1e commit 1ec732b
Show file tree
Hide file tree
Showing 19 changed files with 1,474 additions and 39 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ _(aten, atan2) \
_(aten, atleast_1d) \
_(aten, atleast_2d) \
_(aten, atleast_3d) \
_(aten, _autocast_to_reduced_precision) \
_(aten, _autocast_to_full_precision) \
_(aten, avg_pool1d) \
_(aten, avg_pool2d) \
_(aten, avg_pool2d_backward) \
Expand Down
37 changes: 37 additions & 0 deletions aten/src/ATen/native/TensorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,43 @@ static inline Tensor to_impl(
self, dtype, layout, device, pin_memory, non_blocking, optional_memory_format);
}

// If input tensor is fp32, cast it to fp16, otherwise leave it alone.
// (this is intended to be used internally by the JIT autocast implementation)
Tensor _autocast_to_reduced_precision(const Tensor& self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) {
if (self.dtype() == at::ScalarType::Float &&
((self.device().is_cuda() && cuda_enabled) ||
(self.device().is_cpu() && cpu_enabled))
) {
at::ScalarType target = at::ScalarType::Undefined;
if (self.device().is_cuda()) {
target = cuda_dtype;
} else if (self.device().is_cpu()) {
target = cpu_dtype;
}

TORCH_INTERNAL_ASSERT(target != at::ScalarType::Undefined, "_autocast_to_reduced_precision requires legit ScalarType argument for given device");

return to_impl(
self, target, c10::nullopt, c10::nullopt, c10::nullopt, false, false, c10::nullopt);
} else {
return self;
}
}

// If input tensor is fp16, cast it to fp32, otherwise leave it alone.
// (this is intended to be used internally by the JIT autocast implementation)
Tensor _autocast_to_full_precision(const Tensor& self, bool cuda_enabled, bool cpu_enabled) {
if (self.dtype() == at::ScalarType::Half &&
((self.device().is_cuda() && cuda_enabled) ||
(self.device().is_cpu() && cpu_enabled))
) {
return to_impl(
self, at::ScalarType::Float, c10::nullopt, c10::nullopt, c10::nullopt, false, false, c10::nullopt);
} else {
return self;
}
}

Tensor to(
const Tensor& self,
c10::optional<ScalarType> dtype,
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5425,6 +5425,14 @@
- func: choose_qparams_optimized(Tensor input, int numel, int n_bins, float ratio, int bit_width) -> (Tensor, Tensor)
variants: function

- func: _autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)
variants: method
device_guard: False

- func: _autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)
variants: method
device_guard: False

- func: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
device_check: NoCheck
device_guard: False
Expand Down
Loading

0 comments on commit 1ec732b

Please sign in to comment.