Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2828,7 +2828,7 @@ def _get_mm_dummy_batch(
def _dummy_run(
self,
num_tokens: int,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
cudagraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False,
uniform_decode: bool = False,
allow_microbatching: bool = True,
Expand All @@ -2844,6 +2844,8 @@ def _dummy_run(
Args:
num_tokens: Number of tokens to run the dummy forward pass.
cudagraph_runtime_mode: used to control the behavior.
- if not set will determine the cudagraph mode based on using
the self.cudagraph_dispatcher.
- CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
- CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
- CUDAGraphMode.FULL: Full cudagraph, attention metadata is
Expand All @@ -2857,7 +2859,7 @@ def _dummy_run(
(1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run
"""
assert cudagraph_runtime_mode in {
assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
}

Expand Down Expand Up @@ -2899,10 +2901,6 @@ def _dummy_run(
elif uniform_decode:
assert not create_mixed_batch
num_reqs = cdiv(num_tokens, max_query_len)
assert num_reqs <= max_num_reqs, \
f"Do not capture num_reqs {num_reqs} > max_num_reqs " \
f"{max_num_reqs} for uniform batch. Num tokens: " \
f"{num_tokens}, max_query_len: {max_query_len}"
num_scheduled_tokens_list = [max_query_len] * num_reqs
if num_tokens % max_query_len != 0:
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
Expand Down Expand Up @@ -3043,18 +3041,20 @@ def _dummy_run(

intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_tokens, None, False)
if cudagraph_runtime_mode == CUDAGraphMode.NONE:
batch_descriptor = None
else:
# filter out the valid batch descriptor
_cg_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(
BatchDescriptor(num_tokens=num_tokens,
uniform_decode=uniform_decode))
# sanity check
assert cudagraph_runtime_mode == _cg_mode, (

# filter out the valid batch descriptor
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
BatchDescriptor(num_tokens=num_tokens,
uniform_decode=uniform_decode))
if cudagraph_runtime_mode is not None:
# we allow forcing NONE when the dispatcher disagrees to support
# warm ups for cudagraph capture
assert cudagraph_runtime_mode == CUDAGraphMode.NONE or \
cudagraph_runtime_mode == _cg_mode, (
f"Cudagraph runtime mode mismatch at dummy_run. "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"Cudagraph runtime mode mismatch at dummy_run. "
assert cudagraph_runtime_mode in [CUDAGraphMode.NONE, _cg_mode], (

f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.")
else:
cudagraph_runtime_mode = _cg_mode

if ubatch_slices is not None:
num_tokens = num_tokens // 2
Expand Down