Skip to content

Commit

Permalink
Use HF tokenizer when vocab file is provided
Browse files Browse the repository at this point in the history
  • Loading branch information
RedBoxing committed May 23, 2023
1 parent 861d694 commit 5a423b1
Show file tree
Hide file tree
Showing 13 changed files with 1,614 additions and 88 deletions.
1,461 changes: 1,435 additions & 26 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ pub struct ModelLoad {
#[arg(long, short = 'm')]
pub model_path: PathBuf,

/// Where to save the model from
#[arg(long, short = 'v')]
pub vocab_path: Option<PathBuf>,

/// Sets the size of the context (in tokens). Allows feeding longer prompts.
/// Note that this affects memory.
///
Expand Down Expand Up @@ -376,6 +380,7 @@ impl ModelLoad {

let model = llm::load::<M>(
&self.model_path,
self.vocab_path.as_deref(),
params,
overrides,
|progress| match progress {
Expand Down
2 changes: 1 addition & 1 deletion binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ fn perplexity<M: llm::KnownModel + 'static>(
fn info<M: llm::KnownModel + 'static>(args: &cli_args::Info) -> Result<()> {
let file = File::open(&args.model_path)?;
let mut reader = BufReader::new(&file);
let mut loader: llm::Loader<M::Hyperparameters, _> = llm::Loader::new(|_| {
let mut loader: llm::Loader<M::Hyperparameters, _> = llm::Loader::new(None, |_| {
// We purposely do not print progress here, as we are only interested in the metadata
});

Expand Down
1 change: 1 addition & 0 deletions crates/llm-base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ partial_sort = "0.2.0"
serde_bytes = "0.11"
memmap2 = "0.5.10"
half = "2.2.1"
tokenizers = "0.13.3"
8 changes: 4 additions & 4 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl InferenceSession {
if should_call_callback {
// NOTE: No string ever tokenizes to the end of sentence. So we
// can just return the id here.
match callback(vocab.token(tk as usize)) {
match callback(&vocab.token(tk as usize)) {
Err(e) => return Err(InferenceError::UserCallback(Some(Box::new(e)))),
Ok(f) => match f {
InferenceFeedback::Continue => (),
Expand All @@ -118,7 +118,7 @@ impl InferenceSession {
params: &InferenceParameters,
output_request: &mut OutputRequest,
rng: &mut impl rand::Rng,
) -> Result<&'v [u8], InferenceError> {
) -> Result<Vec<u8>, InferenceError> {
if self.n_past + 1 >= model.context_size() {
return Err(InferenceError::ContextFull);
}
Expand Down Expand Up @@ -163,7 +163,7 @@ impl InferenceSession {
for token_id in &self.tokens {
// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) =
token_utf8_buf.push(model.vocabulary().token(*token_id as usize))
token_utf8_buf.push(&model.vocabulary().token(*token_id as usize))
{
if let Err(e) = callback(InferenceResponse::SnapshotToken(tokens)) {
return Err(InferenceError::UserCallback(Some(Box::new(e))));
Expand Down Expand Up @@ -204,7 +204,7 @@ impl InferenceSession {
};

// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) = token_utf8_buf.push(token) {
if let Some(tokens) = token_utf8_buf.push(&token) {
match callback(InferenceResponse::InferredToken(tokens)) {
Err(e) => return Err(InferenceError::UserCallback(Some(Box::new(e)))),
Ok(f) => match f {
Expand Down
42 changes: 38 additions & 4 deletions crates/llm-base/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use ggml::{
use memmap2::Mmap;
use thiserror::Error;

use tokenizers::Tokenizer;

#[derive(Debug, PartialEq, Clone, Copy, Eq, Default)]
/// Information about the file.
pub struct FileType {
Expand Down Expand Up @@ -280,6 +282,15 @@ pub enum LoadError {
/// The paths that were found.
paths: Vec<PathBuf>,
},

/// The vocab file for the tokenizer could not be loaded.
///
///
#[error("could not load vocab file {path:?}")]
VocabLoadError {
/// The path that failed.
path: PathBuf,
},
}
impl From<util::FindAllModelFilesError> for LoadError {
fn from(value: util::FindAllModelFilesError) -> Self {
Expand Down Expand Up @@ -343,6 +354,7 @@ pub trait TensorLoader<E: std::error::Error> {
/// store any information about the architecture.
pub fn load<M: KnownModel>(
path: &Path,
vocab_path: Option<&Path>,
params: ModelParameters,
overrides: Option<M::Overrides>,
load_progress_callback: impl FnMut(LoadProgress),
Expand All @@ -364,7 +376,29 @@ pub fn load<M: KnownModel>(
})?;
let mut reader = BufReader::new(&file);

let mut loader = Loader::new(load_progress_callback);
let tokenizer = if let Some(path) = vocab_path {
let tok = if !path.exists() && path.to_str().unwrap().matches("/").count() == 1 {
Tokenizer::from_pretrained(path.to_str().unwrap(), None)
} else if path.exists() && path.is_file() {
Tokenizer::from_file(path)
} else {
return Err(LoadError::VocabLoadError {
path: path.to_owned(),
});
};

if tok.is_err() {
return Err(LoadError::VocabLoadError {
path: path.to_owned(),
});
}

Some(tok.unwrap())
} else {
None
};

let mut loader = Loader::new(tokenizer, load_progress_callback);

ggml::format::load(&mut reader, &mut loader)
.map_err(|err| LoadError::from_format_error(err, path.to_owned()))?;
Expand Down Expand Up @@ -422,7 +456,7 @@ pub fn load<M: KnownModel>(
let mut lora_reader = BufReader::new(&lora_file);
// TODO: Consider updating the progress callback to report the progress of the LoRA file.
// Most LoRAs are small enough that this is not necessary, but it would be nice to have.
let mut lora_loader: Loader<LoraParameters, _> = Loader::new(|_| {});
let mut lora_loader: Loader<LoraParameters, _> = Loader::new(None, |_| {});
ggml::format::load(&mut lora_reader, &mut lora_loader)
.map_err(|err| LoadError::from_format_error(err, lora_path.to_owned()))?;

Expand Down Expand Up @@ -498,13 +532,13 @@ pub struct Loader<Hp: Hyperparameters, F: FnMut(LoadProgress)> {
}
impl<Hp: Hyperparameters, F: FnMut(LoadProgress)> Loader<Hp, F> {
/// Creates a new loader.
pub fn new(load_progress_callback: F) -> Self {
pub fn new(tokenizer: Option<Tokenizer>, load_progress_callback: F) -> Self {
Self {
load_progress_callback,

container_type: ContainerType::Ggml,
hyperparameters: Hp::default(),
vocabulary: Vocabulary::default(),
vocabulary: Vocabulary::new(tokenizer),
tensors: HashMap::default(),
}
}
Expand Down
7 changes: 4 additions & 3 deletions crates/llm-base/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,15 @@ pub trait KnownModel: Send + Sync {
/// is a helper function on top of [llm_base::load](crate::load).
fn load(
path: &Path,
vocab_path: Option<&Path>,
params: ModelParameters,
overrides: Option<Self::Overrides>,
load_progress_callback: impl FnMut(LoadProgress),
) -> Result<Self, LoadError>
where
Self: Sized,
{
crate::load(path, params, overrides, load_progress_callback)
crate::load(path, vocab_path, params, overrides, load_progress_callback)
}

/// Creates a new model from the provided [ModelParameters] hyperparameters.
Expand Down Expand Up @@ -151,7 +152,7 @@ pub trait KnownModel: Send + Sync {
output_request: &mut OutputRequest,
);

/// Get the vocabulary (loaded from the GGML file) for this model.
/// Get the vocabulary for this model.
fn vocabulary(&self) -> &Vocabulary;

/// Get the context size (configured with [ModelParameters::context_size]) used by
Expand Down Expand Up @@ -188,7 +189,7 @@ pub trait Model: Send + Sync {
output_request: &mut OutputRequest,
);

/// Get the vocabulary (loaded from the GGML file) for this model.
/// Get the vocabulary for this model.
fn vocabulary(&self) -> &Vocabulary;

/// Get the context size (configured with [ModelParameters::context_size]) used by
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/quantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ pub fn quantize<M: KnownModel, R: BufRead + Seek, W: Write + Seek>(
// Load the model
let progress_callback = Arc::new(progress_callback);

let mut loader = Loader::<M::Hyperparameters, _>::new({
let mut loader = Loader::<M::Hyperparameters, _>::new(None, {
let progress_callback = progress_callback.clone();
move |p| {
if let LoadProgress::HyperparametersLoaded = p {
Expand Down
130 changes: 88 additions & 42 deletions crates/llm-base/src/vocabulary.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::{collections::HashMap, error::Error, fmt::Display, str::FromStr};

use thiserror::Error;
use tokenizers::Tokenizer;

/// The identifier of a token in a vocabulary.
pub type TokenId = i32;
pub type TokenId = u32;
pub(crate) type Token = Vec<u8>;
pub(crate) type TokenScore = f32;

Expand Down Expand Up @@ -34,9 +35,19 @@ pub struct Vocabulary {

/// The longest token in this vocabulary.
pub max_token_length: usize,

/// The tokenizer
pub tokenizer: Option<Tokenizer>,
}

impl Vocabulary {
/// Intialize a new vocabulary.
pub fn new(tokenizer: Option<Tokenizer>) -> Vocabulary {
let mut vocab = Vocabulary::default();
vocab.tokenizer = tokenizer;

vocab
}
/// Add a token to the vocabulary.
///
/// The token added must have `id` directly after the last token in the vocabulary.
Expand All @@ -45,6 +56,10 @@ impl Vocabulary {
/// - This function can panic if `id` does not correspond to the next token in the vocabulary.
/// That is, if there are already `n` tokens in the vocabulary, then `id` must be `n`.
pub fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) {
if self.tokenizer.is_some() {
return;
}

// These are loader invariants. If this is broken, then the loader is broken and this is a bug,
// not an issue with the model itself.
assert_eq!(self.id_to_token.len(), self.id_to_token_score.len());
Expand All @@ -60,17 +75,33 @@ impl Vocabulary {
}

/// Converts a token index to the token it represents in this vocabulary.
pub fn token(&self, idx: usize) -> &[u8] {
&self.id_to_token[idx]
pub fn token(&self, idx: usize) -> Vec<u8> {
if let Some(tokenizer) = &self.tokenizer {
return tokenizer
.decode(vec![idx as u32], true)
.unwrap()
.as_bytes()
.to_vec();
}

(&self.id_to_token[idx]).clone()
}

/// Returns the number of tokens in the vocabulary.
pub fn len(&self) -> usize {
if let Some(tokenizer) = &self.tokenizer {
return tokenizer.get_vocab_size(false) as usize;
}

self.id_to_token.len()
}

/// Returns whether the vocabulary is empty.
pub fn is_empty(&self) -> bool {
if let Some(tokenizer) = &self.tokenizer {
return tokenizer.get_vocab_size(false) == 0;
}

self.id_to_token.is_empty()
}

Expand All @@ -82,53 +113,68 @@ impl Vocabulary {
&'a self,
text: &str,
bos: bool,
) -> Result<Vec<(&'a [u8], TokenId)>, TokenizationError> {
let len = text.len();

let mut score = vec![0usize; len + 1];
let mut prev = vec![TokenId::default(); len + 1];

for i in 0..len {
let max_len = (len - i).min(self.max_token_length);
for sub_len in 1..=max_len {
let sub = &text.as_bytes()[i..i + sub_len];
let token = self.token_to_id.get(sub);

if let Some(token) = token {
let token_score = sub.len() * sub.len();
let local_score = score[i] + token_score;
let next = i + sub_len;

if score[next] < local_score {
score[next] = local_score;
prev[next] = *token;
) -> Result<Vec<(Vec<u8>, TokenId)>, TokenizationError> {
if let Some(tokenizer) = &self.tokenizer {
let res = tokenizer.encode(text, bos);
if res.is_err() {
return Err(TokenizationError::TokenizationFailed);
} else {
Ok(tokenizer
.encode(text, bos)
.unwrap()
.get_ids()
.iter()
.map(|id| (self.token(*id as usize), *id))
.collect::<Vec<(Vec<u8>, TokenId)>>())
}
} else {
let len = text.len();

let mut score = vec![0usize; len + 1];
let mut prev = vec![TokenId::default(); len + 1];

for i in 0..len {
let max_len = (len - i).min(self.max_token_length);
for sub_len in 1..=max_len {
let sub = &text.as_bytes()[i..i + sub_len];
let token = self.token_to_id.get(sub);

if let Some(token) = token {
let token_score = sub.len() * sub.len();
let local_score = score[i] + token_score;
let next = i + sub_len;

if score[next] < local_score {
score[next] = local_score;
prev[next] = *token;
}
}
}
}
}

// Backward pass
let mut res = vec![];
let mut i = len;
while i > 0 {
let token_id = prev[i];
if token_id == 0 {
return Err(TokenizationError::TokenizationFailed);
// Backward pass
let mut res = vec![];
let mut i = len;
while i > 0 {
let token_id = prev[i];
if token_id == 0 {
return Err(TokenizationError::TokenizationFailed);
}
let token = self.id_to_token[token_id as usize].as_slice();
res.push((token.to_vec(), token_id));
i -= token.len();
}
let token = self.id_to_token[token_id as usize].as_slice();
res.push((token, token_id));
i -= token.len();
}

if bos {
// TODO: replace with vocab.bos
res.push((&[], 1));
}
if bos {
// TODO: replace with vocab.bos
res.push((vec![], 1));
}

// Pieces are in reverse order so correct that
res.reverse();
// Pieces are in reverse order so correct that
res.reverse();

Ok(res)
Ok(res)
}
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/llm/examples/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ fn main() {
let model = llm::load_dynamic(
model_architecture,
model_path,
None,
Default::default(),
overrides,
load_callback,
Expand Down
Loading

0 comments on commit 5a423b1

Please sign in to comment.