Skip to content

Commit 0fa9747

Browse files
committed
fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 8359f83 commit 0fa9747

File tree

4 files changed

+30
-19
lines changed

4 files changed

+30
-19
lines changed

vllm/v1/core/hybrid_allocator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,9 @@ def find_longest_cache_hit(
166166
block_hashes = block_hashes[0]
167167
if len(block_hashes) * self.block_size == num_tokens:
168168
block_hashes = block_hashes[:-1]
169-
return self.allocator.find_longest_cache_hit(block_hashes,
170-
self.group_ids)
169+
blocks, num_computed_tokens = self.allocator.find_longest_cache_hit(
170+
block_hashes, self.group_ids)
171+
return [blocks[0]], num_computed_tokens
171172

172173
def remove_skipped_blocks(
173174
self,

vllm/v1/core/kv_cache_manager.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(
5252
kv_cache_config=kv_cache_config,
5353
block_pool=self.block_pool,
5454
)
55+
self.num_groups = len(kv_cache_config.kv_cache_groups)
5556

5657
@property
5758
def usage(self) -> float:
@@ -104,6 +105,7 @@ def get_computed_blocks(
104105

105106
self.prefix_cache_stats.queries += num_tokens
106107
self.prefix_cache_stats.hits += num_computed_tokens
108+
print(f"computed_blocks: {computed_blocks}")
107109
return computed_blocks, num_computed_tokens
108110

109111
def allocate_slots(
@@ -145,12 +147,13 @@ def allocate_slots(
145147
assert num_input_tokens + num_draft_tokens > 0
146148

147149
if new_computed_blocks is None:
148-
new_computed_blocks = []
150+
new_computed_blocks = [[] for _ in range(self.num_groups)]
149151
assert new_computed_tokens == 0
150-
else:
151-
assert new_computed_tokens > 0
152152

153-
req_blocks = self.req_to_blocks[request.request_id]
153+
req_blocks = self.req_to_blocks.get(request.request_id)
154+
if req_blocks is None:
155+
req_blocks = [[] for _ in range(self.num_groups)]
156+
self.req_to_blocks[request.request_id] = req_blocks
154157

155158
# Free the blocks that are skipped during the attention computation
156159
# (e.g., tokens outside the sliding window).
@@ -183,9 +186,9 @@ def allocate_slots(
183186
return None
184187

185188
# Add the new computed blocks and new blocks to the request.
186-
# FIXME
187-
req_blocks.extend(new_computed_blocks)
188-
req_blocks.extend(new_blocks)
189+
for group_id in range(self.num_groups):
190+
req_blocks[group_id].extend(new_computed_blocks[group_id])
191+
req_blocks[group_id].extend(new_blocks[group_id])
189192
if not self.enable_caching:
190193
return new_blocks
191194

vllm/v1/core/sched/scheduler.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import itertools
56
import time
67
from collections import deque
78
from collections.abc import Iterable
@@ -144,7 +145,7 @@ def schedule(self) -> SchedulerOutput:
144145
# uses structured decoding.
145146
structured_output_request_ids: dict[str, int] = {}
146147

147-
req_to_new_block_ids: dict[str, list[int]] = {}
148+
req_to_new_block_ids: dict[str, list[list[int]]] = {}
148149
num_scheduled_tokens: dict[str, int] = {}
149150
token_budget = self.max_num_scheduled_tokens
150151
# Encoder-related.
@@ -165,7 +166,8 @@ def schedule(self) -> SchedulerOutput:
165166
req_index += 1
166167
continue
167168

168-
num_new_tokens = (request.num_tokens_with_spec -
169+
num_draft_tokens = len(request.draft_token_ids)
170+
num_new_tokens = (request.num_tokens + num_draft_tokens -
169171
request.num_computed_tokens)
170172
if (0 < self.scheduler_config.long_prefill_token_threshold <
171173
num_new_tokens):
@@ -196,7 +198,8 @@ def schedule(self) -> SchedulerOutput:
196198
while True:
197199
new_blocks = self.kv_cache_manager.allocate_slots(
198200
request,
199-
num_new_tokens,
201+
num_new_tokens - num_draft_tokens,
202+
num_draft_tokens=num_draft_tokens,
200203
num_lookahead_tokens=self.num_lookahead_tokens)
201204
if new_blocks is None:
202205
# The request cannot be scheduled.
@@ -233,7 +236,7 @@ def schedule(self) -> SchedulerOutput:
233236
# cycle to fill in the bitmask, which could be a big no-op.
234237
structured_output_request_ids[request.request_id] = req_index
235238
req_to_new_block_ids[request.request_id] = [
236-
b.block_id for b in new_blocks
239+
[b.block_id for b in blocks] for blocks in new_blocks
237240
]
238241
num_scheduled_tokens[request.request_id] = num_new_tokens
239242
token_budget -= num_new_tokens
@@ -330,7 +333,11 @@ def schedule(self) -> SchedulerOutput:
330333
new_encoder_budget = encoder_budget
331334

332335
new_blocks = self.kv_cache_manager.allocate_slots(
333-
request, num_new_tokens, num_computed_tokens, computed_blocks)
336+
request,
337+
num_new_tokens,
338+
new_computed_tokens=num_computed_tokens,
339+
new_computed_blocks=computed_blocks,
340+
num_lookahead_tokens=self.num_lookahead_tokens)
334341
if new_blocks is None:
335342
# The request cannot be scheduled.
336343
break
@@ -355,9 +362,9 @@ def schedule(self) -> SchedulerOutput:
355362

356363
if self.lora_config and request.lora_request:
357364
scheduled_loras.add(request.lora_request.lora_int_id)
358-
req_to_new_block_ids[request.request_id] = [
359-
b.block_id for b in computed_blocks + new_blocks
360-
]
365+
req_to_new_block_ids[request.request_id] = [[
366+
b.block_id for b in itertools.chain(b1, b2)
367+
] for b1, b2 in zip(computed_blocks, new_blocks)]
361368
num_scheduled_tokens[request.request_id] = num_new_tokens
362369
token_budget -= num_new_tokens
363370
request.status = RequestStatus.RUNNING
@@ -463,7 +470,7 @@ def _make_cached_request_data(
463470
request: Request,
464471
num_scheduled_tokens: int,
465472
num_scheduled_spec_tokens: int,
466-
new_block_ids: list[int],
473+
new_block_ids: list[list[int]],
467474
resumed_from_preemption: bool,
468475
) -> CachedRequestData:
469476
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating

vllm/v1/worker/gpu_input_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class CachedRequestState:
3030
sampling_params: SamplingParams
3131
generator: Optional[torch.Generator]
3232

33-
block_ids: list[int]
33+
block_ids: list[list[int]]
3434
num_computed_tokens: int
3535
output_token_ids: list[int]
3636

0 commit comments

Comments
 (0)