Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 118 additions & 1 deletion paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License. */
// https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu
// We add License in the head.

#pragma once

#include <cuda_fp16.h>
#include <float.h>

Expand Down Expand Up @@ -91,6 +93,23 @@ using float16 = plat::float16;

#define MMHA_USE_FP32_ACUM_FOR_LOGITS
#define MMHA_USE_FP32_ACUM_FOR_OUT
#define MMHA_USE_FP32_ACUM_FOR_FMA
#define MMHA_USE_HMMA_FOR_REDUCTION

template <typename D>
class PDDataTypeTraits;

template <>
class PDDataTypeTraits<float> {
public:
typedef float DataType;
};

template <>
class PDDataTypeTraits<float16> {
public:
typedef half DataType;
};

template <typename T>
struct Masked_multihead_attention_params {
Expand Down Expand Up @@ -153,6 +172,17 @@ template <> struct V_vec_<float16, 2> { using Type = uint32_t; };
template <> struct V_vec_<float16, 4> { using Type = uint2; };
template <> struct V_vec_<float16, 8> { using Type = uint4; };

#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
template<typename T>
struct K_vec_acum_fp32_ {
};

template<>
struct K_vec_acum_fp32_<uint32_t> {
using Type = float2;
};
#endif

#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template <typename T> struct V_vec_acum_fp32_ {};
// template <> struct V_vec_acum_fp32_<float> { using Type = float; };
Expand Down Expand Up @@ -321,6 +351,15 @@ inline __device__ uint32_t mul(uint32_t a, float b) {
return res;
}

template <>
inline __device__ float2 mul(uint32_t a, float b) {
float2 tmp = half2_to_float2(a);
float2 res;
res.x = tmp.x * b;
res.y = tmp.y * b;
return res;
}

template <>
inline __device__ uint2 mul(uint2 a, float b) {
uint2 res;
Expand All @@ -347,6 +386,15 @@ inline __device__ float2 mul(float2 a, float b) {
return res;
}

template <>
inline __device__ float2 mul(float2 a, uint32_t b) {
float2 tmp_b = half2_to_float2(b);
float2 res;
res.x = a.x * tmp_b.x;
res.y = a.y * tmp_b.y;
return res;
}

template <>
inline __device__ float4 mul(float4 a, float b) {
float4 res;
Expand Down Expand Up @@ -406,6 +454,12 @@ inline __device__ float2 fma(float2 a, float2 b, float2 c) {
return d;
}

inline __device__ float2 fma(float2 a, uint32_t b, float2 c) {
float2 tmp_b = half2_to_float2(b);
float2 d = fma(a, tmp_b, c);
return d;
}

inline __device__ float4 fma(float4 a, float4 b, float4 c) {
float4 d;
d.x = fma(a.x, b.x, c.x);
Expand Down Expand Up @@ -527,6 +581,49 @@ inline __device__ float qk_dot_(const K_vec (&q)[N],
return qk;
}

inline __device__ float4 hmma_fp32_tensorcore(const uint2 &a, uint32_t b) {
float4 c;
float zero = 0.f;
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5}, \n"
" {%6}, \n"
" {%7, %7, %7, %7}; \n"

: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
return c;
}

template <int N>
inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N],
const uint32_t (&k)[N],
float inv_sqrt_dh) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using K_vec_acum = typename K_vec_acum_fp32_<uint32_t>::Type;
#else
using K_vec_acum = uint32_t;
#endif
K_vec_acum inv_q = mul<K_vec_acum, uint32_t, float>(q[0], inv_sqrt_dh);
K_vec_acum qk_vec = mul<K_vec_acum, K_vec_acum, uint32_t>(inv_q, k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
inv_q = mul<K_vec_acum, uint32_t, float>(q[ii], inv_sqrt_dh);
qk_vec = fma(inv_q, k[ii], qk_vec);
}
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
uint32_t qk_vec_ = float2_to_half2(qk_vec);
return hmma_fp32_tensorcore(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
#else
return hmma_fp32_tensorcore(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
#endif
#else
return 0.f;
#endif
}

template <typename T, int THREADS_PER_KEY>
struct Qk_dot {
template <typename K_vec, int N>
Expand All @@ -537,6 +634,20 @@ struct Qk_dot {
}
};

template <>
struct Qk_dot<float16, 4> {
template <int N>
static inline __device__ float dot(const uint32_t (&q)[N],
const uint32_t (&k)[N],
float inv_sqrt_dh) {
#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && __CUDA_ARCH__ >= 750
return qk_hmma_dot_(q, k, inv_sqrt_dh);
#else
return qk_dot_<4>(q, k, inv_sqrt_dh);
#endif
}
};

template <int WARPS_PER_BLOCK, int WARP_SIZE = 32>
inline __device__ float block_sum(float *red_smem, float sum) {
int warp = threadIdx.x / WARP_SIZE;
Expand Down Expand Up @@ -609,6 +720,8 @@ template <typename T,
__global__ void masked_multihead_attention_kernel(
Masked_multihead_attention_params<T> params) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
typedef PDDataTypeTraits<T> traits_;
typedef typename traits_::DataType DataType_;

static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
Expand Down Expand Up @@ -866,7 +979,7 @@ __global__ void masked_multihead_attention_kernel(
float logit = logits_smem[ti];
out = fma(logit, cast_to_float(v), out);
#else
T logit = logits_smem[ti];
DataType_ logit = static_cast<DataType_>(logits_smem[ti]);
// Update the partial sums.
out = fma(logit, v, out);
#endif
Expand Down Expand Up @@ -990,7 +1103,11 @@ void fmha_launch_kernel(const Masked_multihead_attention_params<T> &params,
if (params.timestep < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream);
} else if (params.timestep < 2048) {
#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && __CUDA_ARCH__ >= 750
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 256, stream);
#else
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream);
#endif
} else {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream);
}
Expand Down