Skip to content

Commit 353fbb7

Browse files
authored
Support chunked prefill when radix cache is disabled (sgl-project#811)
1 parent 84cf6c0 commit 353fbb7

File tree

9 files changed

+164
-27
lines changed

9 files changed

+164
-27
lines changed

python/sglang/srt/constrained/base_cache.py renamed to python/sglang/srt/constrained/base_tool_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
limitations under the License.
1414
"""
1515

16-
"""Base cache class."""
16+
"""Base tool cache for constrained decoding tools."""
1717

1818
import time
1919

2020

21-
class BaseCache:
21+
class BaseToolCache:
2222
def __init__(self, enable=True):
2323
self.enable = enable
2424
self.reset()

python/sglang/srt/constrained/fsm_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
"""Cache for the compressed finite state machine."""
1717

1818
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
19-
from sglang.srt.constrained.base_cache import BaseCache
19+
from sglang.srt.constrained.base_tool_cache import BaseToolCache
2020

2121

22-
class FSMCache(BaseCache):
22+
class FSMCache(BaseToolCache):
2323
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
2424
super().__init__(enable=enable)
2525

python/sglang/srt/constrained/jump_forward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
make_byte_level_fsm,
3131
make_deterministic_fsm,
3232
)
33-
from sglang.srt.constrained.base_cache import BaseCache
33+
from sglang.srt.constrained.base_tool_cache import BaseToolCache
3434

3535
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
3636

@@ -151,7 +151,7 @@ def is_jump_forward_symbol_state(self, state):
151151
)
152152

153153

154-
class JumpForwardCache(BaseCache):
154+
class JumpForwardCache(BaseToolCache):
155155
def __init__(self):
156156
super().__init__()
157157

python/sglang/srt/managers/schedule_batch.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from sglang.global_config import global_config
2929
from sglang.srt.constrained import RegexGuide
3030
from sglang.srt.constrained.jump_forward import JumpForwardMap
31+
from sglang.srt.mem_cache.chunk_cache import ChunkCache
3132
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
3233
from sglang.srt.mem_cache.radix_cache import RadixCache
3334

@@ -486,15 +487,33 @@ def retract_decode(self):
486487
req = self.reqs[idx]
487488
retracted_reqs.append(req)
488489

489-
# TODO: apply more fine-grained retraction
490-
last_uncached_pos = len(req.prefix_indices)
491-
token_indices = self.req_to_token_pool.req_to_token[
492-
req_pool_indices_cpu[idx]
493-
][last_uncached_pos : seq_lens_cpu[idx]]
494-
self.token_to_kv_pool.free(token_indices)
495-
496-
# release the last node
497-
self.tree_cache.dec_lock_ref(req.last_node)
490+
if isinstance(self.tree_cache, ChunkCache):
491+
# ChunkCache does not have eviction
492+
token_indices = self.req_to_token_pool.req_to_token[
493+
req_pool_indices_cpu[idx]
494+
][: seq_lens_cpu[idx]]
495+
self.token_to_kv_pool.free(token_indices)
496+
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
497+
del self.tree_cache.entries[req.rid]
498+
else:
499+
# TODO: apply more fine-grained retraction
500+
last_uncached_pos = len(req.prefix_indices)
501+
token_indices = self.req_to_token_pool.req_to_token[
502+
req_pool_indices_cpu[idx]
503+
][last_uncached_pos : seq_lens_cpu[idx]]
504+
self.token_to_kv_pool.free(token_indices)
505+
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
506+
507+
# release the last node
508+
self.tree_cache.dec_lock_ref(req.last_node)
509+
510+
# NOTE(lsyin): we should use the newly evictable memory instantly.
511+
residual_size = (
512+
len(sorted_indices) * global_config.retract_decode_steps
513+
- self.token_to_kv_pool.available_size()
514+
)
515+
residual_size = max(0, residual_size)
516+
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
498517

499518
req.prefix_indices = None
500519
req.last_node = None
@@ -575,6 +594,7 @@ def check_for_jump_forward(self, model_runner):
575594
if req_pool_indices_cpu is None:
576595
req_pool_indices_cpu = self.req_pool_indices.tolist()
577596
self.tree_cache.cache_req(
597+
rid=req.rid,
578598
token_ids=cur_all_ids,
579599
last_uncached_pos=len(req.prefix_indices),
580600
req_pool_idx=req_pool_indices_cpu[i],

python/sglang/srt/managers/tp_worker.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
ForwardMode,
4444
Req,
4545
)
46+
from sglang.srt.mem_cache.chunk_cache import ChunkCache
4647
from sglang.srt.mem_cache.radix_cache import RadixCache
4748
from sglang.srt.model_config import ModelConfig
4849
from sglang.srt.model_executor.model_runner import ModelRunner
@@ -144,11 +145,20 @@ def __init__(
144145
)
145146

146147
# Init cache
147-
self.tree_cache = RadixCache(
148-
req_to_token_pool=self.model_runner.req_to_token_pool,
149-
token_to_kv_pool=self.model_runner.token_to_kv_pool,
150-
disable=server_args.disable_radix_cache,
151-
)
148+
if (
149+
server_args.chunked_prefill_size is not None
150+
and server_args.disable_radix_cache
151+
):
152+
self.tree_cache = ChunkCache(
153+
req_to_token_pool=self.model_runner.req_to_token_pool,
154+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
155+
)
156+
else:
157+
self.tree_cache = RadixCache(
158+
req_to_token_pool=self.model_runner.req_to_token_pool,
159+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
160+
disable=server_args.disable_radix_cache,
161+
)
152162
self.tree_cache_metrics = {"total": 0, "hit": 0}
153163
self.scheduler = PolicyScheduler(
154164
self.schedule_policy,
@@ -354,7 +364,10 @@ def get_new_prefill_batch(self) -> Optional[Batch]:
354364
# Compute matched prefix length
355365
for req in self.waiting_queue:
356366
req.input_ids = req.origin_input_ids + req.output_ids
357-
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
367+
prefix_indices, last_node = self.tree_cache.match_prefix(
368+
rid=req.rid,
369+
key=req.input_ids,
370+
)
358371
if req.return_logprob:
359372
prefix_indices = prefix_indices[: req.logprob_start_len]
360373
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
@@ -614,6 +627,7 @@ def cache_filled_batch(self, batch: Batch):
614627
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
615628
for i, req in enumerate(batch.reqs):
616629
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
630+
rid=req.rid,
617631
token_ids=tuple(req.input_ids),
618632
last_uncached_pos=len(req.prefix_indices),
619633
req_pool_idx=req_pool_indices_cpu[i],
@@ -771,6 +785,7 @@ def handle_finished_requests(self, batch: Batch):
771785
for i in finished_indices:
772786
req = batch.reqs[i]
773787
self.tree_cache.cache_req(
788+
rid=req.rid,
774789
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
775790
last_uncached_pos=len(req.prefix_indices),
776791
req_pool_idx=req_pool_indices_cpu[i],
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class BasePrefixCache(ABC):
5+
"""Cache can be indexed by either rid or key."""
6+
7+
@abstractmethod
8+
def reset(self):
9+
pass
10+
11+
@abstractmethod
12+
def match_prefix(self, **kwargs):
13+
pass
14+
15+
@abstractmethod
16+
def insert(self, **kwargs):
17+
pass
18+
19+
@abstractmethod
20+
def cache_req(self, **kwargs):
21+
pass
22+
23+
@abstractmethod
24+
def evict(self, num_tokens, evict_callback):
25+
pass
26+
27+
@abstractmethod
28+
def inc_lock_ref(self, node):
29+
pass
30+
31+
@abstractmethod
32+
def dec_lock_ref(self, node):
33+
pass
34+
35+
@abstractmethod
36+
def evictable_size(self):
37+
pass
38+
39+
def total_size(self):
40+
raise NotImplementedError
41+
42+
def pretty_print(self):
43+
raise NotImplementedError
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Cache for chunked prefill, used when RadixCache is disabled."""
2+
3+
from sglang.srt.mem_cache.base_cache import BasePrefixCache
4+
5+
6+
class ChunkCacheEntry:
7+
def __init__(self, rid, value):
8+
self.rid = rid
9+
self.value = value
10+
11+
12+
class ChunkCache(BasePrefixCache):
13+
def __init__(self, req_to_token_pool, token_to_kv_pool):
14+
self.disable = True
15+
self.req_to_token_pool = req_to_token_pool
16+
self.token_to_kv_pool = token_to_kv_pool
17+
18+
self.reset()
19+
20+
def reset(self):
21+
self.entries = {}
22+
23+
def match_prefix(self, rid, **kwargs):
24+
if rid not in self.entries:
25+
return [], None
26+
27+
entry = self.entries[rid]
28+
return entry.value, entry
29+
30+
def cache_req(
31+
self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
32+
):
33+
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
34+
if del_in_memory_pool:
35+
assert rid in self.entries
36+
self.req_to_token_pool.free(req_pool_idx)
37+
self.token_to_kv_pool.free(indices)
38+
return
39+
40+
if rid not in self.entries:
41+
self.entries[rid] = ChunkCacheEntry(rid, indices)
42+
43+
entry = self.entries[rid]
44+
entry.value = indices
45+
return indices, entry
46+
47+
def insert(self):
48+
raise NotImplementedError
49+
50+
def evict(self, num_tokens, evict_callback):
51+
pass
52+
53+
def inc_lock_ref(self, node):
54+
return 0
55+
56+
def dec_lock_ref(self, node):
57+
return 0
58+
59+
def evictable_size(self):
60+
return 0

python/sglang/srt/mem_cache/radix_cache.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
import torch
2525

26+
from sglang.srt.mem_cache.base_cache import BasePrefixCache
27+
2628

2729
class TreeNode:
2830
def __init__(self):
@@ -46,7 +48,7 @@ def _key_match(key0, key1):
4648
return i
4749

4850

49-
class RadixCache:
51+
class RadixCache(BasePrefixCache):
5052
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
5153
self.req_to_token_pool = req_to_token_pool
5254
self.token_to_kv_pool = token_to_kv_pool
@@ -62,7 +64,7 @@ def reset(self):
6264
self.root_node.lock_ref = 1
6365
self.evictable_size_ = 0
6466

65-
def match_prefix(self, key):
67+
def match_prefix(self, key, **kwargs):
6668
if self.disable:
6769
return [], self.root_node
6870

@@ -90,6 +92,7 @@ def cache_req(
9092
req_pool_idx,
9193
del_in_memory_pool=True,
9294
old_last_node=None,
95+
**kwargs,
9396
):
9497
# Insert the request into radix cache
9598
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]

python/sglang/srt/server_args.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,10 +419,6 @@ def check_server_args(self):
419419
self.dp_size > 1 and self.node_rank is not None
420420
), "multi-node data parallel is not supported"
421421

422-
assert not (
423-
self.chunked_prefill_size is not None and self.disable_radix_cache
424-
), "chunked prefill is not supported with radix cache disabled currently"
425-
426422

427423
@dataclasses.dataclass
428424
class PortArgs:

0 commit comments

Comments
 (0)