@@ -8773,8 +8773,6 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
87738773    // TODO: mmq/mmv support
87748774#endif 
87758775
8776-     GGML_ASSERT (dst->backend  == GGML_BACKEND_GPU);
8777- 
87788776    const  int64_t  nb11 = src1->nb [1 ];
87798777    const  int64_t  nb1  =  dst->nb [1 ];
87808778
@@ -8803,13 +8801,21 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88038801    ggml_tensor src1_row = *src1;
88048802    ggml_tensor dst_row = *dst;
88058803
8804+     src1_row.backend  = GGML_BACKEND_GPU;
8805+     dst_row.backend   = GGML_BACKEND_GPU;
8806+ 
88068807    src1_row.extra  = &src1_row_extra;
88078808    dst_row.extra  = &dst_row_extra;
88088809
8809-     char  * src1_original = (char  *) src1_extra->data_device [g_main_device];
8810-     char  * dst_original  = (char  *)  dst_extra->data_device [g_main_device];
8810+     char  * src1_original = src1->backend  == GGML_BACKEND_CPU ?
8811+         (char  *) src1->data  : (char  *) src1_extra->data_device [g_main_device];
8812+     char  * dst_original  =  dst->backend  == GGML_BACKEND_CPU ?
8813+         (char  *)  dst->data  : (char  *)  dst_extra->data_device [g_main_device];
88118814
88128815    if  (src1->ne [1 ] == 1 ) {
8816+         GGML_ASSERT (src1->backend  == GGML_BACKEND_GPU);
8817+         GGML_ASSERT (dst->backend   == GGML_BACKEND_GPU);
8818+ 
88138819        for  (int64_t  i01 = 0 ; i01 < ids->ne [1 ]; i01++) {
88148820            // int32_t row_id;
88158821            // CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
@@ -8837,6 +8843,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88378843        src1_row_extra.data_device [g_main_device] = src1_contiguous;
88388844        dst_row_extra.data_device [g_main_device]  =  dst_contiguous;
88398845
8846+         const  cudaMemcpyKind src1_kind = src1->backend  == GGML_BACKEND_CPU ?
8847+             cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
8848+         const  cudaMemcpyKind dst_kind  =  dst->backend  == GGML_BACKEND_CPU ?
8849+             cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
8850+ 
88408851        for  (int32_t  row_id = 0 ; row_id < n_as; ++row_id) {
88418852            const  struct  ggml_tensor  * src0_row = dst->src [row_id + 2 ];
88428853
@@ -8851,7 +8862,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88518862                GGML_ASSERT (row_id >= 0  && row_id < n_as);
88528863
88538864                CUDA_CHECK (cudaMemcpyAsync (src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
8854-                                         nb11, cudaMemcpyDeviceToDevice , stream));
8865+                                         nb11, src1_kind , stream));
88558866                num_src1_rows++;
88568867            }
88578868
@@ -8883,14 +8894,18 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88838894                GGML_ASSERT (row_id >= 0  && row_id < n_as);
88848895
88858896                CUDA_CHECK (cudaMemcpyAsync (dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
8886-                                         nb1, cudaMemcpyDeviceToDevice , stream));
8897+                                         nb1, dst_kind , stream));
88878898                num_src1_rows++;
88888899            }
88898900        }
88908901
88918902        ggml_cuda_pool_free (src1_contiguous, as_src1);
88928903        ggml_cuda_pool_free (dst_contiguous,  as_dst);
88938904    }
8905+ 
8906+     if  (dst->backend  == GGML_BACKEND_CPU) {
8907+         CUDA_CHECK (cudaStreamSynchronize (stream));
8908+     }
88948909}
88958910
88968911static  void  ggml_cuda_scale (const  ggml_tensor * src0, const  ggml_tensor * src1, ggml_tensor * dst) {
@@ -9289,7 +9304,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
92899304        || (tensor->src [0 ] != nullptr  && (tensor->src [0 ]->backend  == GGML_BACKEND_GPU || tensor->src [0 ]->backend  == GGML_BACKEND_GPU_SPLIT))
92909305        || (tensor->src [1 ] != nullptr  && tensor->src [1 ]->backend  == GGML_BACKEND_GPU);
92919306
9292-     if  (!any_on_device && tensor->op  != GGML_OP_MUL_MAT) {
9307+     if  (!any_on_device && tensor->op  != GGML_OP_MUL_MAT && tensor-> op  != GGML_OP_MUL_MAT_ID ) {
92939308        return  false ;
92949309    }
92959310
0 commit comments