Skip to content

Commit ecf84ca

Browse files
committed
llama: rwkv6: Fix tensor loading for 7B/14B models
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
1 parent 8d6b2b1 commit ecf84ca

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/llama.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7726,10 +7726,9 @@ static bool llm_load_tensors(
77267726
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
77277727
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
77287728

7729-
// TODO: Parameterize this
7730-
const int time_mix_extra_dim = 32;
7731-
const int time_decay_extra_dim = 64;
7732-
const int head_size = 64;
7729+
const int time_mix_extra_dim = (n_embd == 4096) ? 64 : 32;
7730+
const int time_decay_extra_dim = (n_embd == 4096) ? 128 : 64;
7731+
const int head_size = hparams.wkv_head_size;
77337732
const int attn_hidden_size = n_embd;
77347733
const int ffn_size = (int)(n_embd * 3.5 / 32) * 32;
77357734

0 commit comments

Comments
 (0)