Skip to content

Commit

Permalink
metal: make Metal inference work
Browse files Browse the repository at this point in the history
  • Loading branch information
pixelspark committed Jun 18, 2023
1 parent cb507c4 commit eba54c3
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 112 deletions.
13 changes: 0 additions & 13 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ use std::{

use crate::{sys, usize_to_i32, usize_to_i64, Buffer, ComputationGraph, Tensor, Type};

const ALIGN_OWNED_TENSORS_TO_PAGE_SIZE: usize = 16384;

/// Acts as a RAII-guard over a `sys::ggml_context`, allocating via
/// `ggml_init` and dropping via `ggml_free`.
pub struct Context {
Expand Down Expand Up @@ -42,17 +40,6 @@ impl Context {
}
}

/// Allocates aligned memory associated with this context (meaning it will be deallocated when the context is dropped)
pub fn alloc_owned_aligned(&self, size: usize) -> *mut u8 {
let size_bytes =
(size & (!ALIGN_OWNED_TENSORS_TO_PAGE_SIZE)) + ALIGN_OWNED_TENSORS_TO_PAGE_SIZE;
let layout = Layout::from_size_align(size_bytes, ALIGN_OWNED_TENSORS_TO_PAGE_SIZE).unwrap();
let ptr = unsafe { std::alloc::alloc(layout).cast() };
let om = self.owned_memory.lock().unwrap();
om.borrow_mut().push((ptr, layout));
ptr
}

/// Wraps a raw tensor with a weak pointer to the context.
fn new_tensor_raw(&self, raw: *mut sys::ggml_tensor) -> Tensor {
Tensor {
Expand Down
73 changes: 20 additions & 53 deletions crates/ggml/src/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,26 @@ impl MetalContext {
}
}

/// Register a buffer mapping
pub fn add_scratch_buffer(&mut self, buf: &Buffer) {
unsafe {
let raw_metal_context = self.ptr.as_ptr();

//Last we need to add the scratch buffers to the buffers
assert!(
metal::ggml_metal_add_buffer(
raw_metal_context,
"scratch\0".as_ptr().cast(), // FIXME: allocate string and insert number in name
buf.data.as_ptr() as *mut core::ffi::c_void,
buf.data.len(),
buf.data.len()
),
"{}",
format!("Could not add scratch buffer to metal context")
);
}
}

/// Add a context's memory as buffer to this Metal context
pub fn add_context(&mut self, from_context: Arc<Context>) {
self.ref_context(from_context.clone());
Expand Down Expand Up @@ -57,59 +77,6 @@ impl MetalContext {
self.contexts.push(context);
}

/// Initializes the buffers needed for a metal forward pass.
pub fn initialize_buffers(
&mut self,
context: Arc<Context>,
memory_k: &mut Tensor,
memory_v: &mut Tensor,
scratch: &mut [Buffer],
) {
unsafe {
let raw_metal_context = self.ptr.as_ptr();

//This is the `kv` section from the original code, we dont have a joined kv buffer, so we need to add them seperately
assert!(
metal::ggml_metal_add_buffer(
raw_metal_context,
"k\0".as_ptr().cast(),
memory_k.data(),
memory_k.element_size(),
0
),
"Could not add k buffer to metal context"
);

assert!(
metal::ggml_metal_add_buffer(
raw_metal_context,
"v\0".as_ptr().cast(),
memory_v.data(),
memory_v.element_size(),
0
),
"Could not add v buffer to metal context"
);

//Last we need to add the scratch buffers to the buffers
for (i, buf) in scratch.iter().enumerate() {
assert!(
metal::ggml_metal_add_buffer(
raw_metal_context,
"scrN\0".as_ptr().cast(), // FIXME: allocate string and insert number in name
buf.data.as_ptr() as *mut core::ffi::c_void,
buf.data.len(),
buf.data.len()
),
"{}",
format!("Could not add scratch buffer {} to metal context", i)
);
}
}

self.ref_context(context);
}

/// Computes the specified graph using Metal.
pub fn graph_compute(&self, graph: &mut ComputationGraph) {
unsafe {
Expand Down
27 changes: 9 additions & 18 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
/// to use it from multiple threads.
pub struct InferenceSession {
// Must be kept alive for the model
pub(crate) _session_ctx: Arc<ggml::Context>,
pub(crate) session_ctx: Arc<ggml::Context>,

// Original size of the memory used to create this context.
pub(crate) memory_size: usize,
Expand Down Expand Up @@ -405,16 +405,13 @@ impl InferenceSession {
// Initialize key + value memory tensors
let n_mem = n_layer * n_ctx;
let n_elements = n_embd * n_mem;
let mut memory_k = session_ctx.new_tensor_1d(config.memory_k_type.into(), n_elements);
let mut memory_v = session_ctx.new_tensor_1d(config.memory_v_type.into(), n_elements);

unsafe {
memory_k.set_data(session_ctx.alloc_owned_aligned(memory_k.nbytes()).cast());
memory_v.set_data(session_ctx.alloc_owned_aligned(memory_v.nbytes()).cast());
}
let memory_k = session_ctx.new_tensor_1d(config.memory_k_type.into(), n_elements);
let memory_v = session_ctx.new_tensor_1d(config.memory_v_type.into(), n_elements);
ggml::set_name(&memory_k, "memory_k");
ggml::set_name(&memory_v, "memory_v");

InferenceSession {
_session_ctx: session_ctx,
session_ctx,
memory_size: ctx_size,
config,
memory_k,
Expand All @@ -430,17 +427,11 @@ impl InferenceSession {
impl Clone for InferenceSession {
fn clone(&self) -> Self {
let context = Arc::new(ggml::Context::init(self.memory_size, false));
let mut memory_k =
context.new_tensor_1d(self.memory_k.get_type(), self.memory_k.nelements());
let mut memory_v =
context.new_tensor_1d(self.memory_v.get_type(), self.memory_v.nelements());
unsafe {
memory_k.set_data(context.alloc_owned_aligned(memory_k.nbytes()).cast());
memory_v.set_data(context.alloc_owned_aligned(memory_v.nbytes()).cast());
}
let memory_k = context.new_tensor_1d(self.memory_k.get_type(), self.memory_k.nelements());
let memory_v = context.new_tensor_1d(self.memory_v.get_type(), self.memory_v.nelements());

Self {
_session_ctx: context,
session_ctx: context,
memory_size: self.memory_size,
config: self.config,
memory_k,
Expand Down
16 changes: 7 additions & 9 deletions crates/llm-base/src/model/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,21 @@ pub fn prepare_for_evaluate_v2(
) -> EvaluationContext {
let (ctx0, embd) = prepare_for_evaluate(n_layer, session, input_tokens);

let mut scratch = scratch_buffers();
let scratch = scratch_buffers();

#[cfg(feature = "metal")]
{
// FIXME can only process one token at a time currently
// See https://github.com/ggerganov/llama.cpp/blob/e1886cf4fe0d0f31661dda52a4a9f34bd9b9009a/llama.cpp#L1692
let metal_context = if session.config.use_gpu && input_tokens.len() == 1 {
let mut metal_context = MetalContext::new();
metal_context.initialize_buffers(
session._session_ctx.clone(),
&mut session.memory_k,
&mut session.memory_v,
&mut scratch,
);

for buf in scratch.iter() {
metal_context.add_scratch_buffer(buf);
}

metal_context.add_context(ctx0.clone());
metal_context.add_context(session._session_ctx.clone());
metal_context.add_context(session.session_ctx.clone());
metal_context.add_context(model_context);

Some(metal_context)
Expand All @@ -121,7 +119,7 @@ pub fn prepare_for_evaluate_v2(
}

/// Common code to prepare a model to evaluate input
pub fn prepare_for_evaluate(
fn prepare_for_evaluate(
n_layer: usize,
session: &mut InferenceSession,
input_tokens: &[TokenId],
Expand Down
17 changes: 11 additions & 6 deletions crates/models/bloom/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
//! for the `llm` ecosystem.
#![deny(missing_docs)]

use std::sync::Arc;

use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
Expand Down Expand Up @@ -36,7 +38,7 @@ pub struct Bloom {
layers: Vec<Layer>,

// must be kept alive for the model
_context: ggml::Context,
context: Arc<ggml::Context>,
_mmap: Option<Mmap>,
}

Expand Down Expand Up @@ -88,7 +90,7 @@ impl KnownModel for Bloom {
layers.push(layer);
}

let (_context, _, _mmap) = tl.finish();
let (context, _, _mmap) = tl.finish();

let ModelParameters { context_size, .. } = params;

Expand All @@ -103,7 +105,7 @@ impl KnownModel for Bloom {
output_norm_bias,
output,
layers,
_context,
context: Arc::new(context),
_mmap,
})
}
Expand Down Expand Up @@ -139,9 +141,12 @@ impl KnownModel for Bloom {
file_type: _,
} = self.hyperparameters;

let (ctx0, embd) = common::prepare_for_evaluate(n_layer, session, input_tokens);
let evaluation_ctx =
common::prepare_for_evaluate_v2(n_layer, session, self.context.clone(), input_tokens);
let ctx0 = &evaluation_ctx.ctx0;
let embd = &evaluation_ctx.embd;

let mut input_layer = ctx0.op_get_rows(&self.wte, &embd);
let mut input_layer = ctx0.op_get_rows(&self.wte, embd);

// normalize embeddings
input_layer = ctx0.op_norm(&input_layer);
Expand Down Expand Up @@ -357,7 +362,7 @@ impl KnownModel for Bloom {
common::read_last_token(session, &input_layer, n_vocab, input_len);
common::extract_logits(output_request, &input_layer, n_vocab, input_len);
common::extract_embeddings(output_request, &embeddings_tensor, n_embd, input_len);
common::update_session(session, &ctx0, input_tokens.len(), input_len);
common::update_session(session, ctx0, input_tokens.len(), input_len);
}

fn vocabulary(&self) -> &Vocabulary {
Expand Down
17 changes: 11 additions & 6 deletions crates/models/gpt2/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! An implementation of [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2) for the `llm` ecosystem.
#![deny(missing_docs)]

use std::sync::Arc;

use ggml::Tensor;
use llm_base::{
ggml,
Expand Down Expand Up @@ -35,7 +37,7 @@ pub struct Gpt2 {
layers: Vec<Layer>,

// must be kept alive for the model
_context: ggml::Context,
context: Arc<ggml::Context>,
_mmap: Option<Mmap>,
}

Expand Down Expand Up @@ -80,7 +82,7 @@ impl KnownModel for Gpt2 {
layers.push(layer);
}

let (_context, _, _mmap) = tl.finish();
let (context, _, _mmap) = tl.finish();

let ModelParameters { context_size, .. } = params;

Expand All @@ -94,7 +96,7 @@ impl KnownModel for Gpt2 {
wte,
wpe,
lm_head,
_context,
context: Arc::new(context),
_mmap,
})
}
Expand Down Expand Up @@ -129,15 +131,18 @@ impl KnownModel for Gpt2 {
..
} = self.hyperparameters;

let (ctx0, embd) = common::prepare_for_evaluate(n_layer, session, input_tokens);
let evaluation_ctx =
common::prepare_for_evaluate_v2(n_layer, session, self.context.clone(), input_tokens);
let ctx0 = &evaluation_ctx.ctx0;
let embd = &evaluation_ctx.embd;

let position_buf: Vec<usize> = (0..input_len).map(|i| session_len + i).collect();

let mut position = ctx0.new_tensor_1d(ggml::Type::I32, input_len);
unsafe { position.write_data(bytemuck::cast_slice(&position_buf)) };

let mut input_layer = ctx0.op_add(
&ctx0.op_get_rows(&self.wte, &embd),
&ctx0.op_get_rows(&self.wte, embd),
&ctx0.op_get_rows(&self.wpe, &position),
);

Expand Down Expand Up @@ -307,7 +312,7 @@ impl KnownModel for Gpt2 {
common::read_last_token(session, &input_layer, n_vocab, input_len);
common::extract_logits(output_request, &input_layer, n_vocab, input_len);
common::extract_embeddings(output_request, &embeddings_tensor, n_embd, input_len);
common::update_session(session, &ctx0, input_tokens.len(), input_len);
common::update_session(session, ctx0, input_tokens.len(), input_len);
}

fn vocabulary(&self) -> &Vocabulary {
Expand Down
17 changes: 10 additions & 7 deletions crates/models/gptj/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! An implementation of [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj) for the `llm` ecosystem.
#![deny(missing_docs)]

use std::error::Error;
use std::{error::Error, sync::Arc};

use ggml::Tensor;
use llm_base::{
Expand Down Expand Up @@ -36,7 +36,7 @@ pub struct GptJ {
layers: Vec<Layer>,

// must be kept alive for the model
_context: ggml::Context,
context: Arc<ggml::Context>,
_mmap: Option<Mmap>,
}

Expand Down Expand Up @@ -82,7 +82,7 @@ impl KnownModel for GptJ {
layers.push(layer);
}

let (_context, _, _mmap) = tl.finish();
let (context, _, _mmap) = tl.finish();

let ModelParameters { context_size, .. } = params;

Expand All @@ -97,7 +97,7 @@ impl KnownModel for GptJ {
lmh_b,
layers,
_mmap,
_context,
context: Arc::new(context),
})
}

Expand Down Expand Up @@ -132,9 +132,12 @@ impl KnownModel for GptJ {
..
} = self.hyperparameters;

let (ctx0, embd) = common::prepare_for_evaluate(n_layer, session, input_tokens);
let evaluation_ctx =
common::prepare_for_evaluate_v2(n_layer, session, self.context.clone(), input_tokens);
let ctx0 = &evaluation_ctx.ctx0;
let embd = &evaluation_ctx.embd;

let mut input_layer = ctx0.op_get_rows(&self.wte, &embd);
let mut input_layer = ctx0.op_get_rows(&self.wte, embd);

let memory_k = &session.memory_k;
let memory_k_size = memory_k.element_size();
Expand Down Expand Up @@ -286,7 +289,7 @@ impl KnownModel for GptJ {
common::read_last_token(session, &input_layer, n_vocab, input_len);
common::extract_logits(output_request, &input_layer, n_vocab, input_len);
common::extract_embeddings(output_request, &embeddings_tensor, n_embd, input_len);
common::update_session(session, &ctx0, input_tokens.len(), input_len);
common::update_session(session, ctx0, input_tokens.len(), input_len);
}

fn vocabulary(&self) -> &Vocabulary {
Expand Down

0 comments on commit eba54c3

Please sign in to comment.