Skip to content

Commit 8e6a5e7

Browse files
LucasWilkinsonyewentao256
authored andcommitted
[BugFix] AssertionError: Do not capture num_reqs > max_num_reqs for uniform batch (#25505)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent faae7a7 commit 8e6a5e7

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2828,7 +2828,7 @@ def _get_mm_dummy_batch(
28282828
def _dummy_run(
28292829
self,
28302830
num_tokens: int,
2831-
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
2831+
cudagraph_runtime_mode: Optional[CUDAGraphMode] = None,
28322832
force_attention: bool = False,
28332833
uniform_decode: bool = False,
28342834
allow_microbatching: bool = True,
@@ -2844,6 +2844,8 @@ def _dummy_run(
28442844
Args:
28452845
num_tokens: Number of tokens to run the dummy forward pass.
28462846
cudagraph_runtime_mode: used to control the behavior.
2847+
- if not set will determine the cudagraph mode based on using
2848+
the self.cudagraph_dispatcher.
28472849
- CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
28482850
- CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
28492851
- CUDAGraphMode.FULL: Full cudagraph, attention metadata is
@@ -2857,7 +2859,7 @@ def _dummy_run(
28572859
(1 token) and prefill (multiple tokens) requests.
28582860
remove_lora: If False, dummy LoRAs are not destroyed after the run
28592861
"""
2860-
assert cudagraph_runtime_mode in {
2862+
assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in {
28612863
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
28622864
}
28632865

@@ -2899,10 +2901,6 @@ def _dummy_run(
28992901
elif uniform_decode:
29002902
assert not create_mixed_batch
29012903
num_reqs = cdiv(num_tokens, max_query_len)
2902-
assert num_reqs <= max_num_reqs, \
2903-
f"Do not capture num_reqs {num_reqs} > max_num_reqs " \
2904-
f"{max_num_reqs} for uniform batch. Num tokens: " \
2905-
f"{num_tokens}, max_query_len: {max_query_len}"
29062904
num_scheduled_tokens_list = [max_query_len] * num_reqs
29072905
if num_tokens % max_query_len != 0:
29082906
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
@@ -3043,18 +3041,20 @@ def _dummy_run(
30433041

30443042
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
30453043
num_tokens, None, False)
3046-
if cudagraph_runtime_mode == CUDAGraphMode.NONE:
3047-
batch_descriptor = None
3048-
else:
3049-
# filter out the valid batch descriptor
3050-
_cg_mode, batch_descriptor = \
3051-
self.cudagraph_dispatcher.dispatch(
3052-
BatchDescriptor(num_tokens=num_tokens,
3053-
uniform_decode=uniform_decode))
3054-
# sanity check
3055-
assert cudagraph_runtime_mode == _cg_mode, (
3044+
3045+
# filter out the valid batch descriptor
3046+
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
3047+
BatchDescriptor(num_tokens=num_tokens,
3048+
uniform_decode=uniform_decode))
3049+
if cudagraph_runtime_mode is not None:
3050+
# we allow forcing NONE when the dispatcher disagrees to support
3051+
# warm ups for cudagraph capture
3052+
assert cudagraph_runtime_mode == CUDAGraphMode.NONE or \
3053+
cudagraph_runtime_mode == _cg_mode, (
30563054
f"Cudagraph runtime mode mismatch at dummy_run. "
30573055
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.")
3056+
else:
3057+
cudagraph_runtime_mode = _cg_mode
30583058

30593059
if ubatch_slices is not None:
30603060
num_tokens = num_tokens // 2

0 commit comments

Comments
 (0)