From d67148ebb2983e50c165c53c1955256c47aaa333 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Sun, 18 Jun 2023 16:12:08 +0200 Subject: [PATCH] Bugfix: Mat-Mul and wrong memory --- crates/llm/src/lib.rs | 9 +++++---- crates/models/falcon/src/lib.rs | 18 +++++++----------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index d40f37b7..d792b64b 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -92,6 +92,8 @@ use serde::Serialize; pub mod models { #[cfg(feature = "bloom")] pub use llm_bloom::{self as bloom, Bloom}; + #[cfg(feature = "falcon")] + pub use llm_falcon::{self as falcon, Falcon}; #[cfg(feature = "gpt2")] pub use llm_gpt2::{self as gpt2, Gpt2}; #[cfg(feature = "gptj")] @@ -102,8 +104,6 @@ pub mod models { pub use llm_llama::{self as llama, Llama}; #[cfg(feature = "mpt")] pub use llm_mpt::{self as mpt, Mpt}; - #[cfg(feature = "falcon")] - pub use llm_falcon::{self as falcon, Falcon}; } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)] @@ -276,8 +276,9 @@ pub fn load_dynamic( #[cfg(feature = "mpt")] Mpt => load_model::(path, vocabulary_source, params, load_progress_callback)?, #[cfg(feature = "falcon")] - Falcon => load_model::(path, vocabulary_source, params, load_progress_callback)?, - + Falcon => { + load_model::(path, vocabulary_source, params, load_progress_callback)? + } }; Ok(model) diff --git a/crates/models/falcon/src/lib.rs b/crates/models/falcon/src/lib.rs index 06505f78..78a26b0f 100644 --- a/crates/models/falcon/src/lib.rs +++ b/crates/models/falcon/src/lib.rs @@ -197,14 +197,14 @@ impl KnownModel for Falcon { // store key and value to memory let k = ctx0.op_view_1d( - &memory_k, + memory_k, n * head_dim, (memory_k_size * head_dim) * (il * ctx_size + session_len), ); let v = ctx0.op_view_1d( - &memory_v, + memory_v, n * head_dim, - (memory_k_size * head_dim) * (il * ctx_size + session_len), + (memory_v_size * head_dim) * (il * ctx_size + session_len), ); gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); @@ -216,7 +216,7 @@ impl KnownModel for Falcon { let mut bigk = ctx0.op_permute( &ctx0.op_reshape_3d( &ctx0.op_view_1d( - &memory_k, + memory_k, (session_len + n) * head_dim, il * ctx_size * memory_k_size * head_dim, ), @@ -228,7 +228,7 @@ impl KnownModel for Falcon { ); // K * Q bigk = ctx0.op_cont(&ctx0.op_repeat(&bigk, &repeat_dummy)); - let big_kq = ctx0.op_mul(&bigk, &bigq); + let big_kq = ctx0.op_mul_mat(&bigk, &bigq); // KQ_scaled = KQ / sqrt(n_embd/n_head) let big_kq_scaled = ctx0.op_scale_inplace( @@ -243,7 +243,7 @@ impl KnownModel for Falcon { let mut bigv = ctx0.op_permute( &ctx0.op_reshape_3d( &ctx0.op_view_1d( - &memory_v, + memory_v, (session_len + n) * head_dim, il * ctx_size * memory_v_size * head_dim, ), @@ -326,7 +326,7 @@ impl KnownModel for Falcon { } fn bot_token_id(&self) -> Option { - self.vocabulary.id("<|padding|>".as_bytes()) + None } fn eot_token_id(&self) -> TokenId { @@ -347,8 +347,6 @@ impl KnownModel for Falcon { pub struct Hyperparameters { /// Size of the model's vocabulary n_vocab: usize, - /// Maximum sequence length - n_ctx: usize, /// Size of the model's embedding layer n_embd: usize, /// n_heads @@ -363,7 +361,6 @@ impl llm_base::Hyperparameters for Hyperparameters { fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { let hyperparameters = Hyperparameters { n_vocab: util::read_i32(reader)?.try_into()?, - n_ctx: util::read_i32(reader)?.try_into()?, n_embd: util::read_i32(reader)?.try_into()?, n_head: util::read_i32(reader)?.try_into()?, n_layer: util::read_i32(reader)?.try_into()?, @@ -376,7 +373,6 @@ impl llm_base::Hyperparameters for Hyperparameters { fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { util::write_i32(writer, self.n_vocab.try_into()?)?; util::write_i32(writer, self.n_embd.try_into()?)?; - util::write_i32(writer, self.n_embd.try_into()?)?; util::write_i32(writer, self.n_head.try_into()?)?; util::write_i32(writer, self.n_layer.try_into()?)?; util::write_i32(writer, self.file_type.into())?;