Skip to content

Commit

Permalink
Add an option to disable reduced precision reductions for FP16 GEMM (p…
Browse files Browse the repository at this point in the history
…ytorch#67946)

Summary:
pytorch#67578 disabled reduced precision reductions for FP16 GEMMs. After benchmarking, we've found that this has substantial performance impacts for common GEMM shapes (e.g., those found in popular instantiations of multiheaded-attention) on architectures such as Volta. As these performance regressions may come as a surprise to current users, this PR adds a toggle to disable reduced precision reductions
`torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = `
rather than making it the default behavior.

CC ngimel ptrblck
stas00 Note that the behavior after the previous PR can be replicated with
`torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False`

Pull Request resolved: pytorch#67946

Reviewed By: zou3519

Differential Revision: D32289896

Pulled By: ngimel

fbshipit-source-id: a1ea2918b77e27a7d9b391e030417802a0174abe
  • Loading branch information
eqy authored and facebook-github-bot committed Nov 10, 2021
1 parent 078c655 commit 790763b
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 5 deletions.
8 changes: 8 additions & 0 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ void Context::setAllowTF32CuBLAS(bool b) {
allow_tf32_cublas = b;
}

bool Context::allowFP16ReductionCuBLAS() const {
return allow_fp16_reduction_cublas;
}

void Context::setAllowFP16ReductionCuBLAS(bool b) {
allow_fp16_reduction_cublas = b;
}

bool Context::hasMKL() {
#if AT_MKL_ENABLED()
return true;
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ class TORCH_API Context {
void setAllowTF32CuDNN(bool);
bool allowTF32CuBLAS() const;
void setAllowTF32CuBLAS(bool);
bool allowFP16ReductionCuBLAS() const;
void setAllowFP16ReductionCuBLAS(bool);
at::QEngine qEngine() const;
void setQEngine(at::QEngine e);
static const std::vector<at::QEngine>& supportedQEngines() ;
Expand Down Expand Up @@ -239,6 +241,7 @@ class TORCH_API Context {
bool benchmark_cudnn = false;
bool allow_tf32_cudnn = true;
bool allow_tf32_cublas = true;
bool allow_fp16_reduction_cublas = true;
bool enabled_mkldnn = true;
#ifdef C10_MOBILE
bool release_original_weights = true;
Expand Down
6 changes: 5 additions & 1 deletion aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,12 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#else
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
}
// Disallow fp16 reductions that could lead to unexpected overflow issues.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, static_cast<cublasMath_t>(CUBLAS_DEFAULT_MATH | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION)));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
#endif // defined(CUDA_VERSION) && CUDA_VERSION < 11000
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
Expand Down
4 changes: 4 additions & 0 deletions docs/source/backends.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ torch.backends.cuda
A :class:`bool` that controls whether TensorFloat-32 tensor cores may be used in matrix
multiplications on Ampere or newer GPUs. See :ref:`tf32_on_ampere`.

.. attribute:: torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction

A :class:`bool` that controls whether reduced precision reductions (e.g., with fp16 accumulation type) are allowed with fp16 GEMMs.

.. attribute:: torch.backends.cuda.cufft_plan_cache

``cufft_plan_cache`` caches the cuFFT plans
Expand Down
41 changes: 41 additions & 0 deletions docs/source/notes/cuda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,47 @@ For more information about TF32, see:
.. _CUDA 11: https://devblogs.nvidia.com/cuda-11-features-revealed/
.. _Ampere architecture: https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/

.. _fp16reducedprecision:

Reduced Precision Reduction in FP16 GEMMs
-----------------------------------------

fp16 GEMMs are potentially done with some intermediate reduced precision reductions (e.g., in fp16 rather than fp32). These selective reductions in precision can allow for higher performance on certain workloads (particularly those with a large `k` dimension) and GPU architectures at the cost of numerical precision and potential for overflow.

Some example benchmark data on V100:

.. code::
[--------------------------- bench_gemm_transformer --------------------------]
[ m , k , n ] | allow_fp16_reduc=True | allow_fp16_reduc=False
1 threads: --------------------------------------------------------------------
[4096, 4048, 4096] | 1634.6 | 1639.8
[4096, 4056, 4096] | 1670.8 | 1661.9
[4096, 4080, 4096] | 1664.2 | 1658.3
[4096, 4096, 4096] | 1639.4 | 1651.0
[4096, 4104, 4096] | 1677.4 | 1674.9
[4096, 4128, 4096] | 1655.7 | 1646.0
[4096, 4144, 4096] | 1796.8 | 2519.6
[4096, 5096, 4096] | 2094.6 | 3190.0
[4096, 5104, 4096] | 2144.0 | 2663.5
[4096, 5112, 4096] | 2149.1 | 2766.9
[4096, 5120, 4096] | 2142.8 | 2631.0
[4096, 9728, 4096] | 3875.1 | 5779.8
[4096, 16384, 4096] | 6182.9 | 9656.5
(times in microseconds).
If full precision reductions are needed, users can disable reduced precision reductions in fp16 GEMMs with:

.. code:: python
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
To toggle the reduced precision reduction flags in C++, you can do

.. code:: C++

at::globalContext().setAllowFP16ReductionCuBLAS(false);

Asynchronous execution
----------------------

Expand Down
8 changes: 8 additions & 0 deletions docs/source/notes/numerical_accuracy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,11 @@ with fp32, however, if better accuracy is desired, TF32 can be turned off with
``torch.backends.cuda.matmul.allow_tf32 = False``

For more information see :ref:`TensorFloat32<tf32_on_ampere>`

Reduced Precision Reduction for FP16 GEMMs
------------------------------------------
Half-precision GEMM operations are typically done with intermediate accumulations (reduction) in single-precision for numerical accuracy and improved resilience to overflow. For performance, certain GPU architectures, especially more recent ones, allow a few truncations of the intermediate accumulation results to the reduced precision (e.g., half-precision). This change is often benign from the perspective of model convergence, though it may lead to unexpected results (e.g., ``inf`` values when the final result should be be representable in half-precision).
If reduced-precision reductions are problematic, they can be turned off with
``torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False``

For more information see :ref:`allow_fp16_reduced_precision_reduction<fp16reducedprecision>`
7 changes: 7 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,13 @@ def test_cublas_allow_tf32_get_set(self):
self.assertEqual(torch._C._get_cublas_allow_tf32(), not orig)
torch.backends.cuda.matmul.allow_tf32 = orig

def test_cublas_allow_fp16_reduced_precision_reduction_get_set(self):
orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
self.assertEqual(torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), orig)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = not orig
self.assertEqual(torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), not orig)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig

def test_cudnn_allow_tf32_get_set(self):
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
self.assertFalse(torch.backends.cudnn.allow_tf32)
Expand Down
2 changes: 2 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,8 @@ def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN
def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN
def _get_cublas_allow_tf32() -> _bool: ... # THPModule_allowTF32CuBLAS
def _set_cublas_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuBLAS
def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ... #THPModule_allowFP16ReductionCuBLAS
def _set_cublas_allow_fp16_reduced_precision_reduction(arg: _bool) -> None: ... #THPModule_setAllowFP16ReductionCuBLAS
# NB: There is no Capsule type in typing, see
# https://code.activestate.com/lists/python-dev/139675/
def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack
Expand Down
14 changes: 10 additions & 4 deletions torch/backends/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,18 @@ def __setattr__(self, name, value):

class cuBLASModule:
def __getattr__(self, name):
assert name == "allow_tf32", "Unknown attribute " + name
return torch._C._get_cublas_allow_tf32()
if name == "allow_tf32":
return torch._C._get_cublas_allow_tf32()
elif name == "allow_fp16_reduced_precision_reduction":
return torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
raise AssertionError("Unknown attribute " + name)

def __setattr__(self, name, value):
assert name == "allow_tf32", "Unknown attribute " + name
return torch._C._set_cublas_allow_tf32(value)
if name == "allow_tf32":
return torch._C._set_cublas_allow_tf32(value)
elif name == "allow_fp16_reduced_precision_reduction":
return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value)
raise AssertionError("Unknown attribute " + name)


cufft_plan_cache = cuFFTPlanCacheManager()
Expand Down
18 changes: 18 additions & 0 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,22 @@ PyObject *THPModule_allowTF32CuBLAS(PyObject *_unused, PyObject *noargs)
Py_RETURN_FALSE;
}

PyObject *THPModule_setAllowFP16ReductionCuBLAS(PyObject *_unused, PyObject *arg)
{
THPUtils_assert(PyBool_Check(arg), "set_allow_fp16_reduction_cublas expects a bool, "
"but got %s", THPUtils_typename(arg));
at::globalContext().setAllowFP16ReductionCuBLAS(arg == Py_True);
Py_RETURN_NONE;
}

PyObject *THPModule_allowFP16ReductionCuBLAS(PyObject *_unused, PyObject *noargs)
{
if (at::globalContext().allowFP16ReductionCuBLAS()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}

PyObject *THPModule_setFlushDenormal(PyObject *_unused, PyObject *arg) {
THPUtils_assert(PyBool_Check(arg), "flush_denormal expects a bool, "
"but got %s", THPUtils_typename(arg));
Expand Down Expand Up @@ -676,6 +692,8 @@ static PyMethodDef TorchMethods[] = {
{"_set_warnAlways", THPModule_setWarnAlways, METH_O, nullptr},
{"_get_cublas_allow_tf32", THPModule_allowTF32CuBLAS, METH_NOARGS, nullptr},
{"_set_cublas_allow_tf32", THPModule_setAllowTF32CuBLAS, METH_O, nullptr},
{"_get_cublas_allow_fp16_reduced_precision_reduction", THPModule_allowFP16ReductionCuBLAS, METH_NOARGS, nullptr},
{"_set_cublas_allow_fp16_reduced_precision_reduction", THPModule_setAllowFP16ReductionCuBLAS, METH_O, nullptr},
{"_vmapmode_increment_nesting", THPModule_vmapmode_increment_nesting, METH_NOARGS, nullptr},
{"_vmapmode_decrement_nesting", THPModule_vmapmode_decrement_nesting, METH_NOARGS, nullptr},
{"_debug_only_display_vmap_fallback_warnings", THPModule_set_display_vmap_fallback_warnings_mode, METH_O, nullptr},
Expand Down

0 comments on commit 790763b

Please sign in to comment.