Skip to content

Commit 8a83e1f

Browse files
author
Iwan Kawrakow
committed
cuda: re-add q8_0 -> q8_0 transpose
so mla = 2 can be used with CUDA graphs and q8_0 cache.
1 parent 1faa086 commit 8a83e1f

File tree

2 files changed

+53
-38
lines changed

2 files changed

+53
-38
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2941,29 +2941,6 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
29412941
return fuse_down;
29422942
}
29432943

2944-
static void ggml_cuda_cpy_wrapper(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
2945-
auto src0 = dst->src[0];
2946-
auto src1 = dst->src[1];
2947-
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
2948-
CUDA_CHECK(cudaMemcpyAsync((char *)src1->data, (char *)src0->data, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, ctx.stream()));
2949-
return;
2950-
}
2951-
#ifdef USE_CUDA_GRAPH
2952-
if (ctx.cuda_graph->use_cpy_indirection) {
2953-
GGML_ASSERT(ctx.cuda_graph->graph_cpynode_index < (int)ctx.cuda_graph->cpy_dest_ptrs.size());
2954-
auto dest_ptr = ctx.cuda_graph->cpy_dest_ptrs[ctx.cuda_graph->graph_cpynode_index];
2955-
ggml_tensor aux_src1 = *src1;
2956-
aux_src1.data = dest_ptr;
2957-
ggml_cuda_cpy(ctx, src0, &aux_src1);
2958-
++ctx.cuda_graph->graph_cpynode_index;
2959-
} else {
2960-
ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]);
2961-
}
2962-
#else
2963-
ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]);
2964-
#endif
2965-
}
2966-
29672944
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next, bool& skip_next) {
29682945
// why is this here instead of mul_mat?
29692946
if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) {
@@ -2985,7 +2962,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
29852962
ggml_cuda_dup(ctx, dst);
29862963
break;
29872964
case GGML_OP_CPY:
2988-
//ggml_cuda_cpy_wrapper(ctx, dst);
29892965
ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]);
29902966
break;
29912967
case GGML_OP_CONT:
@@ -3269,20 +3245,6 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
32693245
}
32703246

32713247
#ifdef USE_CUDA_GRAPH
3272-
//static void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs,
3273-
// const int host_dest_ptrs_size, cudaStream_t stream) {
3274-
// if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
3275-
// CUDA_CHECK(cudaStreamSynchronize(stream));
3276-
// if (cuda_graph->dest_ptrs_d != nullptr) {
3277-
// CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
3278-
// }
3279-
// CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
3280-
// cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
3281-
// }
3282-
// // copy destination pointers to GPU
3283-
// CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
3284-
// cuda_graph->graph_cpynode_index = 0; // reset index
3285-
//}
32863248

32873249
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
32883250
bool use_cuda_graph) {

ggml/src/ggml-cuda/cpy.cu

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,54 @@ static void ggml_cpy_q6_0_f32_cuda(
325325
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
326326
}
327327

328+
static __global__ void k_transpose_q8_0(const char * cx, char * cdst,
329+
const int ne10, const int ne11, const int ne12,
330+
const int nb01, const int nb02, const int nb03,
331+
const int nb11, const int nb12, const int nb13) {
332+
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
333+
334+
const int64_t i13 = i/(ne10 * ne11 * ne12);
335+
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
336+
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
337+
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
338+
339+
//const int64_t ne00 = ne11;
340+
//const int64_t ne01 = ne10;
341+
//const int64_t ne02 = ne12;
342+
const int64_t i03 = i13;
343+
const int64_t i02 = i12;
344+
const int64_t i01 = i10; //(i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00;
345+
const int64_t i00 = i11; //i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00;
346+
347+
const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03);
348+
const int ib0 = i00/QK8_0;
349+
const int iq0 = i00%QK8_0;
350+
351+
float xi = __half2float(q8[ib0].d)*q8[ib0].qs[iq0];
352+
float amax = fabsf(xi);
353+
amax = warp_reduce_max(amax);
354+
355+
float d = amax/127;
356+
int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
357+
358+
block_q8_0 * dst = (block_q8_0 *)(cdst + i11*nb11 + i12*nb12 + i13*nb13);
359+
dst[i10 / QK8_0].qs[i10 % QK8_0] = q;
360+
361+
if (threadIdx.x == 0) {
362+
dst[i10 / QK8_0].d = __float2half(d);
363+
}
364+
}
365+
366+
static void transpose_q8_0(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
367+
auto stream = ctx.stream();
368+
auto num_blocks = ggml_nelements(dst)/QK8_0;
369+
k_transpose_q8_0<<<num_blocks, QK8_0, 0, stream>>>(
370+
(const char *)src->data, (char *)dst->data,
371+
dst->ne[0], dst->ne[1], dst->ne[2], src->nb[0], src->nb[2], src->nb[3],
372+
dst->nb[1], dst->nb[2], dst->nb[3]);
373+
}
374+
375+
328376
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
329377
const int64_t ne = ggml_nelements(src0);
330378
GGML_ASSERT(ne == ggml_nelements(src1));
@@ -428,6 +476,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
428476
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
429477
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
430478
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
479+
} else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {
480+
// This is needed for MLA with mla=2 when using q8_0 cache.
481+
transpose_q8_0(ctx, src0, src1);
431482
} else {
432483
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
433484
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -497,6 +548,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
497548
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
498549
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
499550
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
551+
} else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {
552+
return (void *)transpose_q8_0;
500553
} else {
501554
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
502555
ggml_type_name(src0->type), ggml_type_name(src1->type));

0 commit comments

Comments
 (0)