@@ -484,57 +484,147 @@ static __device__ __forceinline__ float get_alibi_slope(
484
484
return powf (base, exph);
485
485
}
486
486
487
- static constexpr __device__ int ggml_blck_size_device (ggml_type type) {
488
- return type == GGML_TYPE_F16 ? 1 :
489
- type == GGML_TYPE_Q4_0 ? QK4_0 :
490
- type == GGML_TYPE_Q4_1 ? QK4_1 :
491
- type == GGML_TYPE_Q5_0 ? QK5_0 :
492
- type == GGML_TYPE_Q5_1 ? QK5_1 :
493
- type == GGML_TYPE_Q8_0 ? QK8_0 :
494
- type == GGML_TYPE_Q2_K ? QK_K :
495
- type == GGML_TYPE_Q3_K ? QK_K :
496
- type == GGML_TYPE_Q4_K ? QK_K :
497
- type == GGML_TYPE_Q5_K ? QK_K :
498
- type == GGML_TYPE_Q6_K ? QK_K :
499
- type == GGML_TYPE_IQ2_XXS ? QK_K :
500
- type == GGML_TYPE_IQ2_XS ? QK_K :
501
- type == GGML_TYPE_IQ2_S ? QK_K :
502
- type == GGML_TYPE_IQ3_XXS ? QK_K :
503
- type == GGML_TYPE_IQ1_S ? QK_K :
504
- type == GGML_TYPE_IQ1_M ? QK_K :
505
- type == GGML_TYPE_IQ4_NL ? QK4_NL :
506
- type == GGML_TYPE_IQ4_XS ? QK_K :
507
- type == GGML_TYPE_IQ3_S ? QK_K :
508
- 0 ;
509
- }
487
+ template <ggml_type type>
488
+ struct ggml_cuda_type_traits ;
510
489
511
- static constexpr __device__ int get_qr_device (ggml_type type) {
512
- return type == GGML_TYPE_F16 ? 1 :
513
- type == GGML_TYPE_Q4_0 ? QR4_0 :
514
- type == GGML_TYPE_Q4_1 ? QR4_1 :
515
- type == GGML_TYPE_Q5_0 ? QR5_0 :
516
- type == GGML_TYPE_Q5_1 ? QR5_1 :
517
- type == GGML_TYPE_Q8_0 ? QR8_0 :
518
- type == GGML_TYPE_Q2_K ? QR2_K :
519
- type == GGML_TYPE_Q3_K ? QR3_K :
520
- type == GGML_TYPE_Q4_K ? QR4_K :
521
- type == GGML_TYPE_Q5_K ? QR5_K :
522
- type == GGML_TYPE_Q6_K ? QR6_K :
523
- type == GGML_TYPE_IQ2_XXS ? QR2_XXS :
524
- type == GGML_TYPE_IQ2_XS ? QR2_XS :
525
- type == GGML_TYPE_IQ2_S ? QR2_S :
526
- type == GGML_TYPE_IQ3_XXS ? QR3_XXS :
527
- type == GGML_TYPE_IQ1_S ? QR1_S :
528
- type == GGML_TYPE_IQ1_M ? QR1_M :
529
- type == GGML_TYPE_IQ4_NL ? QR4_NL :
530
- type == GGML_TYPE_IQ4_XS ? QR4_XS :
531
- type == GGML_TYPE_IQ3_S ? QR3_S :
532
- 0 ;
533
- }
490
+ template <>
491
+ struct ggml_cuda_type_traits <GGML_TYPE_F16> {
492
+ static constexpr int qk = 1 ;
493
+ static constexpr int qr = 1 ;
494
+ };
534
495
535
- static constexpr __device__ int get_qi_device (ggml_type type) {
536
- return ggml_blck_size_device (type) / (sizeof (int )*get_qr_device (type));
537
- }
496
+ template <>
497
+ struct ggml_cuda_type_traits <GGML_TYPE_Q4_0> {
498
+ static constexpr int qk = QK4_0;
499
+ static constexpr int qr = QR4_0;
500
+ static constexpr int qi = QI4_0;
501
+ };
502
+
503
+ template <>
504
+ struct ggml_cuda_type_traits <GGML_TYPE_Q4_1> {
505
+ static constexpr int qk = QK4_1;
506
+ static constexpr int qr = QR4_1;
507
+ static constexpr int qi = QI4_1;
508
+ };
509
+
510
+ template <>
511
+ struct ggml_cuda_type_traits <GGML_TYPE_Q5_0> {
512
+ static constexpr int qk = QK5_0;
513
+ static constexpr int qr = QR5_0;
514
+ static constexpr int qi = QI5_0;
515
+ };
516
+
517
+ template <>
518
+ struct ggml_cuda_type_traits <GGML_TYPE_Q5_1> {
519
+ static constexpr int qk = QK5_1;
520
+ static constexpr int qr = QR5_1;
521
+ static constexpr int qi = QI5_1;
522
+ };
523
+
524
+ template <>
525
+ struct ggml_cuda_type_traits <GGML_TYPE_Q8_0> {
526
+ static constexpr int qk = QK8_0;
527
+ static constexpr int qr = QR8_0;
528
+ static constexpr int qi = QI8_0;
529
+ };
530
+
531
+ template <>
532
+ struct ggml_cuda_type_traits <GGML_TYPE_Q2_K> {
533
+ static constexpr int qk = QK_K;
534
+ static constexpr int qr = QR2_K;
535
+ static constexpr int qi = QI2_K;
536
+ };
537
+
538
+ template <>
539
+ struct ggml_cuda_type_traits <GGML_TYPE_Q3_K> {
540
+ static constexpr int qk = QK_K;
541
+ static constexpr int qr = QR3_K;
542
+ static constexpr int qi = QI3_K;
543
+ };
544
+
545
+ template <>
546
+ struct ggml_cuda_type_traits <GGML_TYPE_Q4_K> {
547
+ static constexpr int qk = QK_K;
548
+ static constexpr int qr = QR4_K;
549
+ static constexpr int qi = QI4_K;
550
+ };
551
+
552
+ template <>
553
+ struct ggml_cuda_type_traits <GGML_TYPE_Q5_K> {
554
+ static constexpr int qk = QK_K;
555
+ static constexpr int qr = QR5_K;
556
+ static constexpr int qi = QI5_K;
557
+ };
558
+
559
+ template <>
560
+ struct ggml_cuda_type_traits <GGML_TYPE_Q6_K> {
561
+ static constexpr int qk = QK_K;
562
+ static constexpr int qr = QR6_K;
563
+ static constexpr int qi = QI6_K;
564
+ };
565
+
566
+ template <>
567
+ struct ggml_cuda_type_traits <GGML_TYPE_IQ2_XXS> {
568
+ static constexpr int qk = QK_K;
569
+ static constexpr int qr = QR2_XXS;
570
+ static constexpr int qi = QI2_XXS;
571
+ };
572
+
573
+ template <>
574
+ struct ggml_cuda_type_traits <GGML_TYPE_IQ2_XS> {
575
+ static constexpr int qk = QK_K;
576
+ static constexpr int qr = QR2_XS;
577
+ static constexpr int qi = QI2_XS;
578
+ };
579
+
580
+ template <>
581
+ struct ggml_cuda_type_traits <GGML_TYPE_IQ2_S> {
582
+ static constexpr int qk = QK_K;
583
+ static constexpr int qr = QR2_S;
584
+ static constexpr int qi = QI2_S;
585
+ };
586
+
587
+ template <>
588
+ struct ggml_cuda_type_traits <GGML_TYPE_IQ3_XXS> {
589
+ static constexpr int qk = QK_K;
590
+ static constexpr int qr = QR3_XXS;
591
+ static constexpr int qi = QI3_XXS;
592
+ };
593
+
594
+ template <>
595
+ struct ggml_cuda_type_traits <GGML_TYPE_IQ1_S> {
596
+ static constexpr int qk = QK_K;
597
+ static constexpr int qr = QR1_S;
598
+ static constexpr int qi = QI1_S;
599
+ };
600
+
601
+ template <>
602
+ struct ggml_cuda_type_traits <GGML_TYPE_IQ1_M> {
603
+ static constexpr int qk = QK_K;
604
+ static constexpr int qr = QR1_M;
605
+ static constexpr int qi = QI1_M;
606
+ };
607
+
608
+ template <>
609
+ struct ggml_cuda_type_traits <GGML_TYPE_IQ4_NL> {
610
+ static constexpr int qk = QK4_NL;
611
+ static constexpr int qr = QR4_NL;
612
+ static constexpr int qi = QI4_NL;
613
+ };
614
+
615
+ template <>
616
+ struct ggml_cuda_type_traits <GGML_TYPE_IQ4_XS> {
617
+ static constexpr int qk = QK_K;
618
+ static constexpr int qr = QR4_XS;
619
+ static constexpr int qi = QI4_XS;
620
+ };
621
+
622
+ template <>
623
+ struct ggml_cuda_type_traits <GGML_TYPE_IQ3_S> {
624
+ static constexpr int qk = QK_K;
625
+ static constexpr int qr = QR3_S;
626
+ static constexpr int qi = QI3_S;
627
+ };
538
628
539
629
static int get_mmq_x_max_host (const int cc) {
540
630
#ifdef CUDA_USE_TENSOR_CORES
0 commit comments