Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions extensions/csrc/cuda/funcs/cast_functor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#pragma once

#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <functional>

#include "../utils/micros.h"

// Note(LiuYang): This file provides base math operation for data type
// include POD and cuda built-in type such as half and __nv_bfloat16

namespace colossalAI {
namespace cuda {
namespace funcs {

// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality

template <>
struct TypeConverter<half2> {
using Type = at::Half;
};

template <>
struct TypeConverter<at::Half> {
using Type = half2;
};

template <>
struct TypeConverter<__nv_bfloat162> {
using Type = at::BFloat16;
};

template <>
struct TypeConverter<at::BFloat16> {
using Type = __nv_bfloat162;
};

template <typename From, typename To>
struct CastFunctor : public std::unary_function<From, To> {
HOSTDEVICE To operator()(From val) { return static_cast<To>(val); }
};

#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, STMT, \
FUNCTION_MODIFIER) \
template <> \
struct CastFunctor<FROM, TO> : public std::unary_function<FROM, TO> { \
FUNCTION_MODIFIER TO operator()(FROM val) { return STMT; } \
};

COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(int2, float2, make_float2(val.x, val.y),
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, float2, make_float2(val, val),
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, float2, __half22float2(val), DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, half2, __float22half2_rn(val),
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, __float2half2_rn(val),
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, half2, __half2half2(val), DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, __half2float(val), DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, nv_bfloat162,
__float2bfloat162_rn(val), DEVICE)

#undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION
} // namespace funcs
} // namespace cuda
} // namespace colossalAI
86 changes: 73 additions & 13 deletions extensions/csrc/cuda/funcs/op_functor.h
Original file line number Diff line number Diff line change
@@ -1,31 +1,91 @@
#pragma once

#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <functional>

#include "../utils/micros.h"

namespace colossalAI {
namespace cuda {
namespace funcs {

enum class BinaryOpType { kAdd = 0, kMinus, kMul, KDiv, kMax, KMin };
enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin };

template <typename T, BinaryOpType Op>
// Note(LiuYang): This file provides base math operation for data type
// include POD and cuda built-in type such as half and __nv_bfloat16
template <typename LT, typename RT, typename RET, BinaryOpType Op>
struct BinaryOpFunctor;

template <typename T>
struct BinaryOpFunctor<T, BinaryOpType::kAdd>
: public std::binary_function<T, T, T> {
__host__ __device__ T operator()(T lhs, T rhs) { return lhs + rhs; }
};

template <typename T>
struct BinaryOpFunctor<T, BinaryOpType::kMax>
: public std::binary_function<T, T, T> {
__host__ __device__ T operator()(T lhs, T rhs) { return max(lhs, rhs); }
};
#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BINARY_OP_TYPE, STMT, \
FUNCTION_MODIFIER, ARGS...) \
template <ARGS> \
struct BinaryOpFunctor<T, T, T, BINARY_OP_TYPE> \
: public std::binary_function<T, T, T> { \
FUNCTION_MODIFIER T operator()(T lhs, T rhs) { return STMT; } \
};

COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kAdd, lhs + rhs,
HOSTDEVICE, typename T)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMinus, lhs - rhs,
HOSTDEVICE, typename T)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMul, lhs* rhs,
HOSTDEVICE, typename T)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kDiv, lhs / rhs,
HOSTDEVICE, typename T)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMax, max(lhs, rhs),
HOSTDEVICE, typename T)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMin, min(lhs, rhs),
HOSTDEVICE, typename T)

COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kAdd,
__hadd(lhs, rhs), DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kAdd,
__hadd2(lhs, rhs), DEVICE)

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd,
__hadd(lhs, rhs), DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kAdd,
__hadd2(lhs, rhs), DEVICE)
#else
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd,
__float2bfloat16(__bfloat162float(lhs) +
__bfloat162float(rhs)),
DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
__nv_bfloat162, BinaryOpType::kAdd,
__floats2bfloat162_rn(__low2float(lhs) + __low2float(rhs),
__high2float(lhs) + __high2float(rhs)),
DEVICE)
#endif

COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kMul,
__hmul(lhs, rhs), DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kMul,
__hmul2(lhs, rhs), DEVICE)

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul,
__hmul(lhs, rhs), DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kMul,
__hmul2(lhs, rhs), DEVICE)
#else
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul,
__float2bfloat16(__bfloat162float(lhs) *
__bfloat162float(rhs)),
DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
__nv_bfloat162, BinaryOpType::kMul,
__floats2bfloat162_rn(__low2float(lhs) * __low2float(rhs),
__high2float(lhs) * __high2float(rhs)),
DEVICE)
#endif

#undef COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION

} // namespace funcs
} // namespace cuda
Expand Down
4 changes: 2 additions & 2 deletions extensions/csrc/cuda/include/block_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ struct GetOpForReduceType;

template <typename T>
struct GetOpForReduceType<T, ReduceType::kMax> {
using Op = funcs::BinaryOpFunctor<T, funcs::BinaryOpType::kMax>;
using Op = funcs::BinaryOpFunctor<T, T, T, funcs::BinaryOpType::kMax>;
};

template <typename T>
struct GetOpForReduceType<T, ReduceType::kSum> {
using Op = funcs::BinaryOpFunctor<T, funcs::BinaryOpType::kAdd>;
using Op = funcs::BinaryOpFunctor<T, T, T, funcs::BinaryOpType::kAdd>;
};

#define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \
Expand Down
31 changes: 21 additions & 10 deletions extensions/csrc/cuda/rms_layernorm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@

#include "block_reduce.h"
#include "../common/micros.h"
#include "utils/cuda_type_utils.h"
#include "funcs/cast_functor.h"
#include "funcs/op_functor.h"

using colossalAI::cuda::utils::block_reduce;
using colossalAI::cuda::utils::ReduceType;
using colossalAI::cuda::funcs::TypeConverter;
using colossalAI::cuda::funcs::CastFunctor;
using colossalAI::cuda::funcs::BinaryOpFunctor;
using colossalAI::cuda::funcs::BinaryOpType;

#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \
if (DATA_SIZE == 2) { \
Expand Down Expand Up @@ -53,6 +58,7 @@ __global__ void rms_layernorm_kernel(
const int num_tokens,
const int hidden_size) {
using scalar2_t = typename TypeConverter<scalar_t>::Type;
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kMul> mul_scalar2t;
__shared__ float s_variance;

/*
Expand All @@ -72,12 +78,13 @@ __global__ void rms_layernorm_kernel(
float variance = 0.0f;
int row_offset = blockIdx.x * hidden_size / 2;


#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = input_ptr[id];
float v1 = cuda_cast<float>(x_local[cnt].x);
float v2 = cuda_cast<float>(x_local[cnt].y);
float v1 = CastFunctor<scalar_t,float>()(x_local[cnt].x);
float v2 = CastFunctor<scalar_t,float>()(x_local[cnt].y);
variance += v1 * v1 + v2 * v2;
}
block_reduce<float, ReduceType::kSum,1>(&variance);
Expand All @@ -86,11 +93,11 @@ __global__ void rms_layernorm_kernel(
}
__syncthreads();

scalar2_t s_variance_2 = cuda_cast<scalar2_t>(s_variance);
scalar2_t s_variance_2 = CastFunctor<float,scalar2_t>()(s_variance);
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
out_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]);
out_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]);
}
}

Expand Down Expand Up @@ -137,6 +144,9 @@ __global__ void fused_add_rms_layernorm_kernel(
const int num_tokens,
const int hidden_size) {
using scalar2_t = typename TypeConverter<scalar_t>::Type;
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kAdd> add_scalar2t;
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kMul> mul_scalar2t;

__shared__ float s_variance;
scalar2_t x_local[4];

Expand All @@ -151,9 +161,9 @@ __global__ void fused_add_rms_layernorm_kernel(
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = input_ptr[id];
x_local[cnt] = add(x_local[cnt], residual_ptr[id]);
float v1 = cuda_cast<float>(x_local[cnt].x);
float v2 = cuda_cast<float>(x_local[cnt].y);
x_local[cnt] = add_scalar2t(x_local[cnt], residual_ptr[id]);
float v1 = CastFunctor<scalar_t,float>()(x_local[cnt].x);
float v2 = CastFunctor<scalar_t,float>()(x_local[cnt].y);
variance += v1 * v1 + v2 * v2;
residual_ptr[id] = x_local[cnt];
}
Expand All @@ -163,11 +173,12 @@ __global__ void fused_add_rms_layernorm_kernel(
}
__syncthreads();

scalar2_t s_variance_2 = cuda_cast<scalar2_t>(s_variance);
scalar2_t s_variance_2 = CastFunctor<float, scalar2_t>()(s_variance);

#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
input_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]);
input_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]);
}
}

Expand Down
Loading