Skip to content

Commit 19596b1

Browse files
committed
CUDA: cov2d with tensor core
1 parent 5d6688d commit 19596b1

File tree

2 files changed

+265
-72
lines changed

2 files changed

+265
-72
lines changed

ggml/src/ggml-cuda/conv2d.cu

Lines changed: 257 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#include "conv2d.cuh"
22
#include "convert.cuh"
33

4+
#include <mma.h>
5+
using namespace nvcuda;
6+
47
struct conv_params {
58
const int64_t IW, IH;
69
const int64_t OW, OH;
@@ -11,112 +14,292 @@ struct conv_params {
1114
const int64_t IC, OC;
1215
const int64_t B;
1316
const int64_t TOTAL;
17+
// helpers
18+
const int64_t IC_KH_KW, N_OH_OW;
1419
};
1520

16-
struct kernel_bounds {
17-
int64_t y_min, y_max;
18-
int64_t x_min, x_max;
21+
auto ceil_div = [](int a, int b) {
22+
return (a + b - 1) / b;
1923
};
2024

21-
__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {
22-
return (a > b) ? a : b;
23-
}
24-
25-
__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {
26-
return (a < b) ? a : b;
27-
}
28-
29-
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {
30-
kernel_bounds bounds;
31-
bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
32-
bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
33-
bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
34-
bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
35-
return bounds;
36-
}
37-
38-
__device__ __forceinline__ int calculate_input_coord(int64_t out_coord,
39-
int64_t kern_coord,
40-
int64_t stride,
41-
int64_t dilation,
42-
int64_t padding) {
25+
__device__ __forceinline__ static int calculate_input_coord(int64_t out_coord,
26+
int64_t kern_coord,
27+
int64_t stride,
28+
int64_t dilation,
29+
int64_t padding) {
4330
return out_coord * stride + kern_coord * dilation - padding;
4431
}
4532

4633
struct whcn_layout {
47-
__device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
34+
__device__ __forceinline__ static int64_t input_index(int64_t n,
35+
int64_t c,
36+
int64_t y,
37+
int64_t x,
38+
const conv_params & P) {
4839
return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;
4940
}
5041

51-
__device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {
42+
__device__ __forceinline__ static int64_t kernel_index(int64_t c_out,
43+
int64_t c_in,
44+
int64_t ky,
45+
int64_t kx,
46+
const conv_params & P) {
5247
return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;
5348
}
5449

55-
__device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
50+
__device__ __forceinline__ static int64_t output_index(int64_t n,
51+
int64_t c,
52+
int64_t y,
53+
int64_t x,
54+
const conv_params & P) {
5655
return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;
5756
}
5857

59-
__device__ static void unpack_indices(int64_t global_idx,
60-
const conv_params & P,
61-
int64_t & n,
62-
int64_t & c,
63-
int64_t & out_y,
64-
int64_t & out_x) {
65-
out_x = global_idx % P.OW;
66-
out_y = (global_idx / P.OW) % P.OH;
67-
c = (global_idx / (P.OW * P.OH)) % P.OC;
68-
n = global_idx / (P.OW * P.OH * P.OC);
58+
__device__ __forceinline__ static void unpack_ickhkw(int64_t idx,
59+
int64_t & ic,
60+
int64_t & kh,
61+
int64_t & kw,
62+
const conv_params & P) {
63+
ic = idx / (P.KW * P.KH);
64+
int64_t r = idx - ic * (P.KW * P.KH);
65+
kh = r / P.KW;
66+
kw = r - kh * P.KW;
67+
}
68+
69+
__device__ __forceinline__ static void unpack_nohow(int64_t idx,
70+
int64_t & n,
71+
int64_t & oh,
72+
int64_t & ow,
73+
const conv_params & P) {
74+
n = idx / (P.OH * P.OW);
75+
int64_t r = idx - n * (P.OH * P.OW);
76+
oh = r / P.OW;
77+
ow = r - oh * P.OW;
78+
}
79+
};
80+
81+
class float_mma {
82+
public:
83+
float * buf;
84+
85+
__device__ __forceinline__ float_mma(float * scratch) {
86+
buf = scratch;
87+
const int lane_id = threadIdx.x % warpSize;
88+
#pragma unroll
89+
for (int i = lane_id; i < WMMA_M * WMMA_N; i += warpSize) {
90+
buf[i] = 0.0f;
91+
}
92+
}
93+
94+
__device__ __forceinline__ void mma(const float * A_sh, const float * B_sh, const int strideA, const int strideB) {
95+
const int lane_id = threadIdx.x % warpSize;
96+
#pragma unroll
97+
for (int e = lane_id; e < (WMMA_M * WMMA_N); e += warpSize) {
98+
int m = e / WMMA_N;
99+
int n = e % WMMA_N;
100+
float sum = buf[m * WMMA_N + n];
101+
#pragma unroll
102+
for (int k = 0; k < WMMA_K; k++) {
103+
float a = A_sh[m * strideA + k];
104+
float b = B_sh[k * strideB + n];
105+
sum = fmaf(a, b, sum);
106+
}
107+
buf[m * WMMA_N + n] = sum;
108+
}
69109
}
110+
111+
__device__ __forceinline__ float * store_result() const { return buf; }
70112
};
71113

72-
template <typename T, typename Layout>
73-
static __global__ void conv2d_kernel(const float * __restrict__ input,
74-
const T * __restrict__ kernel,
75-
float * __restrict__ output,
76-
const conv_params P) {
77-
const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
114+
class half_mma {
115+
private:
116+
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc;
117+
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
118+
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> b_frag;
119+
public:
120+
float * buf;
121+
122+
__device__ __forceinline__ half_mma(float * scratch) {
123+
buf = scratch;
124+
wmma::fill_fragment(acc, 0.0f);
125+
}
126+
127+
__device__ __forceinline__ void mma(const half * A_sh, const half * B_sh, const int strideA, const int strideB) {
128+
wmma::load_matrix_sync(a_frag, A_sh, strideA);
129+
wmma::load_matrix_sync(b_frag, B_sh, strideB);
130+
wmma::mma_sync(acc, a_frag, b_frag, acc);
131+
}
78132

79-
if (global_idx >= P.TOTAL) {
80-
return;
133+
__device__ __forceinline__ float * store_result() const {
134+
wmma::store_matrix_sync(buf, acc, WMMA_N, wmma::mem_row_major);
135+
return buf;
81136
}
137+
};
138+
139+
template <typename T, typename layout, typename mma>
140+
static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT, const conv_params P) {
141+
extern __shared__ unsigned char smem_raw[];
142+
143+
const int64_t OUTPUT_NUMEL = WMMA_M * WMMA_N;
144+
const int64_t NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW;
145+
146+
const int64_t WARPS_PER_NOHOW = max(1, BS_NOHOW / WMMA_N);
147+
148+
const int64_t NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW;
149+
const int64_t tile_id = blockIdx.x;
150+
const int64_t tile_oc = tile_id / NUM_BL_NOHOW;
151+
const int64_t tile_nohow = tile_id % NUM_BL_NOHOW;
152+
const int64_t BLOCK_OC_BASE = tile_oc * BS_OC;
153+
const int64_t BLOCK_NOHOW_BASE = tile_nohow * BS_NOHOW;
154+
155+
const int64_t laneId = threadIdx.x % WARP_SIZE;
156+
const int64_t warpId = threadIdx.x / WARP_SIZE;
157+
158+
const int64_t WARP_OC = warpId / WARPS_PER_NOHOW;
159+
const int64_t WARP_NOHOW = warpId % WARPS_PER_NOHOW;
82160

83-
int64_t n, c_out, out_y, out_x;
84-
Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
161+
const int64_t OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M;
162+
const int64_t NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N;
85163

86-
float acc = 0.0f;
164+
unsigned char * ptr = smem_raw;
165+
T * A_sh = reinterpret_cast<T *>(ptr);
87166

88-
for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
89-
kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
167+
size_t offsetA = BS_OC * BS_ICKHKW * sizeof(T);
168+
ptr += offsetA;
90169

91-
for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {
92-
const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);
170+
T * B_sh = reinterpret_cast<T *>(ptr);
171+
ptr += BS_ICKHKW * BS_NOHOW * sizeof(T);
93172

94-
for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
95-
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
173+
float * shared_scratch = reinterpret_cast<float *>(ptr);
174+
float * warp_scratch = shared_scratch + warpId * (WMMA_M * WMMA_N);
96175

97-
const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
98-
const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
99-
acc += (input_val * ggml_cuda_cast<float>(kernel_val));
176+
const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW;
177+
const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N;
178+
179+
mma acc(warp_scratch);
180+
181+
const int64_t A_total = BS_OC * BS_ICKHKW;
182+
const int64_t B_total = BS_ICKHKW * BS_NOHOW;
183+
184+
#pragma unroll
185+
for (int64_t t = 0; t < NUM_IC_TILES; ++t) {
186+
#pragma unroll
187+
for (int64_t tid = (threadIdx.x); tid < A_total; tid += blockDim.x) {
188+
const int row = tid / BS_ICKHKW;
189+
const int col = tid % BS_ICKHKW;
190+
191+
int64_t shared_oc = BLOCK_OC_BASE + row;
192+
int64_t shared_ickhkw = t * BS_ICKHKW + col;
193+
194+
T val = ggml_cuda_cast<T>(0);
195+
if (shared_oc < P.OC && shared_ickhkw < P.IC_KH_KW) {
196+
int64_t ic, kh, kw;
197+
layout::unpack_ickhkw(shared_ickhkw, ic, kh, kw, P);
198+
199+
const int64_t kidx = layout::kernel_index(shared_oc, ic, kh, kw, P);
200+
val = IK[kidx];
100201
}
202+
A_sh[row * BS_ICKHKW + col] = val;
203+
}
204+
205+
#pragma unroll
206+
for (int64_t tid = (threadIdx.x); tid < B_total; tid += blockDim.x) {
207+
const int brow = tid / BS_NOHOW;
208+
const int bcol = tid % BS_NOHOW;
209+
210+
int64_t IC_KH_KW_IDX = t * BS_ICKHKW + brow;
211+
int64_t N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol;
212+
213+
T val = ggml_cuda_cast<T>(0);
214+
if (N_OH_OW_IDX < P.N_OH_OW && IC_KH_KW_IDX < P.IC_KH_KW) {
215+
int64_t n, oh, ow;
216+
layout::unpack_nohow(N_OH_OW_IDX, n, oh, ow, P);
217+
int64_t ic, kh, kw;
218+
layout::unpack_ickhkw(IC_KH_KW_IDX, ic, kh, kw, P);
219+
int in_y = calculate_input_coord(oh, kh, P.ST_Y, P.DL_Y, P.PD_Y);
220+
int in_x = calculate_input_coord(ow, kw, P.ST_X, P.DL_X, P.PD_X);
221+
if (in_y >= 0 && in_y < P.IH && in_x >= 0 && in_x < P.IW) {
222+
const int64_t in_idx = layout::input_index(n, ic, in_y, in_x, P);
223+
val = ggml_cuda_cast<T>(IN[in_idx]);
224+
}
225+
}
226+
B_sh[brow * BS_NOHOW + bcol] = val;
227+
}
228+
229+
__syncthreads();
230+
231+
#pragma unroll
232+
for (int k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) {
233+
const T * A_k_ptr = A_warp_base + k_tile;
234+
const T * B_k_ptr = B_warp_base + k_tile * BS_NOHOW;
235+
236+
acc.mma(A_k_ptr, B_k_ptr, BS_ICKHKW, BS_NOHOW);
101237
}
238+
__syncthreads();
102239
}
103240

104-
// [N, OC, OH, OW]
105-
output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc;
241+
const float * out_buf = acc.store_result();
242+
#pragma unroll
243+
for (int e = laneId; e < OUTPUT_NUMEL; e += WARP_SIZE) {
244+
const int m = e / WMMA_N;
245+
const int n = e % WMMA_N;
246+
247+
const int64_t oc = OC_BASE + m;
248+
const int64_t nohow = NOHOW_BASE + n;
249+
250+
if (oc < P.OC && nohow < (P.N_OH_OW)) {
251+
int64_t n, oh, ow;
252+
layout::unpack_nohow(nohow, n, oh, ow, P);
253+
const int64_t out_idx = layout::output_index(n, oc, oh, ow, P);
254+
OUT[out_idx] = out_buf[e];
255+
}
256+
}
106257
}
107258

108-
template <typename T>
109-
static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
110-
const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
111-
conv2d_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P);
259+
template <typename T, typename mma>
260+
static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, conv_params P, cudaStream_t st)
261+
262+
{
263+
const int64_t NUM_BL_OC = (P.OC + BS_OC - 1) / BS_OC;
264+
const int64_t NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW;
265+
266+
int64_t TOTAL_TILES = NUM_BL_OC * NUM_BL_NOHOW;
267+
TOTAL_TILES = std::min(TOTAL_TILES, (int64_t) INT_MAX);
268+
269+
const int WARPS_PER_OC = std::max(1, ceil_div(BS_OC, WMMA_M));
270+
const int WARPS_PER_NOHOW = std::max(1, ceil_div(BS_NOHOW, WMMA_N));
271+
const int EXPECTED_WARPS = WARPS_PER_OC * WARPS_PER_NOHOW;
272+
int N_THREADS = EXPECTED_WARPS * WARP_SIZE;
273+
274+
const int MAX_TPB = 1024;
275+
if (N_THREADS > MAX_TPB) {
276+
N_THREADS = (MAX_TPB / WARP_SIZE) * WARP_SIZE;
277+
}
278+
279+
if (N_THREADS < WARP_SIZE) {
280+
N_THREADS = WARP_SIZE;
281+
}
282+
283+
const int N_WARPS = N_THREADS / WARP_SIZE;
284+
285+
// scratch_buff to store output, can't store directly using wmma,
286+
// output mapping is unknown
287+
const int64_t scratch_bytes = N_WARPS * (WMMA_M * WMMA_N) * sizeof(float);
288+
289+
const int64_t A_bytes = BS_OC * BS_ICKHKW * sizeof(T);
290+
const int64_t B_bytes = BS_ICKHKW * BS_NOHOW * sizeof(T);
291+
const int64_t shared_bytes = A_bytes + B_bytes + scratch_bytes;
292+
293+
dim3 grid(TOTAL_TILES, 1, 1);
294+
conv2d_kernel<T, whcn_layout, mma><<<grid, N_THREADS, shared_bytes, st>>>(X_D, K_D, Y_D, P);
112295
}
113296

114-
static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
115-
conv2d_cuda<half>(X_D, K_D, Y_D, P, st);
297+
static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, conv_params & P, cudaStream_t st) {
298+
conv2d_cuda<half, half_mma>(X_D, K_D, Y_D, P, st);
116299
}
117300

118-
static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
119-
conv2d_cuda<float>(X_D, K_D, Y_D, P, st);
301+
static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, conv_params & P, cudaStream_t st) {
302+
conv2d_cuda<float, float_mma>(X_D, K_D, Y_D, P, st);
120303
}
121304

122305
void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -155,11 +338,14 @@ void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
155338
const int OC = kernel->ne[3]; // ouptut_chanles
156339
const int B = input->ne[3]; // n_batches
157340

158-
const int64_t total = B * OC * OH * OW;
159-
conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
341+
const int64_t TOTAL = B * OC * OH * OW;
342+
const int64_t IC_KH_KW = IC * KH * KW;
343+
const int64_t N_OH_OW = B * OH * OW;
344+
conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X,
345+
PD_Y, DL_X, DL_Y, IC, OC, B, TOTAL, IC_KH_KW, N_OH_OW };
160346

161347
if (kernel->type == GGML_TYPE_F16) {
162-
conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
348+
conv2d_cuda_f16(X_D, (const half *) K_D, Y_D, params, st);
163349
} else {
164350
conv2d_cuda_f32(X_D, K_D, Y_D, params, st);
165351
}

0 commit comments

Comments
 (0)