-
Notifications
You must be signed in to change notification settings - Fork 191
/
gemm_cuda_gen.cu
298 lines (254 loc) · 13.5 KB
/
gemm_cuda_gen.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
// Inspired by NVIDIA's FasterTransformer
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <torch/extension.h>
#include "gemm_cuda.h"
#include <cuda_fp16.h>
#include <c10/cuda/CUDAGuard.h>
// Pack two half values.
static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
__device__ __forceinline__ int make_divisible(int c, int divisor){
return (c + divisor - 1) / divisor;
}
template <int G>
__global__ void __launch_bounds__(128) gemm_forward_4bit_cuda_m128n64k32(int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* zeros, int M, int IC, int OC, half* __restrict__ C)
{
static constexpr uint32_t ZERO = 0x0;
float C_warp[64];
__shared__ half A_shared[128 * (32 + 8)];
__shared__ half B_shared[64 * (32 + 8)];
// __shared__ half scaling_factors_shared[64];
// __shared__ half zeros_shared[64];
int j_factors1 = ((OC + 64 - 1) / 64);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 128 - 1) / 128 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 128 - 1) / 128 * j_factors1);
half A_shared_warp[32];
half B_shared_warp[16];
for (int i_0_3_init = 0; i_0_3_init < 4; ++i_0_3_init) {
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[((i_0_3_init * 16) + (j_0_4_init * 8)) + i] = 0.0;
}
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride_A = 4 * 32 * 8 / 32;
static constexpr int row_stride = 4 * 32 * 8 / 32;
const int make_divisible_multipler = 128 / G;
const int zeros_w = make_divisible(make_divisible(IC / G, 8), make_divisible_multipler) * make_divisible_multipler;
const int sf_w = zeros_w * 8;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
int ld_A_row = (blockIdx_y / j_factors1 * 128 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32); // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 128 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (IC / 8) * 8
+ (((int)threadIdx.x) / (32 / 8)) * (IC / 8)
+ (((int)blockIdx_y) % j_factors1) * 64 * (IC / 8)
+ (((int)threadIdx.x) % (32 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 4) * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* zeros_ptr = zeros
+ ((int)threadIdx.y) * zeros_w * 8
+ (((int)threadIdx.x) / (32 / 8)) * zeros_w
+ (((int)blockIdx_y) % j_factors1) * 64 * zeros_w
// this term is zero
+ (((int)threadIdx.x) % (32 / 8)) / G ;
half* scaling_factors_ptr = scaling_factors
+ ((int)threadIdx.y) * sf_w * 8
+ (((int)threadIdx.x) / (32 / 8)) * sf_w
+ (((int)blockIdx_y) % j_factors1) * (64) * sf_w
// this term is zero
+ (((int)threadIdx.x) % (32 / 8)) * 8 / G;
// Haotian: TBD, check, May 29 11:46 AM PST
half* C_ptr = C
+ blockIdx_z * M * OC // blockIdx_z -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64
+ (((int)threadIdx.y) / 2) * 32
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = make_divisible(IC / 32, split_k_iters); // (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * 32 + blockIdx_z >= IC) k_bound -= 1;
// TODO (Haotian): load scales and zero points to smem
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
// TODO: Haotian: Here we assume M % cta_M = 0.
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0)
{
if (ld_A_row + ax0_ax1_fused_0 * row_stride_A < M)
{
*(uint4*)(A_shared_ptr + ax0_ax1_fused_0 * row_stride_A * 40) = *(uint4*)(A_ptr + (ax0_ax1_fused_0 * row_stride_A * IC) + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr + ax0_ax1_fused_0 * row_stride_A * 40) = make_uint4(0, 0, 0, 0);
}
}
int* zeros_ptr_local = zeros_ptr + k_0_0 * 32 / G / 8;
half* scaling_factors_ptr_local = scaling_factors_ptr + k_0_0 * 32 / G;
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * (32 / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
int B_loaded_current = *(B_ptr_local + ax0_ax1_fused_0 * row_stride * (IC / 8));
int zeros_loaded = *(zeros_ptr_local + ax0_ax1_fused_0 * row_stride * zeros_w);
zeros_loaded >>= ((k_0_0 * 32 / G) % 8) * 4;
float current_zeros = (float)(zeros_loaded & 0xF);
half scaling_factors_loaded = *(scaling_factors_ptr_local + ax0_ax1_fused_0 * row_stride * sf_w);
half B_loaded_fp16[8];
#pragma unroll
for (int ic_1 = 0; ic_1 < 8; ic_1++){
float current_single_weight_fp = (float)(B_loaded_current & 0xF);
half dequantized_weight = __float2half(__half2float(scaling_factors_loaded) * (current_single_weight_fp - current_zeros));
B_loaded_current = B_loaded_current >> 4;
B_loaded_fp16[ic_1] = dequantized_weight;
}
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (32 + 8)) = *reinterpret_cast<uint4*>(B_loaded_fp16);
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
for (int ax0_0 = 0; ax0_0 < 4; ++ax0_0) {
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[((((((int)threadIdx.y) & 1) * 2560) + (ax0_0 * 640)) + (k_0_1 * 16))])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int ax0_0_1 = 0; ax0_0_1 < 2; ++ax0_0_1) {
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[((((((int)threadIdx.y) >> 1) * 1280) + (ax0_0_1 * 640)) + (k_0_1 * 16))])) + ((((((int)threadIdx.x) >> 4) * 320) + ((((int)threadIdx.x) & 7) * 40)) + (((((int)threadIdx.x) & 15) >> 3) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[3])
: "r"(addr)
);
}
}
for (int i_0_3 = 0; i_0_3 < 4; ++i_0_3) {
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) {
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3]));
}
}
}
}
}
// Haotian: Here (May 29 11:46AM PST)
// TODO: Shang: Hoist loop invariance.
for (int ax0_0_2 = 0; ax0_0_2 < 4; ++ax0_0_2) {
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 128 + (threadIdx.y % 2) * 64 + ax0_0_2 * 16 + (local_id % 4) / 2 * 8 + ((int)threadIdx.x) / 4;
if (row_offset < M)
{
*(C_ptr + ax1_0 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax0_0_2 * 16) + (ax1_0 * 8) + local_id]);
}
}
}
}
}
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
// assume that batch_size < 16 for now
torch::Tensor gemm_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int group_size,
int split_k_iters)
{
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
// for int4, need _kernel.size(1) * 8
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(0)}, options);
int num_out_feats = _out_feats.size(-2);
int num_out_channels = _out_feats.size(-1);
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
// blockIdx_x: i_factors[0] * j_factors[0]
// blockIdx_y: i_factors[1] * j_factors[1]
if (num_out_channels % 64 != 0)
throw std::invalid_argument("OC is not multiple of cta_N = 64");
if (num_out_channels % 8 != 0)
throw std::invalid_argument("OC is not multiple of pack_num = 8");
int j_factors1 = num_out_channels / 64 / 1;
dim3 num_blocks((num_out_feats + 128 - 1) / 128 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 4);
if (group_size == 128)
{
gemm_forward_4bit_cuda_m128n64k32<128><<<num_blocks, threads_per_block>>>(
split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else if (group_size == 64)
{
gemm_forward_4bit_cuda_m128n64k32<64><<<num_blocks, threads_per_block>>>(
split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else
{
throw std::invalid_argument("Group size temporarily not supported.");
}
return _out_feats.sum(0);
}