Skip to content

Commit

Permalink
[CUDA] baddmm should fall back to addmm for batch=1 (#114992) (#116518)
Browse files Browse the repository at this point in the history
I.e. it feels reasonable to always call `at::cuda::gemm` rather than `at::cuda::bgemm` when num_batches == 1
After the change, benchmarking torch built with CUDA-12 using  [following perf script](https://gist.github.com/malfet/6a17156d7f5663b8b12054a1beff3fe1) on A100  are as follows:
|      Shape     |  bmm_time |  mm_time  | slow down (%) |
| -------------- | --------- | --------- | ------------- |
|    1x1x4096    |   14.18   |   14.31   |     -0.89     |
|    1x1x8192    |   14.37   |   14.37   |     -0.05     |
|   1x1x16384    |   14.03   |   14.12   |     -0.68     |
|   1x1x32768    |   14.19   |   14.24   |     -0.35     |
|   1x1x65536    |   14.85   |   14.52   |     2.30      |
|   1x1x131072   |   14.03   |   14.07   |     -0.33     |
|  128x128x128   |   11.34   |   11.06   |     2.56      |
|  256x256x256   |   14.85   |   14.40   |     3.15      |
|  512x512x512   |   27.22   |   27.22   |     -0.01     |
| 1024x1024x1024 |  129.66   |  129.50   |     0.12      |
| 2048x2048x2048 |  972.18   |  973.24   |     -0.11     |
|  129x127x129   |   11.21   |   11.25   |     -0.39     |
|  257x255x257   |   14.50   |   14.43   |     0.44      |
|  513x511x513   |   29.01   |   29.01   |     0.01      |
| 1025x1023x1025 |  137.65   |  137.64   |     0.01      |
| 2049x2047x2049 |  982.58   |  982.65   |     -0.01     |
|  4097x3x4097   |   86.65   |   86.64   |     0.01      |
|  8193x3x8193   |  384.02   |  383.96   |     0.02      |
| 16385x3x16385  |  1106.73  |  1107.32  |     -0.05     |
| 32769x3x32769  |  4739.49  |  4739.48  |     0.00      |
| 65537x3x65537  | 17377.78  | 17378.74  |     -0.01     |
|  4097x5x4097   |   87.09   |   87.12   |     -0.03     |
|  8193x5x8193   |  301.38   |  301.36   |     0.01      |
| 16385x5x16385  |  1107.38  |  1108.04  |     -0.06     |
| 32769x5x32769  |  4743.73  |  4744.07  |     -0.01     |
| 65537x5x65537  | 17392.32  | 17395.42  |     -0.02     |
|  4097x7x4097   |   87.17   |   87.19   |     -0.02     |
|  8193x7x8193   |  301.94   |  302.00   |     -0.02     |
| 16385x7x16385  |  1107.17  |  1106.79  |     0.03      |
| 32769x7x32769  |  4747.15  |  4747.13  |     0.00      |
| 65537x7x65537  | 17403.85  | 17405.02  |     -0.01     |

Fixes perf problem reported in #114911
Pull Request resolved: #114992
Approved by: https://github.com/Skylion007, https://github.com/eqy

Co-authored-by: Nikita Shulga <nshulga@meta.com>
  • Loading branch information
atalman and malfet authored Jan 2, 2024
1 parent ab7505f commit 1a3e3c7
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
40 changes: 26 additions & 14 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
Expand Down Expand Up @@ -369,12 +370,10 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
}

const Tensor& baddbmm_out_cuda_impl(const Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
IntArrayRef batch1_sizes = batch1.sizes();

// handle pathological cases that blas may not like
if (result.numel() == 0) {
return result;
} else if (batch1_sizes[2] == 0) {
} else if (batch1.size(2) == 0) {
if (beta.to<c10::complex<double>>() == 0.0) {
return result.zero_();
} else {
Expand Down Expand Up @@ -421,17 +420,30 @@ const Tensor& baddbmm_out_cuda_impl(const Tensor& result, const Tensor& self, co
const scalar_t* batch1_ptr = batch1_->const_data_ptr<scalar_t>();
const scalar_t* batch2_ptr = batch2_->const_data_ptr<scalar_t>();
scalar_t* result_ptr = result_->mutable_data_ptr<scalar_t>();
at::cuda::blas::bgemm<scalar_t>(
transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n',
transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n',
m, n, k,
alpha_val,
batch1_ptr, lda, batch1_->strides()[0],
batch2_ptr, ldb, batch2_->strides()[0],
beta_val,
result_ptr, ldc, result_->strides()[0],
num_batches
);
const auto transa = transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n';
const auto transb = transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n';
// If batch is 1 call gemm rather than bgemm
if (num_batches == 1) {
at::cuda::blas::gemm<scalar_t>(
transa, transb,
m, n, k,
alpha_val,
batch1_ptr, lda,
batch2_ptr, ldb,
beta_val,
result_ptr, ldc);
} else {
at::cuda::blas::bgemm<scalar_t>(
transa, transb,
m, n, k,
alpha_val,
batch1_ptr, lda, batch1_->strides()[0],
batch2_ptr, ldb, batch2_->strides()[0],
beta_val,
result_ptr, ldc, result_->strides()[0],
num_batches
);
}
});
if (!result.is_same(*result_)) {
result.copy_(*result_);
Expand Down
6 changes: 2 additions & 4 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
skipCPUIfNoMklSparse,
toleranceOverride, tol)
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION, SM53OrLater, SM60OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN,
PLATFORM_SUPPORTS_FLASH_ATTENTION, SM53OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN,
_get_torch_cuda_version, _get_torch_rocm_version,
)
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -15937,9 +15937,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
op=lambda tensors, equation: torch.einsum(equation, tensors),
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, *[torch.bfloat16]
if (SM60OrLater or
TEST_WITH_ROCM) else []),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down

0 comments on commit 1a3e3c7

Please sign in to comment.