Skip to content

Commit

Permalink
Merge pull request rustformers#367 from radu-matei/feat/logging-tracing
Browse files Browse the repository at this point in the history
feat(tracing): add tracing to `llm` and `llm-base` crates
  • Loading branch information
philpax authored Jul 16, 2023
2 parents 693e6c9 + c344592 commit 0269796
Show file tree
Hide file tree
Showing 10 changed files with 215 additions and 43 deletions.
130 changes: 130 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ serde_json = { version = "1.0" }
spinoff = { version = "0.7.0", default-features = false, features = ["dots2"] }
clap = { version = "4.1.8", features = ["derive"] }
memmap2 = "0.5.10"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing = { version = "0.1", features = ["log"] }

# Config for 'cargo dist'
[workspace.metadata.dist]
Expand All @@ -50,4 +52,4 @@ inherits = "release"
lto = "thin"

[workspace.metadata.release]
tag-prefix = ""
tag-prefix = ""
3 changes: 3 additions & 0 deletions binaries/llm-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ num_cpus = "1.15.0"

color-eyre = { version = "0.6.2", default-features = false }
zstd = { version = "0.12", default-features = false }
tracing-subscriber = {workspace = true }
tracing = { workspace = true}
tracing-appender = "0.2.2"

[dev-dependencies]
rusty-hook = "^0.11.2"
Expand Down
94 changes: 53 additions & 41 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{
convert::Infallible,
fs::File,
io::{BufReader, BufWriter},
io::{BufReader, BufWriter, IsTerminal},
};

use clap::Parser;
Expand All @@ -14,10 +14,12 @@ mod snapshot;
mod util;

fn main() -> eyre::Result<()> {
env_logger::builder()
.filter_level(log::LevelFilter::Info)
.parse_default_env()
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_ansi(std::io::stderr().is_terminal())
.init();

color_eyre::install()?;

let args = Args::parse();
Expand All @@ -32,6 +34,7 @@ fn main() -> eyre::Result<()> {
}
}

#[tracing::instrument(skip_all)]
fn infer(args: &cli_args::Infer) -> eyre::Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?;
let inference_session_config = args.generate.inference_session_config();
Expand All @@ -46,46 +49,55 @@ fn infer(args: &cli_args::Infer) -> eyre::Result<()> {
let parameters = args.generate.inference_parameters(model.eot_token_id());

let mut rng = args.generate.rng();
let res = session.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: prompt.as_str().into(),
parameters: &parameters,
play_back_previous_tokens: session_loaded,
maximum_token_count: args.generate.num_predict,
},
// OutputRequest
&mut Default::default(),
|r| {
match r {
llm::InferenceResponse::PromptToken(t) if !args.hide_prompt => util::print_token(t),
llm::InferenceResponse::InferredToken(t) => util::print_token(t),
_ => {}

let span = tracing::trace_span!("infer");

span.in_scope(|| {
// do work inside the span...
let res = session.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: prompt.as_str().into(),
parameters: &parameters,
play_back_previous_tokens: session_loaded,
maximum_token_count: args.generate.num_predict,
},
// OutputRequest
&mut Default::default(),
|r| {
match r {
llm::InferenceResponse::PromptToken(t) if !args.hide_prompt => {
util::print_token(t)
}
llm::InferenceResponse::InferredToken(t) => util::print_token(t),
_ => {}
}
Ok(llm::InferenceFeedback::Continue)
},
);

println!();

match res {
Ok(stats) => {
if args.stats {
println!();
println!("{}", stats);
println!();
}
}
Ok(llm::InferenceFeedback::Continue)
},
);
println!();

match res {
Ok(stats) => {
if args.stats {
println!();
println!("{}", stats);
println!();
Err(llm::InferenceError::ContextFull) => {
log::warn!("Context window full, stopping inference.")
}
Err(llm::InferenceError::TokenizationFailed(err)) => {
log::error!("A tokenization-related failure occurred: {}", err);
}
Err(llm::InferenceError::UserCallback(_)) | Err(llm::InferenceError::EndOfText) => {
unreachable!("cannot fail")
}
}
Err(llm::InferenceError::ContextFull) => {
log::warn!("Context window full, stopping inference.")
}
Err(llm::InferenceError::TokenizationFailed(err)) => {
log::error!("A tokenization-related failure occurred: {}", err);
}
Err(llm::InferenceError::UserCallback(_)) | Err(llm::InferenceError::EndOfText) => {
unreachable!("cannot fail")
}
}
});

if let Some(session_path) = args.save_session.as_ref().or(args.persist_session.as_ref()) {
// Write the memory to the cache file
Expand Down
1 change: 1 addition & 0 deletions crates/llm-base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ memmap2 = { workspace = true }
half = "2.2.1"
tokenizers = {version="0.13.3", default-features=false, features=["onig"]}
regex = "1.8"
tracing = { workspace = true }

[features]
tokenizers-remote = ["tokenizers/http"]
Expand Down
Loading

0 comments on commit 0269796

Please sign in to comment.