forked from rustformers/llm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.rs
66 lines (57 loc) · 1.98 KB
/
inference.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
use llm::{
load_progress_callback_stdout as load_callback, InferenceFeedback, InferenceRequest,
InferenceResponse, ModelArchitecture,
};
use std::{convert::Infallible, io::Write, path::Path};
fn main() {
let raw_args: Vec<String> = std::env::args().skip(1).collect();
if raw_args.len() < 2 {
println!("Usage: cargo run --release --example inference <model_architecture> <model_path> [prompt] [overrides, json]");
std::process::exit(1);
}
let model_architecture: ModelArchitecture = raw_args[0].parse().unwrap();
let model_path = Path::new(&raw_args[1]);
let prompt = raw_args
.get(2)
.map(|s| s.as_str())
.unwrap_or("Rust is a cool programming language because");
let overrides = raw_args.get(3).map(|s| serde_json::from_str(s).unwrap());
let now = std::time::Instant::now();
let model = llm::load_dynamic(
model_architecture,
model_path,
Default::default(),
overrides,
load_callback,
)
.unwrap_or_else(|err| {
panic!("Failed to load {model_architecture} model from {model_path:?}: {err}")
});
println!(
"Model fully loaded! Elapsed: {}ms",
now.elapsed().as_millis()
);
let mut session = model.start_session(Default::default());
let res = session.infer::<Infallible>(
model.as_ref(),
&mut rand::thread_rng(),
&InferenceRequest {
prompt: prompt.into(),
..Default::default()
},
// OutputRequest
&mut Default::default(),
|r| match r {
InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => {
print!("{t}");
std::io::stdout().flush().unwrap();
Ok(InferenceFeedback::Continue)
}
_ => Ok(InferenceFeedback::Continue),
},
);
match res {
Ok(result) => println!("\n\nInference stats:\n{result}"),
Err(err) => println!("\n{err}"),
}
}