Skip to content

Commit fb92acd

Browse files
committed
better solution for p0 computation
1 parent eec6b66 commit fb92acd

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

ggml-cuda.cu

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include <stdio.h>
66
#include <atomic>
77
#include <assert.h>
8-
#include <vector>
98

109
#if defined(GGML_USE_HIPBLAS)
1110
#include <hip/hip_runtime.h>
@@ -440,6 +439,7 @@ static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullpt
440439
struct ggml_tensor_extra_gpu {
441440
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
442441
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
442+
bool copied;
443443
};
444444

445445
// this is faster on Windows
@@ -4356,6 +4356,14 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
43564356
}
43574357

43584358
// rope == RoPE == rotary positional embedding
4359+
static __global__ void compute_rope_p0(const int32_t * pos, float * p0, int n, int mode, float freq_scale) {
4360+
int i = blockIdx.x * blockDim.x + threadIdx.x;
4361+
if (i < n) {
4362+
int p = pos[i];
4363+
p0[i] = (((mode & 1) == 0 ? p : 0)) * freq_scale;
4364+
}
4365+
}
4366+
43594367
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float * p0,
43604368
const float p_delta, const int p_delta_rows, const float theta_scale) {
43614369
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
@@ -6091,18 +6099,20 @@ inline void ggml_cuda_op_rope(
60916099

60926100
GGML_ASSERT(src1->type == GGML_TYPE_I32);
60936101
GGML_ASSERT(src1->ne[0] == ne2);
6102+
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
60946103

6095-
std::vector<float> p0s(ne2);
6096-
for (int64_t i = 0; i < ne2; ++i) {
6097-
int n_past = ((int32_t *) src1->data)[i];
6098-
p0s[i] = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
6104+
int id;
6105+
CUDA_CHECK(cudaGetDevice(&id));
6106+
6107+
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
6108+
if (!src1_extra->copied) {
6109+
CUDA_CHECK(cudaMemcpyAsync(src1_extra->data_device[id], src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream));
6110+
src1_extra->copied = true;
60996111
}
61006112

61016113
size_t p0d_as = 0;
6102-
float * p0d;
6103-
6104-
p0d = (float *) ggml_cuda_pool_malloc(ne2 * sizeof(float), &p0d_as);
6105-
CUDA_CHECK(cudaMemcpyAsync(p0d, p0s.data(), ne2 * sizeof(float), cudaMemcpyHostToDevice, main_stream));
6114+
float * p0d = (float *) ggml_cuda_pool_malloc(ne2 * sizeof(float), &p0d_as);
6115+
compute_rope_p0<<<(ne2 + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE, CUDA_ROPE_BLOCK_SIZE, 0, main_stream>>>((int32_t*)src1_extra->data_device[id], p0d, ne2, mode, freq_scale);
61066116

61076117
const bool is_neox = mode & 2;
61086118
const bool is_glm = mode & 4;

llama.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2705,6 +2705,7 @@ static struct ggml_cgraph * llm_build_llama(
27052705

27062706
// KQ_pos - contains the positions
27072707
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
2708+
offload_func_kq(KQ_pos);
27082709
ggml_allocr_alloc(lctx.alloc, KQ_pos);
27092710
if (!ggml_allocr_is_measure(lctx.alloc)) {
27102711
int * data = (int *) KQ_pos->data;
@@ -2715,6 +2716,7 @@ static struct ggml_cgraph * llm_build_llama(
27152716

27162717
// K_shift
27172718
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
2719+
offload_func_kq(K_shift);
27182720
ggml_allocr_alloc(lctx.alloc, K_shift);
27192721
if (!ggml_allocr_is_measure(lctx.alloc)) {
27202722
int * data = (int *) K_shift->data;
@@ -3087,6 +3089,7 @@ static struct ggml_cgraph * llm_build_baichaun(
30873089

30883090
// KQ_pos - contains the positions
30893091
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
3092+
offload_func_kq(KQ_pos);
30903093
ggml_allocr_alloc(lctx.alloc, KQ_pos);
30913094
if (!ggml_allocr_is_measure(lctx.alloc)) {
30923095
int * data = (int *) KQ_pos->data;
@@ -3097,6 +3100,7 @@ static struct ggml_cgraph * llm_build_baichaun(
30973100

30983101
// K_shift
30993102
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
3103+
offload_func_kq(K_shift);
31003104
ggml_allocr_alloc(lctx.alloc, K_shift);
31013105
if (!ggml_allocr_is_measure(lctx.alloc)) {
31023106
int * data = (int *) K_shift->data;
@@ -3486,6 +3490,7 @@ static struct ggml_cgraph * llm_build_falcon(
34863490

34873491
// KQ_pos - contains the positions
34883492
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
3493+
offload_func_kq(KQ_pos);
34893494
ggml_allocr_alloc(lctx.alloc, KQ_pos);
34903495
if (!ggml_allocr_is_measure(lctx.alloc)) {
34913496
int * data = (int *) KQ_pos->data;
@@ -3496,6 +3501,7 @@ static struct ggml_cgraph * llm_build_falcon(
34963501

34973502
// K_shift
34983503
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
3504+
offload_func_kq(K_shift);
34993505
ggml_allocr_alloc(lctx.alloc, K_shift);
35003506
if (!ggml_allocr_is_measure(lctx.alloc)) {
35013507
int * data = (int *) K_shift->data;

0 commit comments

Comments
 (0)