@@ -512,14 +512,140 @@ def _load_pretrained(self, pretrained_model: str):
512
512
from text_generation_server .models import get_model
513
513
514
514
pretrained_model , revision = parse_revision (pretrained_model )
515
- return TextGenModelWrapper ( get_model (pretrained_model , revision , False , False ) )
515
+ return get_model (pretrained_model , revision , False , False )
516
516
517
517
def _generate_hf (self , inputs : Dict , max_new_tokens : int , use_cache : bool ):
518
518
raise NotImplementedError ()
519
519
520
520
def _allocate_mock_cache (self , past_key_length : int , batch_size : int ):
521
521
raise NotImplementedError ()
522
522
523
+ def _generate_textgen (
524
+ self ,
525
+ batch ,
526
+ max_new_tokens : int ,
527
+ use_cache : bool = True ,
528
+ do_prefill : bool = True ,
529
+ breakdown_latency : bool = False ,
530
+ key_length_step : int = 1 ,
531
+ ignore_oom : bool = False ,
532
+ pad_generated_tokens : float = 0 ,
533
+ ):
534
+ t0 = self ._get_time (breakdown_latency )
535
+ # TODO: Implement
536
+ assert do_prefill
537
+ assert key_length_step == 1
538
+ assert pad_generated_tokens == 0
539
+
540
+ batch_size = len (batch )
541
+
542
+ input_length = max (batch .input_lengths )
543
+ output_length = input_length + max_new_tokens
544
+
545
+ t1 = self ._get_time (breakdown_latency )
546
+ last_time = t1
547
+ generate_times = {}
548
+ with torch .inference_mode ():
549
+ for key_length in range (input_length , output_length , key_length_step ):
550
+ try :
551
+ generated , batch = self .model .generate_token (batch )
552
+ t2 = self ._get_time (breakdown_latency )
553
+ generate_times [key_length ] = t2 - last_time
554
+ last_time = t2
555
+ except torch .cuda .OutOfMemoryError :
556
+ if ignore_oom :
557
+ logger .warning (f"Out of memory at key length { None } " )
558
+ break
559
+ else :
560
+ raise
561
+ output_text = [g .text for g in generated ]
562
+
563
+ metrics = {}
564
+ if breakdown_latency :
565
+ metrics [Metrics .LATENCY_GENERATE_START ] = t1 - t0
566
+ metrics [Metrics .LATENCY_GENERATE_BREAKDOWN ] = generate_times
567
+
568
+ return output_text , metrics
569
+
570
+ def __call__ (
571
+ self ,
572
+ text : List [str ],
573
+ max_new_tokens : int ,
574
+ custom_generate : bool = False ,
575
+ use_cache : bool = True ,
576
+ do_prefill : bool = True ,
577
+ breakdown_latency = False ,
578
+ key_length_step : int = 1 ,
579
+ ignore_oom : bool = False ,
580
+ pad_generated_tokens : float = 0 ,
581
+ ) -> Tuple [List [str ], Dict [str , Any ]]:
582
+ t0 = self ._get_time ()
583
+ inputs = self .tokenizer (text , return_tensors = "pt" , padding = True )
584
+
585
+ from text_generation_server .pb import generate_pb2
586
+ from text_generation_server .models .model import Model
587
+
588
+ model : Model = self .model
589
+
590
+ batch_pb = generate_pb2 .Batch (
591
+ id = 0 ,
592
+ requests = [
593
+ generate_pb2 .Request (
594
+ id = i ,
595
+ inputs = input_ ,
596
+ truncate = 99999 ,
597
+ parameters = generate_pb2 .NextTokenChooserParameters (
598
+ temperature = 1.0 ,
599
+ top_k = 1 ,
600
+ top_p = 1 ,
601
+ typical_p = 1 ,
602
+ do_sample = False ,
603
+ seed = 0 ,
604
+ repetition_penalty = 1.0 ,
605
+ watermark = False ,
606
+ ),
607
+ stopping_parameters = generate_pb2 .StoppingCriteriaParameters (
608
+ max_new_tokens = max_new_tokens ,
609
+ stop_sequences = None ,
610
+ ignore_eos_token = True ,
611
+ ),
612
+ )
613
+ for i , input_ in enumerate (inputs )
614
+ ],
615
+ size = len (inputs ),
616
+ max_tokens = 0 , # Ignored
617
+ )
618
+ batch = model .batch_type .from_pb (batch_pb , self .tokenizer , self .device )
619
+ batch_size = len (batch )
620
+
621
+ # TODO: Implement
622
+ input_length = max (batch .input_lengths )
623
+ output_length = input_length + max_new_tokens
624
+
625
+ output_text , generate_metrics = self ._generate_textgen (
626
+ batch ,
627
+ max_new_tokens ,
628
+ use_cache ,
629
+ do_prefill ,
630
+ breakdown_latency ,
631
+ key_length_step ,
632
+ ignore_oom ,
633
+ pad_generated_tokens ,
634
+ )
635
+ t1 = self ._get_time (True )
636
+
637
+ metrics = {
638
+ ** generate_metrics ,
639
+ Metrics .BATCH_SIZE : batch_size ,
640
+ Metrics .INPUT_LENGTH : input_length ,
641
+ Metrics .OUTPUT_LENGTH : output_length ,
642
+ Metrics .TOKENS_SAMPLE : output_length - input_length ,
643
+ Metrics .TOKENS_BATCH : batch_size * (output_length - input_length ),
644
+ Metrics .LATENCY_E2E : t1 - t0 ,
645
+ }
646
+
647
+ return output_text , metrics
648
+
523
649
524
650
_PIPELINE_CLASS_MAP = {
525
651
"HF_Pipeline" : HF_Pipeline ,
0 commit comments