Skip to content

Commit

Permalink
[FIX] Fix styles in automatic prefix caching & add a automatic prefix…
Browse files Browse the repository at this point in the history
… caching benchmark (vllm-project#3158)
  • Loading branch information
zhuohan123 authored Mar 3, 2024
1 parent ad2542d commit 70d09b0
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 18 deletions.
59 changes: 59 additions & 0 deletions benchmarks/benchmark_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import argparse
import time

from vllm import LLM
from vllm import SamplingParams

PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n"


def test_prefix(llm=None, sampling_params=None, prompts=None, prefix_len=None):
start_time = time.time()
# whether use Prefix
if prefix_len != None:
# start inference
llm.generate(prompts,
sampling_params=sampling_params,
prefix_pos=prefix_len)
else:
llm.generate(prompts, sampling_params=sampling_params)

end_time = time.time()
print(f"cost time {end_time - start_time}")


def main(args):
llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat",
tokenizer_mode='auto',
trust_remote_code=True,
enforce_eager=True,
enable_prefix_caching=args.enable_prefix_caching)

num_prompts = 100
prompts = [PROMPT] * num_prompts
sampling_params = SamplingParams(temperature=0, max_tokens=100)

print("------warm up------")
test_prefix(
llm=llm,
prompts=prompts[:1],
sampling_params=sampling_params,
)

print("------start generating------")
test_prefix(
llm=llm,
prompts=prompts,
sampling_params=sampling_params,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Benchmark the performance with or without automatic '
'prefix caching.')
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
args = parser.parse_args()
main(args)
5 changes: 4 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,10 @@ def main(args: argparse.Namespace):
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument("--enable_prefix_caching", action='store_true')
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
help="enable automatic prefix caching for vLLM backend.")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
15 changes: 4 additions & 11 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,6 @@ def _is_last_block_full(
token_ids_len = len(seq.data.get_token_ids())
return token_ids_len > 0 and token_ids_len % seq.block_size == 0

def _is_last_block(
self,
seq: Sequence,
index: int,
) -> bool:
return index == len(seq.logical_token_blocks) - 1

def _maybe_promote_last_block(
self,
seq: Sequence,
Expand Down Expand Up @@ -436,7 +429,7 @@ def access_all_blocks_in_seq(
def compute_last_full_block_in_seq(self, seq: Sequence):
if seq.seq_id not in self.block_tables:
return
max_full_block = seq.get_len() // seq.block_size - 1
max_full_block = seq.get_len() // self.block_size - 1
block_table = self.block_tables[seq.seq_id]
if max_full_block == -1:
return
Expand All @@ -451,9 +444,9 @@ def get_all_block_ids_till_computed(self, seq: Sequence) -> List[int]:
return [b.block_number for b in block_table[:block_idx + 1]]
return []

# Can return non-empty result only with prefix caching enabled.
def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
# Can return non-empty result only with prefix caching enabled.
if not self.enable_caching:
return []

Expand All @@ -463,9 +456,9 @@ def get_common_computed_block_ids(self,
]
return commonprefix([ids for ids in ids_list if ids != []])

# We only mark the last full block because with prefix caching,
# all blocks until the marked one are guaranteed to be computed.
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# NOTE: We only mark the last full block because with prefix caching,
# all blocks until the marked one are guaranteed to be computed.
if self.enable_caching:
for seq in seq_group.seqs_dict.values():
self.compute_last_full_block_in_seq(seq)
8 changes: 2 additions & 6 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ def __init__(
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0

# TODO The current hashing function is O(L^2). We should optimize this in
# the future.
def hash_of_block(self, logical_idx: int) -> int:
# Compute the number of tokens in the sequence
# TODO: The current hashing function is O(L^2). We should optimize
# this in the future.
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
return hash(tuple(self.data.get_token_ids()[0:num_tokens]))

Expand Down Expand Up @@ -308,10 +308,6 @@ def prompt_token_ids(self) -> List[int]:
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).data.prompt_token_ids

@property
def block_size(self) -> int:
return next(iter(self.seqs_dict.values())).block_size

@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
Expand Down

0 comments on commit 70d09b0

Please sign in to comment.