Skip to content

Commit

Permalink
metal: graph building
Browse files Browse the repository at this point in the history
  • Loading branch information
pixelspark committed Jun 18, 2023
1 parent e7a5416 commit e979913
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
2 changes: 1 addition & 1 deletion crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub struct Context {

/// Metal context for optional acceleration through MPS.
#[cfg(feature = "metal")]
metal_context: Option<MetalContext>,
pub metal_context: Option<MetalContext>,
}

impl Context {
Expand Down
23 changes: 22 additions & 1 deletion crates/ggml/src/metal.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Metal support.
use std::{ptr::NonNull, sync::Arc};

use crate::sys::metal;
use crate::{sys::metal, ComputationGraph, Tensor};

/// Acts as a RAII-guard over a `sys::metal::ggml_metal_context`, allocating via
/// `ggml_metal_init` and dropping via `ggml_metal_free`.
Expand All @@ -18,10 +18,31 @@ impl MetalContext {
ptr: Arc::new(NonNull::new(raw).expect("Should not be null")),
}
}

/// Computes the specified graph using Metal.
pub fn graph_compute(&self, graph: &mut ComputationGraph) {
unsafe {
metal::ggml_metal_graph_compute(
self.ptr.as_ptr(),
&mut graph.inner as *mut ggml_sys::ggml_cgraph as *mut metal::ggml_cgraph,
);
}
}

/// Reads a tensor from Metal
pub fn get_tensor(&self, tensor: &Tensor) {
unsafe {
metal::ggml_metal_get_tensor(
self.ptr.as_ptr(),
tensor.ptr.as_ptr() as *mut metal::ggml_tensor,
)
}
}
}

impl Drop for MetalContext {
fn drop(&mut self) {
panic!();
// SAFETY: The only non-weak copy of ptr is no longer accessible after
// this drop call.
unsafe { metal::ggml_metal_free(self.ptr.as_ptr()) }
Expand Down
12 changes: 11 additions & 1 deletion crates/models/mpt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,17 @@ impl KnownModel for Mpt {

// run the computation
gf.build_forward_expand(&input_layer);
ctx0.graph_compute(&mut gf);

if cfg!(feature = "metal") {
if let Some(ref metal_context) = ctx0.metal_context {
metal_context.graph_compute(&mut gf);
metal_context.get_tensor(&input_layer);
} else {
ctx0.graph_compute(&mut gf);
}
} else {
ctx0.graph_compute(&mut gf);
}

// finish evaluation
common::read_last_token(session, &input_layer, n_vocab, input_len);
Expand Down

0 comments on commit e979913

Please sign in to comment.