Skip to content

Commit

Permalink
metal: move evaluation context to InferenceSession and simplify model…
Browse files Browse the repository at this point in the history
… building interface
  • Loading branch information
pixelspark committed Jun 19, 2023
1 parent e049f9a commit cbbe41c
Show file tree
Hide file tree
Showing 13 changed files with 1,283 additions and 1,263 deletions.
47 changes: 26 additions & 21 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
use std::{
alloc::Layout,
cell::RefCell,
os::raw::{c_int, c_void},
ptr::NonNull,
sync::{Arc, Mutex},
};
use std::{os::raw::c_int, ptr::NonNull, sync::Arc};

use memmap2::Mmap;

Expand All @@ -19,30 +13,45 @@ pub struct Context {
/// with it if the underlying context has been deallocated.
pub ptr: Arc<NonNull<sys::ggml_context>>,

/// Memory allocated and owned by this context
pub owned_memory: Mutex<RefCell<Vec<(*mut u8, Layout)>>>,

/// Memory mapping information
pub mmap: Option<Mmap>,

/// Backing buffer (in case we own it)
pub buffer: Option<Buffer>,
}

impl Context {
/// Creates a new [Context] using the buffer provided as memory
pub fn init_buffer(buffer: Buffer) -> Self {
let raw = unsafe {
sys::ggml_init(sys::ggml_init_params {
mem_size: buffer.size(),
mem_buffer: buffer.data,
no_alloc: false,
})
};

Self {
ptr: Arc::new(NonNull::new(raw).expect("Should not be null")),
mmap: None,
buffer: Some(buffer),
}
}

/// Creates a new [Context] with the memory mapped file provided
pub fn init_mmap(mmap: Mmap) -> Self {
let raw = unsafe {
sys::ggml_init(sys::ggml_init_params {
mem_size: mmap.len(),
// Null here means we want ggml to own this memory. We don't
// support passing an owned buffer from the Rust side.
mem_buffer: std::ptr::null_mut(),
no_alloc: true,
no_alloc: true, // We are mmapping so ggml does not need to allocate any memory for us
})
};

Self {
ptr: Arc::new(NonNull::new(raw).expect("Should not be null")),
owned_memory: Mutex::new(RefCell::new(vec![])),
mmap: Some(mmap),
buffer: None,
}
}

Expand All @@ -51,17 +60,16 @@ impl Context {
let raw = unsafe {
sys::ggml_init(sys::ggml_init_params {
mem_size,
// Null here means we want ggml to own this memory. We don't
// support passing an owned buffer from the Rust side.
// Null here means we want ggml to own this memory.
mem_buffer: std::ptr::null_mut(),
no_alloc: !alloc,
})
};

Self {
ptr: Arc::new(NonNull::new(raw).expect("Should not be null")),
owned_memory: Mutex::new(RefCell::new(vec![])),
mmap: None,
buffer: None,
}
}

Expand Down Expand Up @@ -423,7 +431,7 @@ impl Context {
/// If `scratch_buffer` is `None`, the scratch buffer will be disabled.
pub fn use_scratch<'a>(&'a self, scratch_buffer: Option<&'a mut Buffer>) {
let (size, data) = if let Some(buffer) = scratch_buffer {
(buffer.data.len(), buffer.data.as_ptr() as *mut c_void)
(buffer.size(), buffer.data)
} else {
(0, std::ptr::null_mut())
};
Expand Down Expand Up @@ -467,9 +475,6 @@ impl Drop for Context {
// SAFETY: The only non-weak copy of ptr is no longer accessible after this drop call.
unsafe {
sys::ggml_free(self.ptr.as_ptr());
for (ptr, layout) in self.owned_memory.lock().unwrap().borrow_mut().drain(..) {
std::alloc::dealloc(ptr, layout);
}
}
}
}
27 changes: 17 additions & 10 deletions crates/ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
//! All [Tensor]s are nodes in this computational graph, and values cannot be retrieved until computation is completed.
#![deny(missing_docs)]

use std::os::raw::{c_int, c_void};
use std::{
alloc::Layout,
os::raw::{c_int, c_void},
};

mod context;
mod tensor;
Expand Down Expand Up @@ -221,24 +224,28 @@ impl Type {
///
/// See [Context::use_scratch].
pub struct Buffer {
data: Box<[u8]>,
data: *mut c_void,
layout: Layout,
}

const BUFFER_ALIGN: usize = 16384;

impl Buffer {
/// Creates a new buffer of the specified size.
pub fn new(size: usize) -> Self {
let mut data: Vec<u8> = Vec::with_capacity(size);
let layout = Layout::from_size_align(size, BUFFER_ALIGN).unwrap();

// SAFETY: The contents are intentionally uninitialized, as they will be passed to
// the ggml C API which will fill them with data.
#[allow(clippy::uninit_vec)]
unsafe {
data.set_len(size);
Buffer {
data: std::alloc::alloc(layout).cast(),
layout,
}
}
}

Buffer {
data: data.into_boxed_slice(),
}
/// Returns the size of the buffer in bytes
pub fn size(&self) -> usize {
self.layout.size()
}
}

Expand Down
66 changes: 36 additions & 30 deletions crates/ggml/src/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ impl MetalContext {
metal::ggml_metal_add_buffer(
raw_metal_context,
"scratch\0".as_ptr().cast(), // FIXME: allocate string and insert number in name
buf.data.as_ptr() as *mut core::ffi::c_void,
buf.data.len(),
buf.data.len()
buf.data,
buf.size(),
buf.size()
),
"{}",
format!("Could not add scratch buffer to metal context")
Expand All @@ -45,33 +45,33 @@ impl MetalContext {

/// Add a context's memory as buffer to this Metal context
pub fn add_context(&mut self, from_context: Arc<Context>) {
self.ref_context(from_context.clone());
if self.ref_context(from_context.clone()) {
unsafe {
let raw_context = from_context.ptr.as_ptr();

unsafe {
let raw_context = from_context.ptr.as_ptr();

let (data_ptr, data_size): (*mut c_void, usize) =
if let Some(ref mmap) = from_context.mmap {
// This is a bit naughty...
(mmap.as_ptr().cast_mut().cast(), mmap.len())
} else {
(
ggml_sys::ggml_get_mem_buffer(raw_context),
ggml_sys::ggml_get_mem_size(raw_context),
)
};
let (data_ptr, data_size): (*mut c_void, usize) =
if let Some(ref mmap) = from_context.mmap {
// This is a bit naughty...
(mmap.as_ptr().cast_mut().cast(), mmap.len())
} else {
(
ggml_sys::ggml_get_mem_buffer(raw_context),
ggml_sys::ggml_get_mem_size(raw_context),
)
};

let max_size = ggml_sys::ggml_get_max_tensor_size(raw_context);
assert!(
metal::ggml_metal_add_buffer(
self.ptr.as_ptr(),
"wt\0".as_ptr().cast(), // FIXME provide an actual name
data_ptr,
data_size,
max_size
),
"Could not add weight buffer to metal context"
);
let max_size = ggml_sys::ggml_get_max_tensor_size(raw_context);
assert!(
metal::ggml_metal_add_buffer(
self.ptr.as_ptr(),
"wt\0".as_ptr().cast(), // FIXME provide an actual name
data_ptr,
data_size,
max_size
),
"Could not add weight buffer to metal context"
);
}
}
}
}
Expand All @@ -83,8 +83,14 @@ impl Default for MetalContext {
}

impl MetalContext {
fn ref_context(&mut self, context: Arc<Context>) {
self.contexts.push(context);
/// Registers a context as a context that provides Metal buffers. Returns true if the context was not registered before.
fn ref_context(&mut self, context: Arc<Context>) -> bool {
if self.contexts.iter().any(|c| c.ptr == context.ptr) {
false
} else {
self.contexts.push(context);
true
}
}

/// Computes the specified graph using Metal.
Expand Down
2 changes: 2 additions & 0 deletions crates/ggml/sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ fn enable_metal(build: &mut cc::Build) {

build.file("llama-cpp/ggml-metal.m");
build.flag("-DGGML_USE_METAL");

#[cfg(debug_assertions)]
build.flag("-DGGML_METAL_NDEBUG");
}

Expand Down
Loading

0 comments on commit cbbe41c

Please sign in to comment.