forked from rustformers/llm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request rustformers#123 from philpax/remove-bincode
refactor(llama): remove bincode
- Loading branch information
Showing
6 changed files
with
94 additions
and
90 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,68 @@ | ||
use llama_rs::{InferenceSnapshot, InferenceSnapshotRef, SnapshotError}; | ||
use llama_rs::{InferenceSession, InferenceSessionParameters, Model}; | ||
use std::{ | ||
error::Error, | ||
fs::File, | ||
io::{BufReader, BufWriter}, | ||
path::Path, | ||
}; | ||
use zstd::zstd_safe::CompressionLevel; | ||
use zstd::{ | ||
stream::{read::Decoder, write::Encoder}, | ||
zstd_safe::CompressionLevel, | ||
}; | ||
|
||
const SNAPSHOT_COMPRESSION_LEVEL: CompressionLevel = 1; | ||
|
||
pub fn load_from_disk(path: impl AsRef<Path>) -> Result<InferenceSnapshot, SnapshotError> { | ||
let mut reader = zstd::stream::read::Decoder::new(BufReader::new(File::open(path.as_ref())?))?; | ||
InferenceSnapshot::read(&mut reader) | ||
pub fn read_or_create_session( | ||
model: &Model, | ||
persist_session: Option<&Path>, | ||
load_session: Option<&Path>, | ||
inference_session_params: InferenceSessionParameters, | ||
) -> (InferenceSession, bool) { | ||
fn load(model: &Model, path: &Path) -> InferenceSession { | ||
let file = unwrap_or_exit(File::open(path), || format!("Could not open file {path:?}")); | ||
let decoder = unwrap_or_exit(Decoder::new(BufReader::new(file)), || { | ||
format!("Could not create decoder for {path:?}") | ||
}); | ||
let snapshot = unwrap_or_exit(bincode::deserialize_from(decoder), || { | ||
format!("Could not deserialize inference session from {path:?}") | ||
}); | ||
let session = unwrap_or_exit(model.session_from_snapshot(snapshot), || { | ||
format!("Could not convert snapshot from {path:?} to session") | ||
}); | ||
log::info!("Loaded inference session from {path:?}"); | ||
session | ||
} | ||
|
||
match (persist_session, load_session) { | ||
(Some(path), _) if path.exists() => (load(model, path), true), | ||
(_, Some(path)) => (load(model, path), true), | ||
_ => (model.start_session(inference_session_params), false), | ||
} | ||
} | ||
|
||
pub fn write_to_disk( | ||
snap: &InferenceSnapshotRef<'_>, | ||
path: impl AsRef<Path>, | ||
) -> Result<(), SnapshotError> { | ||
let mut writer = zstd::stream::write::Encoder::new( | ||
BufWriter::new(File::create(path.as_ref())?), | ||
SNAPSHOT_COMPRESSION_LEVEL, | ||
)? | ||
.auto_finish(); | ||
pub fn write_session(mut session: llama_rs::InferenceSession, path: &Path) { | ||
// SAFETY: the session is consumed here, so nothing else can access it. | ||
let snapshot = unsafe { session.get_snapshot() }; | ||
let file = unwrap_or_exit(File::create(path), || { | ||
format!("Could not create file {path:?}") | ||
}); | ||
let encoder = unwrap_or_exit( | ||
Encoder::new(BufWriter::new(file), SNAPSHOT_COMPRESSION_LEVEL), | ||
|| format!("Could not create encoder for {path:?}"), | ||
); | ||
unwrap_or_exit( | ||
bincode::serialize_into(encoder.auto_finish(), &snapshot), | ||
|| format!("Could not serialize inference session to {path:?}"), | ||
); | ||
log::info!("Successfully wrote session to {path:?}"); | ||
} | ||
|
||
snap.write(&mut writer) | ||
fn unwrap_or_exit<T, E: Error>(result: Result<T, E>, error_message: impl Fn() -> String) -> T { | ||
match result { | ||
Ok(t) => t, | ||
Err(err) => { | ||
log::error!("{}. Error: {err}", error_message()); | ||
std::process::exit(1); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters