Skip to content

Commit a219123

Browse files
authored
refactor csrc (#5582)
1 parent 25928d8 commit a219123

17 files changed

+1109
-1246
lines changed

extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include <ATen/cuda/CUDAContext.h>
22
#include <torch/extension.h>
33

4-
#include "utils/vector_copy_utils.h"
4+
#include "utils/vec_copy.h"
55
#include "../common/micros.h"
66

77
template<typename scalar_t, bool Aligned, int VecSize>

extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include <ATen/cuda/CUDAContext.h>
22
#include <torch/extension.h>
33

4-
#include "utils/vector_copy_utils.h"
4+
#include "utils/vec_copy.h"
55
#include "../common/micros.h"
66

77
template<typename scalar_t, bool Aligned, int VecSize>

extensions/csrc/cuda/funcs/op_functor.h renamed to extensions/csrc/cuda/funcs/binary_functor.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ namespace funcs {
1616
enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin };
1717

1818
// Note(LiuYang): This file provides base math operation for data type
19-
// include POD and cuda built-in type such as half and __nv_bfloat16
20-
template <typename LT, typename RT, typename RET, BinaryOpType Op>
19+
// include POD and cuda built-in type such as half and __nv_bfloat16.
20+
// Implementation of common and simple binary operators should be placed here,
21+
// otherwise, they should be placed in a new file under functors dir.
22+
template <typename LT, typename RT, typename RET, BinaryOpType op_type>
2123
struct BinaryOpFunctor;
2224

2325
#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BINARY_OP_TYPE, STMT, \

extensions/csrc/cuda/funcs/cast_functor.h

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,6 @@ namespace colossalAI {
1616
namespace cuda {
1717
namespace funcs {
1818

19-
// Get type2 from type or vice versa (applied to half and bfloat16)
20-
template <typename T>
21-
struct TypeConverter {
22-
using Type = half2;
23-
}; // keep for generality
24-
25-
template <>
26-
struct TypeConverter<half2> {
27-
using Type = at::Half;
28-
};
29-
30-
template <>
31-
struct TypeConverter<at::Half> {
32-
using Type = half2;
33-
};
34-
35-
template <>
36-
struct TypeConverter<__nv_bfloat162> {
37-
using Type = at::BFloat16;
38-
};
39-
40-
template <>
41-
struct TypeConverter<at::BFloat16> {
42-
using Type = __nv_bfloat162;
43-
};
44-
4519
template <typename From, typename To>
4620
struct CastFunctor : public std::unary_function<From, To> {
4721
HOSTDEVICE To operator()(From val) { return static_cast<To>(val); }
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#pragma once
2+
3+
#include <cuda.h>
4+
#include <cuda_bf16.h>
5+
#include <cuda_fp16.h>
6+
#include <cuda_runtime.h>
7+
8+
#include <functional>
9+
10+
#include "../utils/micros.h"
11+
12+
namespace colossalAI {
13+
namespace cuda {
14+
namespace funcs {
15+
16+
// Note(LiuYang): As a retrieved table to check which operation is supported
17+
// already
18+
enum class UnaryOpType { kLog2Ceil = 0 };
19+
20+
// Note(LiuYang): Implementation of common and simple unary operators should be
21+
// placed here, otherwise, they should be placed in a new file under functors
22+
// dir.
23+
template <typename From, typename To, UnaryOpType op_type>
24+
struct UnaryOpFunctor;
25+
26+
#define COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION( \
27+
FROM, TO, UNARY_OP_TYPE, FUNCTION_MODIFIER, STMTS, ARGS...) \
28+
template <ARGS> \
29+
struct UnaryOpFunctor<FROM, TO, UNARY_OP_TYPE> \
30+
: public std::unary_function<FROM, TO> { \
31+
FUNCTION_MODIFIER TO operator()(FROM val) STMTS \
32+
};
33+
34+
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(int, int, UnaryOpType::kLog2Ceil,
35+
HOSTDEVICE, {
36+
int log2_value = 0;
37+
while ((1 << log2_value) < val)
38+
++log2_value;
39+
return log2_value;
40+
})
41+
42+
#undef COLOSSAL_UARY_FUNCTOR_SPECIALIZATION
43+
44+
} // namespace funcs
45+
} // namespace cuda
46+
} // namespace colossalAI

extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#include <ATen/cuda/CUDAContext.h>
33
#include <torch/extension.h>
44

5-
#include "utils/vector_copy_utils.h"
5+
#include "utils/vec_copy.h"
66
#include "../common/micros.h"
77
#include "../common/mp_type_traits.h"
88

extensions/csrc/cuda/get_cos_and_sin_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include <ATen/cuda/CUDAContext.h>
22
#include <torch/extension.h>
33

4-
#include "utils/vector_copy_utils.h"
4+
#include "utils/vec_copy.h"
55
#include "../common/micros.h"
66
#include "stdio.h"
77

extensions/csrc/cuda/include/block_reduce.h

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
#include <cuda_fp16.h>
55
#include <cuda_runtime.h>
66

7-
#include "../funcs/op_functor.h"
7+
#include "../funcs/binary_functor.h"
88

99
namespace colossalAI {
1010
namespace cuda {
1111
namespace utils {
1212

1313
const float kReduceFloatInfNeg = -100000000.f;
1414
const float kReduceFloatInfPos = 100000000.f;
15-
const int kWarpSize = 32;
1615
const unsigned int kWarpReduceMask = 0xffffffff;
1716

1817
enum class ReduceType { kMax = 0, kSum };
@@ -31,44 +30,42 @@ struct GetOpForReduceType<T, ReduceType::kSum> {
3130
};
3231

3332
#define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \
34-
for (int offset = 0; offset < LANES; ++offset) { \
33+
_Pragma("unroll") for (int offset = 0; offset < LANES; ++offset) { \
3534
*(VAL_PTR + offset) = \
3635
OP(*(VAL_PTR + offset), \
3736
__shfl_xor_sync(MASK, *(VAL_PTR + offset), DELTA, WIDTH)); \
3837
}
3938

40-
#define COLOSSAL_WARP_REDUCE_IMPL(MASK, VAL_PTR, OP, LANES) \
41-
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 16, 32, OP, LANES) \
42-
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 8, 32, OP, LANES) \
43-
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 4, 32, OP, LANES) \
44-
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 2, 32, OP, LANES) \
45-
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 1, 32, OP, LANES)
46-
47-
#define COLOSSAL_BLOCK_REDUCE_IMPL(DTYPE, MASK, VAL_PTR, OP, LANES, \
48-
DEFAULT_VALUE, REDUCE_TYPE) \
49-
__shared__ T shm[LANES][32]; \
50-
int lane_id = threadIdx.x & 0x1f; \
51-
int warp_id = threadIdx.x >> 5; \
52-
\
53-
warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR); \
54-
if (lane_id == 0) { \
55-
for (int offset = 0; offset < LANES; ++offset) { \
56-
shm[offset][warp_id] = *(VAL_PTR + offset); \
57-
} \
58-
} \
59-
__syncthreads(); \
60-
\
61-
for (int offset = 0; offset < LANES; ++offset) { \
62-
*(VAL_PTR + offset) = (threadIdx.x < (blockDim.x >> 5)) \
63-
? shm[offset][lane_id] \
64-
: static_cast<T>(DEFAULT_VALUE); \
65-
} \
39+
#define COLOSSAL_WARP_REDUCE_IMPL(MASK, VAL_PTR, WIDTH, OP, LANES) \
40+
_Pragma("unroll") for (int DELTA = (WIDTH >> 1); DELTA > 0; DELTA >>= 1) { \
41+
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \
42+
}
43+
44+
#define COLOSSAL_BLOCK_REDUCE_IMPL(DTYPE, VAL_PTR, OP, LANES, DEFAULT_VALUE, \
45+
REDUCE_TYPE) \
46+
__shared__ T shm[LANES][32]; \
47+
int lane_id = threadIdx.x & 0x1f; \
48+
int warp_id = threadIdx.x >> 5; \
49+
\
50+
warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR); \
51+
if (lane_id == 0) { \
52+
for (int offset = 0; offset < LANES; ++offset) { \
53+
shm[offset][warp_id] = *(VAL_PTR + offset); \
54+
} \
55+
} \
56+
__syncthreads(); \
57+
\
58+
_Pragma("unroll") for (int offset = 0; offset < LANES; ++offset) { \
59+
*(VAL_PTR + offset) = (threadIdx.x < (blockDim.x >> 5)) \
60+
? shm[offset][lane_id] \
61+
: static_cast<T>(DEFAULT_VALUE); \
62+
} \
6663
warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR);
6764

68-
template <typename T, ReduceType rtype, int lanes>
65+
template <typename T, ReduceType rtype, int lanes, int width = 32>
6966
__forceinline__ __device__ void warp_reduce(T* pval) {
7067
typename GetOpForReduceType<T, rtype>::Op op;
71-
COLOSSAL_WARP_REDUCE_IMPL(kWarpReduceMask, pval, op, lanes);
68+
COLOSSAL_WARP_REDUCE_IMPL(kWarpReduceMask, pval, width, op, lanes);
7269
}
7370

7471
template <typename T, ReduceType rtype>
@@ -84,8 +81,7 @@ template <typename T, ReduceType rtype, int lanes>
8481
__forceinline__ __device__ void block_reduce(T* pval) {
8582
constexpr T kDefaultValue = GetDefaultValueForBlockReduce<T, rtype>();
8683
typename GetOpForReduceType<T, rtype>::Op op;
87-
COLOSSAL_BLOCK_REDUCE_IMPL(T, kWarpReduceMask, pval, op, lanes, kDefaultValue,
88-
rtype);
84+
COLOSSAL_BLOCK_REDUCE_IMPL(T, pval, op, lanes, kDefaultValue, rtype);
8985
}
9086

9187
#undef COLOSSAL_SHFL_FUNCTION

extensions/csrc/cuda/pybind/scaled_masked_softmax.cpp

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,15 @@
66

77
#include <vector>
88

9-
namespace multihead_attn {
10-
namespace fused_softmax {
11-
namespace scaled_masked_softmax {
12-
139
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
1410
float scale_factor);
1511

1612
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
1713
torch::Tensor const& softmax_results,
1814
float scale_factor);
1915

20-
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
21-
int attn_heads);
16+
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
17+
int attn_heads);
2218

2319
torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask,
2420
float scale_factor) {
@@ -46,25 +42,13 @@ torch::Tensor bwd(torch::Tensor const& output_grads,
4642
return bwd_cuda(output_grads, softmax_results, scale_factor);
4743
}
4844

49-
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
50-
int attn_heads) {
51-
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches,
52-
attn_heads);
53-
}
54-
55-
} // end namespace scaled_masked_softmax
56-
} // end namespace fused_softmax
57-
} // end namespace multihead_attn
58-
5945
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
60-
m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
46+
m.def("forward", &fwd,
6147
"Self Multihead Attention scaled, time masked softmax -- Forward.");
6248

63-
m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
49+
m.def("backward", &bwd,
6450
"Self Multihead Attention scaled, time masked softmax -- Backward.");
6551

66-
m.def("get_batch_per_block",
67-
&multihead_attn::fused_softmax::scaled_masked_softmax::
68-
get_batch_per_block,
52+
m.def("get_batch_per_block", &get_batch_per_block,
6953
"Return Batch per block size.");
7054
}

extensions/csrc/cuda/pybind/scaled_upper_triang_masked_softmax.cpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@
66

77
#include <vector>
88

9-
namespace multihead_attn {
10-
namespace fused_softmax {
11-
namespace scaled_upper_triang_masked_softmax {
12-
139
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor);
1410

1511
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
@@ -40,15 +36,9 @@ torch::Tensor bwd(torch::Tensor const& output_grads,
4036
return bwd_cuda(output_grads, softmax_results, scale_factor);
4137
}
4238

43-
} // end namespace scaled_upper_triang_masked_softmax
44-
} // end namespace fused_softmax
45-
} // end namespace multihead_attn
46-
4739
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
48-
m.def("forward",
49-
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
40+
m.def("forward", &fwd,
5041
"Self Multihead Attention scaled, time masked softmax -- Forward.");
51-
m.def("backward",
52-
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
42+
m.def("backward", &bwd,
5343
"Self Multihead Attention scaled, time masked softmax -- Backward.");
5444
}

0 commit comments

Comments
 (0)