@@ -29,18 +29,38 @@ pub(super) struct Model {
2929}
3030
3131pub ( super ) enum Output {
32- Text { think : String , content : String } ,
33- Finish ( FinishReason ) ,
32+ Text {
33+ think : String ,
34+ content : String ,
35+ } ,
36+ Finish {
37+ reason : FinishReason ,
38+ num_tokens : [ usize ; 2 ] ,
39+ } ,
3440}
3541
3642struct SessionInfo {
3743 sender : UnboundedSender < Output > ,
3844 buf : TextBuf ,
3945 think : bool ,
4046 tokens : Vec < utok > ,
47+ prompt_tokens : usize ,
4148 accumulated_content : String , // Track all generated content for blacklist detection
4249}
4350
51+ impl SessionInfo {
52+ fn new ( sender : UnboundedSender < Output > , tokens : Vec < utok > ) -> Self {
53+ Self {
54+ buf : TextBuf :: new ( ) ,
55+ think : false ,
56+ prompt_tokens : tokens. len ( ) ,
57+ accumulated_content : String :: new ( ) ,
58+ sender,
59+ tokens,
60+ }
61+ }
62+ }
63+
4464impl Model {
4565 pub fn new ( config : ModelConfig , use_cuda_graph : bool ) -> ( Self , Service ) {
4666 let ModelConfig {
@@ -181,7 +201,10 @@ impl Model {
181201 // Send finish signal
182202 if session_info
183203 . sender
184- . send ( Output :: Finish ( FinishReason :: Stop ) )
204+ . send ( Output :: Finish {
205+ reason : FinishReason :: Stop ,
206+ num_tokens : [ session_info. prompt_tokens , session_info. tokens . len ( ) ] ,
207+ } )
185208 . is_err ( )
186209 {
187210 info ! ( "{session_id:?} 客户端连接已关闭" ) ;
@@ -202,8 +225,13 @@ impl Model {
202225 // 处理会话结束
203226 if !sessions. is_empty ( ) {
204227 for ( session, reason) in sessions {
205- let SessionInfo { tokens, sender, .. } =
206- sessions_guard. remove ( & session. id ) . unwrap ( ) ;
228+ let SessionInfo {
229+ tokens,
230+ sender,
231+ prompt_tokens,
232+ ..
233+ } = sessions_guard. remove ( & session. id ) . unwrap ( ) ;
234+ let num_tokens = [ prompt_tokens, tokens. len ( ) ] ;
207235 let reason = match reason {
208236 ReturnReason :: Finish => {
209237 // 正常完成,插回 cache
@@ -221,7 +249,7 @@ impl Model {
221249 } ;
222250
223251 sender
224- . send ( Output :: Finish ( reason) )
252+ . send ( Output :: Finish { reason, num_tokens } )
225253 . unwrap_or_else ( |_| info ! ( "{:?} 发送正常完成失败" , session. id) ) ;
226254 }
227255 }
@@ -298,18 +326,12 @@ impl Model {
298326 max_tokens,
299327 ) ;
300328
301- let session_info = SessionInfo {
302- sender,
303- tokens,
304- buf : TextBuf :: new ( ) ,
305- think : false ,
306- accumulated_content : String :: new ( ) ,
307- } ;
329+ let session_info = SessionInfo :: new ( sender, tokens) ;
308330 assert ! (
309331 self . sessions
310332 . lock( )
311333 . unwrap( )
312- . insert( id, session_info, )
334+ . insert( id, session_info)
313335 . is_none( )
314336 ) ;
315337
@@ -360,18 +382,12 @@ impl Model {
360382 max_tokens,
361383 ) ;
362384
363- let session_info = SessionInfo {
364- sender,
365- tokens,
366- buf : TextBuf :: new ( ) ,
367- think : false ,
368- accumulated_content : String :: new ( ) ,
369- } ;
385+ let session_info = SessionInfo :: new ( sender, tokens) ;
370386 assert ! (
371387 self . sessions
372388 . lock( )
373389 . unwrap( )
374- . insert( id, session_info, )
390+ . insert( id, session_info)
375391 . is_none( )
376392 ) ;
377393
0 commit comments