Skip to content

Commit f033ada

Browse files
committed
add cast and op_functor for cuda build-in types
1 parent a2878e3 commit f033ada

File tree

5 files changed

+148
-25
lines changed

5 files changed

+148
-25
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#pragma once
2+
3+
#include <cuda.h>
4+
#include <cuda_fp16.h>
5+
#include <cuda_runtime.h>
6+
7+
#include <functional>
8+
9+
#include "../utils/micros.h"
10+
11+
// Note(LiuYang): This file provides base math operation for data type
12+
// include POD and cuda built-in type such as half and __nv_bfloat16
13+
14+
namespace colossalAI {
15+
namespace cuda {
16+
namespace funcs {
17+
18+
// Get type2 from type or vice versa (applied to half and bfloat16)
19+
template <typename T>
20+
struct TypeConverter {
21+
using Type = half2;
22+
}; // keep for generality
23+
24+
template <>
25+
struct TypeConverter<half2> {
26+
using Type = at::Half;
27+
};
28+
29+
template <>
30+
struct TypeConverter<at::Half> {
31+
using Type = half2;
32+
};
33+
34+
template <>
35+
struct TypeConverter<__nv_bfloat162> {
36+
using Type = at::BFloat16;
37+
};
38+
39+
template <>
40+
struct TypeConverter<at::BFloat16> {
41+
using Type = __nv_bfloat162;
42+
};
43+
44+
template <typename From, typename To>
45+
struct CastFunctor : public std::unary_function<From, To> {
46+
HOSTDEVICE To operator()(From val) { return static_cast<To>(val); }
47+
};
48+
49+
#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, STMT, \
50+
FUNCTION_MODIFIER) \
51+
template <> \
52+
struct CastFunctor<FROM, TO> : public std::unary_function<FROM, TO> { \
53+
FUNCTION_MODIFIER TO operator()(FROM val) { return STMT; } \
54+
};
55+
56+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(int2, float2, make_float2(val.x, val.y),
57+
DEVICE)
58+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, float2, make_float2(val, val),
59+
DEVICE)
60+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, float2, __half22float2(val), DEVICE)
61+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, half2, __float22half2_rn(val),
62+
DEVICE)
63+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, __float2half2_rn(val),
64+
DEVICE)
65+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, half2, __half2half2(val), DEVICE)
66+
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, __half2float(val), DEVICE)
67+
68+
#undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION
69+
} // namespace funcs
70+
} // namespace cuda
71+
} // namespace colossalAI

extensions/csrc/cuda/funcs/op_functor.h

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,65 @@
11
#pragma once
22

33
#include <cuda.h>
4+
#include <cuda_bf16.h>
45
#include <cuda_fp16.h>
56
#include <cuda_runtime.h>
67

78
#include <functional>
89

10+
#include "../utils/micros.h"
11+
912
namespace colossalAI {
1013
namespace cuda {
1114
namespace funcs {
1215

13-
enum class BinaryOpType { kAdd = 0, kMinus, kMul, KDiv, kMax, KMin };
16+
enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin };
1417

15-
template <typename T, BinaryOpType Op>
18+
// Note(LiuYang): This file provides base math operation for data type
19+
// include POD and cuda built-in type such as half and __nv_bfloat16
20+
template <typename LT, typename RT, typename RET, BinaryOpType Op>
1621
struct BinaryOpFunctor;
1722

18-
template <typename T>
19-
struct BinaryOpFunctor<T, BinaryOpType::kAdd>
20-
: public std::binary_function<T, T, T> {
21-
__host__ __device__ T operator()(T lhs, T rhs) { return lhs + rhs; }
22-
};
23-
24-
template <typename T>
25-
struct BinaryOpFunctor<T, BinaryOpType::kMax>
26-
: public std::binary_function<T, T, T> {
27-
__host__ __device__ T operator()(T lhs, T rhs) { return max(lhs, rhs); }
28-
};
23+
#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BINARY_OP_TYPE, STMT, \
24+
FUNCTION_MODIFIER, ARGS...) \
25+
template <ARGS> \
26+
struct BinaryOpFunctor<T, T, T, BINARY_OP_TYPE> \
27+
: public std::binary_function<T, T, T> { \
28+
FUNCTION_MODIFIER T operator()(T lhs, T rhs) { return STMT; } \
29+
};
30+
31+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kAdd, lhs + rhs,
32+
HOSTDEVICE, typename T)
33+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMinus, lhs - rhs,
34+
HOSTDEVICE, typename T)
35+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMul, lhs* rhs,
36+
HOSTDEVICE, typename T)
37+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kDiv, lhs / rhs,
38+
HOSTDEVICE, typename T)
39+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMax, max(lhs, rhs),
40+
HOSTDEVICE, typename T)
41+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMin, min(lhs, rhs),
42+
HOSTDEVICE, typename T)
43+
44+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kAdd,
45+
__hadd(lhs, rhs), DEVICE)
46+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kAdd,
47+
__hadd2(lhs, rhs), DEVICE)
48+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd,
49+
__hadd(lhs, rhs), DEVICE)
50+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kAdd,
51+
__hadd2(lhs, rhs), DEVICE)
52+
53+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kMul,
54+
__hmul(lhs, rhs), DEVICE)
55+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kMul,
56+
__hmul2(lhs, rhs), DEVICE)
57+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul,
58+
__hmul(lhs, rhs), DEVICE)
59+
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kMul,
60+
__hmul2(lhs, rhs), DEVICE)
61+
62+
#undef COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION
2963

3064
} // namespace funcs
3165
} // namespace cuda

extensions/csrc/cuda/include/block_reduce.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ struct GetOpForReduceType;
2222

2323
template <typename T>
2424
struct GetOpForReduceType<T, ReduceType::kMax> {
25-
using Op = funcs::BinaryOpFunctor<T, funcs::BinaryOpType::kMax>;
25+
using Op = funcs::BinaryOpFunctor<T, T, T, funcs::BinaryOpType::kMax>;
2626
};
2727

2828
template <typename T>
2929
struct GetOpForReduceType<T, ReduceType::kSum> {
30-
using Op = funcs::BinaryOpFunctor<T, funcs::BinaryOpType::kAdd>;
30+
using Op = funcs::BinaryOpFunctor<T, T, T, funcs::BinaryOpType::kAdd>;
3131
};
3232

3333
#define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \

extensions/csrc/cuda/rms_layernorm_kernel.cu

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@
1010

1111
#include "block_reduce.h"
1212
#include "../common/micros.h"
13-
#include "utils/cuda_type_utils.h"
13+
#include "funcs/cast_functor.h"
14+
#include "funcs/op_functor.h"
1415

1516
using colossalAI::cuda::utils::block_reduce;
1617
using colossalAI::cuda::utils::ReduceType;
18+
using colossalAI::cuda::funcs::TypeConverter;
19+
using colossalAI::cuda::funcs::CastFunctor;
20+
using colossalAI::cuda::funcs::BinaryOpFunctor;
21+
using colossalAI::cuda::funcs::BinaryOpType;
1722

1823
#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \
1924
if (DATA_SIZE == 2) { \
@@ -53,6 +58,9 @@ __global__ void rms_layernorm_kernel(
5358
const int num_tokens,
5459
const int hidden_size) {
5560
using scalar2_t = typename TypeConverter<scalar_t>::Type;
61+
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kMul> mul_scalar2t();
62+
CastFunctor<scalar2_t, float> cast_scalar2t_2_float();
63+
CastFunctor<float, scalar2_t> cast_float_2_scalar2t();
5664
__shared__ float s_variance;
5765

5866
/*
@@ -72,12 +80,13 @@ __global__ void rms_layernorm_kernel(
7280
float variance = 0.0f;
7381
int row_offset = blockIdx.x * hidden_size / 2;
7482

83+
7584
#pragma unroll unroll_factor
7685
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
7786
int id = row_offset + idx;
7887
x_local[cnt] = input_ptr[id];
79-
float v1 = cuda_cast<float>(x_local[cnt].x);
80-
float v2 = cuda_cast<float>(x_local[cnt].y);
88+
float v1 = cast_scalar2t_2_float(x_local[cnt].x);
89+
float v2 = cast_scalar2t_2_float(x_local[cnt].y);
8190
variance += v1 * v1 + v2 * v2;
8291
}
8392
block_reduce<float, ReduceType::kSum,1>(&variance);
@@ -86,11 +95,11 @@ __global__ void rms_layernorm_kernel(
8695
}
8796
__syncthreads();
8897

89-
scalar2_t s_variance_2 = cuda_cast<scalar2_t>(s_variance);
98+
scalar2_t s_variance_2 = cast_float_2_scalar2t(s_variance);
9099
#pragma unroll unroll_factor
91100
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
92101
int id = row_offset + idx;
93-
out_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]);
102+
out_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]);
94103
}
95104
}
96105

@@ -137,6 +146,11 @@ __global__ void fused_add_rms_layernorm_kernel(
137146
const int num_tokens,
138147
const int hidden_size) {
139148
using scalar2_t = typename TypeConverter<scalar_t>::Type;
149+
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kAdd> add_scalar2t();
150+
CastFunctor<scalar2_t, float> cast_scalar2t_2_float();
151+
CastFunctor<float, scalar2_t> cast_float_2_scalar2t();
152+
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kMul> mul_scalar2t();
153+
140154
__shared__ float s_variance;
141155
scalar2_t x_local[4];
142156

@@ -151,9 +165,9 @@ __global__ void fused_add_rms_layernorm_kernel(
151165
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
152166
int id = row_offset + idx;
153167
x_local[cnt] = input_ptr[id];
154-
x_local[cnt] = add(x_local[cnt], residual_ptr[id]);
155-
float v1 = cuda_cast<float>(x_local[cnt].x);
156-
float v2 = cuda_cast<float>(x_local[cnt].y);
168+
x_local[cnt] = add_scalar2t(x_local[cnt], residual_ptr[id]);
169+
float v1 = cast_scalar2t_2_float(x_local[cnt].x);
170+
float v2 = cast_scalar2t_2_float(x_local[cnt].y);
157171
variance += v1 * v1 + v2 * v2;
158172
residual_ptr[id] = x_local[cnt];
159173
}
@@ -163,11 +177,11 @@ __global__ void fused_add_rms_layernorm_kernel(
163177
}
164178
__syncthreads();
165179

166-
scalar2_t s_variance_2 = cuda_cast<scalar2_t>(s_variance);
180+
scalar2_t s_variance_2 = cast_float_2_scalar2t(s_variance);
167181
#pragma unroll unroll_factor
168182
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
169183
int id = row_offset + idx;
170-
input_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]);
184+
input_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]);
171185
}
172186
}
173187

extensions/csrc/cuda/utils/micros.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@
1212
throw std::runtime_error(cudaGetErrorString(status)); \
1313
} \
1414
}
15+
16+
#define HOST __host__
17+
#define DEVICE __device__
18+
#define HOSTDEVICE __host__ __device__

0 commit comments

Comments
 (0)