Skip to content

Commit c73d921

Browse files
committed
Small matmul optimization
1 parent 94a2908 commit c73d921

File tree

3 files changed

+82
-72
lines changed

3 files changed

+82
-72
lines changed

exllama_ext/cuda_buffers.cu

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
#define _cuda_buffers_cu
12
#include "cuda_buffers.cuh"
23

34
CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
5+
// __constant__ half2 q4_table[16][256];
6+
// half2 q4_table_host[16][256];
7+
// bool q4_table_init = false;
48

59
CudaBuffers::CudaBuffers
610
(
@@ -64,4 +68,23 @@ void prepare_buffers_cuda
6468
);
6569

6670
g_buffers[_device] = buffers;
71+
72+
// if (!q4_table_init)
73+
// {
74+
// for (uint v_zero = 0; v_zero < 16; v_zero++)
75+
// {
76+
// for (uint v_read = 0; v_read < 256; v_read++)
77+
// {
78+
// half v_0 = __float2half((float)((int)((v_read ) & 0x0f) - v_zero - 1));
79+
// half v_1 = __float2half((float)((int)((v_read >> 4) & 0x0f) - v_zero - 1));
80+
// half2 v_01 = {v_0, v_1};
81+
// q4_table_host[v_zero][v_read] = v_01;
82+
// }
83+
// }
84+
// q4_table_init = true;
85+
// }
86+
//
87+
// cudaSetDevice(_device);
88+
// cudaMemcpyToSymbol(q4_table, q4_table_host, 16 * 256 * sizeof(half2));
89+
// cudaDeviceSynchronize();
6790
}

exllama_ext/cuda_buffers.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88

99
const int CUDA_MAX_DEVICES = 16;
1010

11+
// #ifndef _cuda_buffers_cu
12+
// extern __constant__ half2 q4_table[16][256];
13+
// #endif
14+
1115
class CudaBuffers
1216
{
1317
public:

exllama_ext/matrix.cuh

Lines changed: 55 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,23 @@
44
#include <cuda_runtime.h>
55
#include <cuda_fp16.h>
66

7+
//#include "cuda_buffers.cuh"
8+
79
class MatrixView_half
810
{
911
public:
1012
const half* data;
1113
const int height;
1214
const int width;
1315

14-
__device__ inline MatrixView_half(const half* data, const int height, const int width)
16+
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
1517
: data(data), height(height), width(width)
1618
{ }
1719

18-
__device__ inline half item(int row, int column) const { return data[row * width + column]; }
19-
__device__ inline half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
20-
__device__ inline half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
21-
__device__ inline const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
20+
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
21+
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
22+
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
23+
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
2224
};
2325

2426
class MatrixView_half_rw
@@ -28,15 +30,15 @@ public:
2830
const int height;
2931
const int width;
3032

31-
__device__ inline MatrixView_half_rw(half* data, const int height, const int width)
33+
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
3234
: data(data), height(height), width(width)
3335
{ }
3436

35-
__device__ inline half item(int row, int column) const { return data[row * width + column]; }
36-
__device__ inline half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
37-
__device__ inline half* item_ptr(int row, int column) { return &data[row * width + column]; }
38-
__device__ inline void set(int row, int column, half value) { data[row * width + column] = value; }
39-
__device__ inline void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
37+
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
38+
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
39+
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
40+
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
41+
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
4042
};
4143

4244
class MatrixView_q4_row
@@ -46,11 +48,11 @@ public:
4648
const int height;
4749
const int width;
4850

49-
__device__ inline MatrixView_q4_row(const uint32_t* data, const int height, const int width)
51+
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
5052
: data(data), height(height), width(width)
5153
{ }
5254

53-
__device__ inline int item(int row, int column) const
55+
__device__ __forceinline__ int item(int row, int column) const
5456
{
5557
int shift = (column & 0x07) * 4;
5658
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
@@ -64,25 +66,25 @@ public:
6466
const int height;
6567
const int width;
6668

67-
__device__ inline MatrixView_q4_column(const uint32_t* data, const int height, const int width)
69+
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
6870
: data(data), height(height), width(width)
6971
{ }
7072

71-
__device__ inline int item(int row, int column) const
73+
__device__ __forceinline__ int item(int row, int column) const
7274
{
7375
int shift = (row & 0x07) * 4;
7476
return (data[row / 8 * width + column] >> shift) & 0x0f;
7577
}
7678

77-
__device__ inline uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
78-
__device__ inline const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
79+
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
80+
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
7981
};
8082

8183
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
8284

8385
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
8486

85-
__device__ inline half2 dot_product_8
87+
__device__ __forceinline__ half2 dot_product_8
8688
(
8789
const half2 acc,
8890
MatrixView_half& h_,
@@ -118,21 +120,22 @@ __device__ inline half2 dot_product_8
118120
half2 v_45 = __halves2half2(v_4, v_5);
119121
half2 v_67 = __halves2half2(v_6, v_7);
120122

121-
v_01 = __hmul2(v_01, v_scale_2);
122-
v_23 = __hmul2(v_23, v_scale_2);
123-
v_45 = __hmul2(v_45, v_scale_2);
124-
v_67 = __hmul2(v_67, v_scale_2);
123+
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
124+
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
125+
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
126+
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
125127

126-
result = __hfma2(*h_ptr++, v_01, result);
127-
result = __hfma2(*h_ptr++, v_23, result);
128-
result = __hfma2(*h_ptr++, v_45, result);
129-
result = __hfma2(*h_ptr++, v_67, result);
128+
half2 tmp = __hmul2(*h_ptr++, v_01);
129+
tmp = __hfma2(*h_ptr++, v_23, tmp);
130+
tmp = __hfma2(*h_ptr++, v_45, tmp);
131+
tmp = __hfma2(*h_ptr++, v_67, tmp);
132+
result = __hfma2(v_scale_2, tmp, result);
130133
}
131134

132135
return result;
133136
}
134137

135-
__device__ inline half dot_product_8_h
138+
__device__ __forceinline__ half dot_product_8_h
136139
(
137140
const half acc,
138141
MatrixView_half& h_,
@@ -163,31 +166,23 @@ __device__ inline half dot_product_8_h
163166
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
164167
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
165168

166-
v_0 = __hmul(v_0, v_scale);
167-
v_1 = __hmul(v_1, v_scale);
168-
v_2 = __hmul(v_2, v_scale);
169-
v_3 = __hmul(v_3, v_scale);
170-
v_4 = __hmul(v_4, v_scale);
171-
v_5 = __hmul(v_5, v_scale);
172-
v_6 = __hmul(v_6, v_scale);
173-
v_7 = __hmul(v_7, v_scale);
174-
175-
result = __hfma(*h_ptr++, v_0, result);
176-
result = __hfma(*h_ptr++, v_1, result);
177-
result = __hfma(*h_ptr++, v_2, result);
178-
result = __hfma(*h_ptr++, v_3, result);
179-
result = __hfma(*h_ptr++, v_4, result);
180-
result = __hfma(*h_ptr++, v_5, result);
181-
result = __hfma(*h_ptr++, v_6, result);
182-
result = __hfma(*h_ptr++, v_7, result);
169+
half tmp = __hmul(*h_ptr++, v_0);
170+
tmp = __hfma(*h_ptr++, v_1, tmp);
171+
tmp = __hfma(*h_ptr++, v_2, tmp);
172+
tmp = __hfma(*h_ptr++, v_3, tmp);
173+
tmp = __hfma(*h_ptr++, v_4, tmp);
174+
tmp = __hfma(*h_ptr++, v_5, tmp);
175+
tmp = __hfma(*h_ptr++, v_6, tmp);
176+
tmp = __hfma(*h_ptr++, v_7, tmp);
177+
result = __hfma(v_scale, tmp, result);
183178
}
184179

185180
return result;
186181
}
187182

188183
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
189184

190-
__device__ inline half2 dot_product_8_x_map
185+
__device__ __forceinline__ half2 dot_product_8_x_map
191186
(
192187
const half2 acc,
193188
MatrixView_half& h_,
@@ -225,11 +220,6 @@ __device__ inline half2 dot_product_8_x_map
225220
half2 v_45 = __halves2half2(v_4, v_5);
226221
half2 v_67 = __halves2half2(v_6, v_7);
227222

228-
v_01 = __hmul2(v_01, v_scale_2);
229-
v_23 = __hmul2(v_23, v_scale_2);
230-
v_45 = __hmul2(v_45, v_scale_2);
231-
v_67 = __hmul2(v_67, v_scale_2);
232-
233223
half h_0 = h_ptr[*x_map_ptr++];
234224
half h_1 = h_ptr[*x_map_ptr++];
235225
half h_2 = h_ptr[*x_map_ptr++];
@@ -244,16 +234,17 @@ __device__ inline half2 dot_product_8_x_map
244234
half2 h_45 = __halves2half2(h_4, h_5);
245235
half2 h_67 = __halves2half2(h_6, h_7);
246236

247-
result = __hfma2(h_01, v_01, result);
248-
result = __hfma2(h_23, v_23, result);
249-
result = __hfma2(h_45, v_45, result);
250-
result = __hfma2(h_67, v_67, result);
237+
half2 tmp = __hmul2(h_01, v_01);
238+
tmp = __hfma2(h_23, v_23, tmp);
239+
tmp = __hfma2(h_45, v_45, tmp);
240+
tmp = __hfma2(h_67, v_67, tmp);
241+
result = __hfma2(v_scale_2, tmp, result);
251242
}
252243

253244
return result;
254245
}
255246

256-
__device__ inline half dot_product_8_x_map_h
247+
__device__ __forceinline__ half dot_product_8_x_map_h
257248
(
258249
const half acc,
259250
MatrixView_half& h_,
@@ -286,23 +277,15 @@ __device__ inline half dot_product_8_x_map_h
286277
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
287278
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
288279

289-
v_0 = __hmul(v_0, v_scale);
290-
v_1 = __hmul(v_1, v_scale);
291-
v_2 = __hmul(v_2, v_scale);
292-
v_3 = __hmul(v_3, v_scale);
293-
v_4 = __hmul(v_4, v_scale);
294-
v_5 = __hmul(v_5, v_scale);
295-
v_6 = __hmul(v_6, v_scale);
296-
v_7 = __hmul(v_7, v_scale);
297-
298-
result = __hfma(h_ptr[*x_map_ptr++], v_0, result);
299-
result = __hfma(h_ptr[*x_map_ptr++], v_1, result);
300-
result = __hfma(h_ptr[*x_map_ptr++], v_2, result);
301-
result = __hfma(h_ptr[*x_map_ptr++], v_3, result);
302-
result = __hfma(h_ptr[*x_map_ptr++], v_4, result);
303-
result = __hfma(h_ptr[*x_map_ptr++], v_5, result);
304-
result = __hfma(h_ptr[*x_map_ptr++], v_6, result);
305-
result = __hfma(h_ptr[*x_map_ptr++], v_7, result);
280+
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
281+
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
282+
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
283+
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
284+
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
285+
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
286+
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
287+
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
288+
result = __hfma(v_scale, tmp, result);
306289
}
307290

308291
return result;

0 commit comments

Comments
 (0)