66#define CUDA_Q8_0_NE_ALIGN 2048
77
88template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
9- static __global__ void dequantize_block (const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
10- const int64_t i = (int64_t )2 *(blockDim .x *blockIdx .x + threadIdx .x );
9+ static __global__ void dequantize_block (const void * __restrict__ vx, dst_t * __restrict__ y,
10+ const int64_t ne00, const int64_t ne01, const int64_t ne02,
11+ const int64_t s01, const int64_t s02, const int64_t s03) {
12+ const int64_t i00 = 2 * (int64_t (blockDim .x )*blockIdx .x + threadIdx .x );
1113
12- if (i >= k ) {
14+ if (i00 >= ne00 ) {
1315 return ;
1416 }
1517
16- const int64_t ib = i/qk; // block index
17- const int64_t iqs = (i%qk)/qr; // quant index
18- const int64_t iybs = i - i%qk; // y block start index
18+ const int64_t i01 = blockIdx .y ;
19+ const int64_t i02 = blockIdx .z % ne02;
20+ const int64_t i03 = blockIdx .z / ne02;
21+
22+ const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
23+
24+ const int64_t ib = ibx0 + i00/qk; // block index
25+ const int64_t iqs = (i00%qk)/qr; // quant index
26+ const int64_t iybs = i00 - i00%qk; // y block start index
1927 const int64_t y_offset = qr == 1 ? 1 : qk/2 ;
2028
2129 // dequantize
2230 dfloat2 v;
2331 dequantize_kernel (vx, ib, iqs, v);
2432
25- y[iybs + iqs + 0 ] = v.x ;
26- y[iybs + iqs + y_offset] = v.y ;
33+ const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
34+ y[iy0 + 0 ] = v.x ;
35+ y[iy0 + y_offset] = v.y ;
2736}
2837
2938template <bool need_check>
@@ -457,9 +466,17 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
457466}
458467
459468template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
460- static void dequantize_block_cuda (const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
461- const int num_blocks = (k + 2 *CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / (2 *CUDA_DEQUANTIZE_BLOCK_SIZE);
462- dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
469+ static void dequantize_block_cuda (const void * vx, dst_t * y,
470+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
471+ const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
472+ const dim3 num_blocks ((ne00 + 2 *CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / (2 *CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
473+ dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>>
474+ (vx, y, ne00, ne01, ne02, s01, s02, s03);
475+ }
476+
477+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
478+ static void dequantize_block_cont_cuda (const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
479+ dequantize_block_cuda<qk, qr, dequantize_kernel, dst_t >(vx, y, k, 1 , 1 , 1 , k/qk, k/qk, k/qk, stream);
463480}
464481
465482static void dequantize_block_q8_0_f16_cuda (const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
@@ -624,14 +641,14 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
624641 case GGML_TYPE_Q4_1:
625642 return dequantize_row_q4_1_cuda;
626643 case GGML_TYPE_Q5_0:
627- return dequantize_block_cuda <QK5_0, QR5_0, dequantize_q5_0>;
644+ return dequantize_block_cont_cuda <QK5_0, QR5_0, dequantize_q5_0>;
628645 case GGML_TYPE_Q5_1:
629- return dequantize_block_cuda <QK5_1, QR5_1, dequantize_q5_1>;
646+ return dequantize_block_cont_cuda <QK5_1, QR5_1, dequantize_q5_1>;
630647 case GGML_TYPE_Q8_0:
631648 if (fp16_available (ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc )) {
632649 return dequantize_block_q8_0_f16_cuda;
633650 }
634- return dequantize_block_cuda <QK8_0, QR8_0, dequantize_q8_0>;
651+ return dequantize_block_cont_cuda <QK8_0, QR8_0, dequantize_q8_0>;
635652 case GGML_TYPE_Q2_K:
636653 return dequantize_row_q2_K_cuda;
637654 case GGML_TYPE_Q3_K:
@@ -676,11 +693,11 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
676693 case GGML_TYPE_Q4_1:
677694 return dequantize_row_q4_1_cuda;
678695 case GGML_TYPE_Q5_0:
679- return dequantize_block_cuda <QK5_0, QR5_0, dequantize_q5_0>;
696+ return dequantize_block_cont_cuda <QK5_0, QR5_0, dequantize_q5_0>;
680697 case GGML_TYPE_Q5_1:
681- return dequantize_block_cuda <QK5_1, QR5_1, dequantize_q5_1>;
698+ return dequantize_block_cont_cuda <QK5_1, QR5_1, dequantize_q5_1>;
682699 case GGML_TYPE_Q8_0:
683- return dequantize_block_cuda <QK8_0, QR8_0, dequantize_q8_0>;
700+ return dequantize_block_cont_cuda <QK8_0, QR8_0, dequantize_q8_0>;
684701 case GGML_TYPE_Q2_K:
685702 return dequantize_row_q2_K_cuda;
686703 case GGML_TYPE_Q3_K:
@@ -722,6 +739,16 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
722739 switch (type) {
723740 case GGML_TYPE_F32:
724741 return convert_unary_cuda<float >;
742+ case GGML_TYPE_Q4_0:
743+ return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
744+ case GGML_TYPE_Q4_1:
745+ return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
746+ case GGML_TYPE_Q5_0:
747+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
748+ case GGML_TYPE_Q5_1:
749+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
750+ case GGML_TYPE_Q8_0:
751+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
725752 case GGML_TYPE_BF16:
726753 return convert_unary_cuda<nv_bfloat16>;
727754 default :
@@ -733,6 +760,16 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
733760 switch (type) {
734761 case GGML_TYPE_F32:
735762 return convert_unary_cuda<float , nv_bfloat16>;
763+ case GGML_TYPE_Q4_0:
764+ return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
765+ case GGML_TYPE_Q4_1:
766+ return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
767+ case GGML_TYPE_Q5_0:
768+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
769+ case GGML_TYPE_Q5_1:
770+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
771+ case GGML_TYPE_Q8_0:
772+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
736773 case GGML_TYPE_F16:
737774 return convert_unary_cuda<half, nv_bfloat16>;
738775 default :
@@ -744,6 +781,16 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
744781 switch (type) {
745782 case GGML_TYPE_F16:
746783 return convert_unary_cuda<half, float >;
784+ case GGML_TYPE_Q4_0:
785+ return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
786+ case GGML_TYPE_Q4_1:
787+ return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
788+ case GGML_TYPE_Q5_0:
789+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
790+ case GGML_TYPE_Q5_1:
791+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
792+ case GGML_TYPE_Q8_0:
793+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
747794 case GGML_TYPE_BF16:
748795 return convert_unary_cuda<nv_bfloat16, float >;
749796 default :
0 commit comments