Skip to content

Commit 1b00dcc

Browse files
committed
fix
1 parent fd75d7c commit 1b00dcc

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

scripts/run_textgen_benchmark_breakdown.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ SAVE_DIR=data/benchmarks/v3
1717
RUN="python3 -m src.main --pipeline_class=TG_Pipeline --max_log_outputs=0 --dtype=float16 --device=cuda --custom_generate --breakdown_latency --ignore_oom --no_fast_init "
1818

1919

20-
IMPL=("flash" "causal" "vector" "bigcode")
20+
IMPL=("flash" "causal" "vector" "bigcode" "bigcode2")
2121

2222

2323
STEP=("" "--no_cache")
@@ -38,7 +38,7 @@ run () { # run(step, runtime, attn)
3838
fi
3939
}
4040

41-
for impl in {0..3}
41+
for impl in {0..4}
4242
do
4343
if [ "${STEP_ID}" -eq "0" ]
4444
then

src/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,10 +619,10 @@ def _generate_textgen(
619619
with torch.inference_mode():
620620
for key_length in range(input_length, output_length, key_length_step):
621621
try:
622-
if (key_length_step > 1 and key_length > key_length) or not use_cache or not do_prefill:
622+
if (key_length_step > 1 and key_length > input_length) or not use_cache or not do_prefill:
623623
if not hasattr(self.model, "fast_forward"):
624624
raise NotImplementedError()
625-
self.model.fast_forward(batch, key_length, use_cache)
625+
self.model.fast_forward(batch, key_length, self.dtype if use_cache else None)
626626
last_time = self._get_time(breakdown_latency)
627627
generated, batch = self.model.generate_token(batch)
628628
t2 = self._get_time(breakdown_latency)

0 commit comments

Comments
 (0)