Skip to content

Commit

Permalink
metal: move scratch buffers to EvaluationContext
Browse files Browse the repository at this point in the history
  • Loading branch information
pixelspark committed Jun 18, 2023
1 parent 828d9b9 commit 8eb9db7
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 41 deletions.
25 changes: 1 addition & 24 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,6 @@ use crate::{
TokenizationError,
};

// The size of a scratch buffer used for inference. This is used for temporary
// storage of intermediate results during inference.
//
// The specific value was copied from `llama.cpp`.
const SCRATCH_SIZE: usize = 512 * 1024 * 1024;

/// An inference session represents the state of the text generation. This holds
/// the full context window, as well as several additional parameters used
/// during sampling.
Expand Down Expand Up @@ -58,14 +52,8 @@ pub struct InferenceSession {
/// The logits that were last predicted by the network. Zeroed out otherwise.
#[doc(hidden)]
pub last_logits: Vec<f32>,

/// Scratch buffers used during inference.
///
/// The number of scratch buffers was copied from `llama.cpp`.
/// There is no specific reason for this number, but one is insufficient.
#[doc(hidden)]
pub scratch: [ggml::Buffer; 2],
}

unsafe impl Send for InferenceSession {}
impl InferenceSession {
/// Feed a prompt to the model for this session.
Expand Down Expand Up @@ -424,8 +412,6 @@ impl InferenceSession {
memory_v.set_data(session_ctx.alloc_owned_aligned(memory_v.nbytes()).cast());
}

let scratch = scratch_buffers();

InferenceSession {
_session_ctx: session_ctx,
memory_size: ctx_size,
Expand All @@ -437,7 +423,6 @@ impl InferenceSession {
tokens: vec![],
decoded_tokens: vec![],
last_logits: vec![0.0; n_vocab],
scratch,
}
}
}
Expand All @@ -464,7 +449,6 @@ impl Clone for InferenceSession {
tokens: self.tokens.clone(),
decoded_tokens: self.decoded_tokens.clone(),
last_logits: self.last_logits.clone(),
scratch: scratch_buffers(),
}
}
}
Expand Down Expand Up @@ -704,10 +688,3 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>(
None => Ok(InferenceFeedback::Continue),
}
}

fn scratch_buffers() -> [ggml::Buffer; 2] {
[
ggml::Buffer::new(SCRATCH_SIZE),
ggml::Buffer::new(SCRATCH_SIZE),
]
}
26 changes: 25 additions & 1 deletion crates/llm-base/src/model/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ use ggml::{metal::MetalContext, ComputationGraph, Context, Tensor};

use crate::{InferenceSession, OutputRequest, TokenId};

// The size of a scratch buffer used for inference. This is used for temporary
// storage of intermediate results during inference.
//
// The specific value was copied from `llama.cpp`.
const SCRATCH_SIZE: usize = 512 * 1024 * 1024;

/// Holds context and tensors used during a single evaluation
pub struct EvaluationContext {
/// The context that holds data
Expand All @@ -15,6 +21,13 @@ pub struct EvaluationContext {
/// When Metal is available: None if Metal is disabled, Some(MetalContext) when Metal acceleration is enabled
#[cfg(feature = "metal")]
pub metal_context: Option<MetalContext>,

/// Scratch buffers used during inference.
///
/// The number of scratch buffers was copied from `llama.cpp`.
/// There is no specific reason for this number, but one is insufficient.
#[doc(hidden)]
pub scratch: [ggml::Buffer; 2],
}

impl EvaluationContext {
Expand All @@ -34,13 +47,23 @@ impl EvaluationContext {
}
}

fn scratch_buffers() -> [ggml::Buffer; 2] {
[
ggml::Buffer::new(SCRATCH_SIZE),
ggml::Buffer::new(SCRATCH_SIZE),
]
}

/// Common code to prepare a model to evaluate input
pub fn prepare_for_evaluate_v2(
n_layer: usize,
session: &mut InferenceSession,
input_tokens: &[TokenId],
) -> EvaluationContext {
let (ctx0, embd) = prepare_for_evaluate(n_layer, session, input_tokens);

let mut scratch = scratch_buffers();

#[cfg(feature = "metal")]
{
// FIXME can only process one token at a time currently
Expand All @@ -51,7 +74,7 @@ pub fn prepare_for_evaluate_v2(
session._session_ctx.clone(),
&mut session.memory_k,
&mut session.memory_v,
&mut session.scratch,
&mut scratch,
);
metal_context.initialize_eval_buffers(ctx0.clone());
Some(metal_context)
Expand All @@ -62,6 +85,7 @@ pub fn prepare_for_evaluate_v2(
metal_context,
embd,
ctx0,
scratch,
}
}

Expand Down
18 changes: 10 additions & 8 deletions crates/models/gptneox/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ impl KnownModel for GptNeoX {
..
} = self.hyperparameters;

let (ctx0, embd) = common::prepare_for_evaluate(n_layer, session, input_tokens);
let mut evaluation_ctx = common::prepare_for_evaluate_v2(n_layer, session, input_tokens);
let ctx0 = &evaluation_ctx.ctx0;
let embd = &evaluation_ctx.embd;

let mut input_layer = ctx0.op_get_rows(&self.wte, &embd);
let mut input_layer = ctx0.op_get_rows(&self.wte, embd);

let memory_k = &session.memory_k;
let memory_k_size = memory_k.element_size();
Expand All @@ -162,7 +164,7 @@ impl KnownModel for GptNeoX {

for il in 0..n_layer {
// attention uses first scratch buffer
ctx0.use_scratch(Some(&mut session.scratch[0]));
ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[0]));

// self-attention
let mut current = ctx0.op_norm(&input_layer);
Expand Down Expand Up @@ -282,12 +284,12 @@ impl KnownModel for GptNeoX {
);

// use the second scratch for the feed forward
ctx0.use_scratch(Some(&mut session.scratch[1]));
ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[1]));

let feedforward_input: Tensor;
if !use_parallel_residual {
feedforward_input = ctx0.op_add(&current, &input_layer);
current = feed_forward_network(&ctx0, &self.layers[il], &feedforward_input);
current = feed_forward_network(ctx0, &self.layers[il], &feedforward_input);
// input for next layer
input_layer = ctx0.op_add(&current, &feedforward_input);
} else {
Expand All @@ -296,7 +298,7 @@ impl KnownModel for GptNeoX {

// this is independent of the self-attention result, so it could be done in parallel to the self-attention
// note here we pass inpL instead of cur
current = feed_forward_network(&ctx0, &self.layers[il], &input_layer);
current = feed_forward_network(ctx0, &self.layers[il], &input_layer);

// layer input + FF
current = ctx0.op_add(&current, &feedforward_input);
Expand All @@ -307,7 +309,7 @@ impl KnownModel for GptNeoX {
}

// use the first scratch for the norm
ctx0.use_scratch(Some(&mut session.scratch[1]));
ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[1]));

// normalize the output
input_layer = ctx0.op_norm(&input_layer);
Expand All @@ -333,7 +335,7 @@ impl KnownModel for GptNeoX {
common::read_last_token(session, &input_layer, n_vocab, n);
common::extract_logits(output_request, &input_layer, n_vocab, n);
common::extract_embeddings(output_request, &embeddings_tensor, n_embd, n);
common::update_session(session, &ctx0, input_tokens.len(), n);
common::update_session(session, ctx0, input_tokens.len(), n);
}

fn vocabulary(&self) -> &Vocabulary {
Expand Down
8 changes: 4 additions & 4 deletions crates/models/llama/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl KnownModel for Llama {
file_type: _,
} = self.hyperparameters;

let evaluation_ctx = common::prepare_for_evaluate_v2(n_layer, session, input_tokens);
let mut evaluation_ctx = common::prepare_for_evaluate_v2(n_layer, session, input_tokens);
let ctx0 = &evaluation_ctx.ctx0;
let embd = &evaluation_ctx.embd;

Expand All @@ -145,7 +145,7 @@ impl KnownModel for Llama {
let input_self_attention = input_layer.share();
let mut current: ggml::Tensor;

ctx0.use_scratch(Some(&mut session.scratch[0]));
ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[0]));

// norm
current = ctx0.op_rms_norm(&input_layer);
Expand Down Expand Up @@ -270,7 +270,7 @@ impl KnownModel for Llama {
// projection (no bias)
current = ctx0.op_mul_mat(&self.layers[il].wo, &current);

ctx0.use_scratch(Some(&mut session.scratch[1]));
ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[1]));

let input_feed_forward = ctx0.op_add(&current, &input_self_attention);

Expand Down Expand Up @@ -298,7 +298,7 @@ impl KnownModel for Llama {
input_layer = current;
}

ctx0.use_scratch(Some(&mut session.scratch[0]));
ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[0]));

// norm
input_layer = ctx0.op_rms_norm(&input_layer);
Expand Down
8 changes: 4 additions & 4 deletions crates/models/mpt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl KnownModel for Mpt {
..
} = self.hyperparameters;

let evaluation_ctx = common::prepare_for_evaluate_v2(n_layer, session, input_tokens);
let mut evaluation_ctx = common::prepare_for_evaluate_v2(n_layer, session, input_tokens);
let ctx0 = &evaluation_ctx.ctx0;
let embd = &evaluation_ctx.embd;

Expand All @@ -133,7 +133,7 @@ impl KnownModel for Mpt {
let mut gf = ggml::ComputationGraph::new(num_threads);
for il in 0..n_layer {
// attention uses first scratch buffer
ctx0.use_scratch(Some(&mut session.scratch[0]));
ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[0]));

let mut current = ctx0.op_norm(&input_layer);
current = ctx0.op_mul(
Expand Down Expand Up @@ -228,7 +228,7 @@ impl KnownModel for Mpt {
input_layer = ctx0.op_add(&input_layer, &current);

// feed forward uses second scratch buffer
ctx0.use_scratch(Some(&mut session.scratch[1]));
ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[1]));

current = ctx0.op_norm(&input_layer);
current = ctx0.op_mul(
Expand All @@ -247,7 +247,7 @@ impl KnownModel for Mpt {
}

//use scratch buffer 0 for the rest
ctx0.use_scratch(Some(&mut session.scratch[0]));
ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[0]));

// norm
input_layer = ctx0.op_norm(&input_layer);
Expand Down

0 comments on commit 8eb9db7

Please sign in to comment.