Skip to content

Commit

Permalink
fix RowwiseMoments vectorization issue on CPU (pytorch#84404)
Browse files Browse the repository at this point in the history
Originally `cpu/moments_utils.h` uses namespace of at::native::utils,
this file contains `Vectorized<>`, in order to make it properly vectorized
on different archs, need to use anonymous namespace or inline namespace.
Otherwise it would be linked to scalar version of the code.

This PR is to fix vectorization issue from `RowwiseMoments` which is used to calculate `mean` and `rstd` in norm layers.
Attach benchmark data, generally fp32 will get 2-3x speedup and bf16 has larger speedup.

This patch will improves layer_norm (input size 32x128x1024) float32 inference:
* avx512 single socket: 2.1x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.439 ms; bf16: 2.479 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.210 ms; bf16: 0.770 ms
```
* avx512 single core: 3.2x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 6.308 ms; bf16: 39.765 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.661 ms; bf16: 12.267 ms
```
* avx2 single socket: 2.3x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 1.248 ms; bf16: 8.487 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.540 ms; bf16: 2.030 ms
```
* avx2 single core: 2.5x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 10.792 ms; bf16: 66.366 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 4.349 ms; bf16: 19.252 ms
```

Attached some original VTune profiling results here to further indicate the issue:

1. original bottlenecks
![master_bottleneck](https://user-images.githubusercontent.com/20233731/180125611-deed41b7-dd2e-4437-a7d9-6ad0096e5850.png)

we can see `RowwiseMomentsImpl<>` takes majority of the runtime here.

2. Instruction level breakdown of `RowwiseMomentsImpl<>`
![rowwise_momentum_impl](https://user-images.githubusercontent.com/20233731/180125759-a3b48bc4-8e54-4219-92b4-defde5e86046.png)

we can see it's all **scalar** instructions here.

3. after the fix, the bottlenecks
![fixed_bottleneck](https://user-images.githubusercontent.com/20233731/180125880-8d08eb1b-af09-4f80-ae58-80215365d407.png)

getting better.

4. after the fix, Instruction level breakdown of `RowwiseMomentsImpl<>`
![fixed_rowwsie_momentum_impl](https://user-images.githubusercontent.com/20233731/180125989-b45db4ad-e6ed-460a-8d51-74fbeecf8b02.png)

now it is all **vectorized** instructions.

Pull Request resolved: pytorch#84404
Approved by: https://github.com/jgong5
  • Loading branch information
mingfeima authored and pytorchmergebot committed Nov 30, 2022
1 parent 92f08f0 commit 87d18cf
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/group_norm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void GroupNormKernelImplInternal(
const T* X_ptr = X_data + i * inner_size;
T mean_val;
T rstd_val;
std::tie(mean_val, rstd_val) = utils::RowwiseMoments(X_ptr, inner_size);
std::tie(mean_val, rstd_val) = RowwiseMoments(X_ptr, inner_size);
rstd_val = T(1) / std::sqrt(std::max(rstd_val, T(0)) + eps);
if (gamma_null && beta_null) {
T* Y_ptr = Y_data + i * inner_size;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/layer_norm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void LayerNormKernelImplInternal(
T* Y_ptr = Y_data + i * N;
T mean_val;
T rstd_val;
std::tie(mean_val, rstd_val) = utils::RowwiseMoments(X_ptr, N);
std::tie(mean_val, rstd_val) = RowwiseMoments(X_ptr, N);
rstd_val = T(1) / std::sqrt(rstd_val + eps);
const T_ACC scale = rstd_val;
const T_ACC bias = -rstd_val * mean_val;
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/cpu/moments_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

namespace at {
namespace native {
namespace utils {
inline namespace CPU_CAPABILITY {

constexpr int64_t kChunkSize = 16;

Expand Down Expand Up @@ -63,7 +63,7 @@ std::pair<T, T> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) {
constexpr int64_t kVecSize = Vec::size();
const int64_t n = N / kVecSize;
const int64_t m = divup(n, kChunkSize);
const int64_t depth = CeilLog2(m);
const int64_t depth = utils::CeilLog2(m);

const Vec kZeroVec(T(0));
c10::SmallVector<int64_t, kMaxDepth> m0_stk(depth, 0);
Expand Down Expand Up @@ -136,7 +136,7 @@ std::pair<T, T> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) {
constexpr int64_t kVecSize = Vec::size();
const int64_t n = N / kVecSize;
const int64_t m = divup(n, kChunkSize);
const int64_t depth = CeilLog2(m);
const int64_t depth = utils::CeilLog2(m);
if (depth <= 4) {
return RowwiseMomentsImpl<T, 4>(X, N, ddof);
} else if (depth <= 8) {
Expand All @@ -150,6 +150,6 @@ std::pair<T, T> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) {
}
}

} // namespace utils
} // namespace CPU_CAPABILITY
} // namespace native
} // namespace at

0 comments on commit 87d18cf

Please sign in to comment.