From d8e83e3c3e71807adecee8e284cc82dcb57b4a79 Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 15 Aug 2023 15:47:58 +0200 Subject: [PATCH] added n_gqa and n_head_kv fields --- binaries/llm-cli/src/cli_args.rs | 1 + crates/llm-base/src/model/mod.rs | 3 ++ crates/models/llama/src/lib.rs | 67 +++++++++++++++++++++++++++----- 3 files changed, 61 insertions(+), 10 deletions(-) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 9beeaf22..21b4a897 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -519,6 +519,7 @@ impl ModelLoad { use_gpu, gpu_layers: self.gpu_layers, rope_overrides: self.rope_scaling.to_rope_arguments(), + n_gqa: None, }; let mut sp = Some(spinoff::Spinner::new( diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index 3d5bc163..b31faf56 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -209,6 +209,8 @@ pub struct ModelParameters { pub gpu_layers: Option, /// The arguments/overrides to pass to the [custom RoPE](https://arxiv.org/pdf/2306.15595.pdf) function, if it is used by the model. pub rope_overrides: Option, + /// Enables gouped-query attention for Llama-2 70B model + pub n_gqa: Option, } impl Default for ModelParameters { @@ -220,6 +222,7 @@ impl Default for ModelParameters { use_gpu: false, gpu_layers: None, rope_overrides: None, + n_gqa: None, } } } diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index e701a42e..b321b489 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -1,7 +1,7 @@ //! An implementation of [LLaMA](https://huggingface.co/docs/transformers/model_doc/llama) for the `llm` ecosystem. #![deny(missing_docs)] -use std::{error::Error, sync::Arc}; +use std::{cell::Cell, error::Error, sync::Arc}; use llm_base::{ ggml::{self}, @@ -18,7 +18,7 @@ pub struct Llama { params: ModelParameters, hyperparameters: Hyperparameters, tokenizer: Tokenizer, - + version: LlamaModelVersion, // model-global weights // weighted token embeddings wte: ggml::Tensor, @@ -41,7 +41,7 @@ impl KnownModel for Llama { type Hyperparameters = Hyperparameters; fn new( - hyperparameters: Self::Hyperparameters, + mut hyperparameters: Self::Hyperparameters, params: ModelParameters, tokenizer: Tokenizer, tensor_loader: impl TensorLoader, @@ -94,9 +94,31 @@ impl KnownModel for Llama { } let context = tl.finish(); + // TODO: read from file + let version = match hyperparameters.n_layer { + 26 => LlamaModelVersion::Model3b, + 32 => LlamaModelVersion::Model7b, + 40 => LlamaModelVersion::Model13b, + 60 => LlamaModelVersion::Model30b, + 80 => LlamaModelVersion::Model65b, + _ => LlamaModelVersion::Model7b, // anything < 32 + }; + // TODO: temporary fix for 70B models + if let Some(n_gqa) = params.n_gqa { + if hyperparameters.n_layer >= 80 { + assert_eq!( + hyperparameters.n_head % n_gqa, + 0, + "assuming 70B Llama2 model based on GQA == 8" + ); + hyperparameters.n_head_kv = hyperparameters.n_head / n_gqa; + } + } + Ok(Self { hyperparameters, params, + version, tokenizer, wte, norm, @@ -133,6 +155,7 @@ impl KnownModel for Llama { n_embd, n_mult: _, n_head, + n_head_kv, n_layer, n_rot, file_type: _, @@ -386,6 +409,8 @@ pub struct Hyperparameters { pub n_mult: usize, /// n_head pub n_head: usize, + /// grouped-query attention + pub n_head_kv: usize, /// Number of layers in the model pub n_layer: usize, /// n_rot @@ -396,14 +421,26 @@ pub struct Hyperparameters { impl llm_base::Hyperparameters for Hyperparameters { fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { + let n_head = util::read_i32(reader)?.try_into()?; + let n_layer = util::read_i32(reader)?.try_into()?; + let n_vocab = util::read_i32(reader)?.try_into()?; + let n_embd = util::read_i32(reader)?.try_into()?; + let n_mult = util::read_i32(reader)?.try_into()?; + let n_rot = util::read_i32(reader)?.try_into()?; + let file_type = util::read_filetype(reader)?; + + // n_head_kv == n_heads for Multi-Head Attention + let n_head_kv = n_head; + Ok(Hyperparameters { - n_vocab: util::read_i32(reader)?.try_into()?, - n_embd: util::read_i32(reader)?.try_into()?, - n_mult: util::read_i32(reader)?.try_into()?, - n_head: util::read_i32(reader)?.try_into()?, - n_layer: util::read_i32(reader)?.try_into()?, - n_rot: util::read_i32(reader)?.try_into()?, - file_type: util::read_filetype(reader)?, + n_head, + n_head_kv, + n_vocab, + n_embd, + n_mult, + n_layer, + n_rot, + file_type, }) } @@ -447,3 +484,13 @@ struct Layer { w2: ggml::Tensor, w3: ggml::Tensor, } + +/// Available Llama models +enum LlamaModelVersion { + Model3b, + Model7b, + Model13b, + Model30b, + Model65b, + Model70b, +}