Skip to content

[Feature] Use max_num_seqs tokens with profile_run for decode #1110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.kv_transfer import has_kv_transfer_group
from vllm.distributed.parallel_state import get_dp_group, get_pp_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
Expand Down Expand Up @@ -1195,8 +1196,13 @@ def profile_run(self) -> None:
# maximum num_tokens.
num_reqs = self.scheduler_config.max_num_seqs
num_tokens = self.max_num_tokens
min_tokens_per_req = num_tokens // num_reqs
# TODO: Decoding doesn't need `max_num_tokens` for profiles; just set
# `max_num_seqs`. However, for MTP, this might require adjustment.
if has_kv_transfer_group() and \
self.vllm_config.kv_transfer_config.is_kv_consumer:
num_tokens = num_reqs

min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
Expand All @@ -1215,7 +1221,7 @@ def profile_run(self) -> None:
]

# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens)
hidden_states = self._dummy_run(num_tokens)

if get_pp_group().is_last_rank:
hidden_states = hidden_states[logit_indices]
Expand Down
Loading