Skip to content

Commit

Permalink
gqa in self attention
Browse files Browse the repository at this point in the history
  • Loading branch information
AmineDiro committed Aug 17, 2023
1 parent dadd757 commit 002ef93
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions crates/models/llama/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl KnownModel for Llama {
let context = tl.finish();

// TODO: read from file
let version = match hyperparameters.n_layer {
let mut version = match hyperparameters.n_layer {
26 => LlamaModelVersion::Model3b,
32 => LlamaModelVersion::Model7b,
40 => LlamaModelVersion::Model13b,
Expand All @@ -112,6 +112,7 @@ impl KnownModel for Llama {
"assuming 70B Llama2 model based on GQA == 8"
);
hyperparameters.n_head_kv = hyperparameters.n_head / n_gqa;
version = LlamaModelVersion::Model70b;
}
}

Expand Down Expand Up @@ -205,7 +206,7 @@ impl KnownModel for Llama {
&ctx0.op_reshape_3d(
&ctx0.op_mul_mat(&self.layers[il].wk, &current),
n_embd / n_head,
n_head,
n_head_kv,
input_len,
),
session_len,
Expand Down Expand Up @@ -252,7 +253,7 @@ impl KnownModel for Llama {
il * ctx_size * builder.memory_k.element_size() * n_embd,
),
n_embd / n_head,
n_head,
n_head_kv,
session_len + input_len,
),
(0, 2, 1, 3),
Expand Down Expand Up @@ -282,7 +283,7 @@ impl KnownModel for Llama {
let v = ctx0
.op_view_3d(
builder.memory_v,
(session_len + input_len, n_embd / n_head, n_head),
(session_len + input_len, n_embd / n_head, n_head_kv),
(
ctx_size * builder.memory_v.element_size(),
ctx_size * builder.memory_v.element_size() * n_embd / n_head,
Expand Down

0 comments on commit 002ef93

Please sign in to comment.