@@ -75,7 +75,6 @@ def __init__(
75
75
else :
76
76
self .model = self ._load_pretrained (pretrained_model )
77
77
78
- self .model .eval ()
79
78
t3 = self ._get_time ()
80
79
self .global_metrics [Metrics .INIT_TOKEN ] = t1 - t0
81
80
self .global_metrics [Metrics .INIT_CONFIG ] = t2 - t1
@@ -101,7 +100,7 @@ def _create_model(self) -> PreTrainedModel:
101
100
self .global_metrics [Metrics .INIT_DEVICE ] = t2 - t1
102
101
self .global_metrics [Metrics .INIT_WEIGHTS ] = t3 - t2
103
102
104
- return model
103
+ return model . eval ()
105
104
106
105
def _reload_model (self ):
107
106
self ._save_pretrained ("tmp" )
@@ -136,7 +135,7 @@ def _load_pretrained(self, pretrained_model: str) -> PreTrainedModel:
136
135
model = model .to (self .device )
137
136
t2 = self ._get_time ()
138
137
self .global_metrics [Metrics .INIT_DEVICE ] = t2 - t1
139
- return model
138
+ return model . eval ()
140
139
141
140
def _get_config (
142
141
self ,
@@ -386,8 +385,8 @@ def aggregate_metrics(self, metrics: List[Dict[str, Any]]):
386
385
breakdown = all_metrics .pop (Metrics .LATENCY_GENERATE_BREAKDOWN , [])
387
386
388
387
mean_metrics = {key : np .mean (value ).item () for key , value in all_metrics .items () if len (value ) > 0 }
389
- throughput = mean_metrics [ Metrics .TOKENS_BATCH ] / mean_metrics [ Metrics .LATENCY_E2E ]
390
- model_throughput = mean_metrics [ Metrics .TOKENS_BATCH ] / mean_metrics [ Metrics .LATENCY_MODEL ]
388
+ throughput = mean_metrics . get ( Metrics .TOKENS_BATCH , 0 ) / mean_metrics . get ( Metrics .LATENCY_E2E , 1 )
389
+ model_throughput = mean_metrics . get ( Metrics .TOKENS_BATCH , 0 ) / mean_metrics . get ( Metrics .LATENCY_MODEL , 1 )
391
390
392
391
if len (breakdown ) > 0 :
393
392
mean_metrics [Metrics .LATENCY_GENERATE_BREAKDOWN ] = {
@@ -487,10 +486,13 @@ def __call__(
487
486
class TG_Pipeline (Pipeline ):
488
487
def __init__ (self , ** kwargs ):
489
488
super ().__init__ (** kwargs )
489
+ # TODO: Ignoring dtype
490
490
491
491
if self .device != torch .device ("cuda:0" ):
492
492
raise ValueError (f"Textgen does not support device { self .device } " )
493
493
494
+ self .config = self .model .model .transformer .config
495
+
494
496
def _get_config (
495
497
self ,
496
498
model_type : Optional [str ],
@@ -512,14 +514,77 @@ def _load_pretrained(self, pretrained_model: str):
512
514
from text_generation_server .models import get_model
513
515
514
516
pretrained_model , revision = parse_revision (pretrained_model )
515
- return get_model (pretrained_model , revision , False , False )
517
+
518
+ with fast_init (self .device ) if self .fast_init else contextlib .nullcontext ():
519
+ return get_model (pretrained_model , revision , False , False )
516
520
517
521
def _generate_hf (self , inputs : Dict , max_new_tokens : int , use_cache : bool ):
518
522
raise NotImplementedError ()
519
523
520
524
def _allocate_mock_cache (self , past_key_length : int , batch_size : int ):
521
525
raise NotImplementedError ()
522
526
527
+ def get_num_parameters (self ) -> int :
528
+ return 0
529
+
530
+ def _update_generate_batch (self , batch , use_cache , do_prefill , key_length ):
531
+ from text_generation_server .models .flash_causal_lm import FlashCausalLMBatch
532
+
533
+ assert do_prefill or use_cache
534
+
535
+ if isinstance (batch , FlashCausalLMBatch ):
536
+ # Tested for flash santacoder only
537
+ assert max (batch .input_lengths ) == batch .max_seqlen
538
+ seqlen_diff = key_length - batch .max_seqlen
539
+ assert seqlen_diff >= 0
540
+ if batch .past_key_values is None :
541
+ mock_cache = use_cache and not do_prefill
542
+ else :
543
+ if not use_cache :
544
+ batch .past_key_values = None
545
+ mock_cache = use_cache and seqlen_diff > 0
546
+ if mock_cache :
547
+ batch .past_key_values = []
548
+
549
+ for i , old_length in enumerate (batch .input_lengths ):
550
+ length = old_length + seqlen_diff
551
+ batch .input_lengths [i ] = length
552
+ batch .max_seqlen = max (batch .max_seqlen , length )
553
+ add_tokens = [self .tokenizer .pad_token_id ] * seqlen_diff
554
+ batch .all_input_ids [i ].extend (add_tokens )
555
+ batch .all_input_ids_tensor [i ][old_length :length ] = torch .tensor (add_tokens )
556
+ batch .cu_seqlens [(i + 1 )] = batch .cu_seqlens [i ] + length
557
+
558
+ if use_cache and batch .past_key_values is not None :
559
+ # Decode
560
+ batch .input_ids [i ] = batch .all_input_ids_tensor [i ][length - 1 : length ]
561
+ batch .position_ids [i ] = length - 1
562
+ if mock_cache :
563
+ batch .stopping_criterias [i ].current_tokens = max (batch .stopping_criterias [i ].current_tokens , 1 )
564
+ batch .past_key_values .append (
565
+ torch .randn (
566
+ [self .config .n_layer , length , 2 , 1 , self .config .n_embd // self .config .n_head ],
567
+ dtype = self .model .dtype ,
568
+ device = self .device ,
569
+ )
570
+ )
571
+ batch .past_key_values .append (
572
+ torch .zeros (
573
+ [self .config .n_layer , 1 , 2 , 1 , self .config .n_embd // self .config .n_head ],
574
+ dtype = self .model .dtype ,
575
+ device = self .device ,
576
+ )
577
+ )
578
+ else :
579
+ # Prefill
580
+ batch .input_ids [i ] = batch .all_input_ids_tensor [i ][:length ]
581
+ batch .position_ids [i ] = torch .arange (0 , length , dtype = torch .int32 , device = self .device )
582
+
583
+ assert batch .max_seqlen == key_length
584
+
585
+ else :
586
+ raise NotImplementedError ()
587
+
523
588
def _generate_textgen (
524
589
self ,
525
590
batch ,
@@ -532,13 +597,10 @@ def _generate_textgen(
532
597
pad_generated_tokens : float = 0 ,
533
598
):
534
599
t0 = self ._get_time (breakdown_latency )
535
- # TODO: Implement
536
- assert do_prefill
537
- assert key_length_step == 1
600
+ assert do_prefill or use_cache
601
+ # TODO: Implement?
538
602
assert pad_generated_tokens == 0
539
603
540
- batch_size = len (batch )
541
-
542
604
input_length = max (batch .input_lengths )
543
605
output_length = input_length + max_new_tokens
544
606
@@ -548,6 +610,9 @@ def _generate_textgen(
548
610
with torch .inference_mode ():
549
611
for key_length in range (input_length , output_length , key_length_step ):
550
612
try :
613
+ if key_length_step > 1 or not use_cache or not do_prefill :
614
+ self ._update_generate_batch (batch , use_cache , do_prefill , key_length )
615
+ last_time = self ._get_time (breakdown_latency )
551
616
generated , batch = self .model .generate_token (batch )
552
617
t2 = self ._get_time (breakdown_latency )
553
618
generate_times [key_length ] = t2 - last_time
@@ -558,7 +623,7 @@ def _generate_textgen(
558
623
break
559
624
else :
560
625
raise
561
- output_text = [g .text for g in generated ]
626
+ output_text = ["" if g . generated_text is None else g . generated_text .text for g in generated ]
562
627
563
628
metrics = {}
564
629
if breakdown_latency :
@@ -580,7 +645,6 @@ def __call__(
580
645
pad_generated_tokens : float = 0 ,
581
646
) -> Tuple [List [str ], Dict [str , Any ]]:
582
647
t0 = self ._get_time ()
583
- inputs = self .tokenizer (text , return_tensors = "pt" , padding = True )
584
648
585
649
from text_generation_server .pb import generate_pb2
586
650
from text_generation_server .models .model import Model
@@ -592,7 +656,7 @@ def __call__(
592
656
requests = [
593
657
generate_pb2 .Request (
594
658
id = i ,
595
- inputs = input_ ,
659
+ inputs = t ,
596
660
truncate = 99999 ,
597
661
parameters = generate_pb2 .NextTokenChooserParameters (
598
662
temperature = 1.0 ,
@@ -610,9 +674,9 @@ def __call__(
610
674
ignore_eos_token = True ,
611
675
),
612
676
)
613
- for i , input_ in enumerate (inputs )
677
+ for i , t in enumerate (text )
614
678
],
615
- size = len (inputs ),
679
+ size = len (text ),
616
680
max_tokens = 0 , # Ignored
617
681
)
618
682
batch = model .batch_type .from_pb (batch_pb , self .tokenizer , self .device )
0 commit comments