From 8eb9db7ca32fd89c348af2bfcc95a1c7eec86a22 Mon Sep 17 00:00:00 2001 From: Tommy van der Vorst Date: Sun, 18 Jun 2023 13:39:14 +0200 Subject: [PATCH] metal: move scratch buffers to EvaluationContext --- crates/llm-base/src/inference_session.rs | 25 +---------------------- crates/llm-base/src/model/common.rs | 26 +++++++++++++++++++++++- crates/models/gptneox/src/lib.rs | 18 ++++++++-------- crates/models/llama/src/lib.rs | 8 ++++---- crates/models/mpt/src/lib.rs | 8 ++++---- 5 files changed, 44 insertions(+), 41 deletions(-) diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index 5314a0b4..f81bc864 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -6,12 +6,6 @@ use crate::{ TokenizationError, }; -// The size of a scratch buffer used for inference. This is used for temporary -// storage of intermediate results during inference. -// -// The specific value was copied from `llama.cpp`. -const SCRATCH_SIZE: usize = 512 * 1024 * 1024; - /// An inference session represents the state of the text generation. This holds /// the full context window, as well as several additional parameters used /// during sampling. @@ -58,14 +52,8 @@ pub struct InferenceSession { /// The logits that were last predicted by the network. Zeroed out otherwise. #[doc(hidden)] pub last_logits: Vec, - - /// Scratch buffers used during inference. - /// - /// The number of scratch buffers was copied from `llama.cpp`. - /// There is no specific reason for this number, but one is insufficient. - #[doc(hidden)] - pub scratch: [ggml::Buffer; 2], } + unsafe impl Send for InferenceSession {} impl InferenceSession { /// Feed a prompt to the model for this session. @@ -424,8 +412,6 @@ impl InferenceSession { memory_v.set_data(session_ctx.alloc_owned_aligned(memory_v.nbytes()).cast()); } - let scratch = scratch_buffers(); - InferenceSession { _session_ctx: session_ctx, memory_size: ctx_size, @@ -437,7 +423,6 @@ impl InferenceSession { tokens: vec![], decoded_tokens: vec![], last_logits: vec![0.0; n_vocab], - scratch, } } } @@ -464,7 +449,6 @@ impl Clone for InferenceSession { tokens: self.tokens.clone(), decoded_tokens: self.decoded_tokens.clone(), last_logits: self.last_logits.clone(), - scratch: scratch_buffers(), } } } @@ -704,10 +688,3 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>( None => Ok(InferenceFeedback::Continue), } } - -fn scratch_buffers() -> [ggml::Buffer; 2] { - [ - ggml::Buffer::new(SCRATCH_SIZE), - ggml::Buffer::new(SCRATCH_SIZE), - ] -} diff --git a/crates/llm-base/src/model/common.rs b/crates/llm-base/src/model/common.rs index 5e32161c..2b85fd38 100644 --- a/crates/llm-base/src/model/common.rs +++ b/crates/llm-base/src/model/common.rs @@ -4,6 +4,12 @@ use ggml::{metal::MetalContext, ComputationGraph, Context, Tensor}; use crate::{InferenceSession, OutputRequest, TokenId}; +// The size of a scratch buffer used for inference. This is used for temporary +// storage of intermediate results during inference. +// +// The specific value was copied from `llama.cpp`. +const SCRATCH_SIZE: usize = 512 * 1024 * 1024; + /// Holds context and tensors used during a single evaluation pub struct EvaluationContext { /// The context that holds data @@ -15,6 +21,13 @@ pub struct EvaluationContext { /// When Metal is available: None if Metal is disabled, Some(MetalContext) when Metal acceleration is enabled #[cfg(feature = "metal")] pub metal_context: Option, + + /// Scratch buffers used during inference. + /// + /// The number of scratch buffers was copied from `llama.cpp`. + /// There is no specific reason for this number, but one is insufficient. + #[doc(hidden)] + pub scratch: [ggml::Buffer; 2], } impl EvaluationContext { @@ -34,6 +47,13 @@ impl EvaluationContext { } } +fn scratch_buffers() -> [ggml::Buffer; 2] { + [ + ggml::Buffer::new(SCRATCH_SIZE), + ggml::Buffer::new(SCRATCH_SIZE), + ] +} + /// Common code to prepare a model to evaluate input pub fn prepare_for_evaluate_v2( n_layer: usize, @@ -41,6 +61,9 @@ pub fn prepare_for_evaluate_v2( input_tokens: &[TokenId], ) -> EvaluationContext { let (ctx0, embd) = prepare_for_evaluate(n_layer, session, input_tokens); + + let mut scratch = scratch_buffers(); + #[cfg(feature = "metal")] { // FIXME can only process one token at a time currently @@ -51,7 +74,7 @@ pub fn prepare_for_evaluate_v2( session._session_ctx.clone(), &mut session.memory_k, &mut session.memory_v, - &mut session.scratch, + &mut scratch, ); metal_context.initialize_eval_buffers(ctx0.clone()); Some(metal_context) @@ -62,6 +85,7 @@ pub fn prepare_for_evaluate_v2( metal_context, embd, ctx0, + scratch, } } diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index 1aa66de7..99db83e0 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -148,9 +148,11 @@ impl KnownModel for GptNeoX { .. } = self.hyperparameters; - let (ctx0, embd) = common::prepare_for_evaluate(n_layer, session, input_tokens); + let mut evaluation_ctx = common::prepare_for_evaluate_v2(n_layer, session, input_tokens); + let ctx0 = &evaluation_ctx.ctx0; + let embd = &evaluation_ctx.embd; - let mut input_layer = ctx0.op_get_rows(&self.wte, &embd); + let mut input_layer = ctx0.op_get_rows(&self.wte, embd); let memory_k = &session.memory_k; let memory_k_size = memory_k.element_size(); @@ -162,7 +164,7 @@ impl KnownModel for GptNeoX { for il in 0..n_layer { // attention uses first scratch buffer - ctx0.use_scratch(Some(&mut session.scratch[0])); + ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[0])); // self-attention let mut current = ctx0.op_norm(&input_layer); @@ -282,12 +284,12 @@ impl KnownModel for GptNeoX { ); // use the second scratch for the feed forward - ctx0.use_scratch(Some(&mut session.scratch[1])); + ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[1])); let feedforward_input: Tensor; if !use_parallel_residual { feedforward_input = ctx0.op_add(¤t, &input_layer); - current = feed_forward_network(&ctx0, &self.layers[il], &feedforward_input); + current = feed_forward_network(ctx0, &self.layers[il], &feedforward_input); // input for next layer input_layer = ctx0.op_add(¤t, &feedforward_input); } else { @@ -296,7 +298,7 @@ impl KnownModel for GptNeoX { // this is independent of the self-attention result, so it could be done in parallel to the self-attention // note here we pass inpL instead of cur - current = feed_forward_network(&ctx0, &self.layers[il], &input_layer); + current = feed_forward_network(ctx0, &self.layers[il], &input_layer); // layer input + FF current = ctx0.op_add(¤t, &feedforward_input); @@ -307,7 +309,7 @@ impl KnownModel for GptNeoX { } // use the first scratch for the norm - ctx0.use_scratch(Some(&mut session.scratch[1])); + ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[1])); // normalize the output input_layer = ctx0.op_norm(&input_layer); @@ -333,7 +335,7 @@ impl KnownModel for GptNeoX { common::read_last_token(session, &input_layer, n_vocab, n); common::extract_logits(output_request, &input_layer, n_vocab, n); common::extract_embeddings(output_request, &embeddings_tensor, n_embd, n); - common::update_session(session, &ctx0, input_tokens.len(), n); + common::update_session(session, ctx0, input_tokens.len(), n); } fn vocabulary(&self) -> &Vocabulary { diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index 81126a07..1e903a09 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -123,7 +123,7 @@ impl KnownModel for Llama { file_type: _, } = self.hyperparameters; - let evaluation_ctx = common::prepare_for_evaluate_v2(n_layer, session, input_tokens); + let mut evaluation_ctx = common::prepare_for_evaluate_v2(n_layer, session, input_tokens); let ctx0 = &evaluation_ctx.ctx0; let embd = &evaluation_ctx.embd; @@ -145,7 +145,7 @@ impl KnownModel for Llama { let input_self_attention = input_layer.share(); let mut current: ggml::Tensor; - ctx0.use_scratch(Some(&mut session.scratch[0])); + ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[0])); // norm current = ctx0.op_rms_norm(&input_layer); @@ -270,7 +270,7 @@ impl KnownModel for Llama { // projection (no bias) current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); - ctx0.use_scratch(Some(&mut session.scratch[1])); + ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[1])); let input_feed_forward = ctx0.op_add(¤t, &input_self_attention); @@ -298,7 +298,7 @@ impl KnownModel for Llama { input_layer = current; } - ctx0.use_scratch(Some(&mut session.scratch[0])); + ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[0])); // norm input_layer = ctx0.op_rms_norm(&input_layer); diff --git a/crates/models/mpt/src/lib.rs b/crates/models/mpt/src/lib.rs index ef9376d4..aa8aa500 100644 --- a/crates/models/mpt/src/lib.rs +++ b/crates/models/mpt/src/lib.rs @@ -116,7 +116,7 @@ impl KnownModel for Mpt { .. } = self.hyperparameters; - let evaluation_ctx = common::prepare_for_evaluate_v2(n_layer, session, input_tokens); + let mut evaluation_ctx = common::prepare_for_evaluate_v2(n_layer, session, input_tokens); let ctx0 = &evaluation_ctx.ctx0; let embd = &evaluation_ctx.embd; @@ -133,7 +133,7 @@ impl KnownModel for Mpt { let mut gf = ggml::ComputationGraph::new(num_threads); for il in 0..n_layer { // attention uses first scratch buffer - ctx0.use_scratch(Some(&mut session.scratch[0])); + ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[0])); let mut current = ctx0.op_norm(&input_layer); current = ctx0.op_mul( @@ -228,7 +228,7 @@ impl KnownModel for Mpt { input_layer = ctx0.op_add(&input_layer, ¤t); // feed forward uses second scratch buffer - ctx0.use_scratch(Some(&mut session.scratch[1])); + ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[1])); current = ctx0.op_norm(&input_layer); current = ctx0.op_mul( @@ -247,7 +247,7 @@ impl KnownModel for Mpt { } //use scratch buffer 0 for the rest - ctx0.use_scratch(Some(&mut session.scratch[0])); + ctx0.use_scratch(Some(&mut evaluation_ctx.scratch[0])); // norm input_layer = ctx0.op_norm(&input_layer);