@@ -325,6 +325,54 @@ static void ggml_cpy_q6_0_f32_cuda(
325325 ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
326326}
327327
328+ static __global__ void k_transpose_q8_0 (const char * cx, char * cdst,
329+ const int ne10, const int ne11, const int ne12,
330+ const int nb01, const int nb02, const int nb03,
331+ const int nb11, const int nb12, const int nb13) {
332+ const int64_t i = blockDim .x *blockIdx .x + threadIdx .x ;
333+
334+ const int64_t i13 = i/(ne10 * ne11 * ne12);
335+ const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
336+ const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
337+ const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
338+
339+ // const int64_t ne00 = ne11;
340+ // const int64_t ne01 = ne10;
341+ // const int64_t ne02 = ne12;
342+ const int64_t i03 = i13;
343+ const int64_t i02 = i12;
344+ const int64_t i01 = i10; // (i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00;
345+ const int64_t i00 = i11; // i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00;
346+
347+ const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03);
348+ const int ib0 = i00/QK8_0;
349+ const int iq0 = i00%QK8_0;
350+
351+ float xi = __half2float (q8[ib0].d )*q8[ib0].qs [iq0];
352+ float amax = fabsf (xi);
353+ amax = warp_reduce_max (amax);
354+
355+ float d = amax/127 ;
356+ int8_t q = amax == 0 .0f ? 0 : roundf (xi / d);
357+
358+ block_q8_0 * dst = (block_q8_0 *)(cdst + i11*nb11 + i12*nb12 + i13*nb13);
359+ dst[i10 / QK8_0].qs [i10 % QK8_0] = q;
360+
361+ if (threadIdx .x == 0 ) {
362+ dst[i10 / QK8_0].d = __float2half (d);
363+ }
364+ }
365+
366+ static void transpose_q8_0 (ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
367+ auto stream = ctx.stream ();
368+ auto num_blocks = ggml_nelements (dst)/QK8_0;
369+ k_transpose_q8_0<<<num_blocks, QK8_0, 0 , stream>>> (
370+ (const char *)src->data , (char *)dst->data ,
371+ dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], src->nb [0 ], src->nb [2 ], src->nb [3 ],
372+ dst->nb [1 ], dst->nb [2 ], dst->nb [3 ]);
373+ }
374+
375+
328376void ggml_cuda_cpy (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
329377 const int64_t ne = ggml_nelements (src0);
330378 GGML_ASSERT (ne == ggml_nelements (src1));
@@ -428,6 +476,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
428476 ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
429477 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
430478 ggml_cpy_flt_cuda<nv_bfloat16, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
479+ } else if (ggml_are_same_shape (src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {
480+ // This is needed for MLA with mla=2 when using q8_0 cache.
481+ transpose_q8_0 (ctx, src0, src1);
431482 } else {
432483 GGML_ABORT (" %s: unsupported type combination (%s to %s)\n " , __func__,
433484 ggml_type_name (src0->type ), ggml_type_name (src1->type ));
@@ -497,6 +548,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
497548 return (void *) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
498549 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
499550 return (void *) cpy_flt<cpy_1_flt<nv_bfloat16, float >>;
551+ } else if (ggml_are_same_shape (src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {
552+ return (void *)transpose_q8_0;
500553 } else {
501554 GGML_ABORT (" %s: unsupported type combination (%s to %s)\n " , __func__,
502555 ggml_type_name (src0->type ), ggml_type_name (src1->type ));
0 commit comments