Skip to content

Commit

Permalink
Bugfix: Mat-Mul and wrong memory
Browse files Browse the repository at this point in the history
  • Loading branch information
LLukas22 committed Jun 18, 2023
1 parent 611d245 commit d67148e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 15 deletions.
9 changes: 5 additions & 4 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ use serde::Serialize;
pub mod models {
#[cfg(feature = "bloom")]
pub use llm_bloom::{self as bloom, Bloom};
#[cfg(feature = "falcon")]
pub use llm_falcon::{self as falcon, Falcon};
#[cfg(feature = "gpt2")]
pub use llm_gpt2::{self as gpt2, Gpt2};
#[cfg(feature = "gptj")]
Expand All @@ -102,8 +104,6 @@ pub mod models {
pub use llm_llama::{self as llama, Llama};
#[cfg(feature = "mpt")]
pub use llm_mpt::{self as mpt, Mpt};
#[cfg(feature = "falcon")]
pub use llm_falcon::{self as falcon, Falcon};
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
Expand Down Expand Up @@ -276,8 +276,9 @@ pub fn load_dynamic(
#[cfg(feature = "mpt")]
Mpt => load_model::<models::Mpt>(path, vocabulary_source, params, load_progress_callback)?,
#[cfg(feature = "falcon")]
Falcon => load_model::<models::Falcon>(path, vocabulary_source, params, load_progress_callback)?,

Falcon => {
load_model::<models::Falcon>(path, vocabulary_source, params, load_progress_callback)?
}
};

Ok(model)
Expand Down
18 changes: 7 additions & 11 deletions crates/models/falcon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,14 @@ impl KnownModel for Falcon {
// store key and value to memory

let k = ctx0.op_view_1d(
&memory_k,
memory_k,
n * head_dim,
(memory_k_size * head_dim) * (il * ctx_size + session_len),
);
let v = ctx0.op_view_1d(
&memory_v,
memory_v,
n * head_dim,
(memory_k_size * head_dim) * (il * ctx_size + session_len),
(memory_v_size * head_dim) * (il * ctx_size + session_len),
);

gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k));
Expand All @@ -216,7 +216,7 @@ impl KnownModel for Falcon {
let mut bigk = ctx0.op_permute(
&ctx0.op_reshape_3d(
&ctx0.op_view_1d(
&memory_k,
memory_k,
(session_len + n) * head_dim,
il * ctx_size * memory_k_size * head_dim,
),
Expand All @@ -228,7 +228,7 @@ impl KnownModel for Falcon {
);
// K * Q
bigk = ctx0.op_cont(&ctx0.op_repeat(&bigk, &repeat_dummy));
let big_kq = ctx0.op_mul(&bigk, &bigq);
let big_kq = ctx0.op_mul_mat(&bigk, &bigq);

// KQ_scaled = KQ / sqrt(n_embd/n_head)
let big_kq_scaled = ctx0.op_scale_inplace(
Expand All @@ -243,7 +243,7 @@ impl KnownModel for Falcon {
let mut bigv = ctx0.op_permute(
&ctx0.op_reshape_3d(
&ctx0.op_view_1d(
&memory_v,
memory_v,
(session_len + n) * head_dim,
il * ctx_size * memory_v_size * head_dim,
),
Expand Down Expand Up @@ -326,7 +326,7 @@ impl KnownModel for Falcon {
}

fn bot_token_id(&self) -> Option<TokenId> {
self.vocabulary.id("<|padding|>".as_bytes())
None
}

fn eot_token_id(&self) -> TokenId {
Expand All @@ -347,8 +347,6 @@ impl KnownModel for Falcon {
pub struct Hyperparameters {
/// Size of the model's vocabulary
n_vocab: usize,
/// Maximum sequence length
n_ctx: usize,
/// Size of the model's embedding layer
n_embd: usize,
/// n_heads
Expand All @@ -363,7 +361,6 @@ impl llm_base::Hyperparameters for Hyperparameters {
fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result<Self, LoadError> {
let hyperparameters = Hyperparameters {
n_vocab: util::read_i32(reader)?.try_into()?,
n_ctx: util::read_i32(reader)?.try_into()?,
n_embd: util::read_i32(reader)?.try_into()?,
n_head: util::read_i32(reader)?.try_into()?,
n_layer: util::read_i32(reader)?.try_into()?,
Expand All @@ -376,7 +373,6 @@ impl llm_base::Hyperparameters for Hyperparameters {
fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> {
util::write_i32(writer, self.n_vocab.try_into()?)?;
util::write_i32(writer, self.n_embd.try_into()?)?;
util::write_i32(writer, self.n_embd.try_into()?)?;
util::write_i32(writer, self.n_head.try_into()?)?;
util::write_i32(writer, self.n_layer.try_into()?)?;
util::write_i32(writer, self.file_type.into())?;
Expand Down

0 comments on commit d67148e

Please sign in to comment.