Skip to content

Commit 7ebdf48

Browse files
authored
add cast and op_functor for cuda build-in types (#5546)
1 parent 4bb5d89 commit 7ebdf48

File tree

6 files changed

+174
-147
lines changed

6 files changed

+174
-147
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#pragma once
2+
3+
#include <cuda.h>
4+
#include <cuda_bf16.h>
5+
#include <cuda_fp16.h>
6+
#include <cuda_runtime.h>
7+
8+
#include <functional>
9+
10+
#include "../utils/micros.h"
11+
12+
// Note(LiuYang): This file provides base math operation for data type
13+
// include POD and cuda built-in type such as half and __nv_bfloat16
14+
15+
namespace colossalAI {
16+
namespace cuda {
17+
namespace funcs {
18+
19+
// Get type2 from type or vice versa (applied to half and bfloat16)
20+
template <typename T>
21+
struct TypeConverter {
22+
using Type = half2;
23+
}; // keep for generality
24+
25+
template <>
26+
struct TypeConverter<half2> {
27+
using Type = at::Half;
28+
};
29+
30+
template <>
31+
struct TypeConverter<at::Half> {
32+
using Type = half2;
33+
};
34+
35+
template <>
36+
struct TypeConverter<__nv_bfloat162> {
37+
using Type = at::BFloat16;
38+
};
39+
40+
template <>
41+
struct TypeConverter<at::BFloat16> {
42+
using Type = __nv_bfloat162;
43+
};
44+
45+
template <typename From, typename To>
46+
struct CastFunctor : public std::unary_function<From, To> {
47+
HOSTDEVICE To operator()(From val) { return static_cast<To>(val); }
48+
};
49+
50+
#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, STMT, \
51+
FUNCTION_MODIFIER) \
52+
template <> \
53+
struct CastFunctor<FROM, TO> : public std::unary_function<FROM, TO> { \
54+
FUNCTION_MODIFIER TO operator()(FROM val) { return STMT; } \
55+
};
56+
57+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(int2, float2, make_float2(val.x, val.y),
58+
DEVICE)
59+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, float2, make_float2(val, val),
60+
DEVICE)
61+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, float2, __half22float2(val), DEVICE)
62+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, half2, __float22half2_rn(val),
63+
DEVICE)
64+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, __float2half2_rn(val),
65+
DEVICE)
66+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, half2, __half2half2(val), DEVICE)
67+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, __half2float(val), DEVICE)
68+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, nv_bfloat162,
69+
__float2bfloat162_rn(val), DEVICE)
70+
71+
#undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION
72+
} // namespace funcs
73+
} // namespace cuda
74+
} // namespace colossalAI

extensions/csrc/cuda/funcs/op_functor.h

Lines changed: 73 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,91 @@
11
#pragma once
22

33
#include <cuda.h>
4+
#include <cuda_bf16.h>
45
#include <cuda_fp16.h>
56
#include <cuda_runtime.h>
67

78
#include <functional>
89

10+
#include "../utils/micros.h"
11+
912
namespace colossalAI {
1013
namespace cuda {
1114
namespace funcs {
1215

13-
enum class BinaryOpType { kAdd = 0, kMinus, kMul, KDiv, kMax, KMin };
16+
enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin };
1417

15-
template <typename T, BinaryOpType Op>
18+
// Note(LiuYang): This file provides base math operation for data type
19+
// include POD and cuda built-in type such as half and __nv_bfloat16
20+
template <typename LT, typename RT, typename RET, BinaryOpType Op>
1621
struct BinaryOpFunctor;
1722

18-
template <typename T>
19-
struct BinaryOpFunctor<T, BinaryOpType::kAdd>
20-
: public std::binary_function<T, T, T> {
21-
__host__ __device__ T operator()(T lhs, T rhs) { return lhs + rhs; }
22-
};
23-
24-
template <typename T>
25-
struct BinaryOpFunctor<T, BinaryOpType::kMax>
26-
: public std::binary_function<T, T, T> {
27-
__host__ __device__ T operator()(T lhs, T rhs) { return max(lhs, rhs); }
28-
};
23+
#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BINARY_OP_TYPE, STMT, \
24+
FUNCTION_MODIFIER, ARGS...) \
25+
template <ARGS> \
26+
struct BinaryOpFunctor<T, T, T, BINARY_OP_TYPE> \
27+
: public std::binary_function<T, T, T> { \
28+
FUNCTION_MODIFIER T operator()(T lhs, T rhs) { return STMT; } \
29+
};
30+
31+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kAdd, lhs + rhs,
32+
HOSTDEVICE, typename T)
33+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMinus, lhs - rhs,
34+
HOSTDEVICE, typename T)
35+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMul, lhs* rhs,
36+
HOSTDEVICE, typename T)
37+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kDiv, lhs / rhs,
38+
HOSTDEVICE, typename T)
39+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMax, max(lhs, rhs),
40+
HOSTDEVICE, typename T)
41+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMin, min(lhs, rhs),
42+
HOSTDEVICE, typename T)
43+
44+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kAdd,
45+
__hadd(lhs, rhs), DEVICE)
46+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kAdd,
47+
__hadd2(lhs, rhs), DEVICE)
48+
49+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
50+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd,
51+
__hadd(lhs, rhs), DEVICE)
52+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kAdd,
53+
__hadd2(lhs, rhs), DEVICE)
54+
#else
55+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd,
56+
__float2bfloat16(__bfloat162float(lhs) +
57+
__bfloat162float(rhs)),
58+
DEVICE)
59+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
60+
__nv_bfloat162, BinaryOpType::kAdd,
61+
__floats2bfloat162_rn(__low2float(lhs) + __low2float(rhs),
62+
__high2float(lhs) + __high2float(rhs)),
63+
DEVICE)
64+
#endif
65+
66+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kMul,
67+
__hmul(lhs, rhs), DEVICE)
68+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kMul,
69+
__hmul2(lhs, rhs), DEVICE)
70+
71+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
72+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul,
73+
__hmul(lhs, rhs), DEVICE)
74+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kMul,
75+
__hmul2(lhs, rhs), DEVICE)
76+
#else
77+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul,
78+
__float2bfloat16(__bfloat162float(lhs) *
79+
__bfloat162float(rhs)),
80+
DEVICE)
81+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
82+
__nv_bfloat162, BinaryOpType::kMul,
83+
__floats2bfloat162_rn(__low2float(lhs) * __low2float(rhs),
84+
__high2float(lhs) * __high2float(rhs)),
85+
DEVICE)
86+
#endif
87+
88+
#undef COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION
2989

3090
} // namespace funcs
3191
} // namespace cuda

extensions/csrc/cuda/include/block_reduce.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ struct GetOpForReduceType;
2222

2323
template <typename T>
2424
struct GetOpForReduceType<T, ReduceType::kMax> {
25-
using Op = funcs::BinaryOpFunctor<T, funcs::BinaryOpType::kMax>;
25+
using Op = funcs::BinaryOpFunctor<T, T, T, funcs::BinaryOpType::kMax>;
2626
};
2727

2828
template <typename T>
2929
struct GetOpForReduceType<T, ReduceType::kSum> {
30-
using Op = funcs::BinaryOpFunctor<T, funcs::BinaryOpType::kAdd>;
30+
using Op = funcs::BinaryOpFunctor<T, T, T, funcs::BinaryOpType::kAdd>;
3131
};
3232

3333
#define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \

extensions/csrc/cuda/rms_layernorm_kernel.cu

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@
1010

1111
#include "block_reduce.h"
1212
#include "../common/micros.h"
13-
#include "utils/cuda_type_utils.h"
13+
#include "funcs/cast_functor.h"
14+
#include "funcs/op_functor.h"
1415

1516
using colossalAI::cuda::utils::block_reduce;
1617
using colossalAI::cuda::utils::ReduceType;
18+
using colossalAI::cuda::funcs::TypeConverter;
19+
using colossalAI::cuda::funcs::CastFunctor;
20+
using colossalAI::cuda::funcs::BinaryOpFunctor;
21+
using colossalAI::cuda::funcs::BinaryOpType;
1722

1823
#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \
1924
if (DATA_SIZE == 2) { \
@@ -53,6 +58,7 @@ __global__ void rms_layernorm_kernel(
5358
const int num_tokens,
5459
const int hidden_size) {
5560
using scalar2_t = typename TypeConverter<scalar_t>::Type;
61+
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kMul> mul_scalar2t;
5662
__shared__ float s_variance;
5763

5864
/*
@@ -72,12 +78,13 @@ __global__ void rms_layernorm_kernel(
7278
float variance = 0.0f;
7379
int row_offset = blockIdx.x * hidden_size / 2;
7480

81+
7582
#pragma unroll unroll_factor
7683
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
7784
int id = row_offset + idx;
7885
x_local[cnt] = input_ptr[id];
79-
float v1 = cuda_cast<float>(x_local[cnt].x);
80-
float v2 = cuda_cast<float>(x_local[cnt].y);
86+
float v1 = CastFunctor<scalar_t,float>()(x_local[cnt].x);
87+
float v2 = CastFunctor<scalar_t,float>()(x_local[cnt].y);
8188
variance += v1 * v1 + v2 * v2;
8289
}
8390
block_reduce<float, ReduceType::kSum,1>(&variance);
@@ -86,11 +93,11 @@ __global__ void rms_layernorm_kernel(
8693
}
8794
__syncthreads();
8895

89-
scalar2_t s_variance_2 = cuda_cast<scalar2_t>(s_variance);
96+
scalar2_t s_variance_2 = CastFunctor<float,scalar2_t>()(s_variance);
9097
#pragma unroll unroll_factor
9198
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
9299
int id = row_offset + idx;
93-
out_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]);
100+
out_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]);
94101
}
95102
}
96103

@@ -137,6 +144,9 @@ __global__ void fused_add_rms_layernorm_kernel(
137144
const int num_tokens,
138145
const int hidden_size) {
139146
using scalar2_t = typename TypeConverter<scalar_t>::Type;
147+
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kAdd> add_scalar2t;
148+
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kMul> mul_scalar2t;
149+
140150
__shared__ float s_variance;
141151
scalar2_t x_local[4];
142152

@@ -151,9 +161,9 @@ __global__ void fused_add_rms_layernorm_kernel(
151161
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
152162
int id = row_offset + idx;
153163
x_local[cnt] = input_ptr[id];
154-
x_local[cnt] = add(x_local[cnt], residual_ptr[id]);
155-
float v1 = cuda_cast<float>(x_local[cnt].x);
156-
float v2 = cuda_cast<float>(x_local[cnt].y);
164+
x_local[cnt] = add_scalar2t(x_local[cnt], residual_ptr[id]);
165+
float v1 = CastFunctor<scalar_t,float>()(x_local[cnt].x);
166+
float v2 = CastFunctor<scalar_t,float>()(x_local[cnt].y);
157167
variance += v1 * v1 + v2 * v2;
158168
residual_ptr[id] = x_local[cnt];
159169
}
@@ -163,11 +173,12 @@ __global__ void fused_add_rms_layernorm_kernel(
163173
}
164174
__syncthreads();
165175

166-
scalar2_t s_variance_2 = cuda_cast<scalar2_t>(s_variance);
176+
scalar2_t s_variance_2 = CastFunctor<float, scalar2_t>()(s_variance);
177+
167178
#pragma unroll unroll_factor
168179
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
169180
int id = row_offset + idx;
170-
input_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]);
181+
input_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]);
171182
}
172183
}
173184

0 commit comments

Comments
 (0)