Skip to content

Commit d7dc4dc

Browse files
committed
feat(xtask): 响应填写 usage 字段
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 83c51f4 commit d7dc4dc

File tree

5 files changed

+73
-55
lines changed

5 files changed

+73
-55
lines changed

xtask/src/service/blacklist_integration_test.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,7 @@ fn test_blacklist_configuration() {
282282
let lower_word = word.to_lowercase();
283283
assert!(
284284
blacklist.iter().any(|bw| bw == &lower_word),
285-
"Blacklist should contain '{}' (lowercase: '{}')",
286-
word,
287-
lower_word
285+
"Blacklist should contain '{word}' (lowercase: '{lower_word}')"
288286
);
289287
}
290288

@@ -295,10 +293,7 @@ fn test_blacklist_configuration() {
295293
let max_length = word_lengths.iter().max().unwrap();
296294
let min_length = word_lengths.iter().min().unwrap();
297295

298-
info!(
299-
"Blacklist word length range: {} to {} characters",
300-
min_length, max_length
301-
);
296+
info!("Blacklist word length range: {min_length} to {max_length} characters");
302297
assert!(
303298
*max_length > 10,
304299
"Should have words longer than 10 characters for suffix optimization test"
@@ -378,7 +373,7 @@ blacklist = [
378373
"#;
379374

380375
info!("Example TOML configuration:");
381-
info!("{}", toml_config);
376+
info!("{toml_config}");
382377

383378
// In practice, you'd parse this with:
384379
// let config: ModelConfig = toml::from_str(toml_config).unwrap();

xtask/src/service/client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ fn test_blacklisted_check() {
358358
let normal_prompt = "Tell me a story about a cat";
359359
let req_body_normal = requset_body_chat(normal_prompt);
360360

361-
info!("Sending normal request: {}", normal_prompt);
361+
info!("Sending normal request: {normal_prompt}");
362362
let normal_result =
363363
send_single_request(port, &client, &headers, req_body_normal, Some(1)).await;
364364

xtask/src/service/mod.rs

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ mod response;
88
use crate::{
99
parse_gpus,
1010
service::{
11-
openai::{
12-
chat_completion_response, chat_completion_response_stream, create_completion_response,
13-
},
11+
openai::{chat_completion_response, chat_completion_response_stream, completion_response},
1412
response::text_stream,
1513
},
1614
};
@@ -226,16 +224,14 @@ impl HyperService<Request<Incoming>> for App {
226224
return Ok(text_stream(UnboundedReceiverStream::new(receiver).map(
227225
move |output| {
228226
let response = match output {
229-
model::Output::Text { content, .. } => {
230-
create_completion_response(
231-
id,
232-
created,
233-
model_name.clone(),
234-
content,
235-
None,
236-
)
237-
}
238-
model::Output::Finish(reason) => create_completion_response(
227+
model::Output::Text { content, .. } => completion_response(
228+
id,
229+
created,
230+
model_name.clone(),
231+
content,
232+
None,
233+
),
234+
model::Output::Finish { reason, .. } => completion_response(
239235
id,
240236
created,
241237
model_name.clone(),
@@ -257,14 +253,13 @@ impl HyperService<Request<Incoming>> for App {
257253
think_.push_str(&think);
258254
content_.push_str(&content);
259255
}
260-
model::Output::Finish(reason) => {
256+
model::Output::Finish { reason, .. } => {
261257
assert!(reason_.replace(reason).is_none())
262258
}
263259
}
264260
}
265261

266-
let response =
267-
create_completion_response(id, created, model_name, content_, reason_);
262+
let response = completion_response(id, created, model_name, content_, reason_);
268263
Ok(json(response))
269264
})
270265
}
@@ -314,7 +309,7 @@ impl HyperService<Request<Incoming>> for App {
314309
None,
315310
)
316311
}
317-
model::Output::Finish(reason) => {
312+
model::Output::Finish { reason, .. } => {
318313
chat_completion_response_stream(
319314
id,
320315
created,
@@ -333,14 +328,16 @@ impl HyperService<Request<Incoming>> for App {
333328
let mut think_ = String::new();
334329
let mut content_ = String::new();
335330
let mut reason_ = None;
331+
let mut num_tokens_ = [0, 0];
336332
while let Some(output) = receiver.recv().await {
337333
match output {
338334
model::Output::Text { think, content } => {
339335
think_.push_str(&think);
340336
content_.push_str(&content);
341337
}
342-
model::Output::Finish(reason) => {
343-
assert!(reason_.replace(reason).is_none())
338+
model::Output::Finish { reason, num_tokens } => {
339+
assert!(reason_.replace(reason).is_none());
340+
num_tokens_ = num_tokens
344341
}
345342
}
346343
}
@@ -351,6 +348,7 @@ impl HyperService<Request<Incoming>> for App {
351348
model_name,
352349
Some(think_).filter(|s| !s.is_empty()),
353350
Some(content_).filter(|s| !s.is_empty()),
351+
num_tokens_,
354352
reason_,
355353
);
356354
Ok(json(response))

xtask/src/service/model.rs

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,38 @@ pub(super) struct Model {
2929
}
3030

3131
pub(super) enum Output {
32-
Text { think: String, content: String },
33-
Finish(FinishReason),
32+
Text {
33+
think: String,
34+
content: String,
35+
},
36+
Finish {
37+
reason: FinishReason,
38+
num_tokens: [usize; 2],
39+
},
3440
}
3541

3642
struct SessionInfo {
3743
sender: UnboundedSender<Output>,
3844
buf: TextBuf,
3945
think: bool,
4046
tokens: Vec<utok>,
47+
prompt_tokens: usize,
4148
accumulated_content: String, // Track all generated content for blacklist detection
4249
}
4350

51+
impl SessionInfo {
52+
fn new(sender: UnboundedSender<Output>, tokens: Vec<utok>) -> Self {
53+
Self {
54+
buf: TextBuf::new(),
55+
think: false,
56+
prompt_tokens: tokens.len(),
57+
accumulated_content: String::new(),
58+
sender,
59+
tokens,
60+
}
61+
}
62+
}
63+
4464
impl Model {
4565
pub fn new(config: ModelConfig, use_cuda_graph: bool) -> (Self, Service) {
4666
let ModelConfig {
@@ -181,7 +201,10 @@ impl Model {
181201
// Send finish signal
182202
if session_info
183203
.sender
184-
.send(Output::Finish(FinishReason::Stop))
204+
.send(Output::Finish {
205+
reason: FinishReason::Stop,
206+
num_tokens: [session_info.prompt_tokens, session_info.tokens.len()],
207+
})
185208
.is_err()
186209
{
187210
info!("{session_id:?} 客户端连接已关闭");
@@ -202,8 +225,13 @@ impl Model {
202225
// 处理会话结束
203226
if !sessions.is_empty() {
204227
for (session, reason) in sessions {
205-
let SessionInfo { tokens, sender, .. } =
206-
sessions_guard.remove(&session.id).unwrap();
228+
let SessionInfo {
229+
tokens,
230+
sender,
231+
prompt_tokens,
232+
..
233+
} = sessions_guard.remove(&session.id).unwrap();
234+
let num_tokens = [prompt_tokens, tokens.len()];
207235
let reason = match reason {
208236
ReturnReason::Finish => {
209237
// 正常完成,插回 cache
@@ -221,7 +249,7 @@ impl Model {
221249
};
222250

223251
sender
224-
.send(Output::Finish(reason))
252+
.send(Output::Finish { reason, num_tokens })
225253
.unwrap_or_else(|_| info!("{:?} 发送正常完成失败", session.id));
226254
}
227255
}
@@ -298,18 +326,12 @@ impl Model {
298326
max_tokens,
299327
);
300328

301-
let session_info = SessionInfo {
302-
sender,
303-
tokens,
304-
buf: TextBuf::new(),
305-
think: false,
306-
accumulated_content: String::new(),
307-
};
329+
let session_info = SessionInfo::new(sender, tokens);
308330
assert!(
309331
self.sessions
310332
.lock()
311333
.unwrap()
312-
.insert(id, session_info,)
334+
.insert(id, session_info)
313335
.is_none()
314336
);
315337

@@ -360,18 +382,12 @@ impl Model {
360382
max_tokens,
361383
);
362384

363-
let session_info = SessionInfo {
364-
sender,
365-
tokens,
366-
buf: TextBuf::new(),
367-
think: false,
368-
accumulated_content: String::new(),
369-
};
385+
let session_info = SessionInfo::new(sender, tokens);
370386
assert!(
371387
self.sessions
372388
.lock()
373389
.unwrap()
374-
.insert(id, session_info,)
390+
.insert(id, session_info)
375391
.is_none()
376392
);
377393

xtask/src/service/openai.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use hyper::Method;
22
use openai_struct::{
3-
ChatCompletionResponseMessage, ChatCompletionStreamResponseDelta, CreateChatCompletionResponse,
4-
CreateChatCompletionResponseChoices, CreateChatCompletionStreamResponse,
5-
CreateChatCompletionStreamResponseChoices, CreateCompletionResponse,
6-
CreateCompletionResponseChoices, CreateCompletionResponseLogprobs, FinishReason, Model,
3+
ChatCompletionResponseMessage, ChatCompletionStreamResponseDelta, CompletionUsage,
4+
CreateChatCompletionResponse, CreateChatCompletionResponseChoices,
5+
CreateChatCompletionStreamResponse, CreateChatCompletionStreamResponseChoices,
6+
CreateCompletionResponse, CreateCompletionResponseChoices, CreateCompletionResponseLogprobs,
7+
FinishReason, Model,
78
};
89
use serde::Serialize;
910

@@ -42,6 +43,7 @@ pub(crate) fn chat_completion_response(
4243
model: String,
4344
think: Option<String>,
4445
answer: Option<String>,
46+
[prompt_tokens, total_tokens]: [usize; 2],
4547
finish_reason: Option<FinishReason>,
4648
) -> CreateChatCompletionResponse {
4749
let choices = vec![CreateChatCompletionResponseChoices {
@@ -59,6 +61,13 @@ pub(crate) fn chat_completion_response(
5961
model,
6062
choices,
6163
created,
64+
usage: Some(CompletionUsage {
65+
completion_tokens: (total_tokens - prompt_tokens) as _,
66+
prompt_tokens: prompt_tokens as _,
67+
total_tokens: total_tokens as _,
68+
completion_tokens_details: None,
69+
prompt_tokens_details: None,
70+
}),
6271
..Default::default()
6372
}
6473
}
@@ -90,7 +99,7 @@ pub(crate) fn chat_completion_response_stream(
9099
}
91100
}
92101

93-
pub(crate) fn create_completion_response(
102+
pub(crate) fn completion_response(
94103
id: usize,
95104
created: i32,
96105
model: String,

0 commit comments

Comments
 (0)