|
43 | 43 | ForwardMode, |
44 | 44 | Req, |
45 | 45 | ) |
| 46 | +from sglang.srt.mem_cache.chunk_cache import ChunkCache |
46 | 47 | from sglang.srt.mem_cache.radix_cache import RadixCache |
47 | 48 | from sglang.srt.model_config import ModelConfig |
48 | 49 | from sglang.srt.model_executor.model_runner import ModelRunner |
@@ -144,11 +145,20 @@ def __init__( |
144 | 145 | ) |
145 | 146 |
|
146 | 147 | # Init cache |
147 | | - self.tree_cache = RadixCache( |
148 | | - req_to_token_pool=self.model_runner.req_to_token_pool, |
149 | | - token_to_kv_pool=self.model_runner.token_to_kv_pool, |
150 | | - disable=server_args.disable_radix_cache, |
151 | | - ) |
| 148 | + if ( |
| 149 | + server_args.chunked_prefill_size is not None |
| 150 | + and server_args.disable_radix_cache |
| 151 | + ): |
| 152 | + self.tree_cache = ChunkCache( |
| 153 | + req_to_token_pool=self.model_runner.req_to_token_pool, |
| 154 | + token_to_kv_pool=self.model_runner.token_to_kv_pool, |
| 155 | + ) |
| 156 | + else: |
| 157 | + self.tree_cache = RadixCache( |
| 158 | + req_to_token_pool=self.model_runner.req_to_token_pool, |
| 159 | + token_to_kv_pool=self.model_runner.token_to_kv_pool, |
| 160 | + disable=server_args.disable_radix_cache, |
| 161 | + ) |
152 | 162 | self.tree_cache_metrics = {"total": 0, "hit": 0} |
153 | 163 | self.scheduler = PolicyScheduler( |
154 | 164 | self.schedule_policy, |
@@ -354,7 +364,10 @@ def get_new_prefill_batch(self) -> Optional[Batch]: |
354 | 364 | # Compute matched prefix length |
355 | 365 | for req in self.waiting_queue: |
356 | 366 | req.input_ids = req.origin_input_ids + req.output_ids |
357 | | - prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids) |
| 367 | + prefix_indices, last_node = self.tree_cache.match_prefix( |
| 368 | + rid=req.rid, |
| 369 | + key=req.input_ids, |
| 370 | + ) |
358 | 371 | if req.return_logprob: |
359 | 372 | prefix_indices = prefix_indices[: req.logprob_start_len] |
360 | 373 | req.extend_input_len = len(req.input_ids) - len(prefix_indices) |
@@ -614,6 +627,7 @@ def cache_filled_batch(self, batch: Batch): |
614 | 627 | req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() |
615 | 628 | for i, req in enumerate(batch.reqs): |
616 | 629 | new_prefix_indices, new_last_node = self.tree_cache.cache_req( |
| 630 | + rid=req.rid, |
617 | 631 | token_ids=tuple(req.input_ids), |
618 | 632 | last_uncached_pos=len(req.prefix_indices), |
619 | 633 | req_pool_idx=req_pool_indices_cpu[i], |
@@ -771,6 +785,7 @@ def handle_finished_requests(self, batch: Batch): |
771 | 785 | for i in finished_indices: |
772 | 786 | req = batch.reqs[i] |
773 | 787 | self.tree_cache.cache_req( |
| 788 | + rid=req.rid, |
774 | 789 | token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1], |
775 | 790 | last_uncached_pos=len(req.prefix_indices), |
776 | 791 | req_pool_idx=req_pool_indices_cpu[i], |
|
0 commit comments