Skip to content

Commit

Permalink
chore: formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
pixelspark committed Jun 18, 2023
1 parent 4ace7bb commit bccb99e
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 18 deletions.
74 changes: 59 additions & 15 deletions crates/ggml/src/metal.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Metal support.
use std::{ptr::NonNull, sync::Arc, ffi::CString, borrow::BorrowMut};
use crate::{sys::metal, ComputationGraph, Tensor, Context, Buffer};
use crate::{sys::metal, Buffer, ComputationGraph, Context, Tensor};
use std::{ffi::CString, ptr::NonNull, sync::Arc};

/// Acts as a RAII-guard over a `sys::metal::ggml_metal_context`, allocating via
/// `ggml_metal_init` and dropping via `ggml_metal_free`.
Expand All @@ -16,44 +16,88 @@ impl Default for MetalContext {
ptr: Arc::new(NonNull::new(raw).expect("Should not be null")),
}
}
}


impl MetalContext {
/// Initializes the buffers needed for a metal forward pass.
pub fn initialize_buffers(&self,
pub fn initialize_buffers(
&self,
context: &Context,
eval_context: &Context,
memory_k:&mut Tensor,
memory_v:&mut Tensor,
scratch: &mut [Buffer]){
unsafe{
memory_k: &mut Tensor,
memory_v: &mut Tensor,
scratch: &mut [Buffer],
) {
unsafe {
let raw_metal_context = self.ptr.as_ptr();

//TODO check if this works with mmap
let raw_context = context.ptr.as_ptr();
let data_ptr = ggml_sys::ggml_get_mem_buffer(raw_context);
let data_size = ggml_sys::ggml_get_mem_size(raw_context);
let data_name = CString::new("data").unwrap();
assert!(metal::ggml_metal_add_buffer(raw_metal_context, data_name.as_ptr() ,data_ptr, data_size),"Could not add data buffer to metal context");
assert!(
metal::ggml_metal_add_buffer(
raw_metal_context,
data_name.as_ptr(),
data_ptr,
data_size
),
"Could not add data buffer to metal context"
);

// in our implementation this should be the `ctx0` buffer
// Original code: ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size)
let raw_eval_context = eval_context.ptr.as_ptr();
let eval_ptr = ggml_sys::ggml_get_mem_buffer(raw_eval_context);
let eval_size = ggml_sys::ggml_get_mem_size(raw_eval_context);
let eval_name = CString::new("eval").unwrap();
assert!(metal::ggml_metal_add_buffer(raw_metal_context, eval_name.as_ptr() ,eval_ptr, eval_size),"Could not add eval buffer to metal context");
assert!(
metal::ggml_metal_add_buffer(
raw_metal_context,
eval_name.as_ptr(),
eval_ptr,
eval_size
),
"Could not add eval buffer to metal context"
);

//This is the `kv` section from the original code, we dont have a joined kv buffer, so we need to add them seperately
let k_name = CString::new("k").unwrap();
assert!(metal::ggml_metal_add_buffer(raw_metal_context, k_name.as_ptr() ,memory_k.data(), memory_k.element_size()), "Could not add k buffer to metal context");
assert!(
metal::ggml_metal_add_buffer(
raw_metal_context,
k_name.as_ptr(),
memory_k.data(),
memory_k.element_size()
),
"Could not add k buffer to metal context"
);

let v_name = CString::new("v").unwrap();
assert!(metal::ggml_metal_add_buffer(raw_metal_context, v_name.as_ptr() ,memory_v.data(), memory_v.element_size()), "Could not add v buffer to metal context");
assert!(
metal::ggml_metal_add_buffer(
raw_metal_context,
v_name.as_ptr(),
memory_v.data(),
memory_v.element_size()
),
"Could not add v buffer to metal context"
);

//Last we need to add the scratch buffers to the buffers
for (i,buf) in scratch.iter().enumerate(){
let name = CString::new(format!("scr{}",i)).unwrap();
assert!(metal::ggml_metal_add_buffer(raw_metal_context, name.as_ptr() ,buf.data.as_ptr() as *mut core::ffi::c_void, buf.data.len()), "{}", format!("Could not add scratch buffer {} to metal context",i));
for (i, buf) in scratch.iter().enumerate() {
let name = CString::new(format!("scr{}", i)).unwrap();
assert!(
metal::ggml_metal_add_buffer(
raw_metal_context,
name.as_ptr(),
buf.data.as_ptr() as *mut core::ffi::c_void,
buf.data.len()
),
"{}",
format!("Could not add scratch buffer {} to metal context", i)
);
}
}
}
Expand Down
10 changes: 7 additions & 3 deletions crates/models/mpt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,16 @@ impl KnownModel for Mpt {

let (ctx0, embd) = common::prepare_for_evaluate(n_layer, session, input_tokens);


//ctx0 is the context used for the computation graph => its memory footprint is the `eval` buffer
//We probably should move this to the session
let metal_context = self._context.metal_context.as_ref().unwrap();
metal_context.initialize_buffers( &self._context,&ctx0,&mut session.memory_k, &mut session.memory_v, &mut session.scratch);

metal_context.initialize_buffers(
&self._context,
&ctx0,
&mut session.memory_k,
&mut session.memory_v,
&mut session.scratch,
);

let mut input_layer = ctx0.op_get_rows(&self.wte, &embd);

Expand Down

0 comments on commit bccb99e

Please sign in to comment.