Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda : improve text-generation and batched decoding performance #3776

Merged
merged 7 commits into from
Oct 27, 2023
Prev Previous commit
Next Next commit
cuda : add F32 sgemm branch
  • Loading branch information
ggerganov committed Oct 25, 2023
commit 16b60dd75c8c89b726da5e9252454791fa1300b7
38 changes: 35 additions & 3 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7252,7 +7252,8 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];

if (ggml_is_contiguous(src0)) {
#if 0
{
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
half * src0_as_f16 = nullptr;
size_t src0_as = 0;
Expand Down Expand Up @@ -7306,9 +7307,40 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm
if (src1_as != 0) {
ggml_cuda_pool_free(src1_as_f16, src1_as);
}
} else {
GGML_ASSERT(false && "not implemented");
}
#else
{
// convert src0 to fp32, multiply as fp32
float * src0_as_f32 = nullptr;
size_t src0_as = 0;
if (src0->type != GGML_TYPE_F32) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
GGML_ASSERT(to_fp32_cuda != nullptr);
const size_t ne = ne01*ne00;
src0_as_f32 = (float *) ggml_cuda_pool_malloc(ne * sizeof(float), &src0_as);
to_fp32_cuda(src0_ddq, src0_as_f32, ne, main_stream);
}

const float * src0_ptr = src0->type == GGML_TYPE_F32 ? (const float *) src0_ddq : src0_as_f32;

const float * src1_ptr = (const float *) src1_ddf;

const float alpha = 1.0f;
const float beta = 0.0f;

CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
CUBLAS_CHECK(
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, src0_ptr, ne00,
src1_ptr, ne10,
&beta, dst_ddf, ne01));

if (src0_as != 0) {
ggml_cuda_pool_free(src0_as_f32, src0_as);
}
}
#endif
}

static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
Expand Down