Skip to content

Commit

Permalink
split embdding to groups
Browse files Browse the repository at this point in the history
  • Loading branch information
AmineDiro committed Aug 17, 2023
1 parent 138404c commit e532678
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions crates/models/llama/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ impl KnownModel for Llama {
n_rot,
file_type: _,
} = self.hyperparameters;
let n_embd_gqa = n_embd / (n_head / n_head_kv);

let outputs = session.compute(self.context.clone(), input_tokens, |builder| {
let mut ctx0 = builder.ctx0.borrow_mut();
Expand Down Expand Up @@ -220,21 +221,21 @@ impl KnownModel for Llama {
// compute the transposed [N, n_embd] V matrix
let v_current = ctx0.op_transpose(&ctx0.op_reshape_2d(
&ctx0.op_mul_mat(&self.layers[il].wv, &current),
n_embd,
n_embd_gqa,
input_len,
));

let k = ctx0.op_view_1d(
builder.memory_k,
input_len * n_embd,
(builder.memory_k.element_size() * n_embd) * (il * ctx_size + session_len),
input_len * n_embd_gqa,
(builder.memory_k.element_size() * n_embd_gqa) * (il * ctx_size + session_len),
);

let v = ctx0.op_view_2d(
builder.memory_v,
(input_len, n_embd),
(input_len, n_embd_gqa),
ctx_size * builder.memory_v.element_size(),
(il * ctx_size) * builder.memory_v.element_size() * n_embd
(il * ctx_size) * builder.memory_v.element_size() * n_embd_gqa
+ session_len * builder.memory_v.element_size(),
);

Expand All @@ -249,8 +250,8 @@ impl KnownModel for Llama {
&ctx0.op_reshape_3d(
&ctx0.op_view_1d(
builder.memory_k,
(session_len + input_len) * n_embd,
il * ctx_size * builder.memory_k.element_size() * n_embd,
(session_len + input_len) * n_embd_gqa,
il * ctx_size * builder.memory_k.element_size() * n_embd_gqa,
),
n_embd / n_head,
n_head_kv,
Expand Down Expand Up @@ -288,7 +289,7 @@ impl KnownModel for Llama {
ctx_size * builder.memory_v.element_size(),
ctx_size * builder.memory_v.element_size() * n_embd / n_head,
),
il * ctx_size * builder.memory_v.element_size() * n_embd,
il * ctx_size * builder.memory_v.element_size() * n_embd_gqa,
)
.set_name("V");

Expand Down

0 comments on commit e532678

Please sign in to comment.