@@ -910,6 +910,13 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
910910 ((ggml_tensor*)dst->extra )->ne );
911911 return ;
912912 }
913+ if (dst->type == GGML_TYPE_Q4_0) {
914+ aclrtlaunch_ascendc_quantize_f16_to_q4_0 (
915+ 24 , ctx.stream (), src->data , dst->data ,
916+ ((ggml_tensor*)src->extra )->ne , ((ggml_tensor*)src->extra )->nb ,
917+ ((ggml_tensor*)dst->extra )->ne );
918+ return ;
919+ }
913920 if (dst->type == GGML_TYPE_F16) {
914921 if (ggml_are_same_shape (src, dst)) {
915922 cann_copy (ctx, acl_src, acl_dst);
@@ -971,6 +978,13 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
971978 ((ggml_tensor*)dst->extra )->ne );
972979 return ;
973980 }
981+ if (dst->type == GGML_TYPE_Q4_0) {
982+ aclrtlaunch_ascendc_quantize_f32_to_q4_0 (
983+ 24 , ctx.stream (), src->data , dst->data ,
984+ ((ggml_tensor*)src->extra )->ne , ((ggml_tensor*)src->extra )->nb ,
985+ ((ggml_tensor*)dst->extra )->ne );
986+ return ;
987+ }
974988 if (dst->type == GGML_TYPE_F32) {
975989 if (ggml_are_same_shape (src, dst)) {
976990 cann_copy (ctx, acl_src, acl_dst);
@@ -2463,21 +2477,33 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
24632477 * @param dst The destination tensor where the result of the matrix
24642478 * multiplication will be stored.
24652479 */
2466- static void ggml_cann_mul_mat_q8_0 (ggml_backend_cann_context& ctx,
2467- ggml_tensor* dst) {
2480+ static void ggml_cann_mul_mat_quant (ggml_backend_cann_context& ctx,
2481+ ggml_tensor* dst,
2482+ const enum ggml_type type) {
24682483 ggml_tensor* src0 = dst->src [0 ]; // weight
24692484 ggml_tensor* src1 = dst->src [1 ]; // input
24702485
24712486 // The shape of the weight is NCHW. Matrix multiplication uses HW dims. HC
24722487 // is regarded as batch. weight need transpose.
24732488 int64_t weight_ne[] = {src0->ne [1 ], src0->ne [0 ]};
2474- size_t weight_elem_size = sizeof (uint8_t );
2475- size_t weight_nb[] = {weight_elem_size * src0->ne [0 ], weight_elem_size};
2489+ float weight_elem_size;
2490+ if (type == GGML_TYPE_Q4_0) {
2491+ weight_elem_size = float (sizeof (uint8_t )) / 2 ;
2492+ }
2493+ else if (type == GGML_TYPE_Q8_0) {
2494+ weight_elem_size = float (sizeof (uint8_t ));
2495+ }
2496+ else {
2497+ GGML_ABORT (" Only support Q4_0 and Q8_0 MUL_MAT" );
2498+ }
2499+ float weight_nb[] = {weight_elem_size * src0->ne [0 ], weight_elem_size};
2500+
24762501 // size of one matrix is element_size * height * width.
24772502 size_t weight_stride = weight_elem_size * src0->ne [0 ] * src0->ne [1 ];
24782503 size_t weight_size = weight_stride * src0->ne [2 ] * src0->ne [3 ];
24792504
24802505 // scale stored at the end of weight. Also need transpose.
2506+ GGML_ASSERT (QK4_0 == QK8_0);
24812507 int64_t scale_ne[] = {src0->ne [1 ], src0->ne [0 ] / QK8_0};
24822508 size_t scale_elem_size = sizeof (uint16_t );
24832509 size_t scale_nb[] = {src0->ne [0 ] / QK8_0 * scale_elem_size,
@@ -2541,8 +2567,9 @@ static void ggml_cann_mul_mat_q8_0(ggml_backend_cann_context& ctx,
25412567 (char *)input_buffer + batch1 * input_stride, ACL_FLOAT16,
25422568 input_elem_size, input_ne, input_nb, 2 );
25432569 aclTensor* acl_weight_tensor = ggml_cann_create_tensor (
2544- (char *)src0->data + batch0 * weight_stride, ACL_INT8,
2545- weight_elem_size, weight_ne, weight_nb, 2 );
2570+ (char *)src0->data + batch0 * weight_stride,
2571+ ggml_cann_type_mapping (type), weight_elem_size, weight_ne,
2572+ weight_nb, 2 );
25462573 aclTensor* acl_scale_tensor = ggml_cann_create_tensor (
25472574 scale_offset + batch0 * scale_stride, ACL_FLOAT16,
25482575 scale_elem_size, scale_ne, scale_nb, 2 );
@@ -2596,11 +2623,9 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
25962623 case GGML_TYPE_F16:
25972624 ggml_cann_mat_mul_fp (ctx, dst);
25982625 break ;
2599- // case GGML_TYPE_Q4_0:
2600- // ggml_cann_mul_mat_q4_0(ctx, dst);
2601- // break;
2626+ case GGML_TYPE_Q4_0:
26022627 case GGML_TYPE_Q8_0:
2603- ggml_cann_mul_mat_q8_0 (ctx, dst);
2628+ ggml_cann_mul_mat_quant (ctx, dst, type );
26042629 break ;
26052630 default :
26062631 GGML_ABORT (" fatal error" );
0 commit comments