Skip to content

Commit

Permalink
[ROCm] Replace layer_norm_grad_input_kernel with cuComputeGradInput f…
Browse files Browse the repository at this point in the history
…or ROCm (pytorch#87726)

We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when fs (=config_m in our benchmark script) is large and bs (=config_n in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of pytorch#68238 (comment) on AMD GPUs.

This PR is to replace layer_norm_grad_input_kernel with the Apex cuComputeGradInput kernel with some ROCm-specific parameter tuning when fs (=config_m) is larger than or equal to `32768` on AMD GPUs. Some of the code changes in LayerNormBackwardKernelImplInternal are from another PR: pytorch#87635

We used the same benchmark script in the previous PR and tested the optimized kernel with various input shapes on AMD MI100 GPU.

**At [the previous PR](pytorch#87635
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl65
	{color:windowtext;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38589 | 0.92603 | 0.38367 | 1.15148
50176 | 384 | 0.38719 | 0.91579 | 0.37815 | 1.13761
200704 | 192 | 0.99787 | 2.39954 | 0.98996 | 2.54284
802816 | 64 | 3.66525 | 7.96952 | 3.61293 | 7.69946
200 | 256 | 0.06578 | 0.34613 | 0.06966 | 0.35449
1000 | 256 | 0.07837 | 0.37631 | 0.07725 | 0.37758
6000 | 256 | 0.09318 | 0.3788 | 0.09202 | 0.37989
6272 | 256 | 0.08694 | 0.36267 | 0.08703 | 0.3615
200 | 512 | 0.06975 | 0.34506 | 0.06973 | 0.34208
1000 | 512 | 0.07012 | 0.36363 | 0.07307 | 0.36741
6000 | 512 | 0.09725 | 0.36251 | 0.09908 | 0.37078
6272 | 512 | 0.09899 | 0.36519 | 0.10068 | 0.37514
200 | 1024 | 0.07188 | 0.33896 | 0.0712 | 0.34683
1000 | 1024 | 0.07357 | 0.3625 | 0.0734 | 0.3598
6000 | 1024 | 0.12642 | 0.38949 | 0.12973 | 0.5035
6272 | 1024 | 0.12901 | 0.40759 | 0.13609 | 0.51871
200 | 1536 | 0.06998 | 0.34782 | 0.07419 | 0.3514
1000 | 1536 | 0.07987 | 0.37915 | 0.07888 | 0.37264
6000 | 1536 | 0.15401 | 0.47524 | 0.15416 | 0.68609
6272 | 1536 | 0.15286 | 0.48843 | 0.17681 | 0.72997
200 | 2048 | 0.07054 | 0.34791 | 0.07289 | 0.35138
1000 | 2048 | 0.07767 | 0.37954 | 0.08554 | 0.37464
6000 | 2048 | 0.18744 | 0.5811 | 0.25004 | 0.93338
6272 | 2048 | 0.20037 | 0.63398 | 0.26918 | 0.97018
200 | 3072 | 0.07687 | 0.36739 | 0.08917 | 0.37845
1000 | 3072 | 0.09323 | 0.38901 | 0.09739 | 0.39823
6000 | 3072 | 0.24314 | 0.89029 | 0.38093 | 1.30719
6272 | 3072 | 0.26079 | 0.92023 | 0.38352 | 1.51012
128 | 2097152 | 6.17775 | 23.876 | 10.27952 | 30.10848
256 | 1048576 | 4.51855 | 19.47637 | 10.07609 | 29.42678
512 | 524288 | 4.13615 | 18.80888 | 10.07853 | 32.29804
1024 | 262144 | 4.47397 | 17.88388 | 9.50367 | 31.15699
2048 | 131072 | 4.2458 | 16.70852 | 9.17979 | 30.51708
4096 | 65536 | 4.24412 | 16.43098 | 8.97651 | 30.1617
8192 | 32768 | 4.24556 | 16.09038 | 8.77001 | 30.3643
16384 | 16384 | 4.14642 | 15.80355 | 8.82402 | 30.35291
32768 | 8192 | 4.12599 | 15.68897 | 8.82605 | 30.43423

</body>

</html>

----

**At this PR:**

<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl65
	{color:windowtext;}
.xl66
	{background:yellow;
	mso-pattern:black none;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38667 | 0.84133 | 0.37916 | 1.01222
50176 | 384 | 0.3814 | 0.87266 | 0.37858 | 1.04399
200704 | 192 | 0.99902 | 2.14386 | 0.98973 | 2.33265
802816 | 64 | 3.66578 | 6.85376 | 3.6092 | 7.00331
200 | 256 | 0.06607 | 0.34176 | 0.07009 | 0.34548
1000 | 256 | 0.06947 | 0.36461 | 0.07902 | 0.37851
6000 | 256 | 0.09319 | 0.37432 | 0.09342 | 0.36927
6272 | 256 | 0.09544 | 0.37565 | 0.09476 | 0.37377
200 | 512 | 0.07935 | 0.364 | 0.07891 | 0.36894
1000 | 512 | 0.07676 | 0.37552 | 0.07957 | 0.37564
6000 | 512 | 0.10472 | 0.37504 | 0.1051 | 0.38782
6272 | 512 | 0.1069 | 0.36662 | 0.10062 | 0.38506
200 | 1024 | 0.07793 | 0.36561 | 0.08023 | 0.35019
1000 | 1024 | 0.07426 | 0.36729 | 0.07345 | 0.35851
6000 | 1024 | 0.12729 | 0.39219 | 0.12974 | 0.51526
6272 | 1024 | 0.13622 | 0.41627 | 0.14252 | 0.52926
200 | 1536 | 0.07615 | 0.36621 | 0.0797 | 0.3695
1000 | 1536 | 0.08327 | 0.38174 | 0.07938 | 0.37573
6000 | 1536 | 0.14894 | 0.46197 | 0.15268 | 0.63814
6272 | 1536 | 0.15368 | 0.48818 | 0.16309 | 0.71441
200 | 2048 | 0.06935 | 0.36691 | 0.07258 | 0.35548
1000 | 2048 | 0.07738 | 0.36388 | 0.08036 | 0.36452
6000 | 2048 | 0.18757 | 0.58573 | 0.23701 | 0.92915
6272 | 2048 | 0.1938 | 0.61628 | 0.26475 | 0.96896
200 | 3072 | 0.07884 | 0.3673 | 0.07724 | 0.37869
1000 | 3072 | 0.09342 | 0.38193 | 0.09822 | 0.38646
6000 | 3072 | 0.24452 | 0.86776 | 0.38251 | 1.3036
6272 | 3072 | 0.25971 | 0.91053 | 0.38744 | 1.39039
128 | 2097152 | 6.06752 | 23.26379 | 9.87466 | 29.81851
256 | 1048576 | 4.50336 | 19.4614 | 10.11239 | 29.25554
512 | 524288 | 4.12649 | 18.72831 | 10.054 | 32.26784
1024 | 262144 | 4.40855 | 17.77993 | 9.38856 | 31.18679
2048 | 131072 | 4.18716 | 16.74615 | 9.14487 | 30.24603
4096 | 65536 | 4.17374 | 16.34444 | 8.94894 | 30.0326
8192 | 32768 | 4.19095 | 16.05751 | 8.70358 | 30.14669
16384 | 16384 | 4.15404 | 15.83771 | 8.80042 | 30.5022
32768 | 8192 | 4.12515 | 15.5657 | 8.66138 | 28.87386

</body>

</html>

---

**Performance Improvement (%)**

<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl65
	{color:windowtext;}
.xl66
	{mso-number-format:"0\.000";}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwdbwd, torch.float16 | fwdbwd, torch.float32
-- | -- | -- | --
50432 | 384 | 9.147 | 12.094
50176 | 384 | 4.710 | 8.230
200704 | 192 | 10.655 | 8.266
802816 | 64 | 14.000 | 9.042
200 | 256 | 1.263 | 2.542
1000 | 256 | 3.109 | -0.246
6000 | 256 | 1.183 | 2.796
6272 | 256 | -3.579 | -3.394
200 | 512 | -5.489 | -7.852
1000 | 512 | -3.270 | -2.240
6000 | 512 | -3.456 | -4.596
6272 | 512 | -0.392 | -2.644
200 | 1024 | -7.862 | -0.969
1000 | 1024 | -1.321 | 0.359
6000 | 1024 | -0.693 | -2.336
6272 | 1024 | -2.130 | -2.034
200 | 1536 | -5.287 | -5.151
1000 | 1536 | -0.683 | -0.829
6000 | 1536 | 2.792 | 6.989
6272 | 1536 | 0.051 | 2.132
200 | 2048 | -5.461 | -1.167
1000 | 2048 | 4.126 | 2.701
6000 | 2048 | -0.797 | 0.453
6272 | 2048 | 2.792 | 0.126
200 | 3072 | 0.024 | -0.063
1000 | 3072 | 1.820 | 2.956
6000 | 3072 | 2.531 | 0.275
6272 | 3072 | 1.054 | 7.929
128 | 2097152 | 2.564 | 0.963
256 | 1048576 | 0.077 | 0.582
512 | 524288 | 0.428 | 0.094
1024 | 262144 | 0.581 | -0.096
2048 | 131072 | -0.225 | 0.888
4096 | 65536 | 0.527 | 0.428
8192 | 32768 | 0.204 | 0.717
16384 | 16384 | -0.216 | -0.492
32768 | 8192 | 0.786 | 5.127

</body>

</html>

CC: @jeffdaily

Pull Request resolved: pytorch#87726
Approved by: https://github.com/ngimel
  • Loading branch information
hubertlu-tw authored and pytorchmergebot committed Nov 28, 2022
1 parent 098cbe2 commit cf4969d
Showing 1 changed file with 132 additions and 0 deletions.
132 changes: 132 additions & 0 deletions aten/src/ATen/native/cuda/layer_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,110 @@ void cuComputeGradGammaBeta(
}
}

template<typename T, typename T_ACC> __global__
void cuComputeGradInput(
const T* __restrict__ dout,
const T* __restrict__ input,
const int64_t M,
const int64_t N,
const T_ACC* __restrict__ mean,
const T_ACC* __restrict__ rstd,
const T* gamma,
T* grad_input)
{
for (int i1=blockIdx.y; i1 < M; i1 += gridDim.y) {
T_ACC sum_loss1 = T_ACC(0);
T_ACC sum_loss2 = T_ACC(0);
T_ACC c_mean = mean[i1];
const T_ACC c_rstd = rstd[i1];
const T* k_input = input + i1*N;
const T* k_dout = dout + i1*N;
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
// Optimization for ROCm MI100
for( int l = 0; l < N ; l += numx) {
int idx = l + thrx;
const T_ACC gamma_idx = static_cast<T_ACC>((idx<N) ? gamma[idx] : T(0));
const T_ACC c_h = static_cast<T_ACC>((idx<N) ? k_input[idx] : T(0));
const T_ACC c_loss = static_cast<T_ACC>((idx<N) ? k_dout[idx] : T(0));
sum_loss1 += c_loss * gamma_idx;
sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_rstd;
}
} else {
for( int l = 0; l < N ; l += numx) {
int idx = l + thrx;
const T_ACC c_h = static_cast<T_ACC>((idx<N) ? k_input[idx] : T(0));
const T_ACC c_loss = static_cast<T_ACC>((idx<N) ? k_dout[idx] : T(0));
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_rstd;
}
}
// intra-warp reductions
for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
}
// inter-warp reductions
if (blockDim.y > 1) {
alignas(sizeof(double)) extern __shared__ char shared[];
T_ACC * buf = reinterpret_cast<T_ACC*>(&shared);
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[2*wrt_i] = sum_loss1;
buf[2*wrt_i+1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2*read_i];
sum_loss2 += buf[2*read_i+1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2*threadIdx.x] = sum_loss1;
buf[2*threadIdx.x+1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y !=0) {
sum_loss1 = buf[2*threadIdx.x];
sum_loss2 = buf[2*threadIdx.x+1];
}
}
// all threads now have the two sums over l
T_ACC fH = (T_ACC)N;
T_ACC term1 = (T_ACC(1) / fH) * c_rstd;
T* k_grad_input = grad_input + i1*N;
if (gamma != NULL) {
for (int l = thrx; l < N; l+=numx) {
const T_ACC c_h = static_cast<T_ACC>(k_input[l]);
const T_ACC c_loss = static_cast<T_ACC>(k_dout[l]);
T_ACC f_grad_input = fH * c_loss * gamma[l];
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
} else {
for (int l = thrx; l < N; l+=numx) {
const T_ACC c_h = static_cast<T_ACC>(k_input[l]);
const T_ACC c_loss = static_cast<T_ACC>(k_dout[l]);
T_ACC f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
}
// prevent race where buf is written again before reads are done
__syncthreads();
}
}

template <typename T>
void LayerNormBackwardKernelImplInternal(
const Tensor& dY,
Expand Down Expand Up @@ -1059,11 +1163,39 @@ void LayerNormBackwardKernelImplInternal(
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
const int warp_size = at::cuda::warp_size();
if (dX_data != nullptr) {
#if defined __HIP_PLATFORM_HCC__
if (M >= 32768) {
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)M, maxGridY), 1);
dim3 threads1(warp_size, 4, 1);
threads1.y = 2; // Optimization for ROCm
int nshared =
threads1.y > 1 ?
threads1.y*threads1.x*sizeof(T_ACC) :
0;
cuComputeGradInput<<<blocks1, threads1, nshared, cuda_stream>>>(
dY_data,
X_data,
M, N,
mean_data,
rstd_data,
gamma_data,
dX_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
const dim3 blocks(M);
int nshared = (num_threads()/warp_size) * sizeof(T_ACC);
layer_norm_grad_input_kernel<<<blocks, num_threads(), nshared, cuda_stream>>>(dY_data,
X_data, mean_data, rstd_data, gamma_data, dX_data, N);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#else
const dim3 blocks(M);
int nshared = (num_threads()/warp_size) * sizeof(T_ACC);
layer_norm_grad_input_kernel<<<blocks, num_threads(), nshared, cuda_stream>>>(dY_data,
X_data, mean_data, rstd_data, gamma_data, dX_data, N);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#endif
}

if (dgamma->defined() || dbeta->defined()) {
Expand Down

0 comments on commit cf4969d

Please sign in to comment.