@@ -355,6 +355,7 @@ def __init__(self, local_interval: float) -> None:
355355 self .num_generation_tokens : List [int ] = []
356356 self .last_local_log = time .time ()
357357 self .local_interval = local_interval
358+ self .spec_decode_metrics : Optional ["SpecDecodeWorkerMetrics" ] = None
358359
359360 @abstractmethod
360361 def info (self , type : str , obj : SupportsMetricsInfo ) -> None :
@@ -364,6 +365,12 @@ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
364365 def log (self , stats : Stats ) -> None :
365366 raise NotImplementedError
366367
368+ def maybe_update_spec_decode_metrics (self , stats : Stats ):
369+ """Save spec decode metrics (since they are unlikely
370+ to be emitted at same time as log interval)."""
371+ if stats .spec_decode_metrics is not None :
372+ self .spec_decode_metrics = stats .spec_decode_metrics
373+
367374
368375class LoggingStatLogger (StatLoggerBase ):
369376 """LoggingStatLogger is used in LLMEngine to log to Stdout."""
@@ -379,6 +386,9 @@ def log(self, stats: Stats) -> None:
379386 self .num_prompt_tokens .append (stats .num_prompt_tokens_iter )
380387 self .num_generation_tokens .append (stats .num_generation_tokens_iter )
381388
389+ # Update spec decode metrics
390+ self .maybe_update_spec_decode_metrics (stats )
391+
382392 # Log locally every local_interval seconds.
383393 if local_interval_elapsed (stats .now , self .last_local_log ,
384394 self .local_interval ):
@@ -408,15 +418,16 @@ def log(self, stats: Stats) -> None:
408418 stats .cpu_cache_usage_sys * 100 ,
409419 )
410420
421+ if self .spec_decode_metrics is not None :
422+ logger .info (
423+ self ._format_spec_decode_metrics_str (
424+ self .spec_decode_metrics ))
425+
411426 # Reset tracked stats for next interval.
412427 self .num_prompt_tokens = []
413428 self .num_generation_tokens = []
414429 self .last_local_log = stats .now
415-
416- if stats .spec_decode_metrics is not None :
417- logger .info (
418- self ._format_spec_decode_metrics_str (
419- stats .spec_decode_metrics ))
430+ self .spec_decode_metrics = None
420431
421432 def _format_spec_decode_metrics_str (
422433 self , metrics : "SpecDecodeWorkerMetrics" ) -> str :
@@ -533,6 +544,9 @@ def log(self, stats: Stats):
533544 self .num_prompt_tokens .append (stats .num_prompt_tokens_iter )
534545 self .num_generation_tokens .append (stats .num_generation_tokens_iter )
535546
547+ # Update spec decode metrics
548+ self .maybe_update_spec_decode_metrics (stats )
549+
536550 # Log locally every local_interval seconds.
537551 if local_interval_elapsed (stats .now , self .last_local_log ,
538552 self .local_interval ):
@@ -550,26 +564,27 @@ def log(self, stats: Stats):
550564 prompt_throughput = prompt_throughput ,
551565 generation_throughput = generation_throughput )
552566
553- # Reset tracked stats for next interval.
554- self .num_prompt_tokens = []
555- self .num_generation_tokens = []
556- self .last_local_log = stats .now
557-
558- if stats .spec_decode_metrics is not None :
567+ if self .spec_decode_metrics is not None :
559568 self ._log_gauge (
560569 self .metrics .gauge_spec_decode_draft_acceptance_rate ,
561- stats .spec_decode_metrics .draft_acceptance_rate )
570+ self .spec_decode_metrics .draft_acceptance_rate )
562571 self ._log_gauge (self .metrics .gauge_spec_decode_efficiency ,
563- stats .spec_decode_metrics .system_efficiency )
572+ self .spec_decode_metrics .system_efficiency )
564573 self ._log_counter (
565574 self .metrics .counter_spec_decode_num_accepted_tokens ,
566- stats .spec_decode_metrics .accepted_tokens )
575+ self .spec_decode_metrics .accepted_tokens )
567576 self ._log_counter (
568577 self .metrics .counter_spec_decode_num_draft_tokens ,
569- stats .spec_decode_metrics .draft_tokens )
578+ self .spec_decode_metrics .draft_tokens )
570579 self ._log_counter (
571580 self .metrics .counter_spec_decode_num_emitted_tokens ,
572- stats .spec_decode_metrics .emitted_tokens )
581+ self .spec_decode_metrics .emitted_tokens )
582+
583+ # Reset tracked stats for next interval.
584+ self .num_prompt_tokens = []
585+ self .num_generation_tokens = []
586+ self .last_local_log = stats .now
587+ self .spec_decode_metrics = None
573588
574589
575590class RayPrometheusStatLogger (PrometheusStatLogger ):
0 commit comments