Skip to content

Commit 9df8ff4

Browse files
optimize rmsnorm: add vectorized elementwise op, feat loop unrolling
1 parent 368a2aa commit 9df8ff4

File tree

2 files changed

+398
-30
lines changed

2 files changed

+398
-30
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* This code from NVIDIA FasterTransformer:
3+
* https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/utils/cuda_type_utils.cuh
4+
*/
5+
6+
#pragma once
7+
8+
#include <cuda.h>
9+
#include <cuda_fp16.h>
10+
11+
template <typename T>
12+
inline __device__ T add(T a, T b) {
13+
return a + b;
14+
}
15+
16+
template <>
17+
inline __device__ half2 add(half2 a, half2 b) {
18+
return __hadd2(a, b);
19+
}
20+
21+
template <>
22+
inline __device__ half add(half a, half b) {
23+
return __hadd(a, b);
24+
}
25+
26+
#if ENABLE_BF16
27+
template <>
28+
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
29+
return bf16hadd2(a, b);
30+
}
31+
32+
template <>
33+
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
34+
return bf16hadd(a, b);
35+
}
36+
37+
#endif // ENABLE_BF16
38+
39+
template <typename T>
40+
inline __device__ T mul(T a, T b, T c) {
41+
return a * b * c;
42+
}
43+
44+
template <>
45+
inline __device__ half2 mul(half2 a, half2 b, half2 c) {
46+
return __hmul2(__hmul2(a, b), c);
47+
}
48+
49+
#if ENABLE_BF16
50+
template <>
51+
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b,
52+
__nv_bfloat16 c) {
53+
return bf16hmul(a, b, c);
54+
}
55+
56+
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b,
57+
__nv_bfloat162 c) {
58+
return bf16hmul2(a, b, c);
59+
}
60+
#endif // ENABLE_BF16
61+
62+
template <typename T_OUT, typename T_IN>
63+
__device__ inline T_OUT cuda_cast(T_IN val) {
64+
return val;
65+
}
66+
67+
template <>
68+
__device__ inline float2 cuda_cast<float2, int2>(int2 val) {
69+
return make_float2(val.x, val.y);
70+
}
71+
template <>
72+
__device__ inline float2 cuda_cast<float2, float>(float val) {
73+
return make_float2(val, val);
74+
}
75+
template <>
76+
__device__ inline float2 cuda_cast<float2, half2>(half2 val) {
77+
return __half22float2(val);
78+
}
79+
template <>
80+
__device__ inline half2 cuda_cast<half2, float2>(float2 val) {
81+
return __float22half2_rn(val);
82+
}
83+
template <>
84+
__device__ inline half2 cuda_cast<half2, float>(float val) {
85+
return __float2half2_rn(val);
86+
}
87+
template <>
88+
__device__ inline half2 cuda_cast<half2, half>(half val) {
89+
return __half2half2(val);
90+
}
91+
template <>
92+
__device__ inline float cuda_cast<float, half>(half val) {
93+
return __half2float(val);
94+
}
95+
96+
// Get type2 from type or vice versa (applied to half and bfloat16)
97+
template <typename T>
98+
struct TypeConverter {
99+
using Type = half2;
100+
}; // keep for generality
101+
102+
template <>
103+
struct TypeConverter<half2> {
104+
using Type = at::Half;
105+
};
106+
107+
template <>
108+
struct TypeConverter<at::Half> {
109+
using Type = half2;
110+
};
111+
112+
#if ENABLE_BF16
113+
template <>
114+
struct TypeConverter<__nv_bfloat162> {
115+
using Type = at::BFloat16;
116+
};
117+
118+
template <>
119+
struct TypeConverter<at::BFloat16> {
120+
using Type = __nv_bfloat162;
121+
};
122+
#endif // ENABLE_BF16

0 commit comments

Comments
 (0)