Skip to content

Commit 2918c96

Browse files
committed
New generate
1 parent 4e59ef2 commit 2918c96

File tree

1 file changed

+127
-1
lines changed

1 file changed

+127
-1
lines changed

src/pipeline.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,14 +512,140 @@ def _load_pretrained(self, pretrained_model: str):
512512
from text_generation_server.models import get_model
513513

514514
pretrained_model, revision = parse_revision(pretrained_model)
515-
return TextGenModelWrapper(get_model(pretrained_model, revision, False, False))
515+
return get_model(pretrained_model, revision, False, False)
516516

517517
def _generate_hf(self, inputs: Dict, max_new_tokens: int, use_cache: bool):
518518
raise NotImplementedError()
519519

520520
def _allocate_mock_cache(self, past_key_length: int, batch_size: int):
521521
raise NotImplementedError()
522522

523+
def _generate_textgen(
524+
self,
525+
batch,
526+
max_new_tokens: int,
527+
use_cache: bool = True,
528+
do_prefill: bool = True,
529+
breakdown_latency: bool = False,
530+
key_length_step: int = 1,
531+
ignore_oom: bool = False,
532+
pad_generated_tokens: float = 0,
533+
):
534+
t0 = self._get_time(breakdown_latency)
535+
# TODO: Implement
536+
assert do_prefill
537+
assert key_length_step == 1
538+
assert pad_generated_tokens == 0
539+
540+
batch_size = len(batch)
541+
542+
input_length = max(batch.input_lengths)
543+
output_length = input_length + max_new_tokens
544+
545+
t1 = self._get_time(breakdown_latency)
546+
last_time = t1
547+
generate_times = {}
548+
with torch.inference_mode():
549+
for key_length in range(input_length, output_length, key_length_step):
550+
try:
551+
generated, batch = self.model.generate_token(batch)
552+
t2 = self._get_time(breakdown_latency)
553+
generate_times[key_length] = t2 - last_time
554+
last_time = t2
555+
except torch.cuda.OutOfMemoryError:
556+
if ignore_oom:
557+
logger.warning(f"Out of memory at key length {None}")
558+
break
559+
else:
560+
raise
561+
output_text = [g.text for g in generated]
562+
563+
metrics = {}
564+
if breakdown_latency:
565+
metrics[Metrics.LATENCY_GENERATE_START] = t1 - t0
566+
metrics[Metrics.LATENCY_GENERATE_BREAKDOWN] = generate_times
567+
568+
return output_text, metrics
569+
570+
def __call__(
571+
self,
572+
text: List[str],
573+
max_new_tokens: int,
574+
custom_generate: bool = False,
575+
use_cache: bool = True,
576+
do_prefill: bool = True,
577+
breakdown_latency=False,
578+
key_length_step: int = 1,
579+
ignore_oom: bool = False,
580+
pad_generated_tokens: float = 0,
581+
) -> Tuple[List[str], Dict[str, Any]]:
582+
t0 = self._get_time()
583+
inputs = self.tokenizer(text, return_tensors="pt", padding=True)
584+
585+
from text_generation_server.pb import generate_pb2
586+
from text_generation_server.models.model import Model
587+
588+
model: Model = self.model
589+
590+
batch_pb = generate_pb2.Batch(
591+
id=0,
592+
requests=[
593+
generate_pb2.Request(
594+
id=i,
595+
inputs=input_,
596+
truncate=99999,
597+
parameters=generate_pb2.NextTokenChooserParameters(
598+
temperature=1.0,
599+
top_k=1,
600+
top_p=1,
601+
typical_p=1,
602+
do_sample=False,
603+
seed=0,
604+
repetition_penalty=1.0,
605+
watermark=False,
606+
),
607+
stopping_parameters=generate_pb2.StoppingCriteriaParameters(
608+
max_new_tokens=max_new_tokens,
609+
stop_sequences=None,
610+
ignore_eos_token=True,
611+
),
612+
)
613+
for i, input_ in enumerate(inputs)
614+
],
615+
size=len(inputs),
616+
max_tokens=0, # Ignored
617+
)
618+
batch = model.batch_type.from_pb(batch_pb, self.tokenizer, self.device)
619+
batch_size = len(batch)
620+
621+
# TODO: Implement
622+
input_length = max(batch.input_lengths)
623+
output_length = input_length + max_new_tokens
624+
625+
output_text, generate_metrics = self._generate_textgen(
626+
batch,
627+
max_new_tokens,
628+
use_cache,
629+
do_prefill,
630+
breakdown_latency,
631+
key_length_step,
632+
ignore_oom,
633+
pad_generated_tokens,
634+
)
635+
t1 = self._get_time(True)
636+
637+
metrics = {
638+
**generate_metrics,
639+
Metrics.BATCH_SIZE: batch_size,
640+
Metrics.INPUT_LENGTH: input_length,
641+
Metrics.OUTPUT_LENGTH: output_length,
642+
Metrics.TOKENS_SAMPLE: output_length - input_length,
643+
Metrics.TOKENS_BATCH: batch_size * (output_length - input_length),
644+
Metrics.LATENCY_E2E: t1 - t0,
645+
}
646+
647+
return output_text, metrics
648+
523649

524650
_PIPELINE_CLASS_MAP = {
525651
"HF_Pipeline": HF_Pipeline,

0 commit comments

Comments
 (0)