4
4
#include < cuda_runtime.h>
5
5
#include < cuda_fp16.h>
6
6
7
+ // #include "cuda_buffers.cuh"
8
+
7
9
class MatrixView_half
8
10
{
9
11
public:
10
12
const half* data;
11
13
const int height;
12
14
const int width;
13
15
14
- __device__ inline MatrixView_half (const half* data, const int height, const int width)
16
+ __device__ __forceinline__ MatrixView_half (const half* data, const int height, const int width)
15
17
: data(data), height(height), width(width)
16
18
{ }
17
19
18
- __device__ inline half item (int row, int column) const { return data[row * width + column]; }
19
- __device__ inline half2 item_half2 (int row, int column) const { return ((half2*)data)[(row * width + column) / 2 ]; }
20
- __device__ inline half2 item_half2half2 (int row, int column) const { return __half2half2 (data[row * width + column]); }
21
- __device__ inline const half* item_ptr (int row, int column) const { return &data[row * width + column]; }
20
+ __device__ __forceinline__ half item (int row, int column) const { return data[row * width + column]; }
21
+ __device__ __forceinline__ half2 item_half2 (int row, int column) const { return ((half2*)data)[(row * width + column) / 2 ]; }
22
+ __device__ __forceinline__ half2 item_half2half2 (int row, int column) const { return __half2half2 (data[row * width + column]); }
23
+ __device__ __forceinline__ const half* item_ptr (int row, int column) const { return &data[row * width + column]; }
22
24
};
23
25
24
26
class MatrixView_half_rw
@@ -28,15 +30,15 @@ public:
28
30
const int height;
29
31
const int width;
30
32
31
- __device__ inline MatrixView_half_rw (half* data, const int height, const int width)
33
+ __device__ __forceinline__ MatrixView_half_rw (half* data, const int height, const int width)
32
34
: data(data), height(height), width(width)
33
35
{ }
34
36
35
- __device__ inline half item (int row, int column) const { return data[row * width + column]; }
36
- __device__ inline half2 item_half2 (int row, int column) const { return ((half2*)data)[(row * width + column) / 2 ]; }
37
- __device__ inline half* item_ptr (int row, int column) { return &data[row * width + column]; }
38
- __device__ inline void set (int row, int column, half value) { data[row * width + column] = value; }
39
- __device__ inline void set_half2 (int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2 ] = value; }
37
+ __device__ __forceinline__ half item (int row, int column) const { return data[row * width + column]; }
38
+ __device__ __forceinline__ half2 item_half2 (int row, int column) const { return ((half2*)data)[(row * width + column) / 2 ]; }
39
+ __device__ __forceinline__ half* item_ptr (int row, int column) { return &data[row * width + column]; }
40
+ __device__ __forceinline__ void set (int row, int column, half value) { data[row * width + column] = value; }
41
+ __device__ __forceinline__ void set_half2 (int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2 ] = value; }
40
42
};
41
43
42
44
class MatrixView_q4_row
@@ -46,11 +48,11 @@ public:
46
48
const int height;
47
49
const int width;
48
50
49
- __device__ inline MatrixView_q4_row (const uint32_t * data, const int height, const int width)
51
+ __device__ __forceinline__ MatrixView_q4_row (const uint32_t * data, const int height, const int width)
50
52
: data(data), height(height), width(width)
51
53
{ }
52
54
53
- __device__ inline int item (int row, int column) const
55
+ __device__ __forceinline__ int item (int row, int column) const
54
56
{
55
57
int shift = (column & 0x07 ) * 4 ;
56
58
return (data[row * width / 8 + column / 8 ] >> shift) & 0x0f ;
@@ -64,25 +66,25 @@ public:
64
66
const int height;
65
67
const int width;
66
68
67
- __device__ inline MatrixView_q4_column (const uint32_t * data, const int height, const int width)
69
+ __device__ __forceinline__ MatrixView_q4_column (const uint32_t * data, const int height, const int width)
68
70
: data(data), height(height), width(width)
69
71
{ }
70
72
71
- __device__ inline int item (int row, int column) const
73
+ __device__ __forceinline__ int item (int row, int column) const
72
74
{
73
75
int shift = (row & 0x07 ) * 4 ;
74
76
return (data[row / 8 * width + column] >> shift) & 0x0f ;
75
77
}
76
78
77
- __device__ inline uint32_t item_uint32_t (int row, int column) { return data[row / 8 * width + column]; }
78
- __device__ inline const uint32_t * item_uint32_ptr (int row, int column) { return &data[row / 8 * width + column]; }
79
+ __device__ __forceinline__ uint32_t item_uint32_t (int row, int column) { return data[row / 8 * width + column]; }
80
+ __device__ __forceinline__ const uint32_t * item_uint32_ptr (int row, int column) { return &data[row / 8 * width + column]; }
79
81
};
80
82
81
83
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
82
84
83
85
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
84
86
85
- __device__ inline half2 dot_product_8
87
+ __device__ __forceinline__ half2 dot_product_8
86
88
(
87
89
const half2 acc,
88
90
MatrixView_half& h_,
@@ -118,21 +120,22 @@ __device__ inline half2 dot_product_8
118
120
half2 v_45 = __halves2half2 (v_4, v_5);
119
121
half2 v_67 = __halves2half2 (v_6, v_7);
120
122
121
- v_01 = __hmul2 (v_01, v_scale_2);
122
- v_23 = __hmul2 (v_23, v_scale_2) ;
123
- v_45 = __hmul2 (v_45, v_scale_2) ;
124
- v_67 = __hmul2 (v_67, v_scale_2) ;
123
+ // half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
124
+ // half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff] ;
125
+ // half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff] ;
126
+ // half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ] ;
125
127
126
- result = __hfma2 (*h_ptr++, v_01, result);
127
- result = __hfma2 (*h_ptr++, v_23, result);
128
- result = __hfma2 (*h_ptr++, v_45, result);
129
- result = __hfma2 (*h_ptr++, v_67, result);
128
+ half2 tmp = __hmul2 (*h_ptr++, v_01);
129
+ tmp = __hfma2 (*h_ptr++, v_23, tmp);
130
+ tmp = __hfma2 (*h_ptr++, v_45, tmp);
131
+ tmp = __hfma2 (*h_ptr++, v_67, tmp);
132
+ result = __hfma2 (v_scale_2, tmp, result);
130
133
}
131
134
132
135
return result;
133
136
}
134
137
135
- __device__ inline half dot_product_8_h
138
+ __device__ __forceinline__ half dot_product_8_h
136
139
(
137
140
const half acc,
138
141
MatrixView_half& h_,
@@ -163,31 +166,23 @@ __device__ inline half dot_product_8_h
163
166
half v_6 = __int2half_rn ((int )((v_read >> 24 ) & 0x0f ) - v_zero);
164
167
half v_7 = __int2half_rn ((int )((v_read >> 28 ) ) - v_zero);
165
168
166
- v_0 = __hmul (v_0, v_scale);
167
- v_1 = __hmul (v_1, v_scale);
168
- v_2 = __hmul (v_2, v_scale);
169
- v_3 = __hmul (v_3, v_scale);
170
- v_4 = __hmul (v_4, v_scale);
171
- v_5 = __hmul (v_5, v_scale);
172
- v_6 = __hmul (v_6, v_scale);
173
- v_7 = __hmul (v_7, v_scale);
174
-
175
- result = __hfma (*h_ptr++, v_0, result);
176
- result = __hfma (*h_ptr++, v_1, result);
177
- result = __hfma (*h_ptr++, v_2, result);
178
- result = __hfma (*h_ptr++, v_3, result);
179
- result = __hfma (*h_ptr++, v_4, result);
180
- result = __hfma (*h_ptr++, v_5, result);
181
- result = __hfma (*h_ptr++, v_6, result);
182
- result = __hfma (*h_ptr++, v_7, result);
169
+ half tmp = __hmul (*h_ptr++, v_0);
170
+ tmp = __hfma (*h_ptr++, v_1, tmp);
171
+ tmp = __hfma (*h_ptr++, v_2, tmp);
172
+ tmp = __hfma (*h_ptr++, v_3, tmp);
173
+ tmp = __hfma (*h_ptr++, v_4, tmp);
174
+ tmp = __hfma (*h_ptr++, v_5, tmp);
175
+ tmp = __hfma (*h_ptr++, v_6, tmp);
176
+ tmp = __hfma (*h_ptr++, v_7, tmp);
177
+ result = __hfma (v_scale, tmp, result);
183
178
}
184
179
185
180
return result;
186
181
}
187
182
188
183
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
189
184
190
- __device__ inline half2 dot_product_8_x_map
185
+ __device__ __forceinline__ half2 dot_product_8_x_map
191
186
(
192
187
const half2 acc,
193
188
MatrixView_half& h_,
@@ -225,11 +220,6 @@ __device__ inline half2 dot_product_8_x_map
225
220
half2 v_45 = __halves2half2 (v_4, v_5);
226
221
half2 v_67 = __halves2half2 (v_6, v_7);
227
222
228
- v_01 = __hmul2 (v_01, v_scale_2);
229
- v_23 = __hmul2 (v_23, v_scale_2);
230
- v_45 = __hmul2 (v_45, v_scale_2);
231
- v_67 = __hmul2 (v_67, v_scale_2);
232
-
233
223
half h_0 = h_ptr[*x_map_ptr++];
234
224
half h_1 = h_ptr[*x_map_ptr++];
235
225
half h_2 = h_ptr[*x_map_ptr++];
@@ -244,16 +234,17 @@ __device__ inline half2 dot_product_8_x_map
244
234
half2 h_45 = __halves2half2 (h_4, h_5);
245
235
half2 h_67 = __halves2half2 (h_6, h_7);
246
236
247
- result = __hfma2 (h_01, v_01, result);
248
- result = __hfma2 (h_23, v_23, result);
249
- result = __hfma2 (h_45, v_45, result);
250
- result = __hfma2 (h_67, v_67, result);
237
+ half2 tmp = __hmul2 (h_01, v_01);
238
+ tmp = __hfma2 (h_23, v_23, tmp);
239
+ tmp = __hfma2 (h_45, v_45, tmp);
240
+ tmp = __hfma2 (h_67, v_67, tmp);
241
+ result = __hfma2 (v_scale_2, tmp, result);
251
242
}
252
243
253
244
return result;
254
245
}
255
246
256
- __device__ inline half dot_product_8_x_map_h
247
+ __device__ __forceinline__ half dot_product_8_x_map_h
257
248
(
258
249
const half acc,
259
250
MatrixView_half& h_,
@@ -286,23 +277,15 @@ __device__ inline half dot_product_8_x_map_h
286
277
half v_6 = __int2half_rn ((int )((v_read >> 24 ) & 0x0f ) - v_zero);
287
278
half v_7 = __int2half_rn ((int )((v_read >> 28 ) ) - v_zero);
288
279
289
- v_0 = __hmul (v_0, v_scale);
290
- v_1 = __hmul (v_1, v_scale);
291
- v_2 = __hmul (v_2, v_scale);
292
- v_3 = __hmul (v_3, v_scale);
293
- v_4 = __hmul (v_4, v_scale);
294
- v_5 = __hmul (v_5, v_scale);
295
- v_6 = __hmul (v_6, v_scale);
296
- v_7 = __hmul (v_7, v_scale);
297
-
298
- result = __hfma (h_ptr[*x_map_ptr++], v_0, result);
299
- result = __hfma (h_ptr[*x_map_ptr++], v_1, result);
300
- result = __hfma (h_ptr[*x_map_ptr++], v_2, result);
301
- result = __hfma (h_ptr[*x_map_ptr++], v_3, result);
302
- result = __hfma (h_ptr[*x_map_ptr++], v_4, result);
303
- result = __hfma (h_ptr[*x_map_ptr++], v_5, result);
304
- result = __hfma (h_ptr[*x_map_ptr++], v_6, result);
305
- result = __hfma (h_ptr[*x_map_ptr++], v_7, result);
280
+ half tmp = __hmul (h_ptr[*x_map_ptr++], v_0);
281
+ tmp = __hfma (h_ptr[*x_map_ptr++], v_1, tmp);
282
+ tmp = __hfma (h_ptr[*x_map_ptr++], v_2, tmp);
283
+ tmp = __hfma (h_ptr[*x_map_ptr++], v_3, tmp);
284
+ tmp = __hfma (h_ptr[*x_map_ptr++], v_4, tmp);
285
+ tmp = __hfma (h_ptr[*x_map_ptr++], v_5, tmp);
286
+ tmp = __hfma (h_ptr[*x_map_ptr++], v_6, tmp);
287
+ tmp = __hfma (h_ptr[*x_map_ptr++], v_7, tmp);
288
+ result = __hfma (v_scale, tmp, result);
306
289
}
307
290
308
291
return result;
0 commit comments