11#include " conv2d.cuh"
22#include " convert.cuh"
33
4+ #include < mma.h>
5+ using namespace nvcuda ;
6+
47struct conv_params {
58 const int64_t IW, IH;
69 const int64_t OW, OH;
@@ -11,112 +14,292 @@ struct conv_params {
1114 const int64_t IC, OC;
1215 const int64_t B;
1316 const int64_t TOTAL;
17+ // helpers
18+ const int64_t IC_KH_KW, N_OH_OW;
1419};
1520
16- struct kernel_bounds {
17- int64_t y_min, y_max;
18- int64_t x_min, x_max;
21+ auto ceil_div = [](int a, int b) {
22+ return (a + b - 1 ) / b;
1923};
2024
21- __device__ __forceinline__ int64_t max64 (int64_t a, int64_t b) {
22- return (a > b) ? a : b;
23- }
24-
25- __device__ __forceinline__ int64_t min64 (int64_t a, int64_t b) {
26- return (a < b) ? a : b;
27- }
28-
29- __device__ __forceinline__ kernel_bounds calculate_kernel_bounds (int64_t out_x, int64_t out_y, const conv_params & P) {
30- kernel_bounds bounds;
31- bounds.y_min = max64 (0 , (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1 ) / P.DL_Y );
32- bounds.y_max = min64 (P.KH , (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1 ) / P.DL_Y );
33- bounds.x_min = max64 (0 , (P.PD_X - out_x * P.ST_X + P.DL_X - 1 ) / P.DL_X );
34- bounds.x_max = min64 (P.KW , (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1 ) / P.DL_X );
35- return bounds;
36- }
37-
38- __device__ __forceinline__ int calculate_input_coord (int64_t out_coord,
39- int64_t kern_coord,
40- int64_t stride,
41- int64_t dilation,
42- int64_t padding) {
25+ __device__ __forceinline__ static int calculate_input_coord (int64_t out_coord,
26+ int64_t kern_coord,
27+ int64_t stride,
28+ int64_t dilation,
29+ int64_t padding) {
4330 return out_coord * stride + kern_coord * dilation - padding;
4431}
4532
4633struct whcn_layout {
47- __device__ static int64_t input_index (int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
34+ __device__ __forceinline__ static int64_t input_index (int64_t n,
35+ int64_t c,
36+ int64_t y,
37+ int64_t x,
38+ const conv_params & P) {
4839 return n * (P.IC * P.IW * P.IH ) + c * P.IW * P.IH + y * P.IW + x;
4940 }
5041
51- __device__ static int64_t kernel_index (int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {
42+ __device__ __forceinline__ static int64_t kernel_index (int64_t c_out,
43+ int64_t c_in,
44+ int64_t ky,
45+ int64_t kx,
46+ const conv_params & P) {
5247 return c_out * (P.IC * P.KH * P.KW ) + c_in * (P.KH * P.KW ) + ky * P.KW + kx;
5348 }
5449
55- __device__ static int64_t output_index (int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
50+ __device__ __forceinline__ static int64_t output_index (int64_t n,
51+ int64_t c,
52+ int64_t y,
53+ int64_t x,
54+ const conv_params & P) {
5655 return n * (P.OC * P.OW * P.OH ) + c * P.OW * P.OH + y * P.OW + x;
5756 }
5857
59- __device__ static void unpack_indices (int64_t global_idx,
60- const conv_params & P,
61- int64_t & n,
62- int64_t & c,
63- int64_t & out_y,
64- int64_t & out_x) {
65- out_x = global_idx % P.OW ;
66- out_y = (global_idx / P.OW ) % P.OH ;
67- c = (global_idx / (P.OW * P.OH )) % P.OC ;
68- n = global_idx / (P.OW * P.OH * P.OC );
58+ __device__ __forceinline__ static void unpack_ickhkw (int64_t idx,
59+ int64_t & ic,
60+ int64_t & kh,
61+ int64_t & kw,
62+ const conv_params & P) {
63+ ic = idx / (P.KW * P.KH );
64+ int64_t r = idx - ic * (P.KW * P.KH );
65+ kh = r / P.KW ;
66+ kw = r - kh * P.KW ;
67+ }
68+
69+ __device__ __forceinline__ static void unpack_nohow (int64_t idx,
70+ int64_t & n,
71+ int64_t & oh,
72+ int64_t & ow,
73+ const conv_params & P) {
74+ n = idx / (P.OH * P.OW );
75+ int64_t r = idx - n * (P.OH * P.OW );
76+ oh = r / P.OW ;
77+ ow = r - oh * P.OW ;
78+ }
79+ };
80+
81+ class float_mma {
82+ public:
83+ float * buf;
84+
85+ __device__ __forceinline__ float_mma (float * scratch) {
86+ buf = scratch;
87+ const int lane_id = threadIdx .x % warpSize ;
88+ #pragma unroll
89+ for (int i = lane_id; i < WMMA_M * WMMA_N; i += warpSize ) {
90+ buf[i] = 0 .0f ;
91+ }
92+ }
93+
94+ __device__ __forceinline__ void mma (const float * A_sh, const float * B_sh, const int strideA, const int strideB) {
95+ const int lane_id = threadIdx .x % warpSize ;
96+ #pragma unroll
97+ for (int e = lane_id; e < (WMMA_M * WMMA_N); e += warpSize ) {
98+ int m = e / WMMA_N;
99+ int n = e % WMMA_N;
100+ float sum = buf[m * WMMA_N + n];
101+ #pragma unroll
102+ for (int k = 0 ; k < WMMA_K; k++) {
103+ float a = A_sh[m * strideA + k];
104+ float b = B_sh[k * strideB + n];
105+ sum = fmaf (a, b, sum);
106+ }
107+ buf[m * WMMA_N + n] = sum;
108+ }
69109 }
110+
111+ __device__ __forceinline__ float * store_result () const { return buf; }
70112};
71113
72- template <typename T, typename Layout>
73- static __global__ void conv2d_kernel (const float * __restrict__ input,
74- const T * __restrict__ kernel,
75- float * __restrict__ output,
76- const conv_params P) {
77- const int64_t global_idx = blockIdx .x * blockDim .x + threadIdx .x ;
114+ class half_mma {
115+ private:
116+ wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float > acc;
117+ wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
118+ wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> b_frag;
119+ public:
120+ float * buf;
121+
122+ __device__ __forceinline__ half_mma (float * scratch) {
123+ buf = scratch;
124+ wmma::fill_fragment (acc, 0 .0f );
125+ }
126+
127+ __device__ __forceinline__ void mma (const half * A_sh, const half * B_sh, const int strideA, const int strideB) {
128+ wmma::load_matrix_sync (a_frag, A_sh, strideA);
129+ wmma::load_matrix_sync (b_frag, B_sh, strideB);
130+ wmma::mma_sync (acc, a_frag, b_frag, acc);
131+ }
78132
79- if (global_idx >= P.TOTAL ) {
80- return ;
133+ __device__ __forceinline__ float * store_result () const {
134+ wmma::store_matrix_sync (buf, acc, WMMA_N, wmma::mem_row_major);
135+ return buf;
81136 }
137+ };
138+
139+ template <typename T, typename layout, typename mma>
140+ static __global__ void conv2d_kernel (const float * IN, const T * IK, float * OUT, const conv_params P) {
141+ extern __shared__ unsigned char smem_raw[];
142+
143+ const int64_t OUTPUT_NUMEL = WMMA_M * WMMA_N;
144+ const int64_t NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1 ) / BS_ICKHKW;
145+
146+ const int64_t WARPS_PER_NOHOW = max (1 , BS_NOHOW / WMMA_N);
147+
148+ const int64_t NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1 ) / BS_NOHOW;
149+ const int64_t tile_id = blockIdx .x ;
150+ const int64_t tile_oc = tile_id / NUM_BL_NOHOW;
151+ const int64_t tile_nohow = tile_id % NUM_BL_NOHOW;
152+ const int64_t BLOCK_OC_BASE = tile_oc * BS_OC;
153+ const int64_t BLOCK_NOHOW_BASE = tile_nohow * BS_NOHOW;
154+
155+ const int64_t laneId = threadIdx .x % WARP_SIZE;
156+ const int64_t warpId = threadIdx .x / WARP_SIZE;
157+
158+ const int64_t WARP_OC = warpId / WARPS_PER_NOHOW;
159+ const int64_t WARP_NOHOW = warpId % WARPS_PER_NOHOW;
82160
83- int64_t n, c_out, out_y, out_x ;
84- Layout::unpack_indices (global_idx, P, n, c_out, out_y, out_x) ;
161+ const int64_t OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M ;
162+ const int64_t NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N ;
85163
86- float acc = 0 .0f ;
164+ unsigned char * ptr = smem_raw;
165+ T * A_sh = reinterpret_cast <T *>(ptr);
87166
88- for ( int64_t c_in = 0 ; c_in < P. IC ; ++c_in) {
89- kernel_bounds bounds = calculate_kernel_bounds (out_x, out_y, P) ;
167+ size_t offsetA = BS_OC * BS_ICKHKW * sizeof (T);
168+ ptr += offsetA ;
90169
91- for ( int64_t ky = bounds. y_min ; ky < bounds. y_max ; ++ky) {
92- const int64_t in_y = calculate_input_coord (out_y, ky, P. ST_Y , P. DL_Y , P. PD_Y );
170+ T * B_sh = reinterpret_cast <T *>(ptr);
171+ ptr += BS_ICKHKW * BS_NOHOW * sizeof (T );
93172
94- for ( int64_t kx = bounds. x_min ; kx < bounds. x_max ; ++kx) {
95- const int64_t in_x = calculate_input_coord (out_x, kx, P. ST_X , P. DL_X , P. PD_X );
173+ float * shared_scratch = reinterpret_cast < float *>(ptr);
174+ float * warp_scratch = shared_scratch + warpId * (WMMA_M * WMMA_N );
96175
97- const float input_val = input[Layout::input_index (n, c_in, in_y, in_x, P)];
98- const T kernel_val = kernel[Layout::kernel_index (c_out, c_in, ky, kx, P)];
99- acc += (input_val * ggml_cuda_cast<float >(kernel_val));
176+ const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW;
177+ const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N;
178+
179+ mma acc (warp_scratch);
180+
181+ const int64_t A_total = BS_OC * BS_ICKHKW;
182+ const int64_t B_total = BS_ICKHKW * BS_NOHOW;
183+
184+ #pragma unroll
185+ for (int64_t t = 0 ; t < NUM_IC_TILES; ++t) {
186+ #pragma unroll
187+ for (int64_t tid = (threadIdx .x ); tid < A_total; tid += blockDim .x ) {
188+ const int row = tid / BS_ICKHKW;
189+ const int col = tid % BS_ICKHKW;
190+
191+ int64_t shared_oc = BLOCK_OC_BASE + row;
192+ int64_t shared_ickhkw = t * BS_ICKHKW + col;
193+
194+ T val = ggml_cuda_cast<T>(0 );
195+ if (shared_oc < P.OC && shared_ickhkw < P.IC_KH_KW ) {
196+ int64_t ic, kh, kw;
197+ layout::unpack_ickhkw (shared_ickhkw, ic, kh, kw, P);
198+
199+ const int64_t kidx = layout::kernel_index (shared_oc, ic, kh, kw, P);
200+ val = IK[kidx];
100201 }
202+ A_sh[row * BS_ICKHKW + col] = val;
203+ }
204+
205+ #pragma unroll
206+ for (int64_t tid = (threadIdx .x ); tid < B_total; tid += blockDim .x ) {
207+ const int brow = tid / BS_NOHOW;
208+ const int bcol = tid % BS_NOHOW;
209+
210+ int64_t IC_KH_KW_IDX = t * BS_ICKHKW + brow;
211+ int64_t N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol;
212+
213+ T val = ggml_cuda_cast<T>(0 );
214+ if (N_OH_OW_IDX < P.N_OH_OW && IC_KH_KW_IDX < P.IC_KH_KW ) {
215+ int64_t n, oh, ow;
216+ layout::unpack_nohow (N_OH_OW_IDX, n, oh, ow, P);
217+ int64_t ic, kh, kw;
218+ layout::unpack_ickhkw (IC_KH_KW_IDX, ic, kh, kw, P);
219+ int in_y = calculate_input_coord (oh, kh, P.ST_Y , P.DL_Y , P.PD_Y );
220+ int in_x = calculate_input_coord (ow, kw, P.ST_X , P.DL_X , P.PD_X );
221+ if (in_y >= 0 && in_y < P.IH && in_x >= 0 && in_x < P.IW ) {
222+ const int64_t in_idx = layout::input_index (n, ic, in_y, in_x, P);
223+ val = ggml_cuda_cast<T>(IN[in_idx]);
224+ }
225+ }
226+ B_sh[brow * BS_NOHOW + bcol] = val;
227+ }
228+
229+ __syncthreads ();
230+
231+ #pragma unroll
232+ for (int k_tile = 0 ; k_tile < BS_ICKHKW; k_tile += WMMA_K) {
233+ const T * A_k_ptr = A_warp_base + k_tile;
234+ const T * B_k_ptr = B_warp_base + k_tile * BS_NOHOW;
235+
236+ acc.mma (A_k_ptr, B_k_ptr, BS_ICKHKW, BS_NOHOW);
101237 }
238+ __syncthreads ();
102239 }
103240
104- // [N, OC, OH, OW]
105- output[Layout::output_index (n, c_out, out_y, out_x, P)] = acc;
241+ const float * out_buf = acc.store_result ();
242+ #pragma unroll
243+ for (int e = laneId; e < OUTPUT_NUMEL; e += WARP_SIZE) {
244+ const int m = e / WMMA_N;
245+ const int n = e % WMMA_N;
246+
247+ const int64_t oc = OC_BASE + m;
248+ const int64_t nohow = NOHOW_BASE + n;
249+
250+ if (oc < P.OC && nohow < (P.N_OH_OW )) {
251+ int64_t n, oh, ow;
252+ layout::unpack_nohow (nohow, n, oh, ow, P);
253+ const int64_t out_idx = layout::output_index (n, oc, oh, ow, P);
254+ OUT[out_idx] = out_buf[e];
255+ }
256+ }
106257}
107258
108- template <typename T>
109- static void conv2d_cuda (const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
110- const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1 ) / CUDA_CONV2D_BLOCK_SIZE;
111- conv2d_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0 , st>>> (X_D, K_D, Y_D, P);
259+ template <typename T, typename mma>
260+ static void conv2d_cuda (const float * X_D, const T * K_D, float * Y_D, conv_params P, cudaStream_t st)
261+
262+ {
263+ const int64_t NUM_BL_OC = (P.OC + BS_OC - 1 ) / BS_OC;
264+ const int64_t NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1 ) / BS_NOHOW;
265+
266+ int64_t TOTAL_TILES = NUM_BL_OC * NUM_BL_NOHOW;
267+ TOTAL_TILES = std::min (TOTAL_TILES, (int64_t ) INT_MAX);
268+
269+ const int WARPS_PER_OC = std::max (1 , ceil_div (BS_OC, WMMA_M));
270+ const int WARPS_PER_NOHOW = std::max (1 , ceil_div (BS_NOHOW, WMMA_N));
271+ const int EXPECTED_WARPS = WARPS_PER_OC * WARPS_PER_NOHOW;
272+ int N_THREADS = EXPECTED_WARPS * WARP_SIZE;
273+
274+ const int MAX_TPB = 1024 ;
275+ if (N_THREADS > MAX_TPB) {
276+ N_THREADS = (MAX_TPB / WARP_SIZE) * WARP_SIZE;
277+ }
278+
279+ if (N_THREADS < WARP_SIZE) {
280+ N_THREADS = WARP_SIZE;
281+ }
282+
283+ const int N_WARPS = N_THREADS / WARP_SIZE;
284+
285+ // scratch_buff to store output, can't store directly using wmma,
286+ // output mapping is unknown
287+ const int64_t scratch_bytes = N_WARPS * (WMMA_M * WMMA_N) * sizeof (float );
288+
289+ const int64_t A_bytes = BS_OC * BS_ICKHKW * sizeof (T);
290+ const int64_t B_bytes = BS_ICKHKW * BS_NOHOW * sizeof (T);
291+ const int64_t shared_bytes = A_bytes + B_bytes + scratch_bytes;
292+
293+ dim3 grid (TOTAL_TILES, 1 , 1 );
294+ conv2d_kernel<T, whcn_layout, mma><<<grid, N_THREADS, shared_bytes, st>>> (X_D, K_D, Y_D, P);
112295}
113296
114- static void conv2d_cuda_f16 (const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
115- conv2d_cuda<half>(X_D, K_D, Y_D, P, st);
297+ static void conv2d_cuda_f16 (const float * X_D, const half * K_D, float * Y_D, conv_params & P, cudaStream_t st) {
298+ conv2d_cuda<half, half_mma >(X_D, K_D, Y_D, P, st);
116299}
117300
118- static void conv2d_cuda_f32 (const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
119- conv2d_cuda<float >(X_D, K_D, Y_D, P, st);
301+ static void conv2d_cuda_f32 (const float * X_D, const float * K_D, float * Y_D, conv_params & P, cudaStream_t st) {
302+ conv2d_cuda<float , float_mma >(X_D, K_D, Y_D, P, st);
120303}
121304
122305void ggml_cuda_op_conv2d (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -155,11 +338,14 @@ void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
155338 const int OC = kernel->ne [3 ]; // ouptut_chanles
156339 const int B = input->ne [3 ]; // n_batches
157340
158- const int64_t total = B * OC * OH * OW;
159- conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
341+ const int64_t TOTAL = B * OC * OH * OW;
342+ const int64_t IC_KH_KW = IC * KH * KW;
343+ const int64_t N_OH_OW = B * OH * OW;
344+ conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X,
345+ PD_Y, DL_X, DL_Y, IC, OC, B, TOTAL, IC_KH_KW, N_OH_OW };
160346
161347 if (kernel->type == GGML_TYPE_F16) {
162- conv2d_cuda_f16 (X_D, (half *) K_D, Y_D, params, st);
348+ conv2d_cuda_f16 (X_D, (const half *) K_D, Y_D, params, st);
163349 } else {
164350 conv2d_cuda_f32 (X_D, K_D, Y_D, params, st);
165351 }
0 commit comments