@@ -54,23 +54,14 @@ __align__(16) struct Params {
5454 uint32_t nb2;
5555 uint32_t nb3;
5656
57- uint32_t KWmp;
58- uint32_t KWL;
59- uint32_t KWKHmp;
60- uint32_t KWKHL;
61- uint32_t OWmp;
62- uint32_t OWL;
63- uint32_t OWOHmp;
64- uint32_t OWOHL;
57+ uint3 KW_fastdiv;
58+ uint3 KWKH_fastdiv;
59+ uint3 OW_fastdiv;
60+ uint3 OWOH_fastdiv;
6561};
6662
6763__constant__ __device__ Params dp;
6864
69- // see init_fastdiv_values in ggml-vulkan.cpp
70- __inline__ __device__ uint fastdiv (uint n, uint mp, uint L) {
71- return (__umulhi (n, mp) + n) >> L;
72- }
73-
7465// --> conv_2d kernel modified to function as a matmul
7566template <typename T, uint BS_K, uint BS_NPQ, uint BS_CRS, uint TS_K, uint TS_NPQ, uint WG_SIZE, uint VEC_SIZE>
7667__global__ void __launch_bounds__ (WG_SIZE, 1 ) mm(uint K,
@@ -139,10 +130,10 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
139130#else
140131 uint32_t CRS_idx_a = idx_CRS + Ac; // Global CRS_idx (column index of A)
141132 // uint32_t Cin_idx_a = CRS_idx_a / (dp.KW*dp.KH);
142- uint32_t Cin_idx_a = fastdiv (CRS_idx_a, dp.KWKHmp , dp. KWKHL ); // divide by (p.KW * p.KH); / (p.KW * p.KH);
133+ uint32_t Cin_idx_a = fastdiv (CRS_idx_a, dp.KWKH_fastdiv ); // divide by (p.KW * p.KH); / (p.KW * p.KH);
143134 uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * dp.KW * dp.KH ;
144135 // uint32_t KH_idx_a = (CRS_idx_a - Cin_idx_a*dp.KW*dp.KH) / dp.KW;
145- uint32_t KH_idx_a = fastdiv (CRS_remainder, dp.KWmp , dp. KWL ); // divide by p.KW;
136+ uint32_t KH_idx_a = fastdiv (CRS_remainder, dp.KW_fastdiv ); // divide by p.KW;
146137// uint32_t KW_idx_a = CRS_idx_a - Cin_idx_a*dp.KW*dp.KH - KH_idx_a*dp.KW; // unused
147138#endif
148139
@@ -177,10 +168,10 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
177168 // Compute indices for N, OH, OW from NPQ_idx
178169 const uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + Bc; /* Global NPQ index (column index of B) */
179170 // const uint32_t N_idx = NPQ_idx / (dp.OH*dp.OW);
180- uint32_t N_idx = fastdiv (NPQ_idx, dp.OWOHmp , dp. OWOHL ); // divide by p.OH * p.OW;
171+ uint32_t N_idx = fastdiv (NPQ_idx, dp.OWOH_fastdiv ); // divide by p.OH * p.OW;
181172 uint32_t NPQ_remainder = NPQ_idx - N_idx * dp.OH * dp.OW ;
182173 // const uint32_t OH_idx = (NPQ_idx - N_idx*dp.OH*dp.OW) / dp.OW;
183- uint32_t OH_idx = fastdiv (NPQ_remainder, dp.OWmp , dp. OWL ); // divide by p.OW;
174+ uint32_t OH_idx = fastdiv (NPQ_remainder, dp.OW_fastdiv ); // divide by p.OW;
184175 const uint32_t OW_idx = NPQ_idx - N_idx * dp.OH * dp.OW - OH_idx * dp.OW ;
185176
186177#ifdef USE_COLLECTIVES
@@ -192,10 +183,10 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
192183 // Compute indices KH, KW, Cin from CRS_idx
193184 uint32_t CRS_idx_b = idx_CRS + r_offset + Br;
194185 // uint32_t Cin_idx_b = CRS_idx_b / (dp.KW*dp.KH);
195- uint32_t Cin_idx_b = fastdiv (CRS_idx_b, dp.KWKHmp , dp. KWKHL ); // divide by (p.KW * p.KH); / (p.KW * p.KH);
186+ uint32_t Cin_idx_b = fastdiv (CRS_idx_b, dp.KWKH_fastdiv ); // divide by (p.KW * p.KH); / (p.KW * p.KH);
196187 uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * dp.KW * dp.KH ;
197188 // uint32_t KH_idx_b = (CRS_idx_b - Cin_idx_b*dp.KW*dp.KH) / dp.KW;
198- uint32_t KH_idx_b = fastdiv (CRS_remainder, dp.KWmp , dp. KWL ); // divide by p.KW;
189+ uint32_t KH_idx_b = fastdiv (CRS_remainder, dp.KW_fastdiv ); // divide by p.KW;
199190 uint32_t KW_idx_b = CRS_idx_b - Cin_idx_b * dp.KW * dp.KH - KH_idx_b * dp.KW ;
200191#endif
201192
@@ -271,9 +262,9 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
271262 const uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
272263 const uint32_t NPQ_idx_c = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
273264 // const uint32_t N_idx_c = NPQ_idx_c / (dp.OH*dp.OW);
274- const uint32_t N_idx_c = fastdiv (NPQ_idx_c, dp.OWOHmp , dp. OWOHL ); // divide by p.OH * p.OW;
265+ const uint32_t N_idx_c = fastdiv (NPQ_idx_c, dp.OWOH_fastdiv ); // divide by p.OH * p.OW;
275266 // const uint32_t OH_idx_c = (NPQ_idx_c - N_idx_c*dp.OH*dp.OW) / dp.OW;
276- const uint32_t OH_idx_c = fastdiv (NPQ_idx_c - N_idx_c * dp.OH * dp.OW , dp.OWmp , dp. OWL ); // divide by p.OW;
267+ const uint32_t OH_idx_c = fastdiv (NPQ_idx_c - N_idx_c * dp.OH * dp.OW , dp.OW_fastdiv ); // divide by p.OW;
277268 const uint32_t OW_idx_c = NPQ_idx_c - N_idx_c * dp.OH * dp.OW - OH_idx_c * dp.OW ;
278269 const uint32_t dst_idx = OW_idx_c + OH_idx_c * dp.nb1 + K_idx * dp.nb2 + N_idx_c * dp.nb3 ;
279270 if (K_idx < K && NPQ_idx_c < NPQ) {
@@ -283,22 +274,6 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
283274 }
284275}
285276
286- // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
287- // Precompute mp (m' in the paper) and L such that division
288- // can be computed using a multiply (high 32b of 64b result)
289- // and a shift:
290- //
291- // n/d = (mulhi(n, mp) + n) >> L;
292- static void init_fastdiv_values (uint32_t d, uint32_t & mp, uint32_t & L) {
293- // compute L = ceil(log2(d));
294- L = 0 ;
295- while (L < 32 && (uint32_t { 1 } << L) < d) {
296- L++;
297- }
298-
299- mp = (uint32_t ) ((uint64_t { 1 } << 32 ) * ((uint64_t { 1 } << L) - d) / d + 1 );
300- }
301-
302277constexpr int conv_shapes[][NUM_VARIANTS] = {
303278 { 128 , 64 , 32 }, // BS_K
304279 { 16 , 32 , 16 }, // BS_CRS
@@ -382,13 +357,13 @@ void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
382357 ggml_tensor * src0 = dst->src [0 ];
383358 ggml_tensor * src1 = dst->src [1 ];
384359
385- // GGML_ASSERT(src0->type == GGML_TYPE_F32);
360+ GGML_ASSERT (src0->type == GGML_TYPE_F32 || src0-> type == GGML_TYPE_F16 );
386361 GGML_ASSERT (src1->type == GGML_TYPE_F32);
387362 GGML_ASSERT (dst->type == GGML_TYPE_F32);
388363
389364 GGML_TENSOR_BINARY_OP_LOCALS
390365
391- // GGML_ASSERT(nb00 == sizeof(float));
366+ GGML_ASSERT (nb00 == sizeof (float ) || nb00 == sizeof (half ));
392367 GGML_ASSERT (nb10 == sizeof (float ));
393368 GGML_ASSERT (nb0 == sizeof (float ));
394369
@@ -423,10 +398,10 @@ void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
423398 p.nb2 = static_cast <uint32_t >(nb2 / nb0);
424399 p.nb3 = static_cast <uint32_t >(nb3 / nb0);
425400
426- init_fastdiv_values (p. KW , p. KWmp , p. KWL );
427- init_fastdiv_values (p.KW * p.KH , p. KWKHmp , p. KWKHL );
428- init_fastdiv_values (p. OW , p. OWmp , p. OWL );
429- init_fastdiv_values (p.OW * p.OH , p. OWOHmp , p. OWOHL );
401+ p. KW_fastdiv = init_fastdiv_values (p. KW );
402+ p. KWKH_fastdiv = init_fastdiv_values (p.KW * p.KH );
403+ p. OW_fastdiv = init_fastdiv_values (p. OW );
404+ p. OWOH_fastdiv = init_fastdiv_values (p.OW * p.OH );
430405
431406 GGML_ASSERT (ne03 == ne2);
432407 GGML_ASSERT (ne02 == ne12);
0 commit comments