Skip to content

Commit a09a0e3

Browse files
committed
improvements
1 parent 2918c96 commit a09a0e3

File tree

2 files changed

+100
-16
lines changed

2 files changed

+100
-16
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2+
# Santacoder
3+
./scripts/run_textgen_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 1 2040 5 0 v2_
4+
./scripts/run_textgen_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 5 0 v2_
5+
./scripts/run_textgen_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 5 0 v2_
6+
7+
./scripts/run_textgen_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 1 2040 11 1 v2_
8+
./scripts/run_textgen_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 11 1 v2_
9+
./scripts/run_textgen_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 11 1 v2_
10+
11+
# Large model
12+
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/large-model 1 8190 11 0 v2_
13+
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/large-model 8 8190 11 0 v2_
14+
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/large-model 32 8190 11 0 v2_
15+
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/large-model 256 8190 11 0 v2_ # OOM?
16+
17+
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/large-model 1 8190 29 1 v2_ 1
18+
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/large-model 8 8190 29 1 v2_ 1
19+
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/large-model 32 8190 29 1 v2_ 1
20+
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/large-model 256 8190 29 1 v2_ 1 # OOM?

src/pipeline.py

Lines changed: 80 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def __init__(
7575
else:
7676
self.model = self._load_pretrained(pretrained_model)
7777

78-
self.model.eval()
7978
t3 = self._get_time()
8079
self.global_metrics[Metrics.INIT_TOKEN] = t1 - t0
8180
self.global_metrics[Metrics.INIT_CONFIG] = t2 - t1
@@ -101,7 +100,7 @@ def _create_model(self) -> PreTrainedModel:
101100
self.global_metrics[Metrics.INIT_DEVICE] = t2 - t1
102101
self.global_metrics[Metrics.INIT_WEIGHTS] = t3 - t2
103102

104-
return model
103+
return model.eval()
105104

106105
def _reload_model(self):
107106
self._save_pretrained("tmp")
@@ -136,7 +135,7 @@ def _load_pretrained(self, pretrained_model: str) -> PreTrainedModel:
136135
model = model.to(self.device)
137136
t2 = self._get_time()
138137
self.global_metrics[Metrics.INIT_DEVICE] = t2 - t1
139-
return model
138+
return model.eval()
140139

141140
def _get_config(
142141
self,
@@ -386,8 +385,8 @@ def aggregate_metrics(self, metrics: List[Dict[str, Any]]):
386385
breakdown = all_metrics.pop(Metrics.LATENCY_GENERATE_BREAKDOWN, [])
387386

388387
mean_metrics = {key: np.mean(value).item() for key, value in all_metrics.items() if len(value) > 0}
389-
throughput = mean_metrics[Metrics.TOKENS_BATCH] / mean_metrics[Metrics.LATENCY_E2E]
390-
model_throughput = mean_metrics[Metrics.TOKENS_BATCH] / mean_metrics[Metrics.LATENCY_MODEL]
388+
throughput = mean_metrics.get(Metrics.TOKENS_BATCH, 0) / mean_metrics.get(Metrics.LATENCY_E2E, 1)
389+
model_throughput = mean_metrics.get(Metrics.TOKENS_BATCH, 0) / mean_metrics.get(Metrics.LATENCY_MODEL, 1)
391390

392391
if len(breakdown) > 0:
393392
mean_metrics[Metrics.LATENCY_GENERATE_BREAKDOWN] = {
@@ -487,10 +486,13 @@ def __call__(
487486
class TG_Pipeline(Pipeline):
488487
def __init__(self, **kwargs):
489488
super().__init__(**kwargs)
489+
# TODO: Ignoring dtype
490490

491491
if self.device != torch.device("cuda:0"):
492492
raise ValueError(f"Textgen does not support device {self.device}")
493493

494+
self.config = self.model.model.transformer.config
495+
494496
def _get_config(
495497
self,
496498
model_type: Optional[str],
@@ -512,14 +514,77 @@ def _load_pretrained(self, pretrained_model: str):
512514
from text_generation_server.models import get_model
513515

514516
pretrained_model, revision = parse_revision(pretrained_model)
515-
return get_model(pretrained_model, revision, False, False)
517+
518+
with fast_init(self.device) if self.fast_init else contextlib.nullcontext():
519+
return get_model(pretrained_model, revision, False, False)
516520

517521
def _generate_hf(self, inputs: Dict, max_new_tokens: int, use_cache: bool):
518522
raise NotImplementedError()
519523

520524
def _allocate_mock_cache(self, past_key_length: int, batch_size: int):
521525
raise NotImplementedError()
522526

527+
def get_num_parameters(self) -> int:
528+
return 0
529+
530+
def _update_generate_batch(self, batch, use_cache, do_prefill, key_length):
531+
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
532+
533+
assert do_prefill or use_cache
534+
535+
if isinstance(batch, FlashCausalLMBatch):
536+
# Tested for flash santacoder only
537+
assert max(batch.input_lengths) == batch.max_seqlen
538+
seqlen_diff = key_length - batch.max_seqlen
539+
assert seqlen_diff >= 0
540+
if batch.past_key_values is None:
541+
mock_cache = use_cache and not do_prefill
542+
else:
543+
if not use_cache:
544+
batch.past_key_values = None
545+
mock_cache = use_cache and seqlen_diff > 0
546+
if mock_cache:
547+
batch.past_key_values = []
548+
549+
for i, old_length in enumerate(batch.input_lengths):
550+
length = old_length + seqlen_diff
551+
batch.input_lengths[i] = length
552+
batch.max_seqlen = max(batch.max_seqlen, length)
553+
add_tokens = [self.tokenizer.pad_token_id] * seqlen_diff
554+
batch.all_input_ids[i].extend(add_tokens)
555+
batch.all_input_ids_tensor[i][old_length:length] = torch.tensor(add_tokens)
556+
batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + length
557+
558+
if use_cache and batch.past_key_values is not None:
559+
# Decode
560+
batch.input_ids[i] = batch.all_input_ids_tensor[i][length - 1 : length]
561+
batch.position_ids[i] = length - 1
562+
if mock_cache:
563+
batch.stopping_criterias[i].current_tokens = max(batch.stopping_criterias[i].current_tokens, 1)
564+
batch.past_key_values.append(
565+
torch.randn(
566+
[self.config.n_layer, length, 2, 1, self.config.n_embd // self.config.n_head],
567+
dtype=self.model.dtype,
568+
device=self.device,
569+
)
570+
)
571+
batch.past_key_values.append(
572+
torch.zeros(
573+
[self.config.n_layer, 1, 2, 1, self.config.n_embd // self.config.n_head],
574+
dtype=self.model.dtype,
575+
device=self.device,
576+
)
577+
)
578+
else:
579+
# Prefill
580+
batch.input_ids[i] = batch.all_input_ids_tensor[i][:length]
581+
batch.position_ids[i] = torch.arange(0, length, dtype=torch.int32, device=self.device)
582+
583+
assert batch.max_seqlen == key_length
584+
585+
else:
586+
raise NotImplementedError()
587+
523588
def _generate_textgen(
524589
self,
525590
batch,
@@ -532,13 +597,10 @@ def _generate_textgen(
532597
pad_generated_tokens: float = 0,
533598
):
534599
t0 = self._get_time(breakdown_latency)
535-
# TODO: Implement
536-
assert do_prefill
537-
assert key_length_step == 1
600+
assert do_prefill or use_cache
601+
# TODO: Implement?
538602
assert pad_generated_tokens == 0
539603

540-
batch_size = len(batch)
541-
542604
input_length = max(batch.input_lengths)
543605
output_length = input_length + max_new_tokens
544606

@@ -548,6 +610,9 @@ def _generate_textgen(
548610
with torch.inference_mode():
549611
for key_length in range(input_length, output_length, key_length_step):
550612
try:
613+
if key_length_step > 1 or not use_cache or not do_prefill:
614+
self._update_generate_batch(batch, use_cache, do_prefill, key_length)
615+
last_time = self._get_time(breakdown_latency)
551616
generated, batch = self.model.generate_token(batch)
552617
t2 = self._get_time(breakdown_latency)
553618
generate_times[key_length] = t2 - last_time
@@ -558,7 +623,7 @@ def _generate_textgen(
558623
break
559624
else:
560625
raise
561-
output_text = [g.text for g in generated]
626+
output_text = ["" if g.generated_text is None else g.generated_text.text for g in generated]
562627

563628
metrics = {}
564629
if breakdown_latency:
@@ -580,7 +645,6 @@ def __call__(
580645
pad_generated_tokens: float = 0,
581646
) -> Tuple[List[str], Dict[str, Any]]:
582647
t0 = self._get_time()
583-
inputs = self.tokenizer(text, return_tensors="pt", padding=True)
584648

585649
from text_generation_server.pb import generate_pb2
586650
from text_generation_server.models.model import Model
@@ -592,7 +656,7 @@ def __call__(
592656
requests=[
593657
generate_pb2.Request(
594658
id=i,
595-
inputs=input_,
659+
inputs=t,
596660
truncate=99999,
597661
parameters=generate_pb2.NextTokenChooserParameters(
598662
temperature=1.0,
@@ -610,9 +674,9 @@ def __call__(
610674
ignore_eos_token=True,
611675
),
612676
)
613-
for i, input_ in enumerate(inputs)
677+
for i, t in enumerate(text)
614678
],
615-
size=len(inputs),
679+
size=len(text),
616680
max_tokens=0, # Ignored
617681
)
618682
batch = model.batch_type.from_pb(batch_pb, self.tokenizer, self.device)

0 commit comments

Comments
 (0)