Skip to content

Commit 9154494

Browse files
CUDA: mul_mat_id always on GPU for batches >= 32 (#4553)
1 parent c083718 commit 9154494

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

ggml-cuda.cu

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8773,8 +8773,6 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
87738773
// TODO: mmq/mmv support
87748774
#endif
87758775

8776-
GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
8777-
87788776
const int64_t nb11 = src1->nb[1];
87798777
const int64_t nb1 = dst->nb[1];
87808778

@@ -8803,13 +8801,21 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88038801
ggml_tensor src1_row = *src1;
88048802
ggml_tensor dst_row = *dst;
88058803

8804+
src1_row.backend = GGML_BACKEND_GPU;
8805+
dst_row.backend = GGML_BACKEND_GPU;
8806+
88068807
src1_row.extra = &src1_row_extra;
88078808
dst_row.extra = &dst_row_extra;
88088809

8809-
char * src1_original = (char *) src1_extra->data_device[g_main_device];
8810-
char * dst_original = (char *) dst_extra->data_device[g_main_device];
8810+
char * src1_original = src1->backend == GGML_BACKEND_CPU ?
8811+
(char *) src1->data : (char *) src1_extra->data_device[g_main_device];
8812+
char * dst_original = dst->backend == GGML_BACKEND_CPU ?
8813+
(char *) dst->data : (char *) dst_extra->data_device[g_main_device];
88118814

88128815
if (src1->ne[1] == 1) {
8816+
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
8817+
GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
8818+
88138819
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
88148820
//int32_t row_id;
88158821
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
@@ -8837,6 +8843,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88378843
src1_row_extra.data_device[g_main_device] = src1_contiguous;
88388844
dst_row_extra.data_device[g_main_device] = dst_contiguous;
88398845

8846+
const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_CPU ?
8847+
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
8848+
const cudaMemcpyKind dst_kind = dst->backend == GGML_BACKEND_CPU ?
8849+
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
8850+
88408851
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
88418852
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
88428853

@@ -8851,7 +8862,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88518862
GGML_ASSERT(row_id >= 0 && row_id < n_as);
88528863

88538864
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
8854-
nb11, cudaMemcpyDeviceToDevice, stream));
8865+
nb11, src1_kind, stream));
88558866
num_src1_rows++;
88568867
}
88578868

@@ -8883,14 +8894,18 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88838894
GGML_ASSERT(row_id >= 0 && row_id < n_as);
88848895

88858896
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
8886-
nb1, cudaMemcpyDeviceToDevice, stream));
8897+
nb1, dst_kind, stream));
88878898
num_src1_rows++;
88888899
}
88898900
}
88908901

88918902
ggml_cuda_pool_free(src1_contiguous, as_src1);
88928903
ggml_cuda_pool_free(dst_contiguous, as_dst);
88938904
}
8905+
8906+
if (dst->backend == GGML_BACKEND_CPU) {
8907+
CUDA_CHECK(cudaStreamSynchronize(stream));
8908+
}
88948909
}
88958910

88968911
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9289,7 +9304,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
92899304
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
92909305
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
92919306

9292-
if (!any_on_device && tensor->op != GGML_OP_MUL_MAT) {
9307+
if (!any_on_device && tensor->op != GGML_OP_MUL_MAT && tensor->op != GGML_OP_MUL_MAT_ID) {
92939308
return false;
92949309
}
92959310

0 commit comments

Comments
 (0)