Skip to content

Commit

Permalink
metal: make mmap work
Browse files Browse the repository at this point in the history
  • Loading branch information
pixelspark committed Jun 18, 2023
1 parent eba54c3 commit 55a8e55
Show file tree
Hide file tree
Showing 13 changed files with 66 additions and 43 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1.0" }
spinoff = { version = "0.7.0", default-features = false, features = ["dots2"] }
clap = { version = "4.1.8", features = ["derive"] }
memmap2 = "0.5.10"

# Config for 'cargo dist'
[workspace.metadata.dist]
Expand Down
1 change: 1 addition & 0 deletions crates/ggml/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ license = "MIT"
[dependencies]
thiserror = { workspace = true }
ggml-sys = { path = "sys", version = "0.2.0-dev" }
memmap2 = { workspace = true }

[dev-dependencies]
rand = { workspace = true }
Expand Down
25 changes: 25 additions & 0 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use std::{
sync::{Arc, Mutex},
};

use memmap2::Mmap;

use crate::{sys, usize_to_i32, usize_to_i64, Buffer, ComputationGraph, Tensor, Type};

/// Acts as a RAII-guard over a `sys::ggml_context`, allocating via
Expand All @@ -19,9 +21,31 @@ pub struct Context {

/// Memory allocated and owned by this context
pub owned_memory: Mutex<RefCell<Vec<(*mut u8, Layout)>>>,

/// Memory mapping information
pub mmap: Option<Mmap>,
}

impl Context {
/// Creates a new [Context] with the memory mapped file provided
pub fn init_mmap(mmap: Mmap) -> Self {
let raw = unsafe {
sys::ggml_init(sys::ggml_init_params {
mem_size: mmap.len(),
// Null here means we want ggml to own this memory. We don't
// support passing an owned buffer from the Rust side.
mem_buffer: std::ptr::null_mut(),
no_alloc: true,
})
};

Self {
ptr: Arc::new(NonNull::new(raw).expect("Should not be null")),
owned_memory: Mutex::new(RefCell::new(vec![])),
mmap: Some(mmap),
}
}

/// Creates a new [Context] with the specified `mem_size` as a working area.
pub fn init(mem_size: usize, alloc: bool) -> Self {
let raw = unsafe {
Expand All @@ -37,6 +61,7 @@ impl Context {
Self {
ptr: Arc::new(NonNull::new(raw).expect("Should not be null")),
owned_memory: Mutex::new(RefCell::new(vec![])),
mmap: None,
}
}

Expand Down
16 changes: 13 additions & 3 deletions crates/ggml/src/metal.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Metal support.
use crate::{sys::metal, Buffer, ComputationGraph, Context, Tensor};
use std::{ptr::NonNull, sync::Arc};
use std::{ffi::c_void, ptr::NonNull, sync::Arc};

/// Acts as a RAII-guard over a `sys::metal::ggml_metal_context`, allocating via
/// `ggml_metal_init` and dropping via `ggml_metal_free`.
Expand Down Expand Up @@ -49,8 +49,18 @@ impl MetalContext {

unsafe {
let raw_context = from_context.ptr.as_ptr();
let data_ptr = ggml_sys::ggml_get_mem_buffer(raw_context);
let data_size = ggml_sys::ggml_get_mem_size(raw_context);

let (data_ptr, data_size): (*mut c_void, usize) =
if let Some(ref mmap) = from_context.mmap {
// This is a bit naughty...
(mmap.as_ptr().cast_mut().cast(), mmap.len())
} else {
(
ggml_sys::ggml_get_mem_buffer(raw_context),
ggml_sys::ggml_get_mem_size(raw_context),
)
};

let max_size = ggml_sys::ggml_get_max_tensor_size(raw_context);
assert!(
metal::ggml_metal_add_buffer(
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ thiserror = { workspace = true }

partial_sort = "0.2.0"
serde_bytes = "0.11"
memmap2 = "0.5.10"
memmap2 = { workspace = true }
half = "2.2.1"
tokenizers = "0.13.3"
regex = "1.8"
Expand Down
27 changes: 12 additions & 15 deletions crates/llm-base/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ pub trait TensorLoader<E: std::error::Error> {
/// Gets a tensor from the loader.
fn load(&mut self, name: &str) -> Result<ggml::Tensor, E>;
/// Finish loading the model, and extract all of the state from the loader.
fn finish(self) -> (Context, HashMap<String, ggml::Tensor>, Option<Mmap>);
fn finish(self) -> (Context, HashMap<String, ggml::Tensor>);
}

/// Load a GGML model from the `path` and configure it per the `params`. The status
Expand Down Expand Up @@ -467,16 +467,15 @@ pub fn load<M: KnownModel>(
}

(load_progress_callback)(LoadProgress::ContextSize { bytes: ctx_size });
let context = Context::init(ctx_size, !use_mmap);

let (mmap, file_size) = {
let (context, file_size) = if use_mmap {
let file = File::open(path)?;
let mmap = if use_mmap {
Some(unsafe { Mmap::map(&file)? })
} else {
None
};
(mmap, file.metadata()?.len())
unsafe {
let mmap = Mmap::map(&file)?;
let file_size = mmap.len() as u64;
(Context::init_mmap(mmap), file_size)
}
} else {
(Context::init(ctx_size, true), file.metadata()?.len())
};

let tensors_len = tensors.len();
Expand All @@ -485,7 +484,6 @@ pub fn load<M: KnownModel>(
file,
tensors,
context,
mmap,
lora_adapters,
load_progress_callback: &mut load_progress_callback,
loaded_tensors: Default::default(),
Expand Down Expand Up @@ -578,7 +576,6 @@ struct MmapCompatibleLoader<'a> {
file: File,
tensors: HashMap<String, TensorLoadInfo>,
context: Context,
mmap: Option<Mmap>,
lora_adapters: Option<Vec<LoraAdapter>>,
load_progress_callback: &'a mut dyn FnMut(LoadProgress),
loaded_tensors: HashMap<String, ggml::Tensor>,
Expand All @@ -594,7 +591,7 @@ impl TensorLoader<LoadError> for MmapCompatibleLoader<'_> {
&self.context,
&mut self.file,
&self.path,
self.mmap.as_ref(),
self.context.mmap.as_ref(),
);

let mut tensor = main_context.get_tensor(info)?;
Expand All @@ -618,8 +615,8 @@ impl TensorLoader<LoadError> for MmapCompatibleLoader<'_> {
Ok(tensor)
}

fn finish(self) -> (Context, HashMap<String, ggml::Tensor>, Option<Mmap>) {
(self.context, self.loaded_tensors, self.mmap)
fn finish(self) -> (Context, HashMap<String, ggml::Tensor>) {
(self.context, self.loaded_tensors)
}
}

Expand Down
6 changes: 2 additions & 4 deletions crates/models/bloom/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, InferenceParameters, InferenceSession, InferenceSessionConfig, KnownModel,
Mmap, ModelParameters, OutputRequest, Regex, TokenId, Vocabulary,
ModelParameters, OutputRequest, Regex, TokenId, Vocabulary,
};

/// The BLOOM model. Ref: [Introducing BLOOM](https://bigscience.huggingface.co/blog/bloom)
Expand Down Expand Up @@ -39,7 +39,6 @@ pub struct Bloom {

// must be kept alive for the model
context: Arc<ggml::Context>,
_mmap: Option<Mmap>,
}

unsafe impl Send for Bloom {}
Expand Down Expand Up @@ -90,7 +89,7 @@ impl KnownModel for Bloom {
layers.push(layer);
}

let (context, _, _mmap) = tl.finish();
let (context, _) = tl.finish();

let ModelParameters { context_size, .. } = params;

Expand All @@ -106,7 +105,6 @@ impl KnownModel for Bloom {
output,
layers,
context: Arc::new(context),
_mmap,
})
}

Expand Down
6 changes: 2 additions & 4 deletions crates/models/gpt2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, InferenceParameters, InferenceSession, InferenceSessionConfig, KnownModel,
LoadError, Mmap, ModelParameters, OutputRequest, Regex, TokenId, Vocabulary,
LoadError, ModelParameters, OutputRequest, Regex, TokenId, Vocabulary,
};

/// The GPT-2 model. Ref: [The Illustrated GPT-2](https://jalammar.github.io/illustrated-gpt2/)
Expand Down Expand Up @@ -38,7 +38,6 @@ pub struct Gpt2 {

// must be kept alive for the model
context: Arc<ggml::Context>,
_mmap: Option<Mmap>,
}

unsafe impl Send for Gpt2 {}
Expand Down Expand Up @@ -82,7 +81,7 @@ impl KnownModel for Gpt2 {
layers.push(layer);
}

let (context, _, _mmap) = tl.finish();
let (context, _) = tl.finish();

let ModelParameters { context_size, .. } = params;

Expand All @@ -97,7 +96,6 @@ impl KnownModel for Gpt2 {
wpe,
lm_head,
context: Arc::new(context),
_mmap,
})
}

Expand Down
6 changes: 2 additions & 4 deletions crates/models/gptj/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, InferenceParameters, InferenceSession, InferenceSessionConfig, KnownModel,
LoadError, Mmap, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Vocabulary,
LoadError, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Vocabulary,
};

/// The GPT-J model. Ref: [GitHub](https://github.com/kingoflolz/mesh-transformer-jax/#gpt-j-6b)
Expand Down Expand Up @@ -37,7 +37,6 @@ pub struct GptJ {

// must be kept alive for the model
context: Arc<ggml::Context>,
_mmap: Option<Mmap>,
}

unsafe impl Send for GptJ {}
Expand Down Expand Up @@ -82,7 +81,7 @@ impl KnownModel for GptJ {
layers.push(layer);
}

let (context, _, _mmap) = tl.finish();
let (context, _) = tl.finish();

let ModelParameters { context_size, .. } = params;

Expand All @@ -96,7 +95,6 @@ impl KnownModel for GptJ {
lmh_g,
lmh_b,
layers,
_mmap,
context: Arc::new(context),
})
}
Expand Down
6 changes: 2 additions & 4 deletions crates/models/gptneox/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, InferenceParameters, InferenceSession, InferenceSessionConfig, KnownModel,
LoadError, Mmap, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Vocabulary,
LoadError, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Vocabulary,
};

/// The GPT-NeoX model. Ref: [GitHub](https://github.com/EleutherAI/gpt-neox)
Expand Down Expand Up @@ -37,7 +37,6 @@ pub struct GptNeoX {

// must be kept alive for the model
context: Arc<ggml::Context>,
_mmap: Option<Mmap>,
}

unsafe impl Send for GptNeoX {}
Expand Down Expand Up @@ -96,7 +95,7 @@ impl KnownModel for GptNeoX {
layers.push(layer);
}

let (context, _, _mmap) = tl.finish();
let (context, _) = tl.finish();

let ModelParameters { context_size, .. } = params;

Expand All @@ -110,7 +109,6 @@ impl KnownModel for GptNeoX {
lmh_g,
layers,
context: Arc::new(context),
_mmap,
})
}

Expand Down
6 changes: 2 additions & 4 deletions crates/models/llama/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, InferenceParameters, InferenceSession, InferenceSessionConfig, KnownModel,
LoadError, Mmap, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Vocabulary,
LoadError, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Vocabulary,
};

/// The LLaMA model. Ref: [Introducing LLaMA](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/)
Expand All @@ -34,7 +34,6 @@ pub struct Llama {

// must be kept alive for the model
context: Arc<ggml::Context>,
_mmap: Option<Mmap>,
}

unsafe impl Send for Llama {}
Expand Down Expand Up @@ -73,7 +72,7 @@ impl KnownModel for Llama {
layers.push(layer);
}

let (context, _tensors, _mmap) = tl.finish();
let (context, _tensors) = tl.finish();

let ModelParameters { context_size, .. } = params;

Expand All @@ -86,7 +85,6 @@ impl KnownModel for Llama {
output,
layers,
context: Arc::new(context),
_mmap,
})
}

Expand Down
6 changes: 2 additions & 4 deletions crates/models/mpt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use llm_base::{
ggml::{self},
model::{common, HyperparametersWriteError},
util, FileType, InferenceParameters, InferenceSession, InferenceSessionConfig, KnownModel,
LoadError, Mmap, ModelParameters, OutputRequest, Regex, TokenId, Vocabulary,
LoadError, ModelParameters, OutputRequest, Regex, TokenId, Vocabulary,
};

/// The MosaicML Pretrained Transformer (MPT) model. Ref: [Mosaic ML](https://www.mosaicml.com/blog/mpt-7b)
Expand All @@ -33,7 +33,6 @@ pub struct Mpt {

// must be kept alive for the model
context: Arc<ggml::Context>,
_mmap: Option<Mmap>,
}

unsafe impl Send for Mpt {}
Expand Down Expand Up @@ -71,7 +70,7 @@ impl KnownModel for Mpt {
layers.push(layer);
}

let (context, _, _mmap) = tl.finish();
let (context, _) = tl.finish();

let ModelParameters { context_size, .. } = params;

Expand All @@ -83,7 +82,6 @@ impl KnownModel for Mpt {
norm,
layers,
context: Arc::new(context),
_mmap,
})
}

Expand Down

0 comments on commit 55a8e55

Please sign in to comment.