Skip to content

Commit 9c90421

Browse files
committed
stuff
1 parent a09a0e3 commit 9c90421

File tree

4 files changed

+30
-22
lines changed

4 files changed

+30
-22
lines changed

scripts/run_all_benchmark_breakdown.sh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 11 1 v2_
1010

1111
# Large model
12-
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 1 8190 11 0 v2_
13-
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 8 8190 11 0 v2_
14-
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 32 8190 11 0 v2_
15-
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 256 8190 11 0 v2_ # OOM?
12+
./scripts/run_benchmark_breakdown.sh large_model ./data/bigcode_large-model 1 8190 11 0 v2_
13+
./scripts/run_benchmark_breakdown.sh large_model ./data/bigcode_large-model 8 8190 11 0 v2_
14+
./scripts/run_benchmark_breakdown.sh large_model ./data/bigcode_large-model 32 8190 11 0 v2_
15+
./scripts/run_benchmark_breakdown.sh large_model ./data/bigcode_large-model 256 8190 11 0 v2_ # OOM?
1616

17-
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 1 8190 29 1 v2_ 1
18-
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 8 8190 29 1 v2_ 1
19-
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 32 8190 29 1 v2_ 1
20-
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 256 8190 29 1 v2_ 1 # OOM?
17+
./scripts/run_benchmark_breakdown.sh large_model ./data/bigcode_large-model 1 8190 29 1 v2_ 1
18+
./scripts/run_benchmark_breakdown.sh large_model ./data/bigcode_large-model 8 8190 29 1 v2_ 1
19+
./scripts/run_benchmark_breakdown.sh large_model ./data/bigcode_large-model 32 8190 29 1 v2_ 1
20+
./scripts/run_benchmark_breakdown.sh large_model ./data/bigcode_large-model 256 8190 29 1 v2_ 1 # OOM?

scripts/run_all_textgen_benchmark_breakdown.sh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
./scripts/run_textgen_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 11 1 v2_
1010

1111
# Large model
12-
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/large-model 1 8190 11 0 v2_
13-
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/large-model 8 8190 11 0 v2_
14-
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/large-model 32 8190 11 0 v2_
15-
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/large-model 256 8190 11 0 v2_ # OOM?
12+
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/bigcode_large-model 1 8190 11 0 v2_
13+
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/bigcode_large-model 8 8190 11 0 v2_
14+
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/bigcode_large-model 32 8190 11 0 v2_
15+
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/bigcode_large-model 256 8190 11 0 v2_ # OOM?
1616

17-
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/large-model 1 8190 29 1 v2_ 1
18-
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/large-model 8 8190 29 1 v2_ 1
19-
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/large-model 32 8190 29 1 v2_ 1
20-
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/large-model 256 8190 29 1 v2_ 1 # OOM?
17+
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/bigcode_large-model 1 8190 29 1 v2_ 1
18+
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/bigcode_large-model 8 8190 29 1 v2_ 1
19+
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/bigcode_large-model 32 8190 29 1 v2_ 1
20+
./scripts/run_textgen_benchmark_breakdown.sh large_model ./data/bigcode_large-model 256 8190 29 1 v2_ 1 # OOM?

scripts/run_textgen_benchmark_breakdown.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ CYCLES=${8:-10}
1515

1616
SAVE_DIR=data/benchmarks/v2
1717
#BATCH_SIZES="1 2 4 8 16 24 32 48 64 96 128 160 224 256"
18-
RUN="python3 src/main.py --max_log_outputs=0 --dtype=float16 --device=cuda --custom_generate --breakdown_latency --ignore_oom"
18+
RUN="python3 src/main.py --max_log_outputs=0 --dtype=float16 --device=cuda --custom_generate --breakdown_latency --ignore_oom --no_fast_init"
1919

2020

2121
RUNTIME=("")

src/pipeline.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -534,17 +534,26 @@ def _update_generate_batch(self, batch, use_cache, do_prefill, key_length):
534534

535535
if isinstance(batch, FlashCausalLMBatch):
536536
# Tested for flash santacoder only
537+
# TODO: Fix batch size 1
537538
assert max(batch.input_lengths) == batch.max_seqlen
538539
seqlen_diff = key_length - batch.max_seqlen
539540
assert seqlen_diff >= 0
541+
kv_shape = [2, 1, self.config.n_embd // self.config.n_head]
540542
if batch.past_key_values is None:
541543
mock_cache = use_cache and not do_prefill
542544
else:
543545
if not use_cache:
544546
batch.past_key_values = None
545547
mock_cache = use_cache and seqlen_diff > 0
546548
if mock_cache:
547-
batch.past_key_values = []
549+
if len(batch.input_lengths) > 1:
550+
batch.past_key_values = []
551+
else:
552+
batch.past_key_values = torch.randn(
553+
[self.config.n_layer, batch.max_tokens, *kv_shape],
554+
dtype=self.model.dtype,
555+
device=self.device,
556+
)
548557

549558
for i, old_length in enumerate(batch.input_lengths):
550559
length = old_length + seqlen_diff
@@ -559,18 +568,18 @@ def _update_generate_batch(self, batch, use_cache, do_prefill, key_length):
559568
# Decode
560569
batch.input_ids[i] = batch.all_input_ids_tensor[i][length - 1 : length]
561570
batch.position_ids[i] = length - 1
562-
if mock_cache:
571+
if mock_cache and len(batch.input_lengths) > 1:
563572
batch.stopping_criterias[i].current_tokens = max(batch.stopping_criterias[i].current_tokens, 1)
564573
batch.past_key_values.append(
565574
torch.randn(
566-
[self.config.n_layer, length, 2, 1, self.config.n_embd // self.config.n_head],
575+
[self.config.n_layer, length, *kv_shape],
567576
dtype=self.model.dtype,
568577
device=self.device,
569578
)
570579
)
571580
batch.past_key_values.append(
572581
torch.zeros(
573-
[self.config.n_layer, 1, 2, 1, self.config.n_embd // self.config.n_head],
582+
[self.config.n_layer, 1, *kv_shape],
574583
dtype=self.model.dtype,
575584
device=self.device,
576585
)
@@ -660,7 +669,6 @@ def __call__(
660669
truncate=99999,
661670
parameters=generate_pb2.NextTokenChooserParameters(
662671
temperature=1.0,
663-
top_k=1,
664672
top_p=1,
665673
typical_p=1,
666674
do_sample=False,

0 commit comments

Comments
 (0)