Skip to content

Commit

Permalink
Eliminated need for ggml_repeat2 by using a modified version of #224
Browse files Browse the repository at this point in the history
…instead
  • Loading branch information
jploski committed Jun 25, 2023
1 parent d8c51b2 commit 2e30a2b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
5 changes: 1 addition & 4 deletions examples/falcon/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,6 @@ bool falcon_eval(

// wte
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
struct ggml_tensor* repeat_dummy = ggml_new_tensor_3d(ctx0, inpL->type, head_dim, N + n_past, n_head);

ggml_type wtype = GGML_TYPE_F32;
const int sizeof_wtype = ggml_type_sizef(wtype);
Expand Down Expand Up @@ -539,8 +538,6 @@ bool falcon_eval(

// K * Q

K = ggml_cont(ctx0, ggml_repeat2(ctx0, K, repeat_dummy));

struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);

Expand Down Expand Up @@ -570,7 +567,7 @@ bool falcon_eval(
head_dim, n_head_kv, n_past + N),
0, 2, 1, 3);

V = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_repeat2(ctx0, V, repeat_dummy)));
V = ggml_cont(ctx0, ggml_transpose(ctx0, V));

// KQV = transpose(V) * KQ_soft_max
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
Expand Down
6 changes: 5 additions & 1 deletion src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -10735,7 +10735,11 @@ static void ggml_compute_forward_mul_mat_f32(

const int64_t ir0 = (ir1/ne11)%(ne02*ne03);
const int64_t i03 = (ir0/(ne02));
const int64_t i02 = (ir0 - i03*ne02);
// Hack for "Falcon multi-query-attention key stutter" / alternative to ggml_repeat2.
// See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470:
const int64_t i02 = (i12 / (ne12 / ne02));
// Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon)
// const int64_t i02 = (ir0 - i03*ne02);

const int64_t i1 = i11;
const int64_t i2 = i12;
Expand Down

0 comments on commit 2e30a2b

Please sign in to comment.