|
1 | 1 | #pragma once |
2 | 2 |
|
3 | 3 | #include <cuda.h> |
| 4 | +#include <cuda_bf16.h> |
4 | 5 | #include <cuda_fp16.h> |
5 | 6 | #include <cuda_runtime.h> |
6 | 7 |
|
7 | 8 | #include <functional> |
8 | 9 |
|
| 10 | +#include "../utils/micros.h" |
| 11 | + |
9 | 12 | namespace colossalAI { |
10 | 13 | namespace cuda { |
11 | 14 | namespace funcs { |
12 | 15 |
|
13 | | -enum class BinaryOpType { kAdd = 0, kMinus, kMul, KDiv, kMax, KMin }; |
| 16 | +enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin }; |
14 | 17 |
|
15 | | -template <typename T, BinaryOpType Op> |
| 18 | +// Note(LiuYang): This file provides base math operation for data type |
| 19 | +// include POD and cuda built-in type such as half and __nv_bfloat16 |
| 20 | +template <typename LT, typename RT, typename RET, BinaryOpType Op> |
16 | 21 | struct BinaryOpFunctor; |
17 | 22 |
|
18 | | -template <typename T> |
19 | | -struct BinaryOpFunctor<T, BinaryOpType::kAdd> |
20 | | - : public std::binary_function<T, T, T> { |
21 | | - __host__ __device__ T operator()(T lhs, T rhs) { return lhs + rhs; } |
22 | | -}; |
23 | | - |
24 | | -template <typename T> |
25 | | -struct BinaryOpFunctor<T, BinaryOpType::kMax> |
26 | | - : public std::binary_function<T, T, T> { |
27 | | - __host__ __device__ T operator()(T lhs, T rhs) { return max(lhs, rhs); } |
28 | | -}; |
| 23 | +#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BINARY_OP_TYPE, STMT, \ |
| 24 | + FUNCTION_MODIFIER, ARGS...) \ |
| 25 | + template <ARGS> \ |
| 26 | + struct BinaryOpFunctor<T, T, T, BINARY_OP_TYPE> \ |
| 27 | + : public std::binary_function<T, T, T> { \ |
| 28 | + FUNCTION_MODIFIER T operator()(T lhs, T rhs) { return STMT; } \ |
| 29 | + }; |
| 30 | + |
| 31 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kAdd, lhs + rhs, |
| 32 | + HOSTDEVICE, typename T) |
| 33 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMinus, lhs - rhs, |
| 34 | + HOSTDEVICE, typename T) |
| 35 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMul, lhs* rhs, |
| 36 | + HOSTDEVICE, typename T) |
| 37 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kDiv, lhs / rhs, |
| 38 | + HOSTDEVICE, typename T) |
| 39 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMax, max(lhs, rhs), |
| 40 | + HOSTDEVICE, typename T) |
| 41 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMin, min(lhs, rhs), |
| 42 | + HOSTDEVICE, typename T) |
| 43 | + |
| 44 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kAdd, |
| 45 | + __hadd(lhs, rhs), DEVICE) |
| 46 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kAdd, |
| 47 | + __hadd2(lhs, rhs), DEVICE) |
| 48 | + |
| 49 | +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 |
| 50 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd, |
| 51 | + __hadd(lhs, rhs), DEVICE) |
| 52 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kAdd, |
| 53 | + __hadd2(lhs, rhs), DEVICE) |
| 54 | +#else |
| 55 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd, |
| 56 | + __float2bfloat16(__bfloat162float(lhs) + |
| 57 | + __bfloat162float(rhs)), |
| 58 | + DEVICE) |
| 59 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( |
| 60 | + __nv_bfloat162, BinaryOpType::kAdd, |
| 61 | + __floats2bfloat162_rn(__low2float(lhs) + __low2float(rhs), |
| 62 | + __high2float(lhs) + __high2float(rhs)), |
| 63 | + DEVICE) |
| 64 | +#endif |
| 65 | + |
| 66 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kMul, |
| 67 | + __hmul(lhs, rhs), DEVICE) |
| 68 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kMul, |
| 69 | + __hmul2(lhs, rhs), DEVICE) |
| 70 | + |
| 71 | +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 |
| 72 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul, |
| 73 | + __hmul(lhs, rhs), DEVICE) |
| 74 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kMul, |
| 75 | + __hmul2(lhs, rhs), DEVICE) |
| 76 | +#else |
| 77 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul, |
| 78 | + __float2bfloat16(__bfloat162float(lhs) * |
| 79 | + __bfloat162float(rhs)), |
| 80 | + DEVICE) |
| 81 | +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( |
| 82 | + __nv_bfloat162, BinaryOpType::kMul, |
| 83 | + __floats2bfloat162_rn(__low2float(lhs) * __low2float(rhs), |
| 84 | + __high2float(lhs) * __high2float(rhs)), |
| 85 | + DEVICE) |
| 86 | +#endif |
| 87 | + |
| 88 | +#undef COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION |
29 | 89 |
|
30 | 90 | } // namespace funcs |
31 | 91 | } // namespace cuda |
|
0 commit comments