Skip to content

Commit

Permalink
Merge pull request rustformers#47 from bcho/stats
Browse files Browse the repository at this point in the history
feat: record inference stats
  • Loading branch information
setzer22 authored Mar 20, 2023
2 parents a9b289c + 3ff2bf7 commit d8ca18d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
4 changes: 3 additions & 1 deletion llama-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ fn main() {
println!();

match res {
Ok(_) => (),
Ok(stats) => {
println!("{}", stats);
}
Err(llama_rs::InferenceError::ContextFull) => {
log::warn!("Context window full, stopping inference.")
}
Expand Down
45 changes: 43 additions & 2 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{
fmt::Display,
io::{BufRead, Read, Seek, SeekFrom},
path::{Path, PathBuf},
time,
};

use thiserror::Error;
Expand Down Expand Up @@ -107,6 +108,38 @@ impl Default for InferenceParameters {
}
}

pub struct InferenceStats {
pub feed_prompt_duration: std::time::Duration,
pub prompt_tokens: usize,
pub predict_duration: std::time::Duration,
pub predict_tokens: usize,
}

impl Default for InferenceStats {
fn default() -> Self {
Self {
feed_prompt_duration: std::time::Duration::from_secs(0),
prompt_tokens: 0,
predict_duration: std::time::Duration::from_secs(0),
predict_tokens: 0,
}
}
}

impl Display for InferenceStats {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"feed_prompt_duration: {}ms\nprompt_tokens: {}\npredict_duration: {}ms\npredict_tokens: {}\nper_token_duration: {:.3}ms",
self.feed_prompt_duration.as_millis(),
self.prompt_tokens,
self.predict_duration.as_millis(),
self.predict_tokens,
(self.predict_duration.as_millis() as f64) / (self.predict_tokens as f64),
)
}
}

type TokenId = i32;
type Token = String;

Expand Down Expand Up @@ -1236,10 +1269,16 @@ impl InferenceSession {
maximum_token_count: Option<usize>,
rng: &mut impl rand::Rng,
callback: impl Fn(OutputToken) -> Result<(), E>,
) -> Result<(), InferenceError> {
) -> Result<InferenceStats, InferenceError> {
let mut stats = InferenceStats::default();

let start_at = time::SystemTime::now();

// Feed the initial prompt through the transformer, to update its
// context window with new data.
self.feed_prompt(model, vocab, params, prompt, |tk| callback(tk))?;
stats.feed_prompt_duration = start_at.elapsed().unwrap();
stats.prompt_tokens = self.n_past;

// After the prompt is consumed, sample tokens by repeatedly calling
// `infer_next_token`. We generate tokens until the model returns an
Expand All @@ -1261,8 +1300,10 @@ impl InferenceSession {
break;
}
}
stats.predict_duration = start_at.elapsed().unwrap();
stats.predict_tokens = self.n_past;

Ok(())
Ok(stats)
}

/// Obtains a serializable snapshot of the current inference status. This
Expand Down

0 comments on commit d8ca18d

Please sign in to comment.