Skip to content

Commit

Permalink
Merge pull request rustformers#123 from philpax/remove-bincode
Browse files Browse the repository at this point in the history
refactor(llama): remove bincode
  • Loading branch information
philpax authored Apr 12, 2023
2 parents 7787170 + 0b4ab40 commit 0e553a0
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 90 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions llama-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ llama-rs = { path = "../llama-rs", features = ["convert"] }

rand = { workspace = true }

bincode = "1.3.3"
clap = { version = "4.1.8", features = ["derive"] }
env_logger = "0.10.0"
log = "0.4"
Expand Down
51 changes: 5 additions & 46 deletions llama-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use std::{convert::Infallible, io::Write, path::Path};
use std::{convert::Infallible, io::Write};

use clap::Parser;
use cli_args::Args;
use llama_rs::{
convert::convert_pth_to_ggml, InferenceError, InferenceSession, InferenceSessionParameters,
Model,
};
use llama_rs::{convert::convert_pth_to_ggml, InferenceError};
use rustyline::error::ReadlineError;

mod cli_args;
Expand All @@ -31,7 +28,7 @@ fn infer(args: &cli_args::Infer) {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let inference_session_params = args.generate.inference_session_parameters();
let (model, vocabulary) = args.model_load.load();
let (mut session, session_loaded) = load_session_from_disk(
let (mut session, session_loaded) = snapshot::read_or_create_session(
&model,
args.persist_session.as_deref(),
args.generate.load_session.as_deref(),
Expand Down Expand Up @@ -70,18 +67,7 @@ fn infer(args: &cli_args::Infer) {

if let Some(session_path) = args.save_session.as_ref().or(args.persist_session.as_ref()) {
// Write the memory to the cache file
// SAFETY: no other model functions used inside the block
unsafe {
match snapshot::write_to_disk(&session.get_snapshot(), session_path) {
Ok(_) => {
log::info!("Successfully wrote session to {session_path:?}");
}
Err(err) => {
log::error!("Could not write session at {session_path:?}: {err}");
std::process::exit(1);
}
}
}
snapshot::write_session(session, session_path);
}
}

Expand Down Expand Up @@ -121,7 +107,7 @@ fn interactive(
let prompt_file = args.prompt_file.contents();
let inference_session_params = args.generate.inference_session_parameters();
let (model, vocabulary) = args.model_load.load();
let (mut session, session_loaded) = load_session_from_disk(
let (mut session, session_loaded) = snapshot::read_or_create_session(
&model,
None,
args.generate.load_session.as_deref(),
Expand Down Expand Up @@ -209,33 +195,6 @@ fn load_prompt_file_with_prompt(
}
}

pub fn load_session_from_disk(
model: &Model,
persist_session: Option<&Path>,
load_session: Option<&Path>,
inference_session_params: InferenceSessionParameters,
) -> (InferenceSession, bool) {
fn load_snapshot_from_disk(model: &Model, path: &Path) -> InferenceSession {
let snapshot = snapshot::load_from_disk(path);
match snapshot.and_then(|snapshot| model.session_from_snapshot(snapshot)) {
Ok(session) => {
log::info!("Loaded inference session from {path:?}");
session
}
Err(err) => {
eprintln!("Could not load inference session. Error: {err}");
std::process::exit(1);
}
}
}

match (persist_session, load_session) {
(Some(path), _) if path.exists() => (load_snapshot_from_disk(model, path), true),
(_, Some(path)) => (load_snapshot_from_disk(model, path), true),
_ => (model.start_session(inference_session_params), false),
}
}

fn process_prompt(raw_prompt: &str, prompt: &str) -> String {
raw_prompt.replace("{{PROMPT}}", prompt)
}
71 changes: 56 additions & 15 deletions llama-cli/src/snapshot.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,68 @@
use llama_rs::{InferenceSnapshot, InferenceSnapshotRef, SnapshotError};
use llama_rs::{InferenceSession, InferenceSessionParameters, Model};
use std::{
error::Error,
fs::File,
io::{BufReader, BufWriter},
path::Path,
};
use zstd::zstd_safe::CompressionLevel;
use zstd::{
stream::{read::Decoder, write::Encoder},
zstd_safe::CompressionLevel,
};

const SNAPSHOT_COMPRESSION_LEVEL: CompressionLevel = 1;

pub fn load_from_disk(path: impl AsRef<Path>) -> Result<InferenceSnapshot, SnapshotError> {
let mut reader = zstd::stream::read::Decoder::new(BufReader::new(File::open(path.as_ref())?))?;
InferenceSnapshot::read(&mut reader)
pub fn read_or_create_session(
model: &Model,
persist_session: Option<&Path>,
load_session: Option<&Path>,
inference_session_params: InferenceSessionParameters,
) -> (InferenceSession, bool) {
fn load(model: &Model, path: &Path) -> InferenceSession {
let file = unwrap_or_exit(File::open(path), || format!("Could not open file {path:?}"));
let decoder = unwrap_or_exit(Decoder::new(BufReader::new(file)), || {
format!("Could not create decoder for {path:?}")
});
let snapshot = unwrap_or_exit(bincode::deserialize_from(decoder), || {
format!("Could not deserialize inference session from {path:?}")
});
let session = unwrap_or_exit(model.session_from_snapshot(snapshot), || {
format!("Could not convert snapshot from {path:?} to session")
});
log::info!("Loaded inference session from {path:?}");
session
}

match (persist_session, load_session) {
(Some(path), _) if path.exists() => (load(model, path), true),
(_, Some(path)) => (load(model, path), true),
_ => (model.start_session(inference_session_params), false),
}
}

pub fn write_to_disk(
snap: &InferenceSnapshotRef<'_>,
path: impl AsRef<Path>,
) -> Result<(), SnapshotError> {
let mut writer = zstd::stream::write::Encoder::new(
BufWriter::new(File::create(path.as_ref())?),
SNAPSHOT_COMPRESSION_LEVEL,
)?
.auto_finish();
pub fn write_session(mut session: llama_rs::InferenceSession, path: &Path) {
// SAFETY: the session is consumed here, so nothing else can access it.
let snapshot = unsafe { session.get_snapshot() };
let file = unwrap_or_exit(File::create(path), || {
format!("Could not create file {path:?}")
});
let encoder = unwrap_or_exit(
Encoder::new(BufWriter::new(file), SNAPSHOT_COMPRESSION_LEVEL),
|| format!("Could not create encoder for {path:?}"),
);
unwrap_or_exit(
bincode::serialize_into(encoder.auto_finish(), &snapshot),
|| format!("Could not serialize inference session to {path:?}"),
);
log::info!("Successfully wrote session to {path:?}");
}

snap.write(&mut writer)
fn unwrap_or_exit<T, E: Error>(result: Result<T, E>, error_message: impl Fn() -> String) -> T {
match result {
Ok(t) => t,
Err(err) => {
log::error!("{}. Error: {err}", error_message());
std::process::exit(1);
}
}
}
8 changes: 4 additions & 4 deletions llama-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ rust-version = "1.65"
[dependencies]
ggml = { path = "../ggml" }

rand = { workspace = true }

bytemuck = "1.13.1"
partial_sort = "0.2.0"
thiserror = "1.0"
rand = { workspace = true }
serde = { version = "1.0.156", features = ["derive"] }
serde = { version = "1.0", features = ["derive"] }
serde_bytes = "0.11"
bincode = "1.3.3"

# Used for the `convert` feature
serde_json = { version = "1.0.94", optional = true }
serde_json = { version = "1.0", optional = true }
protobuf = { version = "= 2.14.0", optional = true }
rust_tokenizers = { version = "3.1.2", optional = true }

Expand Down
51 changes: 27 additions & 24 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,13 @@ impl Clone for InferenceSession {
}

#[derive(serde::Serialize, Clone, PartialEq)]
/// A serializable snapshot of the inference process. Can be saved to disk.
// Keep in sync with [InferenceSession] and [InferenceSnapshot]
/// A serializable snapshot of the inference process.
/// Can be created by calling [InferenceSession::get_snapshot].
///
/// If serializing, ensure that your serializer is binary-efficient.
/// This type contains a large array of bytes; traditional textual serializers
/// are likely to serialize this as an array of numbers at extreme cost.
// Keep in sync with [InferenceSession] and [InferenceSnapshot].
pub struct InferenceSnapshotRef<'a> {
/// How many tokens have been stored in the memory so far.
pub npast: usize,
Expand All @@ -151,11 +156,26 @@ pub struct InferenceSnapshotRef<'a> {
#[serde(with = "serde_bytes")]
pub memory_v: &'a [u8],
}
impl InferenceSnapshotRef<'_> {
/// Creates an owned [InferenceSnapshot] from this [InferenceSnapshotRef].
///
/// The [ToOwned] trait is not used due to its blanket implementation for all [Clone] types.
pub fn to_owned(&self) -> InferenceSnapshot {
InferenceSnapshot {
npast: self.npast,
session_params: self.session_params,
tokens: self.tokens.clone(),
last_logits: self.logits.clone(),
memory_k: self.memory_k.to_vec(),
memory_v: self.memory_v.to_vec(),
}
}
}

/// A serializable snapshot of the inference process. Can be restored by calling
/// `Model::restore_from_snapshot`.
/// [Model::session_from_snapshot].
#[derive(serde::Deserialize, Clone, PartialEq)]
// Keep in sync with [InferenceSession] and [InferenceSnapshotRef]
// Keep in sync with [InferenceSession] and [InferenceSnapshotRef].
pub struct InferenceSnapshot {
/// How many tokens have been stored in the memory so far.
pub npast: usize,
Expand Down Expand Up @@ -515,9 +535,6 @@ pub enum SnapshotError {
/// Arbitrary I/O error.
#[error("I/O error while reading or writing snapshot")]
IO(#[from] std::io::Error),
/// Error during the serialization process.
#[error("error during snapshot serialization")]
Serialization(#[from] bincode::Error),
/// Mismatch between the snapshotted memory and the in-memory memory.
#[error("could not read snapshot due to size mismatch (self={self_size}, input={input_size})")]
MemorySizeMismatch {
Expand Down Expand Up @@ -551,10 +568,10 @@ pub enum InferenceError {
#[derive(Default, Debug, Clone)]
pub struct EvaluateOutputRequest {
/// Returns all the logits for the provided batch of tokens.
/// Output shape is n_batch * n_vocab
/// Output shape is `n_batch * n_vocab`.
pub all_logits: Option<Vec<f32>>,
/// Returns the embeddings for the provided batch of tokens
/// Output shape is n_batch * n_embd
/// Output shape is `n_batch * n_embd`.
pub embeddings: Option<Vec<f32>>,
}

Expand Down Expand Up @@ -1387,7 +1404,7 @@ impl Model {
session.n_past += input_tokens.len();
}

/// Hydrates a previously obtained InferenceSnapshot for this model
/// Hydrates a previously obtained InferenceSnapshot for this model.
pub fn session_from_snapshot(
&self,
snapshot: InferenceSnapshot,
Expand Down Expand Up @@ -1665,20 +1682,6 @@ impl InferenceSession {
}
}

impl<'a> InferenceSnapshotRef<'a> {
/// Write this snapshot to the given writer.
pub fn write(&self, writer: &mut impl std::io::Write) -> Result<(), SnapshotError> {
Ok(bincode::serialize_into(writer, &self)?)
}
}

impl InferenceSnapshot {
/// Read a snapshot from the given reader.
pub fn read(reader: &mut impl std::io::Read) -> Result<Self, SnapshotError> {
Ok(bincode::deserialize_from(reader)?)
}
}

impl Vocabulary {
// SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece
/// Tokenize a `text` with this vocabulary.
Expand Down

0 comments on commit 0e553a0

Please sign in to comment.