Skip to content

[Misc] Clean up uesless code for LLM initialize #1373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 58 additions & 103 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, LazyLoader, cdiv)
from vllm.utils import DeviceMemoryProfiler, LazyLoader, cdiv
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
Expand Down Expand Up @@ -137,82 +136,69 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.lora_config = vllm_config.lora_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.block_size = vllm_config.cache_config.block_size
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
self.block_size)
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.device = device
self.dtype = self.model_config.dtype
self.sampler = Sampler()
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
self.max_num_encoder_input_tokens, self.encoder_cache_size = compute_encoder_budget(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
mm_registry=self.mm_registry)

# Lazy initialization, these will be set after __init__
self.kv_caches: List[torch.Tensor] = []
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
self.attn_mask = None
self.attn_state = None
self.requests: Dict[str, CachedRequestState] = {}
self.intermediate_tensors: Optional[IntermediateTensors] = None

ascend_config = get_ascend_config()
if ascend_config.ascend_scheduler_config.enabled:
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
else:
self.chunked_prefill_enabled = True
self.device = device

self.is_multimodal_model = self.model_config.is_multimodal_model
self.block_size = vllm_config.cache_config.block_size

self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
self.block_size)
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
if self.is_multimodal_model:
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.model_config.get_hidden_size()),
dtype=self.dtype,
device=self.device)

self.graph_block_tables = np.zeros(
(self.vllm_config.scheduler_config.max_num_seqs,
(self.max_num_reqs,
(self.model_config.max_model_len + self.block_size - 1) //
self.block_size),
dtype=np.int32)

# Model-related.
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
vllm_config.parallel_config, LayerBlockType.attention)
self.hidden_size = self.model_config.get_hidden_size()
self.dtype = self.model_config.dtype
cache_config = vllm_config.cache_config
if cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]

self.head_size = self.model_config.get_head_size()
# Set up Attention
self.attn_backend = get_attn_backend(
self.head_size,
0,
self.dtype,
self.kv_cache_dtype,
None,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
if self.attn_backend is None:
error_msg = (
f"Error with get_att_backend: {self.head_size=}, "
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{self.model_config.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 NPUModelRunner.")

self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))

# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = self.model_config.uses_mrope

self.max_num_encoder_input_tokens, self.encoder_cache_size = compute_encoder_budget(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
mm_registry=self.mm_registry)

# Lazy initialization
# self.model: nn.Module # Set after load_model
self.kv_caches: List[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}

# Set up speculative decoding.
self.use_aux_hidden_state_outputs = False
self.use_spec_decode = False
self.spec_attn_mask = None
self.use_eagle = False
self.drafter = None
if self.speculative_config:
self.use_spec_decode = True
self.spec_attn_mask = torch.triu(torch.ones(2048,
Expand All @@ -235,10 +221,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
f"{self.speculative_config.method}")
self.rejection_sampler = AscendRejectionSampler()

# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# Persistent batch.

self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
Expand All @@ -251,9 +233,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: Optional[IntermediateTensors] = None

self.uses_mrope = self.model_config.uses_mrope
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
# NOTE: `mrope_positions` is implemented with one additional dummy
Expand All @@ -275,12 +256,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
device="cpu",
pin_memory=True)

if self.is_multimodal_model:
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)

# OPTIMIZATION: Cache the tensors rather than creating them every step.
self.arange_np: npt.NDArray[np.int32] = np.arange(max(
self.max_num_reqs + 1, self.model_config.max_model_len,
Expand All @@ -304,24 +279,17 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
device="cpu",
pin_memory=True)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()

self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=True)
self.query_start_loc_np = self.query_start_loc_cpu.numpy()

self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=True)
self.seq_lens_np = self.seq_lens_cpu.numpy()

self.input_positions_cpu = torch.arange(0,
self.max_num_tokens,
device="cpu")
self.attn_mask = None
self.attn_state = None
self.use_aclgraph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
Expand All @@ -338,38 +306,27 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
# Therefore, an environment variable is added here to dynamically set
# the size of the pre-constructed mask matrix based on requirements.
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
self.attn_mask_len = min(self.model_config.max_model_len,
int(mask_len))
attn_mask_len = min(self.model_config.max_model_len, int(mask_len))
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
self.attn_mask_len, self.dtype)

self.sampler = Sampler()
attn_mask_len, self.dtype)

self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes

if ascend_config.torchair_graph_config.graph_batch_sizes_init:
self.init_torchair_graph_batch_sizes()

if len(self.torchair_graph_batch_sizes) == 0:
# TODO(zzzzwwjj): check torchair_graph_batch_sizes init code
self.torchair_graph_batch_sizes = [
self.scheduler_config.max_num_seqs
]
self.torchair_graph_batch_sizes = [self.max_num_reqs]

torch._dynamo.cache_size.config.cache_size_limit += len(
self.torchair_graph_batch_sizes)
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._logging.set_logs(
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)

self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank

def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
output.
Expand Down Expand Up @@ -1692,8 +1649,7 @@ def _dummy_run(
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
num_reqs = self.max_num_reqs if num_tokens >= self.max_num_reqs else num_tokens
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
Expand Down Expand Up @@ -1795,14 +1751,13 @@ def profile_run(self) -> None:

# For profile, have maximum num_reqs and that collectively have
# maximum num_tokens.
num_reqs = self.scheduler_config.max_num_seqs
num_tokens = self.max_num_tokens
min_tokens_per_req = num_tokens // num_reqs
min_tokens_per_req = self.max_num_tokens // self.max_num_reqs

num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * self.max_num_reqs
num_scheduled_tokens_list[
-1] += self.max_num_tokens % self.max_num_reqs
assert sum(num_scheduled_tokens_list) == self.max_num_tokens
assert len(num_scheduled_tokens_list) == self.max_num_reqs

num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
Expand Down Expand Up @@ -1830,15 +1785,14 @@ def load_model(self) -> None:

with DeviceMemoryProfiler() as m: # noqa: SIM117
self.model = get_model(vllm_config=self.vllm_config)
if hasattr(self, "drafter"):
if self.drafter:
logger.info("Loading drafter model...")
if self.use_aux_hidden_state_outputs:
self.drafter.load_model(self.model)
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
else:
self.drafter.load_model()
if self.use_aux_hidden_state_outputs:
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
if self.lora_config:
self.model = self.load_lora_model(self.model,
self.model_config,
Expand Down Expand Up @@ -1924,7 +1878,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
device=self.device,
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.cache_config.block_size],
block_sizes=[self.block_size],
)

kv_cache_sizes = {}
Expand Down Expand Up @@ -2004,7 +1958,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""

forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items():
Expand All @@ -2016,7 +1969,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
block_size=self.block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
Expand Down Expand Up @@ -2105,6 +2058,7 @@ def _generate_draft_token_ids(
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + num_sampled_ids
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
assert self.drafter is not None
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx])
if drafter_output is None or len(drafter_output) == 0:
Expand Down Expand Up @@ -2161,6 +2115,7 @@ def _generate_mtp_token_ids(
dtype=torch.int32,
device=self.device,
)
assert self.drafter is not None
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
attn_metadata.query_start_loc,
num_rejected_tokens,
Expand All @@ -2169,7 +2124,7 @@ def _generate_mtp_token_ids(
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = attn_metadata.slot_mapping[token_indices]

assert self.drafter is not None
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
Expand All @@ -2190,7 +2145,7 @@ def init_torchair_graph_batch_sizes(self):
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
start_graph_batch_size = max(start_graph_batch_size, tp_size)

while (start_graph_batch_size <= self.scheduler_config.max_num_seqs):
while (start_graph_batch_size <= self.max_num_reqs):
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
start_graph_batch_size *= 2

Expand Down
Loading
Loading