Skip to content

Commit

Permalink
added n_gqa and n_head_kv fields
Browse files Browse the repository at this point in the history
  • Loading branch information
AmineDiro committed Aug 17, 2023
1 parent 15520ea commit d8e83e3
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 10 deletions.
1 change: 1 addition & 0 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ impl ModelLoad {
use_gpu,
gpu_layers: self.gpu_layers,
rope_overrides: self.rope_scaling.to_rope_arguments(),
n_gqa: None,
};

let mut sp = Some(spinoff::Spinner::new(
Expand Down
3 changes: 3 additions & 0 deletions crates/llm-base/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ pub struct ModelParameters {
pub gpu_layers: Option<usize>,
/// The arguments/overrides to pass to the [custom RoPE](https://arxiv.org/pdf/2306.15595.pdf) function, if it is used by the model.
pub rope_overrides: Option<ggml::RoPEOverrides>,
/// Enables gouped-query attention for Llama-2 70B model
pub n_gqa: Option<usize>,
}

impl Default for ModelParameters {
Expand All @@ -220,6 +222,7 @@ impl Default for ModelParameters {
use_gpu: false,
gpu_layers: None,
rope_overrides: None,
n_gqa: None,
}
}
}
Expand Down
67 changes: 57 additions & 10 deletions crates/models/llama/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! An implementation of [LLaMA](https://huggingface.co/docs/transformers/model_doc/llama) for the `llm` ecosystem.
#![deny(missing_docs)]

use std::{error::Error, sync::Arc};
use std::{cell::Cell, error::Error, sync::Arc};

use llm_base::{
ggml::{self},
Expand All @@ -18,7 +18,7 @@ pub struct Llama {
params: ModelParameters,
hyperparameters: Hyperparameters,
tokenizer: Tokenizer,

version: LlamaModelVersion,
// model-global weights
// weighted token embeddings
wte: ggml::Tensor,
Expand All @@ -41,7 +41,7 @@ impl KnownModel for Llama {
type Hyperparameters = Hyperparameters;

fn new<E: Error>(
hyperparameters: Self::Hyperparameters,
mut hyperparameters: Self::Hyperparameters,
params: ModelParameters,
tokenizer: Tokenizer,
tensor_loader: impl TensorLoader<E>,
Expand Down Expand Up @@ -94,9 +94,31 @@ impl KnownModel for Llama {
}
let context = tl.finish();

// TODO: read from file
let version = match hyperparameters.n_layer {
26 => LlamaModelVersion::Model3b,
32 => LlamaModelVersion::Model7b,
40 => LlamaModelVersion::Model13b,
60 => LlamaModelVersion::Model30b,
80 => LlamaModelVersion::Model65b,
_ => LlamaModelVersion::Model7b, // anything < 32
};
// TODO: temporary fix for 70B models
if let Some(n_gqa) = params.n_gqa {
if hyperparameters.n_layer >= 80 {
assert_eq!(
hyperparameters.n_head % n_gqa,
0,
"assuming 70B Llama2 model based on GQA == 8"
);
hyperparameters.n_head_kv = hyperparameters.n_head / n_gqa;
}
}

Ok(Self {
hyperparameters,
params,
version,
tokenizer,
wte,
norm,
Expand Down Expand Up @@ -133,6 +155,7 @@ impl KnownModel for Llama {
n_embd,
n_mult: _,
n_head,
n_head_kv,
n_layer,
n_rot,
file_type: _,
Expand Down Expand Up @@ -386,6 +409,8 @@ pub struct Hyperparameters {
pub n_mult: usize,
/// n_head
pub n_head: usize,
/// grouped-query attention
pub n_head_kv: usize,
/// Number of layers in the model
pub n_layer: usize,
/// n_rot
Expand All @@ -396,14 +421,26 @@ pub struct Hyperparameters {

impl llm_base::Hyperparameters for Hyperparameters {
fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result<Self, LoadError> {
let n_head = util::read_i32(reader)?.try_into()?;
let n_layer = util::read_i32(reader)?.try_into()?;
let n_vocab = util::read_i32(reader)?.try_into()?;
let n_embd = util::read_i32(reader)?.try_into()?;
let n_mult = util::read_i32(reader)?.try_into()?;
let n_rot = util::read_i32(reader)?.try_into()?;
let file_type = util::read_filetype(reader)?;

// n_head_kv == n_heads for Multi-Head Attention
let n_head_kv = n_head;

Ok(Hyperparameters {
n_vocab: util::read_i32(reader)?.try_into()?,
n_embd: util::read_i32(reader)?.try_into()?,
n_mult: util::read_i32(reader)?.try_into()?,
n_head: util::read_i32(reader)?.try_into()?,
n_layer: util::read_i32(reader)?.try_into()?,
n_rot: util::read_i32(reader)?.try_into()?,
file_type: util::read_filetype(reader)?,
n_head,
n_head_kv,
n_vocab,
n_embd,
n_mult,
n_layer,
n_rot,
file_type,
})
}

Expand Down Expand Up @@ -447,3 +484,13 @@ struct Layer {
w2: ggml::Tensor,
w3: ggml::Tensor,
}

/// Available Llama models
enum LlamaModelVersion {
Model3b,
Model7b,
Model13b,
Model30b,
Model65b,
Model70b,
}

0 comments on commit d8e83e3

Please sign in to comment.