Skip to content

Commit 1c68e1c

Browse files
ywang96Roger Wangknlnguyen1802
authored andcommitted
[Bugfix] Fix scheduling when repeated images in one request (vllm-project#23544)
Signed-off-by: Roger Wang <hey@rogerw.me> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.me> Co-authored-by: knlnguyen1802 <knlnguyen1802@gmail.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
1 parent 4892eb6 commit 1c68e1c

File tree

3 files changed

+96
-39
lines changed

3 files changed

+96
-39
lines changed

tests/v1/core/test_encoder_cache_manager.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_basic_allocate_and_reuse():
2222
req = MockRequest("r1", ["imgA"], [4])
2323

2424
assert not cache.check_and_update_cache(req, 0)
25-
assert cache.try_allocate(req, 0, int(1e9))
25+
assert cache.can_allocate(req, 0, int(1e9), 0)
2626

2727
cache.allocate(req, 0)
2828

@@ -44,7 +44,7 @@ def test_freeing_decreases_refcount_and_moves_to_freeable():
4444
manager = EncoderCacheManager(cache_size=10)
4545
req = MockRequest("req2", ["img3"], [5])
4646

47-
assert manager.try_allocate(req, 0, int(1e9))
47+
assert manager.can_allocate(req, 0, int(1e9), 0)
4848
manager.allocate(req, 0)
4949

5050
assert len(manager.cached["img3"]) == 1
@@ -60,10 +60,10 @@ def test_free_request_frees_all_inputs():
6060
manager = EncoderCacheManager(cache_size=10)
6161
req = MockRequest("req3", ["a", "b"], [2, 3])
6262

63-
assert manager.try_allocate(req, 0, int(1e9))
63+
assert manager.can_allocate(req, 0, int(1e9), 0)
6464
manager.allocate(req, 0)
6565

66-
assert manager.try_allocate(req, 1, int(1e9))
66+
assert manager.can_allocate(req, 1, int(1e9), 0)
6767
manager.allocate(req, 1)
6868

6969
assert len(manager.cached["a"]) == 1
@@ -84,11 +84,11 @@ def test_eviction_when_cache_is_full():
8484
req1 = MockRequest("req1", ["x"], [6])
8585
req2 = MockRequest("req2", ["y"], [5])
8686

87-
assert manager.try_allocate(req1, 0, int(1e9))
87+
assert manager.can_allocate(req1, 0, int(1e9), 0)
8888
manager.allocate(req1, 0)
8989
manager.free_encoder_input(req1, 0)
9090

91-
assert manager.try_allocate(req2, 0, int(1e9))
91+
assert manager.can_allocate(req2, 0, int(1e9), 0)
9292
manager.allocate(req2, 0)
9393

9494
# 'x' should have been evicted.
@@ -100,10 +100,10 @@ def test_get_cached_input_ids():
100100
manager = EncoderCacheManager(cache_size=10)
101101
req = MockRequest("reqX", ["m", "n", "o"], [2, 4, 3])
102102

103-
assert manager.try_allocate(req, 0, int(1e9))
103+
assert manager.can_allocate(req, 0, int(1e9), 0)
104104
manager.allocate(req, 0)
105105

106-
assert manager.try_allocate(req, 2, int(1e9))
106+
assert manager.can_allocate(req, 2, int(1e9), 0)
107107
manager.allocate(req, 2)
108108

109109
cached_ids = manager.get_cached_input_ids(req)
@@ -114,7 +114,7 @@ def test_has_cache_restores_from_freeable():
114114
manager = EncoderCacheManager(cache_size=10)
115115
req = MockRequest("reqY", ["imgZ"], [4])
116116

117-
assert manager.try_allocate(req, 0, int(1e9))
117+
assert manager.can_allocate(req, 0, int(1e9), 0)
118118
manager.allocate(req, 0)
119119

120120
manager.free_encoder_input(req, 0)
@@ -131,14 +131,41 @@ def test_get_freed_mm_hashes_clears_freed_list():
131131
req1 = MockRequest("reqA", ["a"], [5])
132132
req2 = MockRequest("reqB", ["b"], [6])
133133

134-
assert manager.try_allocate(req1, 0, int(1e9))
134+
assert manager.can_allocate(req1, 0, int(1e9), 0)
135135
manager.allocate(req1, 0)
136136
manager.free_encoder_input(req1, 0)
137137

138138
# Should trigger eviction of 'a'.
139-
assert manager.try_allocate(req2, 0, int(1e9))
139+
assert manager.can_allocate(req2, 0, int(1e9), 0)
140140
manager.allocate(req2, 0)
141141

142142
freed = manager.get_freed_mm_hashes()
143143
assert "a" in freed
144144
assert manager.get_freed_mm_hashes() == []
145+
146+
147+
def test_schedule_request_multi_images_respect_space_limit():
148+
manager = EncoderCacheManager(cache_size=10)
149+
req = MockRequest("reqA", ["a", "b"], [5, 6])
150+
compute_budget = 100
151+
152+
num_tokens_to_schedule = 0
153+
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
154+
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
155+
compute_budget -= req.get_num_encoder_tokens(0)
156+
157+
assert not manager.can_allocate(req, 1, compute_budget,
158+
num_tokens_to_schedule)
159+
160+
161+
def test_schedule_request_multi_images_respect_compute_limit():
162+
manager = EncoderCacheManager(cache_size=100)
163+
req = MockRequest("reqA", ["a", "b"], [5, 6])
164+
compute_budget = 10
165+
num_tokens_to_schedule = 0
166+
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
167+
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
168+
compute_budget -= req.get_num_encoder_tokens(0)
169+
170+
assert not manager.can_allocate(req, 1, compute_budget,
171+
num_tokens_to_schedule)

vllm/v1/core/encoder_cache_manager.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,9 @@ def check_and_update_cache(self, request: Request, input_id: int) -> bool:
9999
self.cached[mm_hash].add(request.request_id)
100100
return True
101101

102-
def try_allocate(self, request: Request, input_id: int,
103-
encoder_budget: int) -> bool:
102+
def can_allocate(self, request: Request, input_id: int,
103+
encoder_compute_budget: int,
104+
num_tokens_to_schedule: int) -> bool:
104105
"""Check if there's sufficient cache space for a multimodal input.
105106
If there is, return True and update EncoderCacheManager state.
106107
@@ -116,6 +117,10 @@ def try_allocate(self, request: Request, input_id: int,
116117
Args:
117118
request: The request containing the multimodal input.
118119
input_id: Index of the multimodal input within the request.
120+
encoder_compute_budget: Number of encoder tokens allowed to be
121+
computed when this method is invoked.
122+
num_tokens_to_schedule: Number of tokens already scheduled to be
123+
allocated with cache space when this method is invoked.
119124
120125
Returns:
121126
True if there's enough capacity to hold the encoder output for this
@@ -128,13 +133,13 @@ def try_allocate(self, request: Request, input_id: int,
128133
num_tokens = request.get_num_encoder_tokens(input_id)
129134

130135
# Not enough compute budget
131-
if num_tokens > encoder_budget:
136+
if num_tokens > encoder_compute_budget:
132137
return False
133138

139+
num_tokens += num_tokens_to_schedule
140+
134141
# Enough free slots
135142
if num_tokens <= self.num_free_slots:
136-
self.num_free_slots -= num_tokens
137-
self.num_freeable_slots -= num_tokens
138143
return True
139144

140145
# Not enough reclaimable slots
@@ -149,8 +154,6 @@ def try_allocate(self, request: Request, input_id: int,
149154
del self.cached[mm_hash]
150155
self.freed.append(mm_hash)
151156
self.num_free_slots += num_free_token
152-
self.num_free_slots -= num_tokens
153-
self.num_freeable_slots -= num_tokens
154157
return True
155158

156159
def allocate(self, request: Request, input_id: int) -> None:
@@ -161,19 +164,24 @@ def allocate(self, request: Request, input_id: int) -> None:
161164
the model runner; this method updates the manager's bookkeeping.
162165
163166
Note:
164-
This method assumes try_allocate() returned True for the same input.
167+
This method assumes can_allocate() returned True for the same input.
165168
"""
166-
# Encoder cache space budget should be already updated for the
167-
# multimodal input and non-negative after try_allocate() is called.
168-
assert self.num_free_slots >= 0
169-
assert self.num_freeable_slots >= 0
170169

171170
mm_hash = request.mm_hashes[input_id]
172171
request_id = request.request_id
173172
if mm_hash not in self.cached:
174173
self.cached[mm_hash] = set()
175174

175+
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
176+
177+
# NOTE: Encoder cache should always have enough space for encoder inputs
178+
# that are scheduled since eviction takes place at can_allocate().
179+
assert self.num_free_slots >= num_encoder_tokens
180+
assert self.num_freeable_slots >= num_encoder_tokens
181+
176182
self.cached[mm_hash].add(request_id)
183+
self.num_free_slots -= num_encoder_tokens
184+
self.num_freeable_slots -= num_encoder_tokens
177185

178186
def get_cached_input_ids(self, request: Request) -> set[int]:
179187
"""Get all cached multimodal input IDs for a request.

vllm/v1/core/sched/scheduler.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def schedule(self) -> SchedulerOutput:
182182
token_budget = self.max_num_scheduled_tokens
183183
# Encoder-related.
184184
scheduled_encoder_inputs: dict[str, list[int]] = {}
185-
encoder_budget = self.max_num_encoder_input_tokens
185+
encoder_compute_budget = self.max_num_encoder_input_tokens
186186
# Spec decode-related.
187187
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
188188

@@ -211,12 +211,13 @@ def schedule(self) -> SchedulerOutput:
211211

212212
# Schedule encoder inputs.
213213
encoder_inputs_to_schedule = None
214-
new_encoder_budget = encoder_budget
214+
new_encoder_compute_budget = encoder_compute_budget
215215
if request.has_encoder_inputs:
216216
(encoder_inputs_to_schedule, num_new_tokens,
217-
new_encoder_budget) = self._try_schedule_encoder_inputs(
217+
new_encoder_compute_budget
218+
) = self._try_schedule_encoder_inputs(
218219
request, request.num_computed_tokens, num_new_tokens,
219-
encoder_budget)
220+
encoder_compute_budget)
220221

221222
if num_new_tokens == 0:
222223
# The request cannot be scheduled because one of the following
@@ -298,7 +299,7 @@ def schedule(self) -> SchedulerOutput:
298299
# Allocate the encoder cache.
299300
for i in encoder_inputs_to_schedule:
300301
self.encoder_cache_manager.allocate(request, i)
301-
encoder_budget = new_encoder_budget
302+
encoder_compute_budget = new_encoder_compute_budget
302303

303304
# Record the LoRAs in scheduled_running_reqs
304305
scheduled_loras: set[int] = set()
@@ -382,7 +383,7 @@ def schedule(self) -> SchedulerOutput:
382383
num_computed_tokens = request.num_computed_tokens
383384

384385
encoder_inputs_to_schedule = None
385-
new_encoder_budget = encoder_budget
386+
new_encoder_compute_budget = encoder_compute_budget
386387

387388
# KVTransfer: loading remote KV, do not allocate for new work.
388389
if load_kv_async:
@@ -413,10 +414,10 @@ def schedule(self) -> SchedulerOutput:
413414
# Schedule encoder inputs.
414415
if request.has_encoder_inputs:
415416
(encoder_inputs_to_schedule, num_new_tokens,
416-
new_encoder_budget
417+
new_encoder_compute_budget
417418
) = self._try_schedule_encoder_inputs(
418419
request, num_computed_tokens, num_new_tokens,
419-
encoder_budget)
420+
encoder_compute_budget)
420421
if num_new_tokens == 0:
421422
# The request cannot be scheduled.
422423
break
@@ -495,7 +496,7 @@ def schedule(self) -> SchedulerOutput:
495496
# Allocate the encoder cache.
496497
for i in encoder_inputs_to_schedule:
497498
self.encoder_cache_manager.allocate(request, i)
498-
encoder_budget = new_encoder_budget
499+
encoder_compute_budget = new_encoder_compute_budget
499500

500501
# Put back any skipped requests at the head of the waiting queue
501502
if skipped_waiting_requests:
@@ -658,7 +659,7 @@ def _try_schedule_encoder_inputs(
658659
request: Request,
659660
num_computed_tokens: int,
660661
num_new_tokens: int,
661-
encoder_budget: int,
662+
encoder_compute_budget: int,
662663
) -> tuple[list[int], int, int]:
663664
"""
664665
Determine which encoder inputs need to be scheduled in the current step,
@@ -680,11 +681,17 @@ def _try_schedule_encoder_inputs(
680681
blocks and externally cached blocks (via KVConnector).
681682
"""
682683
if num_new_tokens == 0 or not request.has_encoder_inputs:
683-
return [], num_new_tokens, encoder_budget
684+
return [], num_new_tokens, encoder_compute_budget
684685
encoder_inputs_to_schedule: list[int] = []
685686
mm_positions = request.mm_positions
686687
assert mm_positions is not None
687688
assert len(mm_positions) > 0
689+
690+
# NOTE: since scheduler operates on the request level (possibly with
691+
# multiple encoder inputs per request), we need to create temporary
692+
# trackers for accounting at the encoder input level.
693+
mm_hashes_to_schedule = set()
694+
num_tokens_to_schedule = 0
688695
for i, pos_info in enumerate(mm_positions):
689696
start_pos = pos_info.offset
690697
num_encoder_tokens = pos_info.length
@@ -695,13 +702,20 @@ def _try_schedule_encoder_inputs(
695702
if start_pos >= num_computed_tokens + num_new_tokens:
696703
# The encoder input is not needed in this step.
697704
break
705+
698706
if start_pos + num_encoder_tokens <= num_computed_tokens:
699707
# The encoder input is already computed and stored
700708
# in the decoder's KV cache.
701709
continue
702710

711+
# The same encoder input has already been scheduled in the current
712+
# step.
713+
if request.mm_hashes[i] in mm_hashes_to_schedule:
714+
continue
715+
703716
if self.encoder_cache_manager.check_and_update_cache(request, i):
704-
# The encoder input is already computed and cached.
717+
# The encoder input is already computed and cached from a
718+
# previous step.
705719
continue
706720

707721
# If no encoder input chunking is allowed, we do not want to
@@ -714,8 +728,9 @@ def _try_schedule_encoder_inputs(
714728
num_new_tokens = start_pos - num_computed_tokens
715729
break
716730

717-
if not self.encoder_cache_manager.try_allocate(
718-
request, i, encoder_budget):
731+
if not self.encoder_cache_manager.can_allocate(
732+
request, i, encoder_compute_budget,
733+
num_tokens_to_schedule):
719734
# The encoder cache is full or the encoder budget is exhausted.
720735
# NOTE(woosuk): We assume that the encoder input tokens should
721736
# be processed altogether, as the encoder usually uses
@@ -732,9 +747,16 @@ def _try_schedule_encoder_inputs(
732747
num_new_tokens = 0
733748
break
734749

735-
encoder_budget -= num_encoder_tokens
750+
num_tokens_to_schedule += num_encoder_tokens
751+
encoder_compute_budget -= num_encoder_tokens
752+
mm_hashes_to_schedule.add(request.mm_hashes[i])
736753
encoder_inputs_to_schedule.append(i)
737-
return encoder_inputs_to_schedule, num_new_tokens, encoder_budget
754+
755+
return (
756+
encoder_inputs_to_schedule,
757+
num_new_tokens,
758+
encoder_compute_budget,
759+
)
738760

739761
def get_grammar_bitmask(
740762
self,

0 commit comments

Comments
 (0)