Skip to content

Commit

Permalink
port fix from nvidia megatron
Browse files Browse the repository at this point in the history
  • Loading branch information
sdtblck committed Jan 29, 2022
1 parent 98683ae commit b4aa1c7
Show file tree
Hide file tree
Showing 9 changed files with 468 additions and 95 deletions.
34 changes: 27 additions & 7 deletions megatron/fused_kernels/scaled_masked_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,46 +32,66 @@ torch::Tensor bwd_cuda(
torch::Tensor const& softmax_results,
float scale_factor);

int get_batch_per_block_cuda(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads);

torch::Tensor fwd(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");

return fwd_cuda(input, mask, scale_factor);
}

torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {

AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");

AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");

return bwd_cuda(output_grads, softmax_results, scale_factor);
}

int get_batch_per_block(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) {
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
}

} // end namespace scaled_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
m.def("forward",
&multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");

m.def("backward",
&multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}

m.def("get_batch_per_block",
&multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
"Return Batch per block size."
);
}
27 changes: 19 additions & 8 deletions megatron/fused_kernels/scaled_masked_softmax.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
Expand All @@ -17,6 +16,7 @@

#pragma once

#include <stdio.h>
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
Expand Down Expand Up @@ -112,7 +112,7 @@ __global__ void scaled_masked_softmax_warp_forward(
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = 4;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;

// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
Expand Down Expand Up @@ -231,7 +231,7 @@ __global__ void scaled_masked_softmax_warp_backward(
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = 4;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;

// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
Expand Down Expand Up @@ -311,9 +311,23 @@ __global__ void scaled_masked_softmax_warp_backward(
}
}
}

} // end of anonymous namespace

int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;

int batch_count = batches * attn_heads * query_seq_len;
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;

return batches_per_block;
}

template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(
output_t *dst,
Expand All @@ -326,7 +340,6 @@ void dispatch_scaled_masked_softmax_forward(
int attn_heads,
int pad_batches)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
if (key_seq_len == 0) {
return;
} else {
Expand All @@ -345,7 +358,6 @@ void dispatch_scaled_masked_softmax_forward(

int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
Expand Down Expand Up @@ -415,7 +427,6 @@ void dispatch_scaled_masked_softmax_backward(
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
if (key_seq_len == 0) {
return;
} else {
Expand Down Expand Up @@ -490,4 +501,4 @@ void dispatch_scaled_masked_softmax_backward(
break;
}
}
}
}
19 changes: 12 additions & 7 deletions megatron/fused_kernels/scaled_masked_softmax_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {

int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}


torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
Expand All @@ -46,9 +51,9 @@ torch::Tensor fwd_cuda(
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);

// Output
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);

// Softmax Intermediate Result Ptr
Expand All @@ -74,10 +79,10 @@ torch::Tensor fwd_cuda(
}

torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {

auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();

Expand All @@ -94,8 +99,8 @@ torch::Tensor bwd_cuda(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
Expand Down
50 changes: 23 additions & 27 deletions megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ __device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, con

template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }

template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }

Expand Down Expand Up @@ -112,23 +112,23 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
int micro_batch_size,
int stride,
int element_count)
output_t *dst,
const input_t *src,
const acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = 4;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;

int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
int local_seq = blockIdx.x + 1;
int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;

// micro_batch_size might not be a multiple of WARP_BATCH. Check how
Expand Down Expand Up @@ -194,7 +194,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
if (it < warp_iteration_limit) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
Expand All @@ -211,7 +211,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(

if (element_index < local_seq) {

#pragma unroll
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < local_seq) {
out[element] = elements[i][it + element] / sum[i];
Expand All @@ -224,32 +224,32 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
} else {
break;
}
}
}
}
}

template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
output_t *gradInput,
input_t *grad,
output_t *gradInput,
input_t *grad,
const input_t *output,
acc_t scale,
int micro_batch_size,
int stride,
acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = 4;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;

int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;

int local_seq = blockIdx.x + 1;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
Expand Down Expand Up @@ -296,7 +296,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
}
}
}

acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
Expand Down Expand Up @@ -340,7 +340,6 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
Expand All @@ -360,7 +359,6 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(

int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
Expand Down Expand Up @@ -430,7 +428,6 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
Expand All @@ -450,7 +447,6 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(

int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
Expand Down
Loading

0 comments on commit b4aa1c7

Please sign in to comment.