forked from rustformers/llm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvicuna-chat.rs
125 lines (111 loc) · 4.38 KB
/
vicuna-chat.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use clap::Parser;
use llm_base::conversation_inference_callback;
use rustyline::error::ReadlineError;
use std::{convert::Infallible, io::Write, path::PathBuf};
#[derive(Parser)]
struct Args {
model_architecture: llm::ModelArchitecture,
model_path: PathBuf,
#[arg(long, short = 'v')]
pub tokenizer_path: Option<PathBuf>,
#[arg(long, short = 'r')]
pub tokenizer_repository: Option<String>,
}
impl Args {
pub fn to_tokenizer_source(&self) -> llm::TokenizerSource {
match (&self.tokenizer_path, &self.tokenizer_repository) {
(Some(_), Some(_)) => {
panic!("Cannot specify both --tokenizer-path and --tokenizer-repository");
}
(Some(path), None) => llm::TokenizerSource::HuggingFaceTokenizerFile(path.to_owned()),
(None, Some(repo)) => llm::TokenizerSource::HuggingFaceRemote(repo.to_owned()),
(None, None) => llm::TokenizerSource::Embedded,
}
}
}
fn main() {
let args = Args::parse();
let tokenizer_source = args.to_tokenizer_source();
let model_architecture = args.model_architecture;
let model_path = args.model_path;
let model = llm::load_dynamic(
Some(model_architecture),
&model_path,
tokenizer_source,
Default::default(),
llm::load_progress_callback_stdout,
)
.unwrap_or_else(|err| {
panic!("Failed to load {model_architecture} model from {model_path:?}: {err}")
});
let mut session = model.start_session(Default::default());
let character_name = "### Assistant";
let user_name = "### Human";
let persona = "A chat between a human and an assistant.";
let history = format!(
"{character_name}: Hello - How may I help you today?\n\
{user_name}: What is the capital of France?\n\
{character_name}: Paris is the capital of France."
);
let inference_parameters = llm::InferenceParameters::default();
session
.feed_prompt(
model.as_ref(),
format!("{persona}\n{history}").as_str(),
&mut Default::default(),
llm::feed_prompt_callback(|resp| match resp {
llm::InferenceResponse::PromptToken(t)
| llm::InferenceResponse::InferredToken(t) => {
print_token(t);
Ok::<llm::InferenceFeedback, Infallible>(llm::InferenceFeedback::Continue)
}
_ => Ok(llm::InferenceFeedback::Continue),
}),
)
.expect("Failed to ingest initial prompt.");
let mut rl = rustyline::DefaultEditor::new().expect("Failed to create input reader");
let mut rng = rand::thread_rng();
let mut res = llm::InferenceStats::default();
loop {
println!();
let readline = rl.readline(format!("{user_name}: ").as_str());
print!("{character_name}:");
match readline {
Ok(line) => {
let stats = session
.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: format!("{user_name}: {line}\n{character_name}:")
.as_str()
.into(),
parameters: &inference_parameters,
play_back_previous_tokens: false,
maximum_token_count: None,
},
&mut Default::default(),
conversation_inference_callback(&format!("{character_name}:"), print_token),
)
.unwrap_or_else(|e| panic!("{e}"));
res.feed_prompt_duration = res
.feed_prompt_duration
.saturating_add(stats.feed_prompt_duration);
res.prompt_tokens += stats.prompt_tokens;
res.predict_duration = res.predict_duration.saturating_add(stats.predict_duration);
res.predict_tokens += stats.predict_tokens;
}
Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => {
break;
}
Err(err) => {
println!("{err}");
}
}
}
println!("\n\nInference stats:\n{res}");
}
fn print_token(t: String) {
print!("{t}");
std::io::stdout().flush().unwrap();
}