Skip to content

Commit

Permalink
[ROCm] Optimize layer norm backward kernel for ROCm (pytorch#87635)
Browse files Browse the repository at this point in the history
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when `fs`  (=`config_m` in our benchmark script) is large and `bs`  (=`config_n` in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of [PR pytorch#68238](pytorch#68238 (comment)) on AMD GPUs.

This PR is to replace `GammaBetaBackwardCUDAKernel` with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when `fs`  (=`config_m`) is larger than 512 on AMD GPUs.

There are a few PRs for LayerNorm kernel:
- pytorch#26201
- pytorch#27634
- pytorch#68238

Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100.

---
**Current**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.387256 | 1.372758 | 0.378975 | 1.47892
50176 | 384 | 0.38231 | 1.362416 | 0.378084 | 1.473886
200704 | 192 | 0.997859 | 4.315875 | 0.989306 | 4.560827
802816 | 64 | 3.671828 | 16.68013 | 3.613515 | 16.827946
200 | 256 | 0.066503 | 0.332096 | 0.071422 | 0.325349
1000 | 256 | 0.071848 | 0.333355 | 0.073038 | 0.334753
6000 | 256 | 0.086334 | 0.345139 | 0.086834 | 0.347429
6272 | 256 | 0.088601 | 0.347906 | 0.087855 | 0.351245
200 | 512 | 0.071626 | 0.329726 | 0.073798 | 0.326878
1000 | 512 | 0.073975 | 0.330226 | 0.074166 | 0.332751
6000 | 512 | 0.099617 | 0.362367 | 0.100095 | 0.378313
6272 | 512 | 0.100378 | 0.358066 | 0.099857 | 0.395982
200 | 1024 | 0.072954 | 0.326382 | 0.073899 | 0.333007
1000 | 1024 | 0.0743 | 0.325532 | 0.071126 | 0.330991
6000 | 1024 | 0.127025 | 0.390084 | 0.128692 | 0.471504
6272 | 1024 | 0.130704 | 0.403536 | 0.135244 | 0.487133
200 | 1536 | 0.070331 | 0.339169 | 0.070086 | 0.331015
1000 | 1536 | 0.075085 | 0.330042 | 0.076295 | 0.328778
6000 | 1536 | 0.148889 | 0.44949 | 0.155781 | 0.659987
6272 | 1536 | 0.154939 | 0.478871 | 0.17673 | 0.716025
200 | 2048 | 0.070269 | 0.335585 | 0.072804 | 0.334655
1000 | 2048 | 0.080094 | 0.326991 | 0.080426 | 0.32685
6000 | 2048 | 0.187888 | 0.623023 | 0.245762 | 0.981635
6272 | 2048 | 0.195431 | 0.65244 | 0.262574 | 1.008141
200 | 3072 | 0.068205 | 0.339428 | 0.073068 | 0.344034
1000 | 3072 | 0.087554 | 0.328899 | 0.09218 | 0.346433
6000 | 3072 | 0.240352 | 0.905058 | 0.368135 | 1.280462
6272 | 3072 | 0.26179 | 0.959387 | 0.387782 | 1.476524
128 | 2097152 | 5.905976 | 22.724793 | 10.287974 | 30.242092
256 | 1048576 | 4.561596 | 19.554308 | 10.223171 | 29.42371
512 | 524288 | 4.146751 | 22.7247 | 11.404285 | 39.175902
1024 | 262144 | 5.193135 | 23.403325 | 11.334512 | 38.947192
2048 | 131072 | 4.992907 | 23.377801 | 11.400286 | 40.889191
4096 | 65536 | 5.429488 | 24.275701 | 11.196778 | 41.4751
8192 | 32768 | 5.35758 | 21.360312 | 10.535418 | 42.875646
16384 | 16384 | 5.44947 | 20.852605 | 10.357685 | 34.603408
32768 | 8192 | 4.688925 | 17.379392 | 9.635596 | 31.188271

</body>

</html>

---------
**At this PR**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">

<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl63
	{color:windowtext;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38797 | 0.93103 | 0.37966 | 1.15283
50176 | 384 | 0.3874 | 0.96417 | 0.38462 | 1.18595
200704 | 192 | 1.00002 | 2.40876 | 0.99224 | 2.55579
802816 | 64 | 3.67348 | 7.98658 | 3.61871 | 7.72404
200 | 256 | 0.07292 | 0.35119 | 0.07195 | 0.32602
1000 | 256 | 0.07354 | 0.33325 | 0.07237 | 0.33742
6000 | 256 | 0.08819 | 0.33283 | 0.08453 | 0.3279
6272 | 256 | 0.0886 | 0.33446 | 0.08774 | 0.33426
200 | 512 | 0.0701 | 0.33505 | 0.07072 | 0.33018
1000 | 512 | 0.07042 | 0.33442 | 0.074 | 0.33206
6000 | 512 | 0.09931 | 0.34956 | 0.09895 | 0.3572
6272 | 512 | 0.10103 | 0.32976 | 0.10041 | 0.36635
200 | 1024 | 0.07144 | 0.33579 | 0.07209 | 0.33216
1000 | 1024 | 0.0736 | 0.32803 | 0.07286 | 0.32936
6000 | 1024 | 0.12584 | 0.38916 | 0.12852 | 0.48273
6272 | 1024 | 0.13053 | 0.38804 | 0.13464 | 0.49545
200 | 1536 | 0.07159 | 0.3396 | 0.07062 | 0.33545
1000 | 1536 | 0.07443 | 0.33239 | 0.07366 | 0.33204
6000 | 1536 | 0.14959 | 0.45043 | 0.15826 | 0.69119
6272 | 1536 | 0.1542 | 0.47644 | 0.18249 | 0.72208
200 | 2048 | 0.07258 | 0.33982 | 0.07412 | 0.33859
1000 | 2048 | 0.0793 | 0.32816 | 0.07864 | 0.32583
6000 | 2048 | 0.18973 | 0.571 | 0.25506 | 0.91796
6272 | 2048 | 0.19719 | 0.64208 | 0.26445 | 0.95055
200 | 3072 | 0.07092 | 0.33867 | 0.07104 | 0.34695
1000 | 3072 | 0.08727 | 0.33144 | 0.09144 | 0.36633
6000 | 3072 | 0.24683 | 0.87275 | 0.37761 | 1.3289
6272 | 3072 | 0.26437 | 0.91178 | 0.38496 | 1.53694
128 | 2097152 | 6.27936 | 23.69425 | 10.40004 | 30.13699
256 | 1048576 | 4.5404 | 19.47675 | 10.28494 | 29.36936
512 | 524288 | 4.13951 | 18.78771 | 10.09557 | 32.67083
1024 | 262144 | 4.47576 | 18.00411 | 9.56488 | 31.47117
2048 | 131072 | 4.28026 | 16.95619 | 9.40297 | 30.82845
4096 | 65536 | 4.2653 | 16.5018 | 9.03315 | 30.08392
8192 | 32768 | 4.25613 | 16.13583 | 8.9258 | 30.75296
16384 | 16384 | 4.20256 | 16.38207 | 9.52587 | 31.31113
32768 | 8192 | 4.20231 | 16.19452 | 9.31478 | 31.03514

</body>

</html>

---------

**Performance Improvement (%)**
<html xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=OneNote.File>
<meta name=Generator content="Microsoft OneNote 15">
</head>

<body lang=en-US style='font-family:Calibri;font-size:11.0pt'>
<!--StartFragment-->

<div style='direction:ltr'>

M | N | fwdbwd,   torch.float16 | fwdbwd,   torch.float32
-- | -- | -- | --
50432 | 384 | 32.178 | 22.049
50176 | 384 | 29.231 | 19.536
200704 | 192 | 44.188 | 43.962
802816 | 64 | 52.119 | 54.100
200 | 256 | -5.750 | -0.206
1000 | 256 | 0.031 | -0.797
6000 | 256 | 3.566 | 5.621
6272 | 256 | 3.865 | 4.836
200 | 512 | -1.615 | -1.010
1000 | 512 | -1.270 | 0.208
6000 | 512 | 3.534 | 5.581
6272 | 512 | 7.905 | 7.483
200 | 1024 | -2.883 | 0.254
1000 | 1024 | -0.767 | 0.493
6000 | 1024 | 0.237 | -2.381
6272 | 1024 | 3.840 | -1.707
200 | 1536 | -0.127 | -1.340
1000 | 1536 | -0.711 | -0.992
6000 | 1536 | -0.209 | -4.728
6272 | 1536 | 0.508 | -0.846
200 | 2048 | -1.262 | -1.176
1000 | 2048 | -0.358 | 0.312
6000 | 2048 | 8.350 | 6.487
6272 | 2048 | 1.588 | 5.713
200 | 3072 | 0.223 | -0.848
1000 | 3072 | -0.773 | -5.743
6000 | 3072 | 3.570 | -3.783
6272 | 3072 | 4.962 | -4.092
128 | 2097152 | -4.266 | 0.348
256 | 1048576 | 0.397 | 0.185
512 | 524288 | 17.325 | 16.605
1024 | 262144 | 23.070 | 19.195
2048 | 131072 | 27.469 | 24.605
4096 | 65536 | 32.023 | 27.465
8192 | 32768 | 24.459 | 28.274
16384 | 16384 | 21.439 | 9.514
32768 | 8192 | 6.818 | 0.491

</div>

<!--EndFragment-->
</body>

</html>

---------
**Benchmark script of this PR**
```
# Ref:
#       1. pytorch#26201
#       2. pytorch#68238

from distutils.command.config import config
import torch
from torch.nn import LayerNorm
import timeit

number_runs = 1000  # TODO: Modify this to save time!
def test_forward(layer_norm_cuda, input_cuda):
    layer_norm_cuda(input_cuda); torch.cuda.synchronize()

def test_backward(out_cuda, layer_norm_grad_cuda, create_graph):
    out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize()

def test_fwdbwd(input_cuda, layer_norm_cuda, gO):
    input_cuda.grad = None
    layer_norm_cuda.zero_grad(set_to_none=True)
    out = layer_norm_cuda(input_cuda)
    out.backward(gO)
    torch.cuda.synchronize()

def benchmark(config_m, config_n):

    print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)")
    if len(config_m) != len(config_n):
        print("Please make sure the lengths of config_m and config_m are the same.")

    for i in range(len(config_m)):
        normalized_shape = config_n[i]
        results = [config_m[i], config_n[i]]
        for dtype in (torch.half, torch.float):
            if dtype == torch.half:
                layer_norm_cuda = LayerNorm(normalized_shape).half().cuda()
            else:
                layer_norm_cuda = LayerNorm(normalized_shape).cuda()

            input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True)

            # print("cuda forward:")
            result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs)
            results.append(result_fwd / number_runs * 1000)

            gO = torch.rand_like(input_cuda)

            result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs)
            results.append(result_fwdbwd / number_runs * 1000)

        print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5]))

    print("Times are in microseconds (us).")

# CVT
config_m_cvt = [50432, 50176, 200704, 802816]
config_n_cvt = [384, 384, 192, 64]

# pytorch#68238 (comment)
config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272]
config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072]

# pytorch#27634
config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192]

config_m = config_m_cvt + config_m_68238 + config_m_27634
config_n = config_n_cvt + config_n_68238 + config_n_27634

benchmark(config_m, config_n)
```

CC: @jeffdaily

Pull Request resolved: pytorch#87635
Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/ezyang
  • Loading branch information
hubertlu-tw authored and pytorchmergebot committed Nov 22, 2022
1 parent 00b7d8e commit 74e62a1
Showing 1 changed file with 233 additions and 1 deletion.
234 changes: 233 additions & 1 deletion aten/src/ATen/native/cuda/layer_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#endif

#include <c10/cuda/CUDAMathCompat.h>
#include <c10/util/env.h>


namespace at {
namespace native {
Expand Down Expand Up @@ -832,6 +834,201 @@ void LayerNormKernelImpl(
});
}

template<typename T, typename T_ACC> __device__
void cuLoadWriteStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
T_ACC* warp_buf1,
T_ACC* warp_buf2,
const T* input,
const T* dout,
const int i1_end,
const int64_t N,
const T_ACC* __restrict__ mean,
const T_ACC* __restrict__ rstd)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
T curr_mean = mean[i1];
T curr_rstd = rstd[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1*N+i2;
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<N) {
T curr_input = static_cast<T>(input[load_idx]);
T curr_dout = static_cast<T>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_rstd;
} else {
warp_buf1[write_idx] = T(0);
warp_buf2[write_idx] = T(0);
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
warp_buf1[write_idx] = T(0);
warp_buf2[write_idx] = T(0);
}
}
}

template<typename T, typename T_ACC> __device__
void cuLoadAddStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
T_ACC* warp_buf1,
T_ACC* warp_buf2,
const T* input,
const T* dout,
const int i1_end,
const int64_t N,
const T_ACC* __restrict__ mean,
const T_ACC* __restrict__ rstd)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
T_ACC curr_mean = mean[i1];
T_ACC curr_rstd = rstd[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1*N+i2;
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<N) {
T_ACC curr_input = static_cast<T_ACC>(input[load_idx]);
T_ACC curr_dout = static_cast<T_ACC>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_rstd;
}
}
}
}

template<typename T, typename T_ACC> __global__
void cuComputePartGradGammaBeta(
const T* __restrict__ dout,
const T* __restrict__ input,
const int64_t M,
const int64_t N,
const T_ACC* __restrict__ mean,
const T_ACC* __restrict__ rstd,
T_ACC* part_grad_gamma,
T_ACC* part_grad_beta)
{
const int numsegs_M = (M+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
const int segs_per_block = (numsegs_M + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
const int i1_end = i1_beg_plus_one < M ? i1_beg_plus_one : M;
const int row_stride = blockDim.x+1;
const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
alignas(sizeof(double)) extern __shared__ char shared[];
T_ACC * buf = reinterpret_cast<T_ACC*>(&shared); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
T_ACC* warp_buf1 = (T_ACC*)buf;
T_ACC* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd);
for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
T_ACC acc1 = T_ACC(0);
T_ACC acc2 = T_ACC(0);
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k*blockDim.y;
int idx1 = row1*row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < N) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
part_grad_beta[blockIdx.y*N+i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y*N+i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}

template<typename T, typename T_ACC> __global__
void cuComputeGradGammaBeta(
const T_ACC* part_grad_gamma,
const T_ACC* part_grad_beta,
const int part_size,
const int64_t M,
const int64_t N,
T* grad_gamma,
T* grad_beta)
{
// sum partial gradients for gamma and beta
alignas(sizeof(double)) extern __shared__ char shared[];
T_ACC * buf = reinterpret_cast<T_ACC*>(&shared);
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < N) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
T_ACC sum_gamma = T_ACC(0);
T_ACC sum_beta = T_ACC(0);
const T_ACC* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * N + i2;
const T_ACC* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * N + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset*N];
sum_beta += part_grad_beta_ptr[warp_offset*N];
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx+nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx+nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}

template <typename T>
void LayerNormBackwardKernelImplInternal(
const Tensor& dY,
Expand Down Expand Up @@ -860,8 +1057,8 @@ void LayerNormBackwardKernelImplInternal(
gamma.defined() ? gamma.template data_ptr<T>() : nullptr;
T* dX_data = dX->defined() ? dX->template data_ptr<T>() : nullptr;
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
const int warp_size = at::cuda::warp_size();
if (dX_data != nullptr) {
const int warp_size = at::cuda::warp_size();
const dim3 blocks(M);
int nshared = (num_threads()/warp_size) * sizeof(T_ACC);
layer_norm_grad_input_kernel<<<blocks, num_threads(), nshared, cuda_stream>>>(dY_data,
Expand Down Expand Up @@ -889,6 +1086,40 @@ void LayerNormBackwardKernelImplInternal(
dbeta_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
#if defined(USE_ROCM)
// For small batch size, do colwise reduce directly.
const int part_size = warp_size;
const dim3 threads2(warp_size, 4, 1);
const dim3 blocks2((N + threads2.x - 1) / threads2.x, part_size, 1);
const int nshared2_a = 2 * sizeof(T_ACC) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(T_ACC);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;

const auto part_grad_dtype = at::toAccumulateType(X.scalar_type(), true);
Tensor part_grad_gamma = at::empty({part_size,N}, gamma.options().dtype(part_grad_dtype));
Tensor part_grad_beta = at::native::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, cuda_stream>>>(
dY_data,
X_data,
M,N,
mean_data,
rstd_data,
part_grad_gamma.template data_ptr<T_ACC>(),
part_grad_beta.template data_ptr<T_ACC>());
C10_CUDA_KERNEL_LAUNCH_CHECK();

const dim3 threads3(warp_size, 8, 1); // Optimization for ROCm
const dim3 blocks3((N + threads2.x - 1) / threads2.x, 1, 1);
const int nshared3 = threads3.x * threads3.y * sizeof(T);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, cuda_stream>>>(
part_grad_gamma.template data_ptr<T_ACC>(),
part_grad_beta.template data_ptr<T_ACC>(),
part_size,
M,N,
dgamma_data,
dbeta_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#else
if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) {
// This implementation relies on warp primitives and requires that M and N divide
// exactly to warp size.
Expand Down Expand Up @@ -925,6 +1156,7 @@ void LayerNormBackwardKernelImplInternal(
dbeta_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#endif
}
}
}
Expand Down

0 comments on commit 74e62a1

Please sign in to comment.