Skip to content

Commit 819c40e

Browse files
committed
use fastdiv equivalents already in cuda backend
1 parent 7fa80c1 commit 819c40e

File tree

1 file changed

+18
-43
lines changed

1 file changed

+18
-43
lines changed

ggml/src/ggml-cuda/conv2d-mm.cu

Lines changed: 18 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -54,23 +54,14 @@ __align__(16) struct Params {
5454
uint32_t nb2;
5555
uint32_t nb3;
5656

57-
uint32_t KWmp;
58-
uint32_t KWL;
59-
uint32_t KWKHmp;
60-
uint32_t KWKHL;
61-
uint32_t OWmp;
62-
uint32_t OWL;
63-
uint32_t OWOHmp;
64-
uint32_t OWOHL;
57+
uint3 KW_fastdiv;
58+
uint3 KWKH_fastdiv;
59+
uint3 OW_fastdiv;
60+
uint3 OWOH_fastdiv;
6561
};
6662

6763
__constant__ __device__ Params dp;
6864

69-
// see init_fastdiv_values in ggml-vulkan.cpp
70-
__inline__ __device__ uint fastdiv(uint n, uint mp, uint L) {
71-
return (__umulhi(n, mp) + n) >> L;
72-
}
73-
7465
// --> conv_2d kernel modified to function as a matmul
7566
template <typename T, uint BS_K, uint BS_NPQ, uint BS_CRS, uint TS_K, uint TS_NPQ, uint WG_SIZE, uint VEC_SIZE>
7667
__global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
@@ -139,10 +130,10 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
139130
#else
140131
uint32_t CRS_idx_a = idx_CRS + Ac; //Global CRS_idx (column index of A)
141132
//uint32_t Cin_idx_a = CRS_idx_a / (dp.KW*dp.KH);
142-
uint32_t Cin_idx_a = fastdiv(CRS_idx_a, dp.KWKHmp, dp.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
133+
uint32_t Cin_idx_a = fastdiv(CRS_idx_a, dp.KWKH_fastdiv); // divide by (p.KW * p.KH); / (p.KW * p.KH);
143134
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * dp.KW * dp.KH;
144135
//uint32_t KH_idx_a = (CRS_idx_a - Cin_idx_a*dp.KW*dp.KH) / dp.KW;
145-
uint32_t KH_idx_a = fastdiv(CRS_remainder, dp.KWmp, dp.KWL); // divide by p.KW;
136+
uint32_t KH_idx_a = fastdiv(CRS_remainder, dp.KW_fastdiv); // divide by p.KW;
146137
//uint32_t KW_idx_a = CRS_idx_a - Cin_idx_a*dp.KW*dp.KH - KH_idx_a*dp.KW; // unused
147138
#endif
148139

@@ -177,10 +168,10 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
177168
// Compute indices for N, OH, OW from NPQ_idx
178169
const uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + Bc; /* Global NPQ index (column index of B) */
179170
//const uint32_t N_idx = NPQ_idx / (dp.OH*dp.OW);
180-
uint32_t N_idx = fastdiv(NPQ_idx, dp.OWOHmp, dp.OWOHL); // divide by p.OH * p.OW;
171+
uint32_t N_idx = fastdiv(NPQ_idx, dp.OWOH_fastdiv); // divide by p.OH * p.OW;
181172
uint32_t NPQ_remainder = NPQ_idx - N_idx * dp.OH * dp.OW;
182173
//const uint32_t OH_idx = (NPQ_idx - N_idx*dp.OH*dp.OW) / dp.OW;
183-
uint32_t OH_idx = fastdiv(NPQ_remainder, dp.OWmp, dp.OWL); // divide by p.OW;
174+
uint32_t OH_idx = fastdiv(NPQ_remainder, dp.OW_fastdiv); // divide by p.OW;
184175
const uint32_t OW_idx = NPQ_idx - N_idx * dp.OH * dp.OW - OH_idx * dp.OW;
185176

186177
#ifdef USE_COLLECTIVES
@@ -192,10 +183,10 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
192183
// Compute indices KH, KW, Cin from CRS_idx
193184
uint32_t CRS_idx_b = idx_CRS + r_offset + Br;
194185
//uint32_t Cin_idx_b = CRS_idx_b / (dp.KW*dp.KH);
195-
uint32_t Cin_idx_b = fastdiv(CRS_idx_b, dp.KWKHmp, dp.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
186+
uint32_t Cin_idx_b = fastdiv(CRS_idx_b, dp.KWKH_fastdiv); // divide by (p.KW * p.KH); / (p.KW * p.KH);
196187
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * dp.KW * dp.KH;
197188
//uint32_t KH_idx_b = (CRS_idx_b - Cin_idx_b*dp.KW*dp.KH) / dp.KW;
198-
uint32_t KH_idx_b = fastdiv(CRS_remainder, dp.KWmp, dp.KWL); // divide by p.KW;
189+
uint32_t KH_idx_b = fastdiv(CRS_remainder, dp.KW_fastdiv); // divide by p.KW;
199190
uint32_t KW_idx_b = CRS_idx_b - Cin_idx_b * dp.KW * dp.KH - KH_idx_b * dp.KW;
200191
#endif
201192

@@ -271,9 +262,9 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
271262
const uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
272263
const uint32_t NPQ_idx_c = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
273264
//const uint32_t N_idx_c = NPQ_idx_c / (dp.OH*dp.OW);
274-
const uint32_t N_idx_c = fastdiv(NPQ_idx_c, dp.OWOHmp, dp.OWOHL); // divide by p.OH * p.OW;
265+
const uint32_t N_idx_c = fastdiv(NPQ_idx_c, dp.OWOH_fastdiv); // divide by p.OH * p.OW;
275266
//const uint32_t OH_idx_c = (NPQ_idx_c - N_idx_c*dp.OH*dp.OW) / dp.OW;
276-
const uint32_t OH_idx_c = fastdiv(NPQ_idx_c - N_idx_c * dp.OH * dp.OW, dp.OWmp, dp.OWL); // divide by p.OW;
267+
const uint32_t OH_idx_c = fastdiv(NPQ_idx_c - N_idx_c * dp.OH * dp.OW, dp.OW_fastdiv); // divide by p.OW;
277268
const uint32_t OW_idx_c = NPQ_idx_c - N_idx_c * dp.OH * dp.OW - OH_idx_c * dp.OW;
278269
const uint32_t dst_idx = OW_idx_c + OH_idx_c * dp.nb1 + K_idx * dp.nb2 + N_idx_c * dp.nb3;
279270
if (K_idx < K && NPQ_idx_c < NPQ) {
@@ -283,22 +274,6 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
283274
}
284275
}
285276

286-
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
287-
// Precompute mp (m' in the paper) and L such that division
288-
// can be computed using a multiply (high 32b of 64b result)
289-
// and a shift:
290-
//
291-
// n/d = (mulhi(n, mp) + n) >> L;
292-
static void init_fastdiv_values(uint32_t d, uint32_t & mp, uint32_t & L) {
293-
// compute L = ceil(log2(d));
294-
L = 0;
295-
while (L < 32 && (uint32_t{ 1 } << L) < d) {
296-
L++;
297-
}
298-
299-
mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
300-
}
301-
302277
constexpr int conv_shapes[][NUM_VARIANTS] = {
303278
{ 128, 64, 32 }, // BS_K
304279
{ 16, 32, 16 }, // BS_CRS
@@ -382,13 +357,13 @@ void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
382357
ggml_tensor * src0 = dst->src[0];
383358
ggml_tensor * src1 = dst->src[1];
384359

385-
// GGML_ASSERT(src0->type == GGML_TYPE_F32);
360+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
386361
GGML_ASSERT(src1->type == GGML_TYPE_F32);
387362
GGML_ASSERT(dst->type == GGML_TYPE_F32);
388363

389364
GGML_TENSOR_BINARY_OP_LOCALS
390365

391-
// GGML_ASSERT(nb00 == sizeof(float));
366+
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(half));
392367
GGML_ASSERT(nb10 == sizeof(float));
393368
GGML_ASSERT(nb0 == sizeof(float));
394369

@@ -423,10 +398,10 @@ void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
423398
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
424399
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
425400

426-
init_fastdiv_values(p.KW, p.KWmp, p.KWL);
427-
init_fastdiv_values(p.KW * p.KH, p.KWKHmp, p.KWKHL);
428-
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
429-
init_fastdiv_values(p.OW * p.OH, p.OWOHmp, p.OWOHL);
401+
p.KW_fastdiv = init_fastdiv_values(p.KW);
402+
p.KWKH_fastdiv = init_fastdiv_values(p.KW * p.KH);
403+
p.OW_fastdiv = init_fastdiv_values(p.OW);
404+
p.OWOH_fastdiv = init_fastdiv_values(p.OW * p.OH);
430405

431406
GGML_ASSERT(ne03 == ne2);
432407
GGML_ASSERT(ne02 == ne12);

0 commit comments

Comments
 (0)