Skip to content

Commit

Permalink
Added untested support for Metal. The Metal bindings might need to be
Browse files Browse the repository at this point in the history
reworked. Use metal in features list to activate it.
  • Loading branch information
darxkies committed Jun 10, 2023
1 parent a0a4669 commit 022a075
Show file tree
Hide file tree
Showing 12 changed files with 165 additions and 16 deletions.
41 changes: 34 additions & 7 deletions binaries/generate-ggml-bindings/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ use std::io::Write;
use std::path::PathBuf;

fn main() {
let sys_path = PathBuf::from("crates").join("ggml").join("sys");
let ggml_path = sys_path.join("llama-cpp");
let include_path = ggml_path.to_str().unwrap().to_string();

let bindings = bindgen::Builder::default()
.header("crates/ggml/sys/llama-cpp/ggml.h")
// Suppress some warnings
Expand All @@ -19,25 +23,41 @@ fn main() {
.generate()
.expect("Unable to generate bindings");

let cuda_header = ggml_path.join("ggml-cuda.h").to_str().unwrap().to_string();
let cuda_bindings = bindgen::Builder::default()
.header("crates/ggml/sys/ggml/src/ggml-cuda.h")
.allowlist_file("crates/ggml/sys/ggml/src/ggml-cuda.h")
.header(&cuda_header)
.allowlist_file(&cuda_header)
.allowlist_recursively(false)
.clang_arg("-I")
.clang_arg("crates/ggml/sys/ggml/include/ggml")
.clang_arg(&include_path)
.generate()
.expect("Unable to generate cuda bindings");

let opencl_header = ggml_path
.join("ggml-opencl.h")
.to_str()
.unwrap()
.to_string();
let opencl_bindings = bindgen::Builder::default()
.header("crates/ggml/sys/ggml/src/ggml-opencl.h")
.allowlist_file("crates/ggml/sys/ggml/src/ggml-opencl.h")
.header(&opencl_header)
.allowlist_file(&opencl_header)
.allowlist_recursively(false)
.clang_arg("-I")
.clang_arg("crates/ggml/sys/ggml/include/ggml")
.clang_arg(&include_path)
.generate()
.expect("Unable to generate opencl bindings");

let out_dir = PathBuf::from("crates").join("ggml").join("sys").join("src");
let metal_header = ggml_path.join("ggml-metal.h").to_str().unwrap().to_string();
let metal_bindings = bindgen::Builder::default()
.header(&metal_header)
.allowlist_file(&metal_header)
.allowlist_recursively(false)
.clang_arg("-I")
.clang_arg(&include_path)
.generate()
.expect("Unable to generate metal bindings");

let out_dir = sys_path.join("src");

cuda_bindings
.write_to_file(out_dir.join("lib_cuda.rs"))
Expand All @@ -47,6 +67,10 @@ fn main() {
.write_to_file(out_dir.join("lib_opencl.rs"))
.expect("Couldn't write opencl bindings");

metal_bindings
.write_to_file(out_dir.join("lib_metal.rs"))
.expect("Couldn't write metal bindings");

bindings
.write_to_file(out_dir.join("lib.rs"))
.expect("Couldn't write bindings");
Expand Down Expand Up @@ -91,5 +115,8 @@ fn main() {
writeln!(file, "#[cfg(feature = \"clblast\")]").expect("Couldn't write to bindings file");
writeln!(file, "include!(\"lib_opencl.rs\");").expect("Couldn't write to bindings file");

writeln!(file, "#[cfg(feature = \"metal\")]").expect("Couldn't write to bindings file");
writeln!(file, "pub mod lib_metal;").expect("Couldn't write to bindings file");

println!("Successfully updated bindings");
}
1 change: 1 addition & 0 deletions binaries/llm-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ zstd = { version = "0.12", default-features = false }
[features]
cublas = ["llm/cublas"]
clblast = ["llm/clblast"]
metal = ["llm/metal"]
1 change: 1 addition & 0 deletions crates/ggml/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ anyhow = { workspace = true }
[features]
cublas = ["ggml-sys/cublas"]
clblast = ["ggml-sys/clblast"]
metal = ["ggml-sys/metal"]
1 change: 1 addition & 0 deletions crates/ggml/sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ cc = "^1.0"
[features]
cublas = []
clblast = []
metal = []
13 changes: 12 additions & 1 deletion crates/ggml/sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ fn main() {
enable_cublas(build);
} else if cfg!(feature = "clblast") {
enable_clblast(build);
} else if cfg!(feature = "metal") && cfg!(macos) {
enable_metal(build);
}

match target_arch.as_str() {
Expand Down Expand Up @@ -96,12 +98,21 @@ fn main() {
fn enable_clblast(build: &mut cc::Build) {
println!("cargo:rustc-link-lib=clblast");
println!("cargo:rustc-link-lib=OpenCL");
println!("cargo:rustc-link-lib=openblas");

build.file("llama-cpp/ggml-opencl.cpp");
build.flag("-DGGML_USE_CLBLAST");
}

fn enable_metal(build: &mut cc::Build) {
println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=framework=Metal");
println!("cargo:rustc-link-lib=framework=MetalKit");
println!("cargo:rustc-link-lib=framework=MetalPerformanceShaders");

build.file("llama-cpp/ggml-metal.m");
build.flag("-DGGML_USE_METAL");
}

fn enable_cublas(build: &mut cc::Build) {
let out_dir = env::var("OUT_DIR").unwrap();
let object_file = format!(r"{}\llama-cpp\ggml-cuda.o", &out_dir);
Expand Down
2 changes: 1 addition & 1 deletion crates/ggml/sys/llama-cpp
8 changes: 5 additions & 3 deletions crates/ggml/sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ fn bindgen_test_layout_ggml_init_params() {
pub const ggml_task_type_GGML_TASK_INIT: ggml_task_type = 0;
pub const ggml_task_type_GGML_TASK_COMPUTE: ggml_task_type = 1;
pub const ggml_task_type_GGML_TASK_FINALIZE: ggml_task_type = 2;
pub type ggml_task_type = ::std::os::raw::c_int;
pub type ggml_task_type = ::std::os::raw::c_uint;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct ggml_compute_params {
Expand Down Expand Up @@ -1449,12 +1449,12 @@ extern "C" {
}
pub const ggml_opt_type_GGML_OPT_ADAM: ggml_opt_type = 0;
pub const ggml_opt_type_GGML_OPT_LBFGS: ggml_opt_type = 1;
pub type ggml_opt_type = ::std::os::raw::c_int;
pub type ggml_opt_type = ::std::os::raw::c_uint;
pub const ggml_linesearch_GGML_LINESEARCH_DEFAULT: ggml_linesearch = 1;
pub const ggml_linesearch_GGML_LINESEARCH_BACKTRACKING_ARMIJO: ggml_linesearch = 0;
pub const ggml_linesearch_GGML_LINESEARCH_BACKTRACKING_WOLFE: ggml_linesearch = 1;
pub const ggml_linesearch_GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE: ggml_linesearch = 2;
pub type ggml_linesearch = ::std::os::raw::c_int;
pub type ggml_linesearch = ::std::os::raw::c_uint;
pub const ggml_opt_result_GGML_OPT_OK: ggml_opt_result = 0;
pub const ggml_opt_result_GGML_OPT_DID_NOT_CONVERGE: ggml_opt_result = 1;
pub const ggml_opt_result_GGML_OPT_NO_CONTEXT: ggml_opt_result = 2;
Expand Down Expand Up @@ -2022,3 +2022,5 @@ extern "C" {
include!("lib_cuda.rs");
#[cfg(feature = "clblast")]
include!("lib_opencl.rs");
#[cfg(feature = "metal")]
pub mod lib_metal;
56 changes: 53 additions & 3 deletions crates/ggml/sys/src/lib_cuda.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,43 @@
/* automatically generated by rust-bindgen 0.65.1 */

pub const GGML_CUDA_MAX_DEVICES: u32 = 16;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct ggml_tensor_extra_gpu {
pub data_device: [*mut ::std::os::raw::c_void; 16usize],
}
#[test]
fn bindgen_test_layout_ggml_tensor_extra_gpu() {
const UNINIT: ::std::mem::MaybeUninit<ggml_tensor_extra_gpu> =
::std::mem::MaybeUninit::uninit();
let ptr = UNINIT.as_ptr();
assert_eq!(
::std::mem::size_of::<ggml_tensor_extra_gpu>(),
128usize,
concat!("Size of: ", stringify!(ggml_tensor_extra_gpu))
);
assert_eq!(
::std::mem::align_of::<ggml_tensor_extra_gpu>(),
8usize,
concat!("Alignment of ", stringify!(ggml_tensor_extra_gpu))
);
assert_eq!(
unsafe { ::std::ptr::addr_of!((*ptr).data_device) as usize - ptr as usize },
0usize,
concat!(
"Offset of field: ",
stringify!(ggml_tensor_extra_gpu),
"::",
stringify!(data_device)
)
);
}
extern "C" {
pub fn ggml_init_cublas();
}
extern "C" {
pub fn ggml_cuda_set_tensor_split(tensor_split: *const f32);
}
extern "C" {
pub fn ggml_cuda_mul(src0: *const ggml_tensor, src1: *const ggml_tensor, dst: *mut ggml_tensor);
}
Expand Down Expand Up @@ -35,13 +70,28 @@ extern "C" {
extern "C" {
pub fn ggml_cuda_host_free(ptr: *mut ::std::os::raw::c_void);
}
extern "C" {
pub fn ggml_cuda_transform_tensor(tensor: *mut ggml_tensor);
}
extern "C" {
pub fn ggml_cuda_load_data(
fname: *const ::std::os::raw::c_char,
tensors: *mut ggml_tensor,
offset: usize,
);
}
extern "C" {
pub fn ggml_cuda_free_data(tensor: *mut ggml_tensor);
}
extern "C" {
pub fn ggml_cuda_assign_buffers(tensor: *mut ggml_tensor);
}
extern "C" {
pub fn ggml_cuda_set_main_device(main_device: ::std::os::raw::c_int);
}
extern "C" {
pub fn ggml_cuda_set_scratch_size(scratch_size: usize);
}
extern "C" {
pub fn ggml_cuda_compute_forward(
params: *mut ggml_compute_params,
tensor: *mut ggml_tensor,
) -> bool;
}
41 changes: 41 additions & 0 deletions crates/ggml/sys/src/lib_metal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/* automatically generated by rust-bindgen 0.65.1 */

pub const GGML_METAL_MAX_BUFFERS: u32 = 16;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct ggml_tensor {
_unused: [u8; 0],
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct ggml_cgraph {
_unused: [u8; 0],
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct ggml_metal_context {
_unused: [u8; 0],
}
extern "C" {
pub fn ggml_metal_init() -> *mut ggml_metal_context;
}
extern "C" {
pub fn ggml_metal_free(ctx: *mut ggml_metal_context);
}
extern "C" {
pub fn ggml_metal_add_buffer(
ctx: *mut ggml_metal_context,
name: *const ::std::os::raw::c_char,
data: *mut ::std::os::raw::c_void,
size: usize,
) -> bool;
}
extern "C" {
pub fn ggml_metal_set_tensor(ctx: *mut ggml_metal_context, t: *mut ggml_tensor);
}
extern "C" {
pub fn ggml_metal_get_tensor(ctx: *mut ggml_metal_context, t: *mut ggml_tensor);
}
extern "C" {
pub fn ggml_metal_graph_compute(ctx: *mut ggml_metal_context, gf: *mut ggml_cgraph);
}
13 changes: 13 additions & 0 deletions crates/ggml/sys/src/lib_opencl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
extern "C" {
pub fn ggml_cl_init();
}
extern "C" {
pub fn ggml_cl_mul(src0: *const ggml_tensor, src1: *const ggml_tensor, dst: *mut ggml_tensor);
}
extern "C" {
pub fn ggml_cl_can_mul_mat(
src0: *const ggml_tensor,
Expand Down Expand Up @@ -32,6 +35,16 @@ extern "C" {
extern "C" {
pub fn ggml_cl_host_free(ptr: *mut ::std::os::raw::c_void);
}
extern "C" {
pub fn ggml_cl_free_data(tensor: *const ggml_tensor);
}
extern "C" {
pub fn ggml_cl_transform_tensor(tensor: *mut ggml_tensor);
}
extern "C" {
pub fn ggml_cl_load_data(
fname: *const ::std::os::raw::c_char,
tensor: *mut ggml_tensor,
offset: usize,
);
}
3 changes: 2 additions & 1 deletion crates/llm-base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ regex = "1.8"

[features]
cublas = ["ggml/cublas"]
clblast = ["ggml/clblast"]
clblast = ["ggml/clblast"]
metal = ["ggml/metal"]
1 change: 1 addition & 0 deletions crates/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ gptneox = ["dep:llm-gptneox"]
mpt = ["dep:llm-mpt"]
cublas = ["llm-base/cublas"]
clblast = ["llm-base/clblast"]
metal = ["llm-base/metal"]

0 comments on commit 022a075

Please sign in to comment.