44#include < cuda_fp16.h>
55#include < cuda_runtime.h>
66
7- #include " ../funcs/op_functor .h"
7+ #include " ../funcs/binary_functor .h"
88
99namespace colossalAI {
1010namespace cuda {
1111namespace utils {
1212
1313const float kReduceFloatInfNeg = -100000000 .f;
1414const float kReduceFloatInfPos = 100000000 .f;
15- const int kWarpSize = 32 ;
1615const unsigned int kWarpReduceMask = 0xffffffff ;
1716
1817enum class ReduceType { kMax = 0 , kSum };
@@ -31,44 +30,42 @@ struct GetOpForReduceType<T, ReduceType::kSum> {
3130};
3231
3332#define COLOSSAL_SHFL_FUNCTION (MASK, VAL_PTR, DELTA, WIDTH, OP, LANES ) \
34- for (int offset = 0 ; offset < LANES; ++offset) { \
33+ _Pragma ( " unroll " ) for (int offset = 0 ; offset < LANES; ++offset) { \
3534 *(VAL_PTR + offset) = \
3635 OP (*(VAL_PTR + offset), \
3736 __shfl_xor_sync (MASK, *(VAL_PTR + offset), DELTA, WIDTH)); \
3837 }
3938
40- #define COLOSSAL_WARP_REDUCE_IMPL (MASK, VAL_PTR, OP, LANES ) \
41- COLOSSAL_SHFL_FUNCTION (MASK, VAL_PTR, 16 , 32 , OP, LANES) \
42- COLOSSAL_SHFL_FUNCTION (MASK, VAL_PTR, 8 , 32 , OP, LANES) \
43- COLOSSAL_SHFL_FUNCTION (MASK, VAL_PTR, 4 , 32 , OP, LANES) \
44- COLOSSAL_SHFL_FUNCTION (MASK, VAL_PTR, 2 , 32 , OP, LANES) \
45- COLOSSAL_SHFL_FUNCTION (MASK, VAL_PTR, 1 , 32 , OP, LANES)
46-
47- #define COLOSSAL_BLOCK_REDUCE_IMPL (DTYPE, MASK, VAL_PTR, OP, LANES, \
48- DEFAULT_VALUE, REDUCE_TYPE) \
49- __shared__ T shm[LANES][32 ]; \
50- int lane_id = threadIdx.x & 0x1f ; \
51- int warp_id = threadIdx.x >> 5 ; \
52- \
53- warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR); \
54- if (lane_id == 0 ) { \
55- for (int offset = 0 ; offset < LANES; ++offset) { \
56- shm[offset][warp_id] = *(VAL_PTR + offset); \
57- } \
58- } \
59- __syncthreads (); \
60- \
61- for (int offset = 0 ; offset < LANES; ++offset) { \
62- *(VAL_PTR + offset) = (threadIdx.x < (blockDim.x >> 5 )) \
63- ? shm[offset][lane_id] \
64- : static_cast <T>(DEFAULT_VALUE); \
65- } \
39+ #define COLOSSAL_WARP_REDUCE_IMPL (MASK, VAL_PTR, WIDTH, OP, LANES ) \
40+ _Pragma (" unroll" ) for (int DELTA = (WIDTH >> 1 ); DELTA > 0 ; DELTA >>= 1 ) { \
41+ COLOSSAL_SHFL_FUNCTION (MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \
42+ }
43+
44+ #define COLOSSAL_BLOCK_REDUCE_IMPL (DTYPE, VAL_PTR, OP, LANES, DEFAULT_VALUE, \
45+ REDUCE_TYPE) \
46+ __shared__ T shm[LANES][32 ]; \
47+ int lane_id = threadIdx.x & 0x1f ; \
48+ int warp_id = threadIdx.x >> 5 ; \
49+ \
50+ warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR); \
51+ if (lane_id == 0 ) { \
52+ for (int offset = 0 ; offset < LANES; ++offset) { \
53+ shm[offset][warp_id] = *(VAL_PTR + offset); \
54+ } \
55+ } \
56+ __syncthreads (); \
57+ \
58+ _Pragma (" unroll" ) for (int offset = 0 ; offset < LANES; ++offset) { \
59+ *(VAL_PTR + offset) = (threadIdx.x < (blockDim.x >> 5 )) \
60+ ? shm[offset][lane_id] \
61+ : static_cast <T>(DEFAULT_VALUE); \
62+ } \
6663 warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR);
6764
68- template <typename T, ReduceType rtype, int lanes>
65+ template <typename T, ReduceType rtype, int lanes, int width = 32 >
6966__forceinline__ __device__ void warp_reduce (T* pval) {
7067 typename GetOpForReduceType<T, rtype>::Op op;
71- COLOSSAL_WARP_REDUCE_IMPL (kWarpReduceMask , pval, op, lanes);
68+ COLOSSAL_WARP_REDUCE_IMPL (kWarpReduceMask , pval, width, op, lanes);
7269}
7370
7471template <typename T, ReduceType rtype>
@@ -84,8 +81,7 @@ template <typename T, ReduceType rtype, int lanes>
8481__forceinline__ __device__ void block_reduce (T* pval) {
8582 constexpr T kDefaultValue = GetDefaultValueForBlockReduce<T, rtype>();
8683 typename GetOpForReduceType<T, rtype>::Op op;
87- COLOSSAL_BLOCK_REDUCE_IMPL (T, kWarpReduceMask , pval, op, lanes, kDefaultValue ,
88- rtype);
84+ COLOSSAL_BLOCK_REDUCE_IMPL (T, pval, op, lanes, kDefaultValue , rtype);
8985}
9086
9187#undef COLOSSAL_SHFL_FUNCTION
0 commit comments