Skip to content

Commit 8b6962d

Browse files
struct for qk, qr, qi
1 parent bd8422d commit 8b6962d

File tree

4 files changed

+146
-56
lines changed

4 files changed

+146
-56
lines changed

ggml-cuda/common.cuh

Lines changed: 139 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -484,57 +484,147 @@ static __device__ __forceinline__ float get_alibi_slope(
484484
return powf(base, exph);
485485
}
486486

487-
static constexpr __device__ int ggml_blck_size_device(ggml_type type) {
488-
return type == GGML_TYPE_F16 ? 1 :
489-
type == GGML_TYPE_Q4_0 ? QK4_0 :
490-
type == GGML_TYPE_Q4_1 ? QK4_1 :
491-
type == GGML_TYPE_Q5_0 ? QK5_0 :
492-
type == GGML_TYPE_Q5_1 ? QK5_1 :
493-
type == GGML_TYPE_Q8_0 ? QK8_0 :
494-
type == GGML_TYPE_Q2_K ? QK_K :
495-
type == GGML_TYPE_Q3_K ? QK_K :
496-
type == GGML_TYPE_Q4_K ? QK_K :
497-
type == GGML_TYPE_Q5_K ? QK_K :
498-
type == GGML_TYPE_Q6_K ? QK_K :
499-
type == GGML_TYPE_IQ2_XXS ? QK_K :
500-
type == GGML_TYPE_IQ2_XS ? QK_K :
501-
type == GGML_TYPE_IQ2_S ? QK_K :
502-
type == GGML_TYPE_IQ3_XXS ? QK_K :
503-
type == GGML_TYPE_IQ1_S ? QK_K :
504-
type == GGML_TYPE_IQ1_M ? QK_K :
505-
type == GGML_TYPE_IQ4_NL ? QK4_NL :
506-
type == GGML_TYPE_IQ4_XS ? QK_K :
507-
type == GGML_TYPE_IQ3_S ? QK_K :
508-
0;
509-
}
487+
template <ggml_type type>
488+
struct ggml_cuda_type_traits;
510489

511-
static constexpr __device__ int get_qr_device(ggml_type type) {
512-
return type == GGML_TYPE_F16 ? 1 :
513-
type == GGML_TYPE_Q4_0 ? QR4_0 :
514-
type == GGML_TYPE_Q4_1 ? QR4_1 :
515-
type == GGML_TYPE_Q5_0 ? QR5_0 :
516-
type == GGML_TYPE_Q5_1 ? QR5_1 :
517-
type == GGML_TYPE_Q8_0 ? QR8_0 :
518-
type == GGML_TYPE_Q2_K ? QR2_K :
519-
type == GGML_TYPE_Q3_K ? QR3_K :
520-
type == GGML_TYPE_Q4_K ? QR4_K :
521-
type == GGML_TYPE_Q5_K ? QR5_K :
522-
type == GGML_TYPE_Q6_K ? QR6_K :
523-
type == GGML_TYPE_IQ2_XXS ? QR2_XXS :
524-
type == GGML_TYPE_IQ2_XS ? QR2_XS :
525-
type == GGML_TYPE_IQ2_S ? QR2_S :
526-
type == GGML_TYPE_IQ3_XXS ? QR3_XXS :
527-
type == GGML_TYPE_IQ1_S ? QR1_S :
528-
type == GGML_TYPE_IQ1_M ? QR1_M :
529-
type == GGML_TYPE_IQ4_NL ? QR4_NL :
530-
type == GGML_TYPE_IQ4_XS ? QR4_XS :
531-
type == GGML_TYPE_IQ3_S ? QR3_S :
532-
0;
533-
}
490+
template<>
491+
struct ggml_cuda_type_traits<GGML_TYPE_F16> {
492+
static constexpr int qk = 1;
493+
static constexpr int qr = 1;
494+
};
534495

535-
static constexpr __device__ int get_qi_device(ggml_type type) {
536-
return ggml_blck_size_device(type) / (sizeof(int)*get_qr_device(type));
537-
}
496+
template<>
497+
struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
498+
static constexpr int qk = QK4_0;
499+
static constexpr int qr = QR4_0;
500+
static constexpr int qi = QI4_0;
501+
};
502+
503+
template<>
504+
struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
505+
static constexpr int qk = QK4_1;
506+
static constexpr int qr = QR4_1;
507+
static constexpr int qi = QI4_1;
508+
};
509+
510+
template<>
511+
struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
512+
static constexpr int qk = QK5_0;
513+
static constexpr int qr = QR5_0;
514+
static constexpr int qi = QI5_0;
515+
};
516+
517+
template<>
518+
struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
519+
static constexpr int qk = QK5_1;
520+
static constexpr int qr = QR5_1;
521+
static constexpr int qi = QI5_1;
522+
};
523+
524+
template<>
525+
struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
526+
static constexpr int qk = QK8_0;
527+
static constexpr int qr = QR8_0;
528+
static constexpr int qi = QI8_0;
529+
};
530+
531+
template<>
532+
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
533+
static constexpr int qk = QK_K;
534+
static constexpr int qr = QR2_K;
535+
static constexpr int qi = QI2_K;
536+
};
537+
538+
template<>
539+
struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
540+
static constexpr int qk = QK_K;
541+
static constexpr int qr = QR3_K;
542+
static constexpr int qi = QI3_K;
543+
};
544+
545+
template<>
546+
struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
547+
static constexpr int qk = QK_K;
548+
static constexpr int qr = QR4_K;
549+
static constexpr int qi = QI4_K;
550+
};
551+
552+
template<>
553+
struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
554+
static constexpr int qk = QK_K;
555+
static constexpr int qr = QR5_K;
556+
static constexpr int qi = QI5_K;
557+
};
558+
559+
template<>
560+
struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
561+
static constexpr int qk = QK_K;
562+
static constexpr int qr = QR6_K;
563+
static constexpr int qi = QI6_K;
564+
};
565+
566+
template<>
567+
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
568+
static constexpr int qk = QK_K;
569+
static constexpr int qr = QR2_XXS;
570+
static constexpr int qi = QI2_XXS;
571+
};
572+
573+
template<>
574+
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
575+
static constexpr int qk = QK_K;
576+
static constexpr int qr = QR2_XS;
577+
static constexpr int qi = QI2_XS;
578+
};
579+
580+
template<>
581+
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
582+
static constexpr int qk = QK_K;
583+
static constexpr int qr = QR2_S;
584+
static constexpr int qi = QI2_S;
585+
};
586+
587+
template<>
588+
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
589+
static constexpr int qk = QK_K;
590+
static constexpr int qr = QR3_XXS;
591+
static constexpr int qi = QI3_XXS;
592+
};
593+
594+
template<>
595+
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
596+
static constexpr int qk = QK_K;
597+
static constexpr int qr = QR1_S;
598+
static constexpr int qi = QI1_S;
599+
};
600+
601+
template<>
602+
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
603+
static constexpr int qk = QK_K;
604+
static constexpr int qr = QR1_M;
605+
static constexpr int qi = QI1_M;
606+
};
607+
608+
template<>
609+
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
610+
static constexpr int qk = QK4_NL;
611+
static constexpr int qr = QR4_NL;
612+
static constexpr int qi = QI4_NL;
613+
};
614+
615+
template<>
616+
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
617+
static constexpr int qk = QK_K;
618+
static constexpr int qr = QR4_XS;
619+
static constexpr int qi = QI4_XS;
620+
};
621+
622+
template<>
623+
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
624+
static constexpr int qk = QK_K;
625+
static constexpr int qr = QR3_S;
626+
static constexpr int qi = QI3_S;
627+
};
538628

539629
static int get_mmq_x_max_host(const int cc) {
540630
#ifdef CUDA_USE_TENSOR_CORES

ggml-cuda/dmmv.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,8 @@ static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type
434434

435435
template <ggml_type type>
436436
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
437-
constexpr int qk = ggml_blck_size_device(type); // quantized weights per x block
438-
constexpr int qr = get_qr_device(type); // number of quantized weights per data value in x block
437+
constexpr int qk = ggml_cuda_type_traits<type>::qk; // quantized weights per x block
438+
constexpr int qr = ggml_cuda_type_traits<type>::qr; // number of quantized weights per data value in x block
439439
constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type);
440440

441441
const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;

ggml-cuda/mmq.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,10 +1033,10 @@ static __global__ void mul_mat_q(
10331033
return;
10341034
}
10351035

1036+
constexpr int qk = ggml_cuda_type_traits<type>::qk;
1037+
constexpr int qr = ggml_cuda_type_traits<type>::qr;
1038+
constexpr int qi = ggml_cuda_type_traits<type>::qi;
10361039
constexpr int mmq_y = get_mmq_y_device(mmq_x);
1037-
constexpr int qk = ggml_blck_size_device(type);
1038-
constexpr int qr = get_qr_device(type);
1039-
constexpr int qi = get_qi_device(type);
10401040
constexpr bool need_sum = get_need_sum(type);
10411041
constexpr int vdr = get_vdr_mmq(type);
10421042

ggml-cuda/mmvq.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ static __global__ void mul_mat_vec_q(
5050
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
5151
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
5252

53-
constexpr int qk = ggml_blck_size_device(type);
54-
constexpr int qi = get_qi_device(type);
53+
constexpr int qk = ggml_cuda_type_traits<type>::qk;
54+
constexpr int qi = ggml_cuda_type_traits<type>::qi;
5555
constexpr int vdr = get_vdr_mmvq(type);
5656

5757
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);

0 commit comments

Comments
 (0)