Skip to content

Commit

Permalink
Merge pull request rustformers#409 from AmineDiro/gqa
Browse files Browse the repository at this point in the history
Support for Llama 70-B
  • Loading branch information
LLukas22 authored Aug 19, 2023
2 parents 56e4a35 + a16caba commit 2f6ffd4
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 20 deletions.
1 change: 1 addition & 0 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ impl ModelLoad {
use_gpu,
gpu_layers: self.gpu_layers,
rope_overrides: self.rope_scaling.to_rope_arguments(),
n_gqa: None,
};

let mut sp = Some(spinoff::Spinner::new(
Expand Down
2 changes: 2 additions & 0 deletions crates/ggml/sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ fn enable_cublas(build: &mut cc::Build, out_dir: &Path) {
.arg("static")
.arg("--generate-code=arch=compute_52,code=[compute_52,sm_52]")
.arg("--generate-code=arch=compute_61,code=[compute_61,sm_61]")
.arg("--generate-code=arch=compute_75,code=[compute_75,sm_75]")
.arg("-D_WINDOWS")
.arg("-DNDEBUG")
.arg("-DGGML_USE_CUBLAS")
Expand Down Expand Up @@ -363,6 +364,7 @@ fn enable_cublas(build: &mut cc::Build, out_dir: &Path) {
.arg("-pthread")
.arg("--generate-code=arch=compute_52,code=[compute_52,sm_52]")
.arg("--generate-code=arch=compute_61,code=[compute_61,sm_61]")
.arg("--generate-code=arch=compute_75,code=[compute_75,sm_75]")
.arg("-DGGML_USE_CUBLAS")
.arg("-I/usr/local/cuda/include")
.arg("-I/opt/cuda/include")
Expand Down
3 changes: 3 additions & 0 deletions crates/llm-base/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ pub struct ModelParameters {
pub gpu_layers: Option<usize>,
/// The arguments/overrides to pass to the [custom RoPE](https://arxiv.org/pdf/2306.15595.pdf) function, if it is used by the model.
pub rope_overrides: Option<ggml::RoPEOverrides>,
/// Enables gouped-query attention for Llama-2 70B model
pub n_gqa: Option<usize>,
}

impl Default for ModelParameters {
Expand All @@ -220,6 +222,7 @@ impl Default for ModelParameters {
use_gpu: false,
gpu_layers: None,
rope_overrides: None,
n_gqa: None,
}
}
}
Expand Down
89 changes: 69 additions & 20 deletions crates/models/llama/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub struct Llama {
params: ModelParameters,
hyperparameters: Hyperparameters,
tokenizer: Tokenizer,

_version: LlamaModelType,
// model-global weights
// weighted token embeddings
wte: ggml::Tensor,
Expand All @@ -41,7 +41,7 @@ impl KnownModel for Llama {
type Hyperparameters = Hyperparameters;

fn new<E: Error>(
hyperparameters: Self::Hyperparameters,
mut hyperparameters: Self::Hyperparameters,
params: ModelParameters,
tokenizer: Tokenizer,
tensor_loader: impl TensorLoader<E>,
Expand Down Expand Up @@ -94,9 +94,32 @@ impl KnownModel for Llama {
}
let context = tl.finish();

// TODO: read from file
let mut version = match hyperparameters.n_layer {
26 => LlamaModelType::Model3b,
32 => LlamaModelType::Model7b,
40 => LlamaModelType::Model13b,
60 => LlamaModelType::Model30b,
80 => LlamaModelType::Model65b,
_ => LlamaModelType::Model7b, // anything < 32
};
// TODO: temporary fix for 70B models
if let Some(n_gqa) = params.n_gqa {
if hyperparameters.n_layer >= 80 {
assert_eq!(
hyperparameters.n_head % n_gqa,
0,
"assuming 70B Llama2 model based on GQA == 8"
);
hyperparameters.n_head_kv = hyperparameters.n_head / n_gqa;
version = LlamaModelType::Model70b;
}
}

Ok(Self {
hyperparameters,
params,
_version: version,
tokenizer,
wte,
norm,
Expand Down Expand Up @@ -133,10 +156,12 @@ impl KnownModel for Llama {
n_embd,
n_mult: _,
n_head,
n_head_kv,
n_layer,
n_rot,
file_type: _,
} = self.hyperparameters;
let n_embd_gqa = n_embd / (n_head / n_head_kv);

let outputs = session.compute(self.context.clone(), input_tokens, |builder| {
let mut ctx0 = builder.ctx0.borrow_mut();
Expand Down Expand Up @@ -182,7 +207,7 @@ impl KnownModel for Llama {
&ctx0.op_reshape_3d(
&ctx0.op_mul_mat(&self.layers[il].wk, &current),
n_embd / n_head,
n_head,
n_head_kv,
input_len,
),
session_len,
Expand All @@ -196,21 +221,21 @@ impl KnownModel for Llama {
// compute the transposed [N, n_embd] V matrix
let v_current = ctx0.op_transpose(&ctx0.op_reshape_2d(
&ctx0.op_mul_mat(&self.layers[il].wv, &current),
n_embd,
n_embd_gqa,
input_len,
));

let k = ctx0.op_view_1d(
builder.memory_k,
input_len * n_embd,
(builder.memory_k.element_size() * n_embd) * (il * ctx_size + session_len),
input_len * n_embd_gqa,
(builder.memory_k.element_size() * n_embd_gqa) * (il * ctx_size + session_len),
);

let v = ctx0.op_view_2d(
builder.memory_v,
(input_len, n_embd),
(input_len, n_embd_gqa),
ctx_size * builder.memory_v.element_size(),
(il * ctx_size) * builder.memory_v.element_size() * n_embd
(il * ctx_size) * builder.memory_v.element_size() * n_embd_gqa
+ session_len * builder.memory_v.element_size(),
);

Expand All @@ -225,11 +250,11 @@ impl KnownModel for Llama {
&ctx0.op_reshape_3d(
&ctx0.op_view_1d(
builder.memory_k,
(session_len + input_len) * n_embd,
il * ctx_size * builder.memory_k.element_size() * n_embd,
(session_len + input_len) * n_embd_gqa,
il * ctx_size * builder.memory_k.element_size() * n_embd_gqa,
),
n_embd / n_head,
n_head,
n_head_kv,
session_len + input_len,
),
(0, 2, 1, 3),
Expand Down Expand Up @@ -259,12 +284,12 @@ impl KnownModel for Llama {
let v = ctx0
.op_view_3d(
builder.memory_v,
(session_len + input_len, n_embd / n_head, n_head),
(session_len + input_len, n_embd / n_head, n_head_kv),
(
ctx_size * builder.memory_v.element_size(),
ctx_size * builder.memory_v.element_size() * n_embd / n_head,
),
il * ctx_size * builder.memory_v.element_size() * n_embd,
il * ctx_size * builder.memory_v.element_size() * n_embd_gqa,
)
.set_name("V");

Expand Down Expand Up @@ -386,6 +411,8 @@ pub struct Hyperparameters {
pub n_mult: usize,
/// n_head
pub n_head: usize,
/// grouped-query attention
pub n_head_kv: usize,
/// Number of layers in the model
pub n_layer: usize,
/// n_rot
Expand All @@ -396,14 +423,26 @@ pub struct Hyperparameters {

impl llm_base::Hyperparameters for Hyperparameters {
fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result<Self, LoadError> {
let n_vocab = util::read_i32(reader)?.try_into()?;
let n_embd = util::read_i32(reader)?.try_into()?;
let n_mult = util::read_i32(reader)?.try_into()?;
let n_head = util::read_i32(reader)?.try_into()?;
let n_layer = util::read_i32(reader)?.try_into()?;
let n_rot = util::read_i32(reader)?.try_into()?;
let file_type = util::read_filetype(reader)?;

// Defaults to multi-head attention where n_head_kv == n_heads
let n_head_kv = n_head;

Ok(Hyperparameters {
n_vocab: util::read_i32(reader)?.try_into()?,
n_embd: util::read_i32(reader)?.try_into()?,
n_mult: util::read_i32(reader)?.try_into()?,
n_head: util::read_i32(reader)?.try_into()?,
n_layer: util::read_i32(reader)?.try_into()?,
n_rot: util::read_i32(reader)?.try_into()?,
file_type: util::read_filetype(reader)?,
n_head,
n_head_kv,
n_vocab,
n_embd,
n_mult,
n_layer,
n_rot,
file_type,
})
}

Expand Down Expand Up @@ -447,3 +486,13 @@ struct Layer {
w2: ggml::Tensor,
w3: ggml::Tensor,
}

/// Available Llama models
enum LlamaModelType {
Model3b,
Model7b,
Model13b,
Model30b,
Model65b,
Model70b,
}

0 comments on commit 2f6ffd4

Please sign in to comment.