Skip to content

Commit 93b0547

Browse files
CUDA GPU acceleration for LoRAs + f16 models
1 parent 7487137 commit 93b0547

File tree

4 files changed

+83
-19
lines changed

4 files changed

+83
-19
lines changed

examples/common.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -414,13 +414,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
414414
exit(1);
415415
}
416416

417-
#ifdef GGML_USE_CUBLAS
418-
if (!params.lora_adapter.empty() && params.n_gpu_layers > 0) {
419-
fprintf(stderr, "%s: error: the simultaneous use of LoRAs and GPU acceleration is not supported", __func__);
420-
exit(1);
421-
}
422-
#endif // GGML_USE_CUBLAS
423-
424417
if (escape_prompt) {
425418
process_escapes(params.prompt);
426419
}

ggml-cuda.cu

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,15 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co
194194
dst[i] = x[i] + y[i];
195195
}
196196

197+
static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
198+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
199+
200+
if (i >= k) {
201+
return;
202+
}
203+
dst[i] = x[i] + __float2half(y[i]);
204+
}
205+
197206
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
198207
const int i = blockDim.x*blockIdx.x + threadIdx.x;
199208

@@ -1209,6 +1218,11 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in
12091218
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
12101219
}
12111220

1221+
static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
1222+
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
1223+
add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
1224+
}
1225+
12121226
static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
12131227
const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
12141228
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
@@ -1675,15 +1689,26 @@ inline void ggml_cuda_op_add(
16751689
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
16761690
cudaStream_t & cudaStream_main){
16771691

1678-
GGML_ASSERT(src0_ddf_i != nullptr);
1692+
GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr);
16791693
GGML_ASSERT(src1_ddf_i != nullptr);
16801694
GGML_ASSERT(dst_ddf_i != nullptr);
16811695

16821696
const int64_t ne0 = src0->ne[0];
16831697
const int64_t i01_diff = i01_high - i01_low;
16841698

16851699
// compute
1686-
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
1700+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
1701+
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
1702+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
1703+
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) src0->extra;
1704+
// ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) src1->extra;
1705+
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu * ) dst->extra;
1706+
GGML_ASSERT(src0_extra->data_device[g_main_device] == dst_extra->data_device[g_main_device]);
1707+
GGML_ASSERT(src0_ddq_i == (char *) dst_ddf_i);
1708+
add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
1709+
} else {
1710+
GGML_ASSERT(false);
1711+
}
16871712
CUDA_CHECK(cudaGetLastError());
16881713

16891714
(void) src1;
@@ -2281,8 +2306,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
22812306
}
22822307

22832308
void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2284-
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
2285-
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true);
2309+
// ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op.
2310+
// Due to flatten_rows == true this does in practice not make a difference however.
2311+
// Better solution would be nice but right now that would require disproportionate changes.
2312+
GGML_ASSERT(
2313+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
2314+
src1->type == GGML_TYPE_F32 &&
2315+
(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16));
2316+
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true);
22862317
}
22872318

22882319
void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2535,7 +2566,7 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
25352566
delete extra;
25362567
}
25372568

2538-
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
2569+
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
25392570
if (scratch && g_scratch_size == 0) {
25402571
return;
25412572
}
@@ -2544,22 +2575,23 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
25442575
if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
25452576
const ggml_op src0_op = tensor->src0->op;
25462577
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
2547-
ggml_cuda_assign_buffers_impl(tensor->src0, scratch);
2578+
ggml_cuda_assign_buffers_impl(tensor->src0, scratch, force_inplace);
25482579
}
25492580
}
25502581
if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
2551-
ggml_cuda_assign_buffers_impl(tensor->src1, scratch);
2582+
ggml_cuda_assign_buffers_impl(tensor->src1, scratch, force_inplace);
25522583
}
25532584

25542585
tensor->backend = GGML_BACKEND_GPU;
25552586
struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
25562587

25572588
const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
2558-
tensor->op == GGML_OP_VIEW;
2589+
tensor->op == GGML_OP_VIEW ||
2590+
force_inplace;
25592591
const size_t size = ggml_nbytes(tensor);
25602592

25612593
CUDA_CHECK(cudaSetDevice(g_main_device));
2562-
if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) {
2594+
if (inplace && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT)) {
25632595
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
25642596
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
25652597
size_t offset = 0;
@@ -2598,11 +2630,15 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
25982630
}
25992631

26002632
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
2601-
ggml_cuda_assign_buffers_impl(tensor, true);
2633+
ggml_cuda_assign_buffers_impl(tensor, true, false);
26022634
}
26032635

26042636
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
2605-
ggml_cuda_assign_buffers_impl(tensor, false);
2637+
ggml_cuda_assign_buffers_impl(tensor, false, false);
2638+
}
2639+
2640+
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
2641+
ggml_cuda_assign_buffers_impl(tensor, false, true);
26062642
}
26072643

26082644
void ggml_cuda_set_main_device(int main_device) {

ggml-cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
2929
void ggml_cuda_free_data(struct ggml_tensor * tensor);
3030
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
3131
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
32+
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
3233
void ggml_cuda_set_main_device(int main_device);
3334
void ggml_cuda_set_scratch_size(size_t scratch_size);
3435
void ggml_cuda_free_scratch(void);

llama.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2907,14 +2907,15 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
29072907
return false;
29082908
}
29092909
}
2910-
ggml_tensor* lora_tensor;
2910+
ggml_tensor * lora_tensor;
29112911
if (n_dims == 2) {
29122912
lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]);
29132913
}
29142914
else {
29152915
fprintf(stderr, "%s: unsupported tensor dimension %d\n", __func__, n_dims);
29162916
return 1;
29172917
}
2918+
ggml_set_name(lora_tensor, "lora_tensor");
29182919

29192920
// load tensor data
29202921
size_t offset = fin.tellg();
@@ -2930,6 +2931,21 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
29302931
lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) {
29312932

29322933
ggml_tensor * dest_t = model_tensors[base_name];
2934+
2935+
offload_func_t offload_func = llama_nop;
2936+
offload_func_t offload_func_force_inplace = llama_nop;
2937+
2938+
#ifdef GGML_USE_CUBLAS
2939+
if (dest_t->type != GGML_TYPE_F16) {
2940+
throw std::runtime_error(format(
2941+
"%s: error: the simultaneous use of LoRAs and GPU acceleration is only supported for f16 models", __func__));
2942+
}
2943+
if (dest_t->backend == GGML_BACKEND_GPU || dest_t->backend == GGML_BACKEND_GPU_SPLIT) {
2944+
offload_func = ggml_cuda_assign_buffers;
2945+
offload_func_force_inplace = ggml_cuda_assign_buffers_force_inplace;
2946+
}
2947+
#endif // GGML_USE_CUBLAS
2948+
29332949
ggml_tensor * base_t;
29342950
if (model_loader) {
29352951
// load from base model
@@ -2957,7 +2973,12 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
29572973
}
29582974

29592975
ggml_tensor * loraA = lora_tensors[base_name + ".loraA"];
2976+
GGML_ASSERT(loraA->type == GGML_TYPE_F32);
2977+
ggml_set_name(loraA, "loraA");
2978+
29602979
ggml_tensor * loraB = lora_tensors[base_name + ".loraB"];
2980+
GGML_ASSERT(loraB->type == GGML_TYPE_F32);
2981+
ggml_set_name(loraB, "loraB");
29612982

29622983
if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
29632984
fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
@@ -2967,19 +2988,32 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
29672988

29682989
// w = w + BA*s
29692990
ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
2991+
offload_func(BA);
2992+
ggml_set_name(BA, "BA");
29702993

29712994
if (scaling != 1.0f) {
29722995
ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling);
2996+
ggml_set_name(scale_tensor, "scale_tensor");
2997+
29732998
BA = ggml_scale_inplace(lora_ctx, BA, scale_tensor);
2999+
offload_func(BA);
3000+
ggml_set_name(BA, "BA_scaled");
29743001
}
29753002

29763003
ggml_tensor * r;
29773004
if (base_t == dest_t) {
29783005
r = ggml_add_inplace(lora_ctx, dest_t, BA);
3006+
offload_func_force_inplace(r);
3007+
ggml_set_name(r, "r_add_inplace");
29793008
}
29803009
else {
29813010
r = ggml_add(lora_ctx, base_t, BA);
3011+
offload_func(r);
3012+
ggml_set_name(r, "r_add");
3013+
29823014
r = ggml_cpy(lora_ctx, r, dest_t);
3015+
offload_func(r);
3016+
ggml_set_name(r, "r_cpy");
29833017
}
29843018

29853019
struct ggml_cgraph gf = ggml_build_forward(r);

0 commit comments

Comments
 (0)