diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 00d965d44cd00..d955ac33f506e 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -239,16 +239,9 @@ async def run_vllm_async( return end - start -def run_hf( - requests: List[Tuple[str, int, int]], - model: str, - tokenizer: PreTrainedTokenizerBase, - n: int, - use_beam_search: bool, - max_batch_size: int, - trust_remote_code: bool, - device: str -) -> float: +def run_hf(requests: List[Tuple[str, int, int]], model: str, + tokenizer: PreTrainedTokenizerBase, n: int, use_beam_search: bool, + max_batch_size: int, trust_remote_code: bool, device: str) -> float: assert not use_beam_search llm = AutoModelForCausalLM.from_pretrained( model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) diff --git a/examples/offline_inference_npu.py b/examples/offline_inference_npu.py index 4835d91fbb778..cb78b67f8bae8 100644 --- a/examples/offline_inference_npu.py +++ b/examples/offline_inference_npu.py @@ -1,7 +1,11 @@ import gc + import torch + from vllm import LLM, SamplingParams -from vllm.distributed.parallel_state import destroy_model_parallel, destroy_distributed_environment +from vllm.distributed.parallel_state import (destroy_distributed_environment, + destroy_model_parallel) + def clean_up(): destroy_model_parallel() @@ -9,6 +13,7 @@ def clean_up(): gc.collect() torch.npu.empty_cache() + # Sample prompts. prompts = [ "Hello, my name is", diff --git a/setup.py b/setup.py index 58b7c7c6a31e5..35d576f61cac5 100644 --- a/setup.py +++ b/setup.py @@ -285,6 +285,7 @@ def _is_openvino() -> bool: def _is_xpu() -> bool: return VLLM_TARGET_DEVICE == "xpu" + def _is_npu() -> bool: return VLLM_TARGET_DEVICE == "npu" @@ -294,8 +295,8 @@ def _build_custom_ops() -> bool: def _build_core_ext() -> bool: - return not (_is_neuron() or _is_tpu() or _is_openvino() or _is_xpu() or - _is_npu()) + return not (_is_neuron() or _is_tpu() or _is_openvino() or _is_xpu() + or _is_npu()) def get_hipcc_rocm_version(): diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 44028028990a5..e2b8cf2fa7e2a 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -18,8 +18,8 @@ from ..utils import multi_gpu_test MODELS = [ - "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", + "facebook/opt-125m", "/home/models/llama-2-7b/" + # "meta-llama/Llama-2-7b-hf", ] TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") diff --git a/tests/conftest.py b/tests/conftest.py index 9a4862011c2fc..6a6147801ba66 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,10 +35,9 @@ to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput +from vllm.platforms import current_platform from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, identity, is_cpu) -from vllm.platforms import current_platform - logger = init_logger(__name__) diff --git a/vllm/attention/backends/ascend.py b/vllm/attention/backends/ascend.py index d8a68bb0dbcec..edf3839362fd6 100644 --- a/vllm/attention/backends/ascend.py +++ b/vllm/attention/backends/ascend.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass -from typing import Any, Dict, List, TYPE_CHECKING, Optional, Tuple, Type import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch import torch_npu @@ -11,11 +11,12 @@ CommonMetadataBuilder, compute_slot_mapping_start_idx, is_block_tables_empty) -from vllm.attention.ops.paged_attn import PagedAttention, PagedAttentionMetadata +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) + if TYPE_CHECKING: from vllm.worker.npu_model_runner import ModelInputForNPUBuilder - SHARE_MASK_TRIL_PREFIX_CACHE = None SHARE_MASK_TRIL = None @@ -55,10 +56,9 @@ def swap_blocks( dst_indices = src_to_dst[:, 1] dst_key_cache[dst_indices] = src_key_cache[src_indices].to( - dst_key_cache.device) + dst_key_cache.device) dst_value_cache[dst_indices] = src_value_cache[src_indices].to( - dst_key_cache.device) - + dst_key_cache.device) @staticmethod def copy_blocks( @@ -222,7 +222,7 @@ def prefill_metadata(self) -> Optional["AscendMetadata"]: encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, - ) + ) return self._cached_prefill_metadata @property @@ -260,7 +260,7 @@ def decode_metadata(self) -> Optional["AscendMetadata"]: encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, - ) + ) return self._cached_decode_metadata @@ -308,10 +308,8 @@ def compute_npu_slot_indices(self, is_profile_run, slot_indices, seq_id, slot_indices.extend([[PAD_SLOT_ID, 0]] * (max_query_len - numel)) def _add_seq_group( - self, - inter_data: "ModelInputForNPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool - ): + self, inter_data: "ModelInputForNPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. @@ -319,8 +317,9 @@ def _add_seq_group( """ is_prompt = inter_data.is_prompt block_tables = inter_data.block_tables - max_query_len = max(max(data.query_lens) - for data in self.input_builder.inter_data_list) + max_query_len = max( + max(data.query_lens) + for data in self.input_builder.inter_data_list) is_prompt = inter_data.is_prompt block_tables = inter_data.block_tables @@ -401,7 +400,6 @@ def forward( value: torch.Tensor, kv_cache: List[torch.Tensor], attn_metadata: AscendMetadata, - kv_scale: float = 1.0, k_scale: float = 1.0, v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, @@ -413,29 +411,34 @@ def forward( num_tokens = batch_size * seq_len key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache: shape = [2, num_blocks, block_size, num_kv_heads * head_size] - key_cache [num_blocks, block_size, num_kv_heads * head_size] - value_cache [num_blocks, block_size, num_kv_heads * head_size] + kv_cache: shape = [2, num_blocks, block_size, + num_kv_heads * head_size] + key_cache = [num_blocks, block_size, + num_kv_heads * head_size] + value_cache = [num_blocks, block_size, + num_kv_heads * head_size] attn_metadata: Metadata for attention. Returns: shape = [batch_size, seq_len * num_heads * head_size] """ assert k_scale == 1.0 and v_scale == 1.0 if attn_type != AttentionType.DECODER: - raise NotImplementedError( - "Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "PallasAttentionBackendImpl" - ) + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") # view q k v to BSH num_tokens = query.shape[0] if kv_cache is not None: if attn_metadata.num_prefills > 0: - slot_indices = attn_metadata.prefill_metadata.slot_mapping + slot_indices = (None + if attn_metadata.prefill_metadata is None else + attn_metadata.prefill_metadata.slot_mapping) else: - slot_indices = attn_metadata.decode_metadata.slot_mapping + slot_indices = (None + if attn_metadata.decode_metadata is None else + attn_metadata.decode_metadata.slot_mapping) key_cache, value_cache = kv_cache[0], kv_cache[1] AscendPagedAttention.write_to_paged_cache( key, @@ -450,10 +453,8 @@ def forward( if num_tokens > 16384: attn_metadata.sparse_mode = 2 attention_mask = gen_input_mask( - attn_metadata.max_prefill_seq_len, - self.sliding_window, - num_tokens - ) + attn_metadata.max_prefill_seq_len, self.sliding_window, + num_tokens) attn_metadata.attn_mask = attention_mask if (self.alibi_slopes is not None @@ -491,8 +492,7 @@ def forward( sparse_mode=attn_metadata.sparse_mode, ) output = output.transpose(1, 2).reshape( - num_tokens, -1, self.num_heads * self.head_size - ) + num_tokens, -1, self.num_heads * self.head_size) elif attn_metadata.decode_metadata: # FA for decoding phase @@ -539,15 +539,13 @@ def gen_input_mask(seq_len, sliding_window, len): global SHARE_MASK_TRIL if SHARE_MASK_TRIL is None or SHARE_MASK_TRIL.shape[0] < seq_len: SHARE_MASK_TRIL = ~torch.tril( - torch.ones(seq_len, seq_len, dtype=bool, device="npu") - ) + torch.ones(seq_len, seq_len, dtype=bool, device="npu")) attention_mask = SHARE_MASK_TRIL if sliding_window is not None: attention_mask = ~attention_mask - attention_mask = torch.triu( - attention_mask, diagonal=1 - sliding_window - ) + attention_mask = torch.triu(attention_mask, + diagonal=1 - sliding_window) attention_mask = ~attention_mask return attention_mask diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 92c710e700002..7ed3e3debb916 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -14,8 +14,10 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.triton_utils.importing import HAS_TRITON + if HAS_TRITON: from vllm.triton_utils import maybe_set_triton_cache_manager + from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, cuda_is_initialized, get_distributed_init_method, get_open_port, get_vllm_instance_id, make_async, diff --git a/vllm/executor/multiproc_npu_executor.py b/vllm/executor/multiproc_npu_executor.py index 9d04755450121..e00d70b5c0e91 100644 --- a/vllm/executor/multiproc_npu_executor.py +++ b/vllm/executor/multiproc_npu_executor.py @@ -1,11 +1,13 @@ import os -import torch, torch_npu # noqa +import torch # noqa +import torch_npu # noqa + +from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync) from vllm.executor.npu_executor import NPUExecutor from vllm.logger import init_logger from vllm.utils import update_environment_variables -from vllm.executor.multiproc_gpu_executor import ( - MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync) logger = init_logger(__name__) @@ -21,7 +23,7 @@ def _check_executor_parameters(self): if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ: update_environment_variables({ "ASCEND_RT_VISIBLE_DEVICES": - (",".join(map(str, range(world_size)))) + (",".join(map(str, range(world_size)))) }) npu_device_count = torch.npu.device_count() diff --git a/vllm/executor/npu_executor.py b/vllm/executor/npu_executor.py index 9923429dac1d5..03441f71080f0 100644 --- a/vllm/executor/npu_executor.py +++ b/vllm/executor/npu_executor.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Callable, Type, Tuple +from typing import Callable, List, Optional, Tuple, Type from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 331b71d7ab961..ca69d07c3f162 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -96,17 +96,14 @@ def forward_npu( import torch_npu if residual is not None: - x, _, residual = torch_npu.npu_add_rms_norm(x, - residual, - self.weight, - self.variance_epsilon) + x, _, residual = torch_npu.npu_add_rms_norm( + x, residual, self.weight, self.variance_epsilon) return x, residual x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) return x - def extra_repr(self) -> str: s = f"hidden_size={self.weight.data.size(0)}" s += f", eps={self.variance_epsilon}" diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index ddceb15c61bd1..b7844dbab0c50 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -50,10 +50,11 @@ from .interfaces import SupportsLoRA - current_backend = "inductor" if current_platform.is_npu(): current_backend = "npu" + + @torch.compile(backend=current_backend) def layer_norm_func(hidden_states, weight, variance_epsilon): input_dtype = hidden_states.dtype diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 4c307b692cdb5..598d85225a8ac 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -261,7 +261,7 @@ def _prepare_seq_groups( # If the current seq group is in decode stage, it is None. seq_len: Optional[int] = None query_len: Optional[int] = None - padding_len: Optional[int] = 0 + padding_len: int = 0 prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices if cache is not None else []) sample_indices: List[int] = (sample_obj.sample_indices diff --git a/vllm/platforms/ascend.py b/vllm/platforms/ascend.py index bd31b5d72c4ad..f60845f5c213d 100644 --- a/vllm/platforms/ascend.py +++ b/vllm/platforms/ascend.py @@ -1,5 +1,5 @@ -from typing import Tuple import os +from typing import Tuple import torch @@ -36,7 +36,7 @@ def inference_mode(cls): return torch.inference_mode() @classmethod - def set_device(cls, device: torch.device) -> torch.device: + def set_device(cls, device: torch.device): torch.npu.set_device(device) @classmethod diff --git a/vllm/worker/npu_model_runner.py b/vllm/worker/npu_model_runner.py index 13c614f7c00d4..cdb25b91b4fb5 100644 --- a/vllm/worker/npu_model_runner.py +++ b/vllm/worker/npu_model_runner.py @@ -1,76 +1,38 @@ import dataclasses -from dataclasses import dataclass -from typing import (Any, Dict, List, Optional, Set, Type, - TypeVar) +from typing import Any, Dict, List, Optional, Set, Type import torch import torch.distributed - -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) from vllm.distributed import get_pp_group -from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models.interfaces import (supports_lora, - supports_multimodal) -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, - MultiModalRegistry) +from vllm.multimodal import MultiModalInputs from vllm.platforms import current_platform from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.prompt_adapter.worker_manager import ( - LRUCacheWorkerPromptAdapterManager) - from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceGroupMetadata -from vllm.utils import (DeviceMemoryProfiler, flatten_2d_lists, - make_tensor_with_pad) -from vllm.worker.model_runner import (ModelInputForGPU, ModelInputForGPUBuilder, +from vllm.utils import flatten_2d_lists, make_tensor_with_pad +from vllm.worker.model_runner import (ModelInputForGPU, + ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata, ModelRunner) - logger = init_logger(__name__) LORA_WARMUP_RANK = 8 -TModelInputForNPU = TypeVar('TModelInputForNPU', bound="ModelInputForNPU") - - -@dataclass(frozen=True) -class ModelInputForNPU(ModelInputForGPU): - """ - This base class contains metadata needed for the base model forward pass - but not metadata for possible additional steps, e.g., sampling. Model - runners that run additional steps should subclass this method to add - additional fields. - """ - pass - - -@dataclass(frozen=True) -class ModelInputForNPUWithSamplingMetadata( - ModelInputForGPUWithSamplingMetadata): - """ - Used by the ModelRunner. - """ - pass - class ModelInputForNPUBuilder(ModelInputForGPUBuilder): - """Build ModelInputForNPU from SequenceGroupMetadata.""" + """Build ModelInputForGPU from SequenceGroupMetadata.""" # Note: ideally we would be using a dataclass(kw_only=True) # here, so that this can be subclassed easily, # but kw_only is not supported in python<3.10. - def build(self) -> ModelInputForNPU: + def build(self) -> ModelInputForGPU: """Finalize the builder intermediate data and create on-device tensors. """ @@ -112,26 +74,21 @@ def build(self) -> ModelInputForNPU: if self.inter_data_list[0].is_prompt: input_tokens_tensor = make_tensor_with_pad( - input_tokens, 0, - dtype=torch.int, - device=self.runner.device) + input_tokens, 0, dtype=torch.int, device=self.runner.device) input_positions_tensor = make_tensor_with_pad( - input_positions, 0, - dtype=torch.int, - device=self.runner.device) + input_positions, 0, dtype=torch.int, device=self.runner.device) input_tokens_tensor = torch.flatten(input_tokens_tensor) input_positions_tensor = torch.flatten(input_positions_tensor) max_seq_len = max(seq_lens) seq_lens = len(seq_lens) * [max_seq_len] else: - input_tokens_tensor = torch.tensor( - flatten_2d_lists(input_tokens), - dtype=torch.long, - device=self.runner.device) + input_tokens_tensor = torch.tensor(flatten_2d_lists(input_tokens), + dtype=torch.long, + device=self.runner.device) input_positions_tensor = torch.tensor( - flatten_2d_lists(input_positions), - dtype=torch.long, - device=self.runner.device) + flatten_2d_lists(input_positions), + dtype=torch.long, + device=self.runner.device) # Sequence and query lengths. seq_lens.extend([1] * cuda_graph_pad_size) @@ -207,108 +164,21 @@ class NPUModelRunner(ModelRunner): """ NPU model runner with sampling step. """ - _model_input_cls: Type[ModelInputForNPUWithSamplingMetadata] = ( - ModelInputForNPUWithSamplingMetadata) + _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( + ModelInputForGPUWithSamplingMetadata) _builder_cls: Type[ModelInputForNPUBuilder] = ModelInputForNPUBuilder - def __init__( - self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - mindie_model_config: Optional[dict], - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - return_hidden_states: bool = False, - observability_config: Optional[ObservabilityConfig] = None, - input_registry: InputRegistry = INPUT_REGISTRY, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY): - super().__init__( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - lora_config, - kv_cache_dtype, - is_driver_worker, - prompt_adapter_config, - return_hidden_states, - observability_config, - input_registry, - mm_registry, - ) - - self.mindie_model_config = mindie_model_config - def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any], - ) -> ModelInputForNPUWithSamplingMetadata: + ) -> ModelInputForGPUWithSamplingMetadata: model_input = \ - ModelInputForNPUWithSamplingMetadata.from_broadcasted_tensor_dict( + ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, ) return model_input - def model_router(self) -> None: - # TODO: goto mindie_model when is_mindie() and model_supports_in_mindie() - # TODO: use super().load_model() after unify the DeviceMemoryProfiler - self.get_vllm_model() - - def load_model(self) -> None: - self.model_router() - - def get_vllm_model(self) -> None: - logger.info("Starting to load model %s...", self.model_config.model) - with DeviceMemoryProfiler() as m: - self.model = get_model(model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config) - - self.model_memory_usage = m.consumed_memory - logger.info("Loading model weights took %.4f GB", - self.model_memory_usage / float(2**30)) - - if self.lora_config: - assert supports_lora(self.model), "Model does not support LoRA" - assert not supports_multimodal( - self.model - ), "To be tested: Multi-modal model with LoRA settings." - - self.lora_manager = LRUCacheWorkerLoRAManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, - self.vocab_size, - self.lora_config, - self.device, - self.model.embedding_modules, - self.model.embedding_padding_modules, - max_position_embeddings=self.model.config. - max_position_embeddings, - ) - self.model = self.lora_manager.create_lora_manager(self.model) - - if self.prompt_adapter_config: - self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, self.device, - self.prompt_adapter_config) - self.model = ( - self.prompt_adapter_manager.create_prompt_adapter_manager( - self.model)) - @current_platform.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. @@ -434,9 +304,13 @@ def prepare_model_input( # Sampling metadata is only required for the final pp group generators = self.get_generators(finished_requests_ids) sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, model_input.seq_lens, - model_input.query_lens, self.device, self.pin_memory, - generators, self.sampling_metadata_cache, + seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + self.pin_memory, + generators, + self.sampling_metadata_cache, pad_for_invariant_seq_len=True) else: sampling_metadata = None diff --git a/vllm/worker/npu_worker.py b/vllm/worker/npu_worker.py index 1574d8ea23ca5..3d814e8e5bb99 100644 --- a/vllm/worker/npu_worker.py +++ b/vllm/worker/npu_worker.py @@ -1,6 +1,6 @@ """A NPU worker class.""" import gc -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple, Type, Union import torch import torch.distributed @@ -17,6 +17,7 @@ from vllm.sequence import SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner +from vllm.worker.model_runner import GPUModelRunnerBase from vllm.worker.npu_model_runner import NPUModelRunner from vllm.worker.worker import Worker @@ -78,20 +79,13 @@ def __init__( not in ["medusa", "mlp_speculator"]) \ else {"return_hidden_states": True} - ModelRunnerClass: Type[NPUModelRunner] = NPUModelRunner + ModelRunnerClass: Union[Type[NPUModelRunner], + Type[EmbeddingModelRunner]] = NPUModelRunner if model_runner_cls is not None: ModelRunnerClass = model_runner_cls - elif self.model_config.embedding_mode: + elif self._is_embedding_model(): ModelRunnerClass = EmbeddingModelRunner - mindie_model_config = { - "backend_type": "atb", - "model_id": model_config.model, - "rank": rank, - "local_rank": local_rank, - "world_size": parallel_config.world_size, - "npu_device_id": local_rank, - } - self.model_runner: NPUModelRunner = ModelRunnerClass( + self.model_runner: GPUModelRunnerBase = ModelRunnerClass( model_config, parallel_config, scheduler_config, @@ -99,7 +93,6 @@ def __init__( cache_config, load_config=load_config, lora_config=self.lora_config, - mindie_model_config=mindie_model_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, prompt_adapter_config=prompt_adapter_config, @@ -115,14 +108,6 @@ def __init__( def init_device(self) -> None: if self.device_config.device.type == "npu": - # # torch.distributed.all_reduce does not free the input tensor until - # # the synchronization point. This causes the memory usage to grow - # # as the number of all_reduce calls increases. This env var disables - # # this behavior. - # # Related issue: - # # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - # os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - # # This env var set by Ray causes exceptions with graph building. # os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) self.device = torch.device(f"npu:{self.local_rank}") @@ -193,12 +178,11 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: def init_worker_distributed_environment( - parallel_config: ParallelConfig, - rank: int, - distributed_init_method: Optional[str] = None, - local_rank: int = -1, - backend: str = "hccl" -) -> None: + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, + backend: str = "hccl") -> None: """Initialize the distributed environment.""" set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)