diff --git a/examples/offline_inference_npu.py b/examples/offline_inference_npu.py new file mode 100644 index 0000000000000..7d476fe70c467 --- /dev/null +++ b/examples/offline_inference_npu.py @@ -0,0 +1,22 @@ +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + # "Hello, my name is", + "The president of the United States is", + # "The capital of France is", + # "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(max_tokens=100, temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM(model="facebook/opt-125m") +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/pyproject.toml b/pyproject.toml index 22a25d9cf32e6..46010a93fc05c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ "ninja", "packaging", "setuptools >= 49.4.0", - "torch == 2.4.0", + "torch == 2.1.0", "wheel", "jinja2", ] diff --git a/requirements-npu.txt b/requirements-npu.txt new file mode 100644 index 0000000000000..4132288471a64 --- /dev/null +++ b/requirements-npu.txt @@ -0,0 +1,11 @@ +# Common dependencies +-r requirements-common.txt + +decorator +pyyaml +scipy +setuptools +torch == 2.1.0 +torch_npu == 2.1.0.post6 +# torch == 2.4.0 +# torch_npu == 2.4.0.rc1 \ No newline at end of file diff --git a/setup.py b/setup.py index 1e08a5bd70cd3..1d9d7696e3624 100644 --- a/setup.py +++ b/setup.py @@ -277,6 +277,9 @@ def _is_openvino() -> bool: def _is_xpu() -> bool: return VLLM_TARGET_DEVICE == "xpu" +def _is_npu() -> bool: + return VLLM_TARGET_DEVICE == "npu" + def _build_custom_ops() -> bool: return _is_cuda() or _is_hip() or _is_cpu() @@ -389,6 +392,8 @@ def get_vllm_version() -> str: version += "+cpu" elif _is_xpu(): version += "+xpu" + elif _is_npu(): + version += "+npu" else: raise RuntimeError("Unknown runtime environment") @@ -444,10 +449,13 @@ def _read_requirements(filename: str) -> List[str]: requirements = _read_requirements("requirements-cpu.txt") elif _is_xpu(): requirements = _read_requirements("requirements-xpu.txt") + elif _is_npu(): + requirements = _read_requirements("requirements-npu.txt") else: raise ValueError( "Unsupported platform, please use CUDA, ROCm, Neuron, " "OpenVINO, or CPU.") + print("requirements", requirements) return requirements diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index ec7c2ba3e3ce0..0ddaf46994f39 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -29,7 +29,7 @@ def test_vllm_gc_ed(): @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"]) +@pytest.mark.parametrize("backend", ["ASCEND_TORCH"]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False, True]) diff --git a/tests/conftest.py b/tests/conftest.py index cd0091b7cba68..13040d25efe4f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,7 +33,7 @@ from vllm.outputs import RequestOutput from vllm.sequence import SampleLogprobs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, - identity, is_cpu) + identity, is_cpu, is_npu) logger = init_logger(__name__) @@ -213,6 +213,10 @@ def wrap_device(self, input: _T) -> _T: if hasattr(input, 'device') and input.device.type == "cuda": return input # Already on GPU, no need to move return input.to("cuda") + elif is_npu(): + if hasattr(input, 'device') and input.device.type == "npu": + return input # Already on GPU, no need to move + return input.to("npu") else: # Check if the input is already on the CPU if hasattr(input, 'device') and input.device.type == "cpu": diff --git a/tests/engine/test_custom_executor.py b/tests/engine/test_custom_executor.py index bff0fc99ed022..725c5e5ae5bdb 100644 --- a/tests/engine/test_custom_executor.py +++ b/tests/engine/test_custom_executor.py @@ -7,7 +7,14 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine from vllm.executor.gpu_executor import GPUExecutor, GPUExecutorAsync +from vllm.executor.npu_executor import NPUExecutor, NPUExecutorAsync from vllm.sampling_params import SamplingParams +from vllm.attention.backends.ascend import AscendPagedAttention +from vllm.attention.ops.paged_attn import PagedAttention + + +# NOTE (cmq): do monkey patch +PagedAttention = AscendPagedAttention class Mock: @@ -23,6 +30,15 @@ def execute_model(self, *args, **kwargs): return super().execute_model(*args, **kwargs) +class CustomNPUExecutor(NPUExecutor): + + def execute_model(self, *args, **kwargs): + # Drop marker to show that this was ran + with open(".marker", "w"): + ... + return super().execute_model(*args, **kwargs) + + class CustomGPUExecutorAsync(GPUExecutorAsync): async def execute_model_async(self, *args, **kwargs): @@ -41,9 +57,13 @@ def test_custom_executor_type_checking(model): engine_args = AsyncEngineArgs(model=model, distributed_executor_backend=Mock) AsyncLLMEngine.from_engine_args(engine_args) + # with pytest.raises(TypeError): + # engine_args = AsyncEngineArgs( + # model=model, distributed_executor_backend=CustomGPUExecutor) + # AsyncLLMEngine.from_engine_args(engine_args) with pytest.raises(TypeError): engine_args = AsyncEngineArgs( - model=model, distributed_executor_backend=CustomGPUExecutor) + model=model, distributed_executor_backend=CustomNPUExecutor) AsyncLLMEngine.from_engine_args(engine_args) @@ -55,7 +75,7 @@ def test_custom_executor(model, tmpdir): assert not os.path.exists(".marker") engine_args = EngineArgs( - model=model, distributed_executor_backend=CustomGPUExecutor) + model=model, distributed_executor_backend=CustomNPUExecutor) engine = LLMEngine.from_engine_args(engine_args) sampling_params = SamplingParams(max_tokens=1) @@ -67,25 +87,25 @@ def test_custom_executor(model, tmpdir): os.chdir(cwd) -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) -def test_custom_executor_async(model, tmpdir): - cwd = os.path.abspath(".") - os.chdir(tmpdir) - try: - assert not os.path.exists(".marker") +# @pytest.mark.parametrize("model", ["facebook/opt-125m"]) +# def test_custom_executor_async(model, tmpdir): +# cwd = os.path.abspath(".") +# os.chdir(tmpdir) +# try: +# assert not os.path.exists(".marker") - engine_args = AsyncEngineArgs( - model=model, distributed_executor_backend=CustomGPUExecutorAsync) - engine = AsyncLLMEngine.from_engine_args(engine_args) - sampling_params = SamplingParams(max_tokens=1) +# engine_args = AsyncEngineArgs( +# model=model, distributed_executor_backend=CustomGPUExecutorAsync) +# engine = AsyncLLMEngine.from_engine_args(engine_args) +# sampling_params = SamplingParams(max_tokens=1) - async def t(): - stream = await engine.add_request("0", "foo", sampling_params) - async for x in stream: - ... +# async def t(): +# stream = await engine.add_request("0", "foo", sampling_params) +# async for x in stream: +# ... - asyncio.run(t()) +# asyncio.run(t()) - assert os.path.exists(".marker") - finally: - os.chdir(cwd) +# assert os.path.exists(".marker") +# finally: +# os.chdir(cwd) diff --git a/vllm/attention/backends/ascend.py b/vllm/attention/backends/ascend.py new file mode 100644 index 0000000000000..acbdc45319170 --- /dev/null +++ b/vllm/attention/backends/ascend.py @@ -0,0 +1,661 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, TYPE_CHECKING, Optional, Tuple, Type + +import torch +try: + import torch_npu +except: + raise ImportError("torch-npu not found. 'pip install torch-npu' if using Ascend backend") + +import math +import numpy as np + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType, + AttentionMetadataBuilder) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.ops.paged_attn import PagedAttention +if TYPE_CHECKING: + from vllm.worker.npu_model_runner import ModelInputForNPUBuilder + +from vllm.utils import make_tensor_with_pad + +SHARE_MASK_TRIL_PREFIX_CACHE = None +SHARE_MASK_TRIL = None + + +class AscendAttentionBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["AscendAttentionBackendImpl"]: + return AscendAttentionBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["AscendMetadata"]: + return AscendMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + # return (2, num_blocks, block_size, num_kv_heads * head_size) + return (2, num_blocks, block_size, num_kv_heads * head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: List[torch.Tensor], + dst_kv_cache: List[torch.Tensor], + src_to_dst: Dict[int, int], + ) -> None: + # TODO (cmq): check me + src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1] + dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1] + for src, dst in src_to_dst.items(): + dst_key_cache[dst] = src_key_cache[src].to(dst_key_cache.device) + dst_value_cache[dst] = src_value_cache[src].to(dst_key_cache.device) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + # TODO (cmq): check me + key_caches = kv_caches[0] + value_caches = kv_caches[1] + layers = len(key_caches) + for src_id, dsts in src_to_dists.items(): + for dst_id in dsts: + key_caches[:][dst_id] = key_caches[:][src_id] + value_caches[:][dst_id] = value_caches[:][src_id] + + @staticmethod + def get_builder_cls() -> Type["AscendMetadataBuilder"]: + return AscendMetadataBuilder + + @classmethod + def make_metadata_builder(cls, *args, **kwargs) -> "AscendMetadataBuilder": + return cls.get_builder_cls()(*args, **kwargs) + + +class AscendPagedAttention(PagedAttention): + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_indices: torch.Tensor, + ) -> None: + torch_npu.npu_scatter_nd_update_(key_cache, slot_indices, key) + torch_npu.npu_scatter_nd_update_(value_cache, slot_indices, value) + + +@dataclass(kw_only=True) +class AscendMetadata(AttentionMetadata): + # Currently, input sequences can only contain all prefills + # or all decoding. + is_prompt: bool + seq_lens: Optional[List[int]] + seq_lens_tensor: Optional[torch.Tensor] + max_seq_len: Optional[int] + + # metadata for NPU + max_query_len: Optional[int] + subquery_start_loc: Optional[torch.Tensor] + seq_start_loc: Optional[torch.Tensor] + block_tables: Optional[torch.Tensor] + context_lens: Optional[torch.Tensor] + block_size: Optional[int] = 0 + slot_mapping: Optional[torch.Tensor] = None + slot_indices: Optional[torch.Tensor] = None + use_cuda_graph: bool = False # TODO (cmq) is this neccesary? + + pse_shift: Optional[torch.Tensor] = None + sparse_mode: Optional[int] = 0 + + attn_mask: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self) -> Optional["AscendMetadata"]: + if self.num_prefills == 0: + return None + + assert self.num_decode_tokens == 0 + # assert self.block_tables is None + # assert self.context_lens is None + return self + + @property + def decode_metadata(self) -> Optional["AscendMetadata"]: + if self.num_decode_tokens == 0: + return None + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.block_tables is not None + assert self.context_lens is not None + return self + + +class AscendMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): + + def __init__(self, input_builder: "ModelInputForNPUBuilder"): + # slot mapping: mapping of sequence offset to physical address + self.slot_mapping: List[List[int]] = [] + self.slot_indices: List[List[List[int]]] = [] + + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + # use_v2_block_manager not supported in Ascend + self.use_v2_block_manager = False + + def compute_slot_indices( + self, + is_profile_run: bool, + slot_indices: List[List[int]], + seq_id: int, + seq_len: int, + context_len: int, + start_idx: int, + block_size: int, + block_tables: Dict[int, List[int]], + ): + """ + Compute slot indices. + """ + if is_profile_run: + # During memory profiling, the block tables are not + # initialized yet. In this case, we just pass the slot indices updating. + return + block_table = block_tables[seq_id] + for i in range(max(start_idx, context_len), seq_len): + block_number = block_table[i // block_size] + block_offset = i % block_size + slot_indices.append([block_number, block_offset]) + + def _add_seq_group( + self, + inter_data: "ModelInputForNPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, + prefix_cache_hit: bool, + ): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for ( + seq_id, + token_len, + seq_len, + curr_seq_len, + query_len, + context_len, + curr_sliding_window_block, + ) in zip( + inter_data.seq_ids, + [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, + inter_data.seq_lens, + inter_data.query_lens, + inter_data.context_lens, + inter_data.curr_sliding_window_blocks, + ): + self.context_lens.append(context_len) + + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert ( + query_len == 1 + ), "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len + ) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ( + chunked_prefill_enabled or not is_prompt + ) and block_tables is not None: + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping and slot indices + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx( + is_prompt, + query_len, + context_len, + self.sliding_window, + self.use_v2_block_manager, + ) + compute_slot_mapping( + is_profile_run, + self.slot_mapping, + seq_id, + seq_len, + context_len, + start_idx, + self.block_size, + inter_data.block_tables, + ) + self.compute_slot_indices( + is_profile_run, + self.slot_indices, + seq_id, + seq_len, + context_len, + start_idx, + self.block_size, + inter_data.block_tables, + ) + + """ + Compute the start index of slot mapping. + """ + start_idx = 0 + if is_prompt and self.sliding_window is not None: + assert context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention in V1 block manager" + ) + # When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - self.sliding_window) + + """ + Compute slot mapping. + """ + if is_profile_run: + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + self.slot_mapping.extend([PAD_SLOT_ID] * seq_len) + return + + # Mask the [0, start_idx) tokens of the prompt with + # PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + # block_table = block_tables[seq_id] + # self.slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len)) + # self.slot_mapping.append([]) + # self.slot_indices.append([]) + + # for i in range(max(start_idx, context_len), seq_len): + # block_number = block_table[i // self.block_size] + # block_offset = i % self.block_size + # slot = block_number * self.block_size + block_offset + # self.slot_mapping[-1].append(slot) + # self.slot_indices[-1].append([block_number, block_offset]) + + def build( + self, + seq_lens: List[int], + query_lens: List[int], + cuda_graph_pad_size: int, + batch_size: int, + ): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any( + [ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ] + ) + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group( + inter_data, self.input_builder.chunked_prefill_enabled, prefix_cache_hit + ) + + device = self.runner.device + + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, "query_lens: {}".format(query_lens) + + context_lens_tensor = torch.tensor( + self.context_lens, dtype=torch.int, device=device + ) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) + query_lens_tensor = torch.tensor(query_lens, dtype=torch.long, device=device) + query_start_loc = torch.zeros( + query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device + ) + seq_start_loc = torch.zeros( + seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device + ) + torch.cumsum( + seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:] + ) + torch.cumsum( + query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:], + ) + + slot_mapping_tensor = torch.tensor( + self.slot_mapping, dtype=torch.long, device=device + ) + max_seq_len = max(seq_lens) + pad_slot_indices = [] + for idx in self.slot_indices: + pad_slot_indices.append(idx) + if len(idx) < max_seq_len: + pad_slot_indices += [[np.iinfo(np.int_).max, 0]] * ( + max_seq_len - len(idx) + ) + slot_indices_tensor = torch.tensor( + self.slot_indices, dtype=torch.int64, device=device + ) + + return AscendMetadata( + is_prompt=True, # TODO (cmq): check me + max_seq_len=max_seq_len, + num_prefills=self.num_prefills, + slot_indices=slot_indices_tensor, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + subquery_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens=context_lens_tensor, + block_tables=block_tables, + block_size=self.block_size, + use_cuda_graph=False, # not support in NPU + ) + + +class AscendAttentionBackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = sliding_window + self.alibi_slopes = alibi_slopes + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + def attn_free_mask_pfa(self): + global SHARE_MASK_TRIL_PREFIX_CACHE + if SHARE_MASK_TRIL_PREFIX_CACHE is None: + SHARE_MASK_TRIL_PREFIX_CACHE = torch.triu( + torch.ones(1, 1, 2048, 2048, dtype=bool, device="npu"), + diagonal=1, + ) + return SHARE_MASK_TRIL_PREFIX_CACHE + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: List[torch.Tensor], + attn_metadata: AscendMetadata, + kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + """Forward pass with Ascend attention. + + Args: + query: shape = [batch_size, seq_len * num_heads * head_size] + key: shape = [batch_size, seq_len * num_kv_heads * head_size] + value: shape = [batch_size, seq_len * num_kv_heads * head_size] + key_cache = [num_blocks, block_size, num_kv_heads * head_size] + value_cache = [num_blocks, block_size, num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size, seq_len * num_heads * head_size] + """ + assert k_scale == 1.0 and v_scale == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl" + ) + # view q k v to BSH + batch_size = query.shape[0] + + if kv_cache is not None: + if attn_metadata.num_prefills > 0: + slot_indices = attn_metadata.prefill_metadata.slot_indices + else: + slot_indices = attn_metadata.decode_metadata.slot_indices + key_cache, value_cache = kv_cache[0], kv_cache[1] + AscendPagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + slot_indices, + ) + + if attn_metadata.num_prefills > 0: + # TODO (cmq): modify attn_metadata.sparse_mode, attention_mask ... + # add ENV var to turn on/off maskfree_attn and change 16384 + # batch_size = len(attn_metadata.seq_lens) + if attn_metadata.attn_mask is None and query.shape[0] > 16384: + attn_metadata.attn_mask = self.attn_free_mask_pfa + attn_metadata.sparse_mode = 2 + + if attn_metadata.attn_mask is None: + query_len = attn_metadata.seq_lens_tensor + kv_len = torch.zeros_like(query_len).to(torch.long) + attention_mask = gen_input_mask( + len(attn_metadata.seq_lens), + attn_metadata.max_seq_len, + query_len, + kv_len, + ) + # attention_mask = gen_input_mask(batch_size, attn_metadata.max_seq_len, query_len, kv_len) + + if self.sliding_window is not None: + attention_mask = ~attention_mask + attention_mask = torch.triu( + attention_mask, diagonal=1 - self.sliding_window + ) + attention_mask = ~attention_mask + attn_metadata.attn_mask = attention_mask + + if self.alibi_slopes is not None and attn_metadata.pse_shift is None: + attn_metadata.pse_shift = _make_alibi_bias( + self.alibi_slopes, + self.num_kv_heads, + dtype=query.dtype, + seq_len=attn_metadata.max_seq_len, + batch_size=batch_size, + ) + # shape of q/k/v [B,S*H] --> [B,S,N,D] + query = query.view( + -1, attn_metadata.max_seq_len, self.num_heads, self.head_size + ).transpose(1, 2) + key = key.view( + -1, attn_metadata.max_seq_len, self.num_kv_heads, self.head_size + ).transpose(1, 2) + value = value.view( + -1, attn_metadata.max_seq_len, self.num_kv_heads, self.head_size + ).transpose(1, 2) + + # FA for prefill phase + output = torch_npu.npu_prompt_flash_attention( + query, + key, + value, + pse_shift=attn_metadata.pse_shift, + atten_mask=attn_metadata.attn_mask, + num_heads=self.num_heads, + scale_value=1 / math.sqrt(self.head_size), + input_layout="BNSD", + num_key_value_heads=self.num_kv_heads, + pre_tokens=65535, + next_tokens=0, + sparse_mode=attn_metadata.sparse_mode, + ) + output = output.transpose(1, 2).reshape( + batch_size, -1, self.num_heads * self.head_size + ) + if output.shape[1] == 1: + output = output.squeeze(1) + elif decode_meta := attn_metadata.decode_metadata: + # FA for decoding phase + assert kv_cache is not None + # shape of query [B,S*H] --> [B,S,H] + query = query.view( + -1, + 1, + self.head_size * self.num_kv_heads, + ) + output = torch_npu.npu_incre_flash_attention( + query, + key_cache, + value_cache, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + scale_value=self.scale, + input_layout="BSH", + block_table=attn_metadata.block_tables, + block_size=attn_metadata.block_size, # max val of block_size == 512 + actual_seq_lengths=attn_metadata.seq_lens, + ).squeeze(1) + + return output + + +# TODO: add padding input +# def pad_input(attn_metadata: AscendMetadata, +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor): + + +def gen_input_mask( + batch_size, seq_len, query_len: torch.LongTensor, kv_len: torch.LongTensor +): + """ + Generating lower triangular matrix + """ + global SHARE_MASK_TRIL + if SHARE_MASK_TRIL is None or SHARE_MASK_TRIL.shape[0] < seq_len: + SHARE_MASK_TRIL = ~torch.tril( + torch.ones(seq_len, seq_len, dtype=bool, device="npu") + ) + range_idx = torch.arange(seq_len, device=query_len.device).expand(batch_size, -1) + select_idx = range_idx + kv_len.unsqueeze(1) + attn_mask = torch.index_select( + SHARE_MASK_TRIL, index=select_idx.view(-1), dim=0 + ).view(batch_size, seq_len, -1) + padding_idx = range_idx >= query_len.unsqueeze(1) + padding_idx = padding_idx.unsqueeze(2) + attn_mask = attn_mask.masked_fill(padding_idx, 1) + q_len = attn_mask.shape[1] + attn_mask = attn_mask[:, :, :q_len] + + return attn_mask.unsqueeze(1)[0].unsqueeze(0) + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + dtype: torch.dtype, + seq_len: int, + batch_size: int, +): + bias = torch.arange(seq_len, dtype=dtype, device="npu") + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + # Calculate a matrix where each element represents ith element- jth + # element. + bias = bias[None, :] - bias[:, None] + + padded_len = (seq_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + batch_size, + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + + return bias diff --git a/vllm/attention/backends/ascend_mindie.py b/vllm/attention/backends/ascend_mindie.py new file mode 100644 index 0000000000000..e15061b5fea7c --- /dev/null +++ b/vllm/attention/backends/ascend_mindie.py @@ -0,0 +1,242 @@ +from dataclasses import dataclass +import torch +from typing import List, Optional, Tuple + +from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata +from vllm.worker.npu_model_runner import ModelInputForNPUBuilder + +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.utils import make_tensor_with_pad + +class AscendMindIEBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "ascend-mindie-atb-models" + + @staticmethod + def get_impl_cls() -> Type["AttentionImpl"]: + # mindie doesn`t use the single ops + return None + + @staticmethod + def get_metadata_cls() -> Type["AscendMindIEMetadata"]: + raise AscendMindIEMetadata + + @classmethod + def make_metadata(cls, *args, **kwargs) -> "AscendMindIEMetadata": + return cls.get_metadata_cls()(*args, **kwargs) + + @staticmethod + def get_builder_cls() -> Type["AscendMindIEMetadataBuilder"]: + return AscendMindIEMetadataBuilder + + @classmethod + def make_metadata_builder(cls, *args, + **kwargs) -> "AscendMindIEMetadataBuilder": + return cls.get_builder_cls()(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + # TODO + pass + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + # TODO + pass + + +@dataclass +class AscendMindIEMetadata(AttentionMetadata): + """ + Metadata for AscendMindIEBackend. + """ + + # Currently, input sequences can only contain all prefills + # or all decoding. + is_prompt: bool + seq_lens: Optional[List[int]] + seq_lens_tensor: Optional[torch.Tensor] + max_seq_len: Optional[int] + + # metadata for NPU + max_query_len: Optional[int] + subquery_start_loc: Optional[torch.Tensor] + seq_start_loc: Optional[torch.Tensor] + block_tables: Optional[torch.Tensor] + context_lens: Optional[torch.Tensor] + + use_cuda_graph: bool = False # TODO (cmq) is this neccesary? + + +class AscendMindIEMetadataBuilder(AscendMindIEMetadataBuilder[AscendMindIEMetadata]): + + def __init__(self, input_builder: "ModelInputForNPUBuilder") -> None: + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.use_v2_block_manager = ( + input_builder.scheduler_config.use_v2_block_manager) + + def _add_seq_group( + self, inter_data: "ModelInputForNPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx( + is_prompt, query_len, context_len, self.sliding_window, + self.use_v2_block_manager) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + + if use_captured_graph: + self.block_tables.extend([] * cuda_graph_pad_size) + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.runner.graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=device) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + context_lens_tensor = torch.tensor(self.context_lens, + dtype=torch.int, + device=device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + return AscendMindIEMetadata( + is_prompt=True, # TODO + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_seq_len=max(seq_lens), + subquery_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) \ No newline at end of file diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ecf964fa49d9b..139e75b888378 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -76,6 +76,8 @@ def __init__( # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() + + # Attention算子调用 attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads, sliding_window, dtype, kv_cache_dtype, block_size, blocksparse_params diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 855586d4e5961..ff97357a6ad05 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -10,7 +10,8 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, is_xpu +from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_mindie, is_npu, is_openvino, is_xpu +from vllm.model_executor.model_loader.ascend_mindie import model_supports_in_mindie logger = init_logger(__name__) @@ -24,6 +25,8 @@ class _Backend(enum.Enum): FLASHINFER = enum.auto() PALLAS = enum.auto() IPEX = enum.auto() + ASCEND_TORCH = enum.auto() + ASCEND_MINDIE = enum.auto() def backend_name_to_enum(backend_name: str) -> _Backend: @@ -146,6 +149,23 @@ def get_attn_backend( logger.info("Using Pallas backend.") from vllm.attention.backends.pallas import PallasAttentionBackend return PallasAttentionBackend + # TODO + # MINDIE 和 Torch_NPU 同一个backend还是不同的backend + # elif backend == _Backend.Ascend_MINDIE: + # from vllm.attention.backends.ascend import AscendAttentionBackend + # return AscendAttentionBackend + # elif backend == _Backend.Ascend_TORCH: + # from vllm.attention.backends.ascend import AscendAttentionBackend + # return AscendAttentionBackend + elif backend == _Backend.ASCEND_TORCH: + logger.info("Using ASCEND_TORCH backend.") + # from vllm.attention.backends.ascend import AscendTorchAttentionBackend + # return AscendTorchAttentionBackend + from vllm.attention.backends.ascend import AscendAttentionBackend + return AscendAttentionBackend + elif backend == _Backend.ASCEND_MINDIE: + from vllm.attention.backends.ascend_mindie import AscendMindIEBackend + return AscendMindIEBackend else: raise ValueError("Invalid attention backend.") @@ -210,6 +230,14 @@ def which_attn_to_use( logger.info("%s is not supported in AMD GPUs.", selected_backend) return _Backend.ROCM_FLASH + if is_npu(): + # TODO: torch and mindie + # Ascend NPU + if selected_backend not in (_Backend.ASCEND_TORCH, _Backend.ASCEND_MINDIE): + logger.info("Cannot use %s backend on NPU.", selected_backend) + if is_mindie() and model_supports_in_mindie(): + return _Backend.ASCEND_MINDIE + return _Backend.ASCEND_TORCH # FlashAttn in NVIDIA GPUs. if selected_backend == _Backend.FLASH_ATTN: if current_platform.get_device_capability()[0] < 8: diff --git a/vllm/config.py b/vllm/config.py index 8f5e02e35f28d..08d73de2cc2e9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -18,7 +18,7 @@ get_hf_text_config) from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes, cuda_device_count_stateless, get_cpu_memory, is_cpu, - is_hip, is_neuron, is_openvino, is_xpu, + is_hip, is_neuron, is_npu, is_openvino, is_xpu, print_warning_once) if TYPE_CHECKING: @@ -60,8 +60,8 @@ class ModelConfig: Args: model: Name or path of the huggingface model to use. - It is also used as the content for `model_name` tag in metrics - output when `served_model_name` is not specified. + It is also used as the content for `model_name` tag in metrics + output when `served_model_name` is not specified. tokenizer: Name or path of the huggingface tokenizer to use. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if available, "slow" will always use the slow tokenizer, and @@ -112,15 +112,15 @@ class ModelConfig: skip_tokenizer_init: If true, skip initialization of tokenizer and detokenizer. served_model_name: The model name used in metrics tag `model_name`, - matches the model name exposed via the APIs. If multiple model - names provided, the first name will be used. If not specified, + matches the model name exposed via the APIs. If multiple model + names provided, the first name will be used. If not specified, the model name will be the same as `model`. - limit_mm_per_prompt: Maximum number of data instances per modality + limit_mm_per_prompt: Maximum number of data instances per modality per prompt. Only applicable for multimodal models. - override_neuron_config: Initialize non default neuron config or - override default neuron config that are specific to Neuron devices, - this argument will be used to configure the neuron config that - can not be gathered from the vllm arguments. + override_neuron_config: Initialize non default neuron config or + override default neuron config that are specific to Neuron devices, + this argument will be used to configure the neuron config that + can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. """ @@ -771,9 +771,8 @@ class LoadConfig: fast weight loading. "bitsandbytes" will load nf4 type weights. ignore_patterns: The list of patterns to ignore when loading the model. - Default to "original/**/*" to avoid repeated loading of llama's + Default to "original/**/*" to avoid repeated loading of llama's checkpoints. - """ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO @@ -950,7 +949,7 @@ class SchedulerConfig: enable_chunked_prefill: If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. embedding_mode: Whether the running model is for embedding. - preemption_mode: Whether to perform preemption by swapping or + preemption_mode: Whether to perform preemption by swapping or recomputation. If not specified, we determine the mode as follows: We use recomputation by default since it incurs lower overhead than swapping. However, when the sequence group has multiple sequences @@ -1068,6 +1067,8 @@ def __init__(self, device: str = "auto") -> None: self.device_type = "cpu" elif is_xpu(): self.device_type = "xpu" + elif is_npu(): + self.device_type = "npu" else: # We don't call torch.cuda.is_available() here to # avoid initializing CUDA before workers are forked @@ -1160,7 +1161,7 @@ def maybe_create_spec_config( typical_acceptance_sampler_posterior_threshold (Optional[float]): A threshold value that sets a lower bound on the posterior probability of a token in the target model for it to be - accepted. This threshold is used only when we use the + accepted. This threshold is used only when we use the TypicalAcceptanceSampler for token acceptance. typical_acceptance_sampler_posterior_alpha (Optional[float]): A scaling factor for the entropy-based threshold in the @@ -1170,7 +1171,7 @@ def maybe_create_spec_config( If set to False, token log probabilities are returned according to the log probability settings in SamplingParams. If not specified, it defaults to True. - + Returns: Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if the necessary conditions are met, else None. @@ -1415,13 +1416,13 @@ def __init__( typical_acceptance_sampler_posterior_threshold (Optional[float]): A threshold value that sets a lower bound on the posterior probability of a token in the target model for it to be - accepted. This threshold is used only when we use the + accepted. This threshold is used only when we use the TypicalAcceptanceSampler for token acceptance. typical_acceptance_sampler_posterior_alpha (Optional[float]): A scaling factor for the entropy-based threshold in the TypicalAcceptanceSampler. disable_logprobs: If set to True, token log probabilities will not - be returned even if requested by sampling parameters. This + be returned even if requested by sampling parameters. This reduces latency by skipping logprob calculation in proposal sampling, target sampling, and after accepted tokens are determined. If set to False, log probabilities will be @@ -1786,10 +1787,10 @@ def _get_and_verify_max_len( def get_served_model_name(model: str, served_model_name: Optional[Union[str, List[str]]]): """ - If the input is a non-empty list, the first model_name in - `served_model_name` is taken. - If the input is a non-empty string, it is used directly. - For cases where the input is either an empty string or an + If the input is a non-empty list, the first model_name in + `served_model_name` is taken. + If the input is a non-empty string, it is used directly. + For cases where the input is either an empty string or an empty list, the fallback is to use `self.model`. """ if not served_model_name: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 94271c4a93151..c3ed4a62f778d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -150,7 +150,7 @@ class LLMEngine: decoding. executor_class: The model executor class for managing distributed execution. - prompt_adapter_config (Optional): The configuration related to serving + prompt_adapter_config (Optional): The configuration related to serving prompt adapters. log_stats: Whether to log statistics. usage_context: Specified entry point, used for usage info collection. @@ -537,6 +537,13 @@ def _get_executor_cls(cls, "multiprocessing distributed executor backend does not " "support VLLM_USE_RAY_SPMD_WORKER=1") executor_class = MultiprocessingGPUExecutor + elif engine_config.device_config.device_type == "npu": + if distributed_executor_backend == "ray": + # TODO + pass + else: + from vllm.executor.npu_executor import NPUExecutor + executor_class = NPUExecutor else: from vllm.executor.gpu_executor import GPUExecutor executor_class = GPUExecutor @@ -879,7 +886,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: "default" decoder prompt be . However, it is possible that in the future - other models may have different or more + other models may have different or more complex logic for the default decoder prompt. This motivates having a special helper method for default decoder prompts. @@ -942,7 +949,7 @@ def _process_encoder_decoder_prompt( have any possible singleton type; thus this method relies on helper functions to obtain token ids for the sub-prompts. - + Arguments: * inputs: an input prompt @@ -1273,7 +1280,7 @@ def _process_model_outputs(self, ctx: The virtual engine context to work on request_id: If provided, then only this request is going to be processed - + """ now = time.time() diff --git a/vllm/executor/npu_executor.py b/vllm/executor/npu_executor.py new file mode 100644 index 0000000000000..33e861b7808a8 --- /dev/null +++ b/vllm/executor/npu_executor.py @@ -0,0 +1,59 @@ +from typing import List, Optional + +from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.gpu_executor import GPUExecutor +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.utils import make_async +from vllm.worker.worker_base import WorkerWrapperBase + +logger = init_logger(__name__) + + +class NPUExecutor(GPUExecutor): + + def _init_executor(self) -> None: + # TODO: 这两个参数作用 + assert not self.scheduler_config.chunked_prefill_enabled, ( + "Chunked prefill is not yet supported for NPU backend") + assert not self.speculative_config, ( + "Speculative decoding is not yet supported for NPU backend") + + # Instantiate the worker and load the model to the device. + self.driver_worker = self._create_worker() + self.driver_worker.init_device() + self.driver_worker.load_model() + + def _create_worker(self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None): + worker_module_name = "vllm.worker.npu_worker" + worker_class_name = "NPUWorker" + + wrapper = WorkerWrapperBase( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + ) + wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, + distributed_init_method)) + return wrapper.worker + + # 输出格式是否要改变?? + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = self.driver_worker.execute_model(execute_model_req) + return output + + +class NPUExecutorAsync(NPUExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req) + return output diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 49247cd5de42a..61018fb3f77a7 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -3,6 +3,9 @@ from vllm.platforms import current_platform from vllm.utils import is_cpu, is_hip, is_xpu +# TODO: add rms_norm op for torch-npu +# AttributeError: '_OpNamespace' '_C' object has no attribute 'rms_norm' + class CustomOp(nn.Module): diff --git a/vllm/model_executor/layers/ascend_sampler.py b/vllm/model_executor/layers/ascend_sampler.py new file mode 100644 index 0000000000000..0ee602f77431c --- /dev/null +++ b/vllm/model_executor/layers/ascend_sampler.py @@ -0,0 +1,403 @@ +import torch +import torch.nn as nn +import numpy as np +import random +from array import array +from typing import Dict, List, Optional, Tuple +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sampling_params import SamplingType +from vllm.model_executor.sampling_metadata import SamplingMetadata, SequenceGroupToSample +from vllm.model_executor.layers.sampler import (get_logprobs, + _modify_greedy_probs_inplace, + _multinomial, + _random_sample, + _greedy_sample, + _build_sampler_output, + ) +from vllm.model_executor.sampling_metadata import (SamplingMetadata, + SamplingTensors, + SequenceGroupToSample) +from mindie_llm.text_generator.utils.sampling_metadata import SamplingData, SamplingParam + +SampleResultType = List[Tuple[List[int], List[int]]] +_SAMPLING_EPS = 1e-5 + + +def _to_npu_tensor(data, dtype=None): + if dtype: + return torch.tensor(data, dtype=dtype, device=torch.device("npu")) + else: + return torch.tensor(data, device=torch.device("npu")) + + +def _sample_with_mindie( + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + output_tokens: torch.Tensor, + include_gpu_probs_tensor: bool, + modify_greedy_probs: bool, +) -> Tuple[SampleResultType, Optional[torch.Tensor]]: + """ + Create output tensor for sampled token ids. + """ + # NOTE (cmq): overwrite _sample_with_torch in vllm/model_executor/layers/sampler.py + categorized_seq_group_ids: Dict[SamplingType, List[int]] = { + t: [] for t in SamplingType + } + categorized_sample_indices = sampling_metadata.categorized_sample_indices + for i, seq_group in enumerate(sampling_metadata.seq_groups): + sampling_params = seq_group.sampling_params + sampling_type = sampling_params.sampling_type + categorized_seq_group_ids[sampling_type].append(i) + + sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} + sample_metadata: Dict[ + SamplingType, Tuple[List[int], List[SequenceGroupToSample]] + ] = {} + multinomial_samples: Dict[SamplingType, torch.Tensor] = {} + + # Create output tensor for sampled token ids. + if include_gpu_probs_tensor: + sampled_token_ids_tensor = torch.empty( + logprobs.shape[0], 1, dtype=torch.long, device=logprobs.device + ) + else: + sampled_token_ids_tensor = None + + for sampling_type in SamplingType: + # TODO (cmq): verify why using categorized_sample_indices[sampling_type][:, 1] + sample_indices = categorized_sample_indices[sampling_type][:, 1] + num_tokens = len(sample_indices) + if num_tokens == 0: + continue + + seq_group_id = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] + sample_metadata[sampling_type] = (seq_group_id, seq_groups) + long_sample_indices = sample_indices.long() + if sampling_type == SamplingType.GREEDY: + greedy_samples = torch.argmax(logprobs[long_sample_indices], dim=-1) + + if sampled_token_ids_tensor is not None: + # Store sampled tokens in output tensor. + sampled_token_ids_tensor[long_sample_indices] = ( + greedy_samples.unsqueeze(-1) + ) + + if modify_greedy_probs: + # If required, modify the probabilities such that sampling from + # the modified distribution would always sample the argmax + # token id. + _modify_greedy_probs_inplace( + logprobs, probs, long_sample_indices, greedy_samples + ) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + max_best_of_in_batch = 1 + for seq_group in seq_groups: + if seq_group.is_prompt: + sampling_params = seq_group.sampling_params + max_best_of_in_batch = max( + max_best_of_in_batch, sampling_params.best_of + ) + seeded_args = ( + {} + if sampling_type == SamplingType.RANDOM + else { + "seq_groups": seq_groups, + } + ) + + multinomial_samples[sampling_type] = _multinomial( + probs[long_sample_indices], max_best_of_in_batch, **seeded_args + ) + + if sampled_token_ids_tensor is not None: + # Store sampled tokens in output tensor. + sampled_token_ids_tensor[long_sample_indices] = multinomial_samples[ + sampling_type + ] + else: + raise ValueError(f"Unsupported sampling type in MindIE: {sampling_type}") + + if not sampling_metadata.skip_sampler_cpu_output: + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + (seq_group_id, seq_groups) = sample_metadata[sampling_type] + # NOTE (cmq): why greedy do same logic as random + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(seq_groups, greedy_samples) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + sample_results = _random_sample(seq_groups, output_tokens) + sample_results_dict.update(zip(seq_group_id, sample_results)) + + sample_results = [ + sample_results_dict.get(i, ([], [])) + for i in range(len(sampling_metadata.seq_groups)) + ] + else: + sample_results = [] + return sample_results, sampled_token_ids_tensor + + +class AscendSampler(nn.Module): + def __init__(self, model=None): + super().__init__() + self.model = model + self.include_gpu_probs_tensor = False + + def forward( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + assert logits is not None + _, vocab_size = logits.shape + + mindie_sampling_data, mindie_sampling_param = self.init_sampling_data( + sampling_metadata, vocab_size + ) + assert mindie_sampling_data is not None + # Sample the next tokens by model.sample + next_tokens = self.model.sample( + logits, + sampling_data=mindie_sampling_data, + sampling_param=mindie_sampling_param, + ) + + # # TODO (cmq): confirm if this is done in self.model.sample? + # # Apply presence and frequency penalties. + # # NOTE (cmq): penalty and top-k/p sampling done in self.model.sample? + # logits = _apply_min_tokens_penalty(logits, sampling_metadata) + # if mindie_sampling_param.penalty_meta.has_penalty: + # logits = _apply_penalties(logits, mindie_sampling_data.all_input_ids, + # mindie_sampling_data.output_ids, + # mindie_sampling_param.penalty_meta.presence_penalty, + # mindie_sampling_param.penalty_meta.frequency_penalty, + # mindie_sampling_param.penalty_meta.repetition_penalty) + + # # Use in-place division to avoid creating a new tensor. + # logits.div_(mindie_sampling_param.temperature.unsqueeze(dim=1)) + + # if params["do_top_p_top_k"]: + # logits = _apply_top_k_top_p(logits, mindie_sampling_param.top_p_meta.top_p_tensor, + # mindie_sampling_param.top_k_meta.top_k_tensor) + + # if params["do_min_p"]: + # print("Not supported") + # # logits = _apply_min_p(logits, mindie_sampling_param.top_p_meta..min_ps) + + # We use float32 for probabilities and log probabilities. + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # Compute the log probabilities. + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + + sample_results, maybe_sampled_tokens_tensor = _sample_with_mindie( + probs=probs, + logprobs=logprobs, + sampling_metadata=sampling_metadata, + output_tokens=torch.from_numpy(next_tokens).unsqueeze(1), + include_gpu_probs_tensor=self.include_gpu_probs_tensor, + modify_greedy_probs=False, + ) + + if self.include_gpu_probs_tensor: + assert maybe_sampled_tokens_tensor is not None + on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) + else: + on_device_tensors = None + + # Get the logprobs query results. + prompt_logprobs = None + sample_logprobs = None + if not sampling_metadata.skip_sampler_cpu_output: + prompt_logprobs, sample_logprobs = _get_logprobs( + logprobs, sampling_metadata, sample_results + ) + return _build_sampler_output( + sample_results, + sampling_metadata, + prompt_logprobs, + sample_logprobs, + on_device_tensors=on_device_tensors, + skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output, + ) + + def init_sampling_data( + self, + sampling_metadata: SamplingMetadata, + vocab_size: int, + ) -> Tuple[SamplingData, SamplingParam]: + """Initalize SamplingData and SamplingParam for MindIE. + + SamplingData receives all_input_tokens (prompt_tokens and output_tokens), + rather than only prompt_tokens. + + output: + mindie_sampling_param: SamplingParam + including params of sampling, including repetition_penalty, frequency_penalty, + presence_penalty, temperature, top-k, top-p, etc. + [!Note] Not support min-p now. + mindie_sampling_data: SamplingData, torch.tensor on NPU + the input and output tokens of self.model.sample + """ + # same params as SamplingTensors.from_sampling_metadata + # get tuple tokens + output_tokens: List[Tuple[int]] = [] + top_ks: List[int] = [] + temperatures: List[float] = [] + top_ps: List[float] = [] + min_ps: List[float] = [] + presence_penalties: List[float] = [] + frequency_penalties: List[float] = [] + repetition_penalties: List[float] = [] + sampling_seeds: List[int] = [] + do_penalties = False + do_top_p_top_k = False + do_min_p = False + # AscendSampler specific params + all_input_tokens: List[Tuple[int]] = [] + do_samples: List[bool] = [] + + assert sampling_metadata.seq_groups is not None + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + temperature = sampling_params.temperature + p = sampling_params.presence_penalty + f = sampling_params.frequency_penalty + r = sampling_params.repetition_penalty + top_p = sampling_params.top_p + min_p = sampling_params.min_p + + do_samples.append(seq_group.do_sample) + + is_greedy = sampling_params.sampling_type == SamplingType.GREEDY + seed = sampling_params.seed + if seed is None: + # create base seed + if is_greedy: + seed = 0 + else: + lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max + seed = random.randint(lo, hi) + + # k should not be greater than the vocab size. + top_k = min(sampling_params.top_k, vocab_size) + top_k = vocab_size if top_k == -1 else top_k + if temperature < _SAMPLING_EPS: + # NOTE: Zero temperature means deterministic sampling + # (i.e., greedy sampling or beam search). + # Set the temperature to 1 to avoid division by zero. + temperature = 1.0 + if not do_top_p_top_k and ( + top_p < 1.0 - _SAMPLING_EPS or top_k != vocab_size + ): + do_top_p_top_k = True + if not do_min_p and min_p > _SAMPLING_EPS: + do_min_p = True + if not do_penalties and ( + abs(p) >= _SAMPLING_EPS + or abs(f) >= _SAMPLING_EPS + or abs(r - 1.0) >= _SAMPLING_EPS + ): + do_penalties = True + + is_prompt = seq_group.is_prompt + if is_prompt and sampling_params.prompt_logprobs is not None: + # For tokens in the prompt that we only need to get + # their logprobs + query_len = seq_group.query_len + assert query_len is not None + prefill_len = len(seq_group.prompt_logprob_indices) + temperatures += [temperature] * prefill_len + # TODO (cmq): check me? + do_samples += [seq_group.do_sample] * prefill_len + top_ps += [top_p] * prefill_len + top_ks += [top_k] * prefill_len + presence_penalties += [0] * prefill_len + frequency_penalties += [0] * prefill_len + repetition_penalties += [1] * prefill_len + + sampling_seeds += [seed] * prefill_len + # output_tokens.extend([] for _ in range(prefill_len)) + # all_input_tokens.extend([] for _ in range(prefill_len)) + + if seq_group.do_sample: + sample_lens = len(seq_group.sample_indices) + assert sample_lens == len(seq_ids) + temperatures += [temperature] * len(seq_ids) + top_ps += [top_p] * len(seq_ids) + top_ks += [top_k] * len(seq_ids) + sampling_seeds += [seed] * len(seq_ids) + presence_penalties += [p] * len(seq_ids) + frequency_penalties += [f] * len(seq_ids) + repetition_penalties += [r] * len(seq_ids) + + if do_penalties: + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + if seq_group.is_prompt and sampling_params.prompt_logprobs is not None: + prefill_len = len(seq_group.prompt_logprob_indices) + output_tokens.extend(array("l") for _ in range(prefill_len)) + if seq_group.do_sample: + for seq_id in seq_ids: + seq_data = seq_group.seq_data[seq_id] + output_tokens.append(seq_data.output_token_ids) + all_input_tokens.append( + seq_data.prompt_token_ids + seq_data.output_token_ids + ) + + repetition_penalties = np.array(repetition_penalties, dtype=np.float32) + frequency_penalties = np.array(frequency_penalties, dtype=np.float32) + presence_penalties = np.array(presence_penalties, dtype=np.float32) + temperatures = np.array(temperatures, dtype=np.float32) + top_ks = np.array(top_ks, dtype=np.int32) + top_ps = np.array(top_ps, dtype=np.float32) + sampling_seeds = np.array(sampling_seeds) + do_samples = np.array(do_samples) + + # pad input and output tokensm then put them to NPU + max_tokens_len = max([len(tokens) for tokens in all_input_tokens], default=0) + padded_all_input_tokens = [ + tokens + [vocab_size] * (max_tokens_len - len(tokens)) + for tokens in all_input_tokens + ] + padded_all_input_tokens = np.array(padded_all_input_tokens, dtype=np.int32) + output_max_len = max([len(tokens) for tokens in output_tokens], default=0) + padded_output_tokens = [ + tokens + [vocab_size] * (output_max_len - len(tokens)) + for tokens in output_tokens + ] + padded_output_tokens = np.array(padded_output_tokens, dtype=np.int32) + + all_input_ids_tensor = None + output_ids_tensor = None + if padded_all_input_tokens is not None: + all_input_ids_tensor = _to_npu_tensor(padded_all_input_tokens, torch.int32) + if padded_output_tokens is not None: + output_ids_tensor = _to_npu_tensor(padded_output_tokens, torch.int32) + # construct SamplingData with padded input and output token + mindie_sampling_data = SamplingData( + all_input_ids_tensor, output_ids=output_ids_tensor + ) + + # construct SamplingParam. + if is_greedy: + mindie_sampling_param = None + else: + mindie_sampling_param = SamplingParam.from_numpy( + repetition_penalty=repetition_penalties, + frequency_penalty=frequency_penalties, + presence_penalty=presence_penalties, + temperature=temperatures, + top_k=top_ks, + top_p=top_ps, + seed=sampling_seeds, + do_sample=do_samples, + to_tensor=_to_npu_tensor, + ) + + return mindie_sampling_data, mindie_sampling_param diff --git a/vllm/model_executor/model_loader/ascend_mindie.py b/vllm/model_executor/model_loader/ascend_mindie.py new file mode 100644 index 0000000000000..abf3b58854bef --- /dev/null +++ b/vllm/model_executor/model_loader/ascend_mindie.py @@ -0,0 +1,260 @@ +"""Utilities for selecting and loading neuron models.""" + +import contextlib +import importlib +import os +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from mindie_llm.text_generator.adapter.generator_torch import GeneratorTorch + +from vllm.config import DeviceConfig, ModelConfig, LoadConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.ascend_sampler import AscendSampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.attention.backends.ascend import AscendMetadata +from vllm.model_executor.models.interfaces import supports_lora + +MINDIE_SUPPORT_DTYPE = [torch.float16, torch.float32, torch.bfloat16] + + +class MindIECasualLM(nn.Module): + + def __init__( + self, + model_config, + linear_method=None, + lora_config=None, + ) -> None: + super().__init__() + self.model_config = model_config + self.model = None + self.sampler = None + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + attn_metadata: AscendMetadata, + ) -> torch.Tensor: + # is_prompt = attn_metadata.num_prefill_tokens > 0 + # TODO (cmq): check me + is_prompt = attn_metadata.is_prompt + + if kv_caches[0][0] is None: + # block_size = 128 is recommand in MindIE + # (https://www.hiascend.com/document/detail/zh/canncommercial/80RC2/apiref/ascendtbapi/ascendtb_01_0070.html) + block_size = 128 + num_kv_heads = self.model.model_wrapper.model_runner.num_kv_heads + head_size = self.model.model_wrapper.model_runner.head_size + num_layers = self.model.model_wrapper.model_runner.num_layers + kv_caches = self.create_kv_caches_with_random( + 1, + block_size, + num_layers, + num_kv_heads, + head_size, + cache_dtype=torch.float32, + model_dtype=torch.float32, + seed=0, + device="npu", + ) + max_seq_len = attn_metadata.prefill_metadata.max_seq_len + batch_size = len(attn_metadata.prefill_metadata.seq_lens_tensor) + num_blocks = math.ceil(max_seq_len / block_size) + block_tables, slot_mapping = self.create_block_table_with_random( + input_ids, num_blocks, block_size, batch_size, device="npu" + ) + else: + block_tables = ( + torch.tensor([0], dtype=torch.int32, device="npu") + if is_prompt + else attn_metadata.decode_metadata.block_tables + ) + slot_mapping = attn_metadata.slot_mapping + + if is_prompt: + input_lengths = attn_metadata.prefill_metadata.seq_lens_tensor.to( + torch.int32 + ) + max_seq_len = attn_metadata.prefill_metadata.max_seq_len + lm_head_indices = ( + attn_metadata.prefill_metadata.seq_lens_tensor.cumsum(dim=-1) - 1 + ).to(torch.int64) + else: + input_lengths = attn_metadata.decode_metadata.seq_lens_tensor + max_seq_len = attn_metadata.decode_metadata.max_seq_len + lm_head_indices = None + + logits = self.model.forward_tensor( + input_ids, + positions, + is_prompt, + kv_caches, + block_tables, + slot_mapping, + input_lengths, + max_seq_len, + lm_head_indices, + ) + + return logits + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + return hidden_states + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + assert ( + load_format in ["auto", "safetensors", "pt"], + f"Unsupported load_format in MindIE: {load_format}. load_format in MindIE supports [safetensors, pt]", + ) + + self.weight_dtype = torch.get_default_dtype() + # TODO (cmq): check if set_default_dtype is required + torch.set_default_dtype(torch.float32) + + self.model = GeneratorTorch(self.model_config) + self.sampler = AscendSampler(self.model) + + torch.set_default_dtype(self.weight_dtype) + + def create_kv_caches_with_random( + self, + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "npu", + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + + assert cache_dtype in MINDIE_SUPPORT_DTYPE + torch.random.manual_seed(seed) + if torch.npu.is_available(): + torch.npu.manual_seed(seed) + + scale = head_size**-0.5 + cache_shape = (num_blocks, block_size, num_heads, head_size) + kv_caches: List[Tuple(torch.Tensor, torch.Tensor)] = [] + for _ in range(num_layers): + key_cache = torch.empty( + size=cache_shape, dtype=self.weight_dtype, device=device + ) + value_cache = torch.empty( + size=cache_shape, dtype=self.weight_dtype, device=device + ) + if cache_dtype in MINDIE_SUPPORT_DTYPE: + key_cache.uniform_(-scale, scale) + value_cache.uniform_(-scale, scale) + else: + raise ValueError( + f"Does not support key cache of type {cache_dtype} in MindIE" + ) + kv_caches.append((key_cache, value_cache)) + + return kv_caches + + def create_block_table_with_random( + self, + input_ids, + num_blocks: int, + block_size: int, + batch_size: int, + device: Optional[str] = "npu", + ): + + block_tables = torch.zeros(batch_size, num_blocks, dtype=int, device=device) + prefill_len = len(input_ids) + num_slots = (prefill_len + block_size - 1) // block_size + slot_mapping = np.concatenate( + [ + np.arange(min(block_size, prefill_len - i * block_size)) + for i in range(num_slots) + ] + ) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device) + return block_tables, slot_mapping + + +def get_mindie_model( + model_config: ModelConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + mindie_model_config, + **kwargs, +) -> nn.Module: + lora_config = kwargs.get("lora_config", None) + + # TODO (cmq): pass in linear_method? + # Get the (maybe quantized) linear method. + linear_method = None + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + # TODO (cmq): check me + # if hasattr(MindIECasualLM, "supported_lora_modules"): + if supports_lora(MindIECasualLM): + model = MindIECasualLM(mindie_model_config, linear_method, lora_config) + elif lora_config: + raise ValueError( + f"Model {MindIECasualLM.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github." + ) + else: + model = MindIECasualLM(mindie_model_config, linear_method) + if load_config.load_format == "dummy": + initialize_dummy_weights(model) + else: + # Load the weights from the cached or downloaded checkpoint. + model.load_weights( + model_config.model, + load_config.download_dir, + load_config.load_format, + model_config.revision, + ) + model = model.to(target_device) + return model.eval() + + +def model_supports_in_mindie(model_config: ModelConfig) -> bool: + model_type = model_config.hf_config.model_type.lower() + + atb_llm_base_path = importlib.import_module("atb_llm").__path__[0] + "/models" + mindie_supported_models = list() + for model_name in os.listdir(atb_llm_base_path): + if model_name.startswith("_") or model_name == "base": + # skip base, __init__.py and __pycache__ + continue + mindie_supported_models.append(model_name) + + if model_type not in mindie_supported_models: + return False + return True diff --git a/vllm/utils.py b/vllm/utils.py index a22081ebe8df0..2b08e46d33d44 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -355,6 +355,20 @@ def is_xpu() -> bool: return False return hasattr(torch, "xpu") and torch.xpu.is_available() +@lru_cache(maxsize=None) +def is_npu() -> bool: + try: + import torch_npu + except ImportError: + torch_npu = None + return torch_npu is not None + + +@lru_cache(maxsize=None) +def is_mindie() -> bool: + # TODO + return False + @lru_cache(maxsize=None) def get_max_shared_memory_bytes(gpu: int = 0) -> int: @@ -743,7 +757,7 @@ def is_pin_memory_available() -> bool: return True -class CudaMemoryProfiler: +class DeviceMemoryProfiler: def __init__(self, device: Optional[torch.types.Device] = None): self.device = device @@ -756,6 +770,9 @@ def current_memory_usage(self) -> float: elif is_xpu(): torch.xpu.reset_peak_memory_stats(self.device) # type: ignore mem = torch.xpu.max_memory_allocated(self.device) # type: ignore + elif is_npu(): + torch.npu.reset_peak_memory_stats(self.device) # type: ignore + mem = torch.npu.max_memory_allocated(self.device) # type: ignore return mem def __enter__(self): @@ -1065,7 +1082,7 @@ def _cuda_device_count_stateless( def cuda_device_count_stateless() -> int: """Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES at the time of call. - + This should be used instead of torch.cuda.device_count() unless CUDA_VISIBLE_DEVICES has already been set to the desired value.""" diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 74f7d4e0860d3..ace8207d90580 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -44,7 +44,7 @@ LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, +from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_hip, is_pin_memory_available, supports_dynamo) from vllm.worker.model_runner_base import ( @@ -913,7 +913,7 @@ def __init__( def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) - with CudaMemoryProfiler() as m: + with DeviceMemoryProfiler() as m: self.model = get_model(model_config=self.model_config, device_config=self.device_config, load_config=self.load_config, diff --git a/vllm/worker/npu_model_runner.py b/vllm/worker/npu_model_runner.py new file mode 100644 index 0000000000000..b76bdb702c1dc --- /dev/null +++ b/vllm/worker/npu_model_runner.py @@ -0,0 +1,1175 @@ +import dataclasses +import gc +import time +import warnings +import weakref +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, + TypeVar, Union) + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn + + +BatchDecodeWithPagedKVCacheWrapper = None +CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None +BatchPrefillWithPagedKVCacheWrapper = None +FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + +import vllm.envs as envs +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig) +from vllm.distributed import get_pp_group +from vllm.distributed.parallel_state import graph_capture +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.model_loader.ascend_mindie import ( + get_mindie_model, model_supports_in_mindie) +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.model_executor.models.interfaces import (supports_lora, + supports_multimodal) +from vllm.model_executor.models.utils import set_cpu_offload_max_bytes +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalInputs, MultiModalRegistry) +from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.prompt_adapter.worker_manager import ( + LRUCacheWorkerPromptAdapterManager) + +from vllm.sampling_params import SamplingParams +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.utils import (DeviceMemoryProfiler, flatten_2d_lists, + get_kv_cache_torch_dtype, is_hip, + is_pin_memory_available, is_mindie, is_npu) +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +_PAD_SLOT_ID = -1 +LORA_WARMUP_RANK = 8 +_BATCH_SIZE_ALIGNMENT = 8 +# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. +# NOTE: _get_graph_batch_size needs to be updated if this list is changed. +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) +] +_NUM_WARMUP_ITERS = 2 + +TModelInputForNPU = TypeVar('TModelInputForNPU', bound="ModelInputForNPU") + +#TODO: 直接继承GPUINPUT相关类 +@dataclass(frozen=True) +class ModelInputForNPU(ModelRunnerInputBase): + """ + This base class contains metadata needed for the base model forward pass + but not metadata for possible additional steps, e.g., sampling. Model + runners that run additional steps should subclass this method to add + additional fields. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None + attn_metadata: Optional["AttentionMetadata"] = None + prompt_adapter_mapping: Optional[PromptAdapterMapping] = None + prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None + multi_modal_kwargs: Optional[BatchedTensorInputs] = None + request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None + finished_requests_ids: Optional[List[str]] = None + virtual_engine: int = 0 + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + "multi_modal_kwargs": self.multi_modal_kwargs, + "prompt_adapter_mapping": self.prompt_adapter_mapping, + "prompt_adapter_requests": self.prompt_adapter_requests, + "virtual_engine": self.virtual_engine, + "request_ids_to_seq_ids": self.request_ids_to_seq_ids, + "finished_requests_ids": self.finished_requests_ids, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type[TModelInputForNPU], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> TModelInputForNPU: + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +@dataclass(frozen=True) +class ModelInputForNPUWithSamplingMetadata(ModelInputForNPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None + # Used for speculative decoding. We do not broadcast it because it is only + # used by the driver worker. + is_prompt: Optional[bool] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + "multi_modal_kwargs": self.multi_modal_kwargs, + "prompt_adapter_mapping": self.prompt_adapter_mapping, + "prompt_adapter_requests": self.prompt_adapter_requests, + "virtual_engine": self.virtual_engine, + "request_ids_to_seq_ids": self.request_ids_to_seq_ids, + "finished_requests_ids": self.finished_requests_ids, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForNPUWithSamplingMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): + """Build ModelInputForNPU from SequenceGroupMetadata.""" + + # Note: ideally we would be using a dataclass(kw_only=True) + # here, so that this can be subclassed easily, + # but kw_only is not supported in python<3.10. + class InterDataForSeqGroup: + """Intermediate data for the current sequence group.""" + + def __init__( + self, + *, + # From sequence group metadata. + request_id: str, + seq_ids: List[int], + is_prompt: bool, + block_tables: Optional[Dict[int, List[int]]], + computed_block_nums: List[int], + n_seqs: int = 0, + + # Input tokens and positions. + input_tokens: Optional[List[List[int]]] = None, + input_positions: Optional[List[List[int]]] = None, + + # The sequence length (may be capped to the sliding window). + seq_lens: Optional[List[int]] = None, + # The original sequence length (before applying sliding window). + # This is used to compute slot mapping. + orig_seq_lens: Optional[List[int]] = None, + # The query length. + query_lens: Optional[List[int]] = None, + # The number of tokens that are already computed. + context_lens: Optional[List[int]] = None, + # The current sliding window block. + curr_sliding_window_blocks: Optional[List[int]] = None, + + # LoRA inputs. + lora_index_mapping: Optional[List[List[int]]] = None, + lora_prompt_mapping: Optional[List[List[int]]] = None, + lora_requests: Optional[Set[LoRARequest]] = None, + + # Prompt adapter inputs. + prompt_adapter_index_mapping: Optional[List[int]] = None, + prompt_adapter_prompt_mapping: Optional[List[int]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + + # Multi-modal inputs. + multi_modal_inputs: Optional[MultiModalInputs] = None, + + # Whether the prefix cache is hit (prefill only). + prefix_cache_hit: bool = False, + ): + self.request_id = request_id + self.seq_ids = seq_ids + self.is_prompt = is_prompt + self.block_tables = block_tables + self.computed_block_nums = computed_block_nums + self.n_seqs = n_seqs + self.input_tokens = input_tokens or [] + self.input_positions = input_positions or [] + self.seq_lens = seq_lens or [] + self.orig_seq_lens = orig_seq_lens or [] + self.query_lens = query_lens or [] + self.context_lens = context_lens or [] + self.curr_sliding_window_blocks = curr_sliding_window_blocks or [] + + self.lora_index_mapping = lora_index_mapping or [] + self.lora_prompt_mapping = lora_prompt_mapping or [] + self.lora_requests = lora_requests or set() + + self.prompt_adapter_index_mapping = (prompt_adapter_index_mapping + or []) + self.prompt_adapter_prompt_mapping = (prompt_adapter_prompt_mapping + or []) + self.prompt_adapter_request = prompt_adapter_request + + self.multi_modal_inputs = multi_modal_inputs + self.prefix_cache_hit = prefix_cache_hit + + self.__post_init__() + + def __post_init__(self): + self.n_seqs = len(self.seq_ids) + + self.input_tokens = [[] for _ in range(self.n_seqs)] + self.input_positions = [[] for _ in range(self.n_seqs)] + self.seq_lens = [0] * self.n_seqs + self.orig_seq_lens = [0] * self.n_seqs + self.query_lens = [0] * self.n_seqs + self.context_lens = [0] * self.n_seqs + self.curr_sliding_window_blocks = [0] * self.n_seqs + + self.lora_index_mapping = [[] for _ in range(self.n_seqs)] + self.lora_prompt_mapping = [[] for _ in range(self.n_seqs)] + + def __init__(self, + runner: "NPUModelRunnerBase", + finished_requests_ids: Optional[List[str]] = None): + super().__init__() + # Compute functions for each sequence in a sequence group. + # WARNING: The order of the functions matters! + self.per_seq_compute_fns = [ + self._compute_lens, + self._compute_for_prefix_cache_hit, + self._compute_for_sliding_window, + self._compute_lora_input, + ] + # Compute functions for each sequence group. + # WARNING: The order of the functions matters! + self.per_seq_group_compute_fns = [ + self._compute_prompt_adapter_input, + self._compute_multi_modal_input, + ] + + self.runner = runner + self.model_input_cls = self.runner._model_input_cls + self.attn_backend = self.runner.attn_backend + self.scheduler_config = self.runner.scheduler_config + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.enable_lora = self.runner.lora_config is not None + self.enable_prompt_adapter = (self.runner.prompt_adapter_config + is not None) + self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper + self.finished_requests_ids = finished_requests_ids + self.decode_only = True + + # Intermediate data (data in CPU before going to GPU) for + # the current sequence group. + self.inter_data_list: List[ + ModelInputForNPUBuilder.InterDataForSeqGroup] = [] + + # Attention metadata inputs. + self.attn_metadata_builder = self.attn_backend.make_metadata_builder( + weakref.proxy(self)) + + # Engine/Model configurations. + self.chunked_prefill_enabled = ( + self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled) + if self.sliding_window is not None: + self.sliding_window_blocks = ( + self.sliding_window + self.block_size - 1) // self.block_size + self.block_aligned_sliding_window = \ + self.sliding_window_blocks * self.block_size + + def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """Compute context length, sequence length and tokens + for the given sequence data. + """ + seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]] + token_chunk_size = seq_group_metadata.token_chunk_size + + # Compute context length (the number of tokens that are + # already computed) and sequence length (total number of tokens). + seq_len = seq_data.get_len() + if inter_data.is_prompt: + context_len = seq_data.get_num_computed_tokens() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + context_len = seq_len - 1 + seq_len = min(seq_len, context_len + token_chunk_size) + + # Compute tokens. + if inter_data.is_prompt: + tokens = seq_data.get_token_ids()[context_len:seq_len] + else: + # Optimization. get_token_ids requires the entire copy of + # tokens. + tokens = [seq_data.get_last_token_id()] + + inter_data.seq_lens[seq_idx] = seq_len + inter_data.orig_seq_lens[seq_idx] = seq_len + inter_data.context_lens[seq_idx] = context_len + inter_data.input_tokens[seq_idx] = tokens + inter_data.input_positions[seq_idx] = list(range(context_len, seq_len)) + inter_data.query_lens[ + seq_idx] = seq_len - context_len if inter_data.is_prompt else 1 + + def _compute_for_prefix_cache_hit( + self, inter_data: InterDataForSeqGroup, seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """Check if hit prefix cache (i.e., some blocks are already computed). + If hit, update input tokens and positions to only compute the + remaining blocks. + """ + computed_block_nums = inter_data.computed_block_nums + + # Note that prefix caching does not support sliding window. + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + and inter_data.is_prompt) + inter_data.prefix_cache_hit = prefix_cache_hit + if self.chunked_prefill_enabled and prefix_cache_hit: + raise RuntimeError( + "chunked prefill cannot be used with prefix caching now.") + + # If prefix cache is hit, advance context length to bypass + # hit blocks. Accordingly, input tokens, position and query length + # have to be updated. + if prefix_cache_hit: + assert computed_block_nums is not None + context_len = len(computed_block_nums) * self.block_size + inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ + seq_idx][context_len:] + inter_data.input_positions[seq_idx] = inter_data.input_positions[ + seq_idx][context_len:] + inter_data.context_lens[seq_idx] = context_len + inter_data.query_lens[ + seq_idx] = inter_data.seq_lens[seq_idx] - context_len + + def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup, + seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """Update seq_len and curr_sliding_window_block for the given + sequence data (only required by decoding) if sliding window is enabled. + """ + curr_sliding_window_block = 0 + sliding_seq_len = inter_data.seq_lens[seq_idx] + if not inter_data.is_prompt and self.sliding_window is not None: + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + curr_sliding_window_block = self.sliding_window_blocks + if self.scheduler_config.use_v2_block_manager: + # number of elements in last block + suff_len = inter_data.seq_lens[seq_idx] % self.block_size + sliding_seq_len = min( + inter_data.seq_lens[seq_idx], + self.block_aligned_sliding_window + suff_len) + if suff_len > 0: + curr_sliding_window_block += 1 + else: + sliding_seq_len = min(inter_data.seq_lens[seq_idx], + self.sliding_window) + + inter_data.curr_sliding_window_blocks[ + seq_idx] = curr_sliding_window_block + inter_data.seq_lens[seq_idx] = sliding_seq_len + + def _compute_lora_input(self, inter_data: InterDataForSeqGroup, + seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """If LoRA is enabled, compute LoRA index and prompt mapping.""" + if not self.enable_lora: + return + + lora_id = seq_group_metadata.lora_int_id + if lora_id > 0: + inter_data.lora_requests.add(seq_group_metadata.lora_request) + query_len = inter_data.query_lens[seq_idx] + inter_data.lora_index_mapping.append([lora_id] * query_len) + inter_data.lora_prompt_mapping.append( + [lora_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs is not None + else 1)) + + def _compute_prompt_adapter_input( + self, inter_data: InterDataForSeqGroup, + seq_group_metadata: SequenceGroupMetadata): + """If prompt adapter is enabled, compute index and prompt mapping. + """ + # Note that when is_prompt=True, we expect only one sequence + # in the group. + if not self.enable_prompt_adapter: + return + + prompt_adapter_id = seq_group_metadata.prompt_adapter_id + if prompt_adapter_id <= 0 or not inter_data.is_prompt: + return + + # We expect only one sequence in the group when is_prompt=True. + assert inter_data.n_seqs == 1 + query_len = inter_data.query_lens[0] + inter_data.prompt_adapter_request = ( + seq_group_metadata.prompt_adapter_request) + + num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens + inter_data.prompt_adapter_index_mapping = [ + prompt_adapter_id + ] * num_tokens + [0] * (query_len - num_tokens) + inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * ( + query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs else 1) + + def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, + seq_group_metadata: SequenceGroupMetadata): + """If multi-modal data is given, add it to the input.""" + mm_data = seq_group_metadata.multi_modal_data + if not mm_data: + return + + mm_kwargs = self.multi_modal_input_mapper(mm_data) + inter_data.multi_modal_inputs = mm_kwargs + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + """Add a sequence group to the builder.""" + seq_ids = list(seq_group_metadata.seq_data.keys()) + n_seqs = len(seq_ids) + is_prompt = seq_group_metadata.is_prompt + + if is_prompt: + assert n_seqs == 1 + self.decode_only = False + + inter_data = self.InterDataForSeqGroup( + request_id=seq_group_metadata.request_id, + seq_ids=seq_ids, + is_prompt=is_prompt, + block_tables=seq_group_metadata.block_tables, + computed_block_nums=seq_group_metadata.computed_block_nums) + self.inter_data_list.append(inter_data) + + for seq_idx in range(n_seqs): + for per_seq_fn in self.per_seq_compute_fns: + per_seq_fn(inter_data, seq_idx, seq_group_metadata) + for per_seq_group_fn in self.per_seq_group_compute_fns: + per_seq_group_fn(inter_data, seq_group_metadata) + + def _use_captured_graph(self, batch_size: int, + max_decode_seq_len: int) -> bool: + return (self.decode_only and not self.runner.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_decode_seq_len <= self.runner.max_seq_len_to_capture) + + def build(self) -> ModelInputForNPU: + """Finalize the builder intermediate data and + create on-device tensors. + """ + # Combine and flatten intermediate data. + input_tokens = flatten_2d_lists([ + flatten_2d_lists(inter_data.input_tokens) + for inter_data in self.inter_data_list + ]) + if not input_tokens: + # This may happen when all prefill requests hit + # prefix caching and there is no decode request. + return self.model_input_cls() + input_positions = flatten_2d_lists([ + flatten_2d_lists(inter_data.input_positions) + for inter_data in self.inter_data_list + ]) + seq_lens = [] + max_decode_seq_len = 0 + for inter_data in self.inter_data_list: + seq_lens.extend(inter_data.seq_lens) + if not inter_data.is_prompt: + max_decode_seq_len = max(max_decode_seq_len, + max(inter_data.seq_lens)) + query_lens = flatten_2d_lists( + [inter_data.query_lens for inter_data in self.inter_data_list]) + # Mapping from request IDs to sequence IDs. Used for Jamba models + # that manages the cache by itself. + request_ids_to_seq_ids = { + data.request_id: data.seq_ids + for data in self.inter_data_list + } + + batch_size = len(input_tokens) + use_captured_graph = self._use_captured_graph(batch_size, + max_decode_seq_len) + + # If cuda graph can be used, pad tensors accordingly. + # See `capture_model` API for more details. + # vLLM uses cuda graph only for decoding requests. + cuda_graph_pad_size = -1 + if use_captured_graph: + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + cuda_graph_pad_size = graph_batch_size - batch_size + batch_size = graph_batch_size + + # Tokens and positions. + input_tokens.extend([0] * cuda_graph_pad_size) + input_positions.extend([0] * cuda_graph_pad_size) + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.runner.device) + input_positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device=self.runner.device) + + # Sequence and query lengths. + seq_lens.extend([1] * cuda_graph_pad_size) + + # Attention metadata. + attn_metadata = self.attn_metadata_builder.build( + seq_lens, query_lens, cuda_graph_pad_size, batch_size) + + # LoRA data. + lora_requests = set() + lora_mapping = None + if self.enable_lora: + lora_requests = set(r for data in self.inter_data_list + for r in data.lora_requests) + lora_index_mapping = flatten_2d_lists([ + flatten_2d_lists(inter_data.lora_index_mapping) + for inter_data in self.inter_data_list + ]) + lora_index_mapping.extend([0] * cuda_graph_pad_size) + lora_prompt_mapping = flatten_2d_lists([ + flatten_2d_lists(inter_data.lora_prompt_mapping) + for inter_data in self.inter_data_list + ]) + lora_mapping = LoRAMapping( + **dict(index_mapping=lora_index_mapping, + prompt_mapping=lora_prompt_mapping, + is_prefill=not self.decode_only)) + + # Prompt adapter data. + prompt_adapter_requests: Set[PromptAdapterRequest] = set() + prompt_adapter_mapping = None + if self.enable_prompt_adapter: + prompt_adapter_requests = set( + data.prompt_adapter_request for data in self.inter_data_list + if data.prompt_adapter_request is not None) + prompt_adapter_index_mapping = flatten_2d_lists([ + inter_data.prompt_adapter_index_mapping + for inter_data in self.inter_data_list + ]) + prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size) + prompt_adapter_prompt_mapping = flatten_2d_lists([ + inter_data.prompt_adapter_prompt_mapping + for inter_data in self.inter_data_list + ]) + prompt_adapter_mapping = PromptAdapterMapping( + prompt_adapter_index_mapping, + prompt_adapter_prompt_mapping, + ) + + # Multi-modal data. + multi_modal_inputs_list = [ + data.multi_modal_inputs for data in self.inter_data_list + if data.multi_modal_inputs is not None + ] + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + + return self.model_input_cls( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, + attn_metadata=attn_metadata, + seq_lens=seq_lens, + query_lens=query_lens, + lora_mapping=lora_mapping, + lora_requests=lora_requests, + multi_modal_kwargs=multi_modal_kwargs, + request_ids_to_seq_ids=request_ids_to_seq_ids, + finished_requests_ids=self.finished_requests_ids, + prompt_adapter_mapping=prompt_adapter_mapping, + prompt_adapter_requests=prompt_adapter_requests) + + +####TODO: 直接继承GPURUNNER相关类 +class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]): + """ + Helper class for shared methods between GPU model runners. + """ + _model_input_cls: Type[TModelInputForNPU] + _builder_cls: Type[ModelInputForNPUBuilder] + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + mindie_model_config: Optional[dict], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + return_hidden_states: bool = False, + observability_config: Optional[ObservabilityConfig] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.lora_config = lora_config + self.load_config = load_config + self.mindie_model_config = mindie_model_config + self.is_driver_worker = is_driver_worker + self.prompt_adapter_config = prompt_adapter_config + self.return_hidden_states = return_hidden_states + # TODO: support https://github.com/vllm-project/vllm/pull/7089 + self.observability_config = observability_config + + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture + + # TODO: Graph + # self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ + # {} for _ in range(self.parallel_config.pipeline_parallel_size) + # ] + # self.graph_memory_pool: Optional[Tuple[ + # int, int]] = None # Set during graph capture. + + self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers( + parallel_config) + + # When using CUDA graph, the input block tables must be padded to + # max_seq_len_to_capture. However, creating the block table in + # Python can be expensive. To optimize this, we cache the block table + # in numpy and only copy the actual input content at every iteration. + # The shape of the cached block table will be + # (max batch size to capture, max context len to capture / block size). + self.graph_block_tables = np.zeros( + (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), + dtype=np.int32) + num_attn_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + # TODO: 增加ATB不支持模型的处理逻辑 + self.attn_backend = get_attn_backend( + num_attn_heads, + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) if num_attn_heads else None + + # Multi-modal data support + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + .create_input_mapper(self.model_config) + + # Multi-modal data support + self.input_registry = input_registry + self.mm_registry = mm_registry + self.multi_modal_input_mapper = mm_registry \ + .create_input_mapper(model_config) + self.mm_registry.init_mm_limits_per_prompt(self.model_config) + + # Lazy initialization + self.model: nn.Module # Set after load_model + # Set after load_model. + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None + + self.flashinfer_decode_workspace_buffer = None + self.flashinfer_decode_wrapper = None + self.flashinfer_prefill_workspace_buffer = None + self.flashinfer_prefill_wrapper = None + + set_cpu_offload_max_bytes( + int(self.cache_config.cpu_offload_gb * 1024**3)) + + def model_router(self) -> None: + # TODO: model_support_in_mindie, get_mindie_model + if is_mindie and model_supports_in_mindie(self.model_config): + self.model = get_mindie_model(self.model_config, + self.device_config, + self.load_config, + self.mindie_model_config) + else: + self.get_vllm_model() + + def load_model(self) -> None: + self.model_router() + + def get_vllm_model(self) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler() as m: + self.model = get_model(model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30)) + + if self.lora_config: + assert supports_lora(self.model), "Model does not support LoRA" + assert not supports_multimodal( + self.model + ), "To be tested: multi-modal model with LoRA settings." + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=self.model.config. + max_position_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + + if self.prompt_adapter_config: + self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, self.device, + self.prompt_adapter_config) + self.model = ( + self.prompt_adapter_manager.create_prompt_adapter_manager( + self.model)) + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from vllm.model_executor.model_loader.loader import ShardedStateLoader + ShardedStateLoader.save_model( + self.model, + path, + pattern=pattern, + max_size=max_size, + ) + + def save_tensorized_model( + self, + tensorizer_config: TensorizerConfig, + ) -> None: + from vllm.model_executor.model_loader.loader import TensorizerLoader + TensorizerLoader.save_model( + self.model, + tensorizer_config=tensorizer_config, + ) + + def get_max_block_per_batch(self) -> int: + block_size = self.block_size + return (self.max_seq_len_to_capture + block_size - 1) // block_size + + def _prepare_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None + ) -> TModelInputForNPU: + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ + builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) + for seq_group_metadata in seq_group_metadata_list: + builder.add_seq_group(seq_group_metadata) + return builder.build() # type: ignore + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests: List[LoRARequest] = [] + dummy_lora_requests_per_seq: List[LoRARequest] = [] + if self.lora_config: + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + # Additional GPU memory may be needed for multi-modal encoding, which + # needs to be accounted for when calculating the GPU blocks for + # vLLM blocker manager. + # To exercise the worst scenario for GPU memory consumption, + # the number of seqs (batch_size) is chosen to maximize the number + # of images processed. + max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( + self.model_config) + + if max_mm_tokens > 0: + max_num_seqs_orig = max_num_seqs + max_num_seqs = min(max_num_seqs, + max_num_batched_tokens // max_mm_tokens) + if max_num_seqs < 1: + expr = (f"min({max_num_seqs_orig}, " + f"{max_num_batched_tokens} // {max_mm_tokens})") + logger.warning( + "Computed max_num_seqs (%s) to be less than 1. " + "Setting it to the minimum value of 1.", expr) + max_num_seqs = 1 + + batch_size = 0 + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len + + seq_data, dummy_multi_modal_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry) + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, + multi_modal_data=dummy_multi_modal_data, + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [None] * num_layers + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input( + seqs, finished_requests_ids=finished_requests_ids) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + dtype=self.model_config.dtype, + device=self.device) + self.execute_model(model_input, kv_caches, intermediate_tensors) + torch.npu.synchronize() + return + + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() + + def remove_all_prompt_adapters(self): + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + self.prompt_adapter_manager.remove_all_adapters() + + def set_active_prompt_adapters( + self, prompt_adapter_requests: Set[PromptAdapterRequest], + prompt_adapter_mapping: PromptAdapterMapping) -> None: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + self.prompt_adapter_manager.set_active_adapters( + prompt_adapter_requests, prompt_adapter_mapping) + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.add_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.list_adapters() + + @torch.inference_mode() + def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: + """Cuda graph capture a model. + + Note that CUDA graph's performance gain is negligible if number + of batched tokens are larger than 200. And since CUDA graph + requires fixed sized tensors, supporting large/variable batch + size requires high GPU memory overhead. Thus, vLLM only captures + decoding requests. Mixed batch (chunked prefill + decoding) or + prefill requests are not captured. + + Since it is used for decoding-only, it assumes there's only 1 token + per sequence in the batch. + """ + pass + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() + + +class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): + """ + NPU model runner with sampling step. + """ + _model_input_cls: Type[ModelInputForNPUWithSamplingMetadata] = ( + ModelInputForNPUWithSamplingMetadata) + _builder_cls: Type[ModelInputForNPUBuilder] = ModelInputForNPUBuilder + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, Any], + ) -> ModelInputForNPUWithSamplingMetadata: + model_input = \ + ModelInputForNPUWithSamplingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + return model_input + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForNPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + if get_pp_group().is_last_rank: + # Sampling metadata is only required for the final pp group + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, model_input.seq_lens, + model_input.query_lens, self.device, self.pin_memory, + generators) + else: + sampling_metadata = None + is_prompt = (seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + is_prompt=is_prompt, + virtual_engine=virtual_engine) + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForNPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + if num_steps > 1: + raise ValueError("num_steps > 1 is not supported in ModelRunner") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + # Currently cuda graph is only supported by the decode phase. + assert model_input.attn_metadata is not None + prefill_meta = model_input.attn_metadata.prefill_metadata + decode_meta = model_input.attn_metadata.decode_metadata + # TODO(andoorve): We can remove this once all + # virtual engines share the same kv cache. + virtual_engine = model_input.virtual_engine + if prefill_meta is None and decode_meta.use_cuda_graph: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[virtual_engine][ + graph_batch_size] + else: + model_executable = self.model + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + seqlen_agnostic_kwargs = { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } if self.has_seqlen_agnostic else {} + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) + + # Compute the logits in the last pipeline stage. + if not get_pp_group().is_last_rank: + return hidden_or_intermediate_states + + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) + + if not self.is_driver_worker: + return [] + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + + if self.return_hidden_states: + # we only need to pass hidden states of most recent token + assert model_input.sampling_metadata is not None + indices = model_input.sampling_metadata.selected_token_indices + if model_input.is_prompt: + hidden_states = hidden_or_intermediate_states.index_select( + 0, indices) + elif decode_meta.use_cuda_graph: + hidden_states = hidden_or_intermediate_states[:len(indices)] + else: + hidden_states = hidden_or_intermediate_states + + output.hidden_states = hidden_states + + # TODO: return output? + return [output] + + +def _get_graph_batch_size(batch_size: int) -> int: + """Returns the padded batch size given actual batch size. + + Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, + 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... + """ + if batch_size <= 2: + return batch_size + elif batch_size <= 4: + return 4 + else: + return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // + _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) diff --git a/vllm/worker/npu_worker.py b/vllm/worker/npu_worker.py new file mode 100644 index 0000000000000..78a9c6993ed0b --- /dev/null +++ b/vllm/worker/npu_worker.py @@ -0,0 +1,393 @@ +"""A GPU worker class.""" +import gc +import os +from typing import List, Optional, Set, Tuple, Type + +import torch +import torch.distributed +import torch_npu + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) +from vllm.lora.request import LoRARequest +from vllm.model_executor import set_random_seed +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.platforms import current_platform +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import ExecuteModelRequest +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.embedding_model_runner import EmbeddingModelRunner +from vllm.worker.npu_model_runner import NPUModelRunnerBase, NPUModelRunner +from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput + + +class NPUWorker(LocalOrDistributedWorkerBase): + """A worker class that executes (a partition of) the model on a NPU. + + Each worker is associated with a single NPU. The worker is responsible for + maintaining the KV cache and executing the model on the NPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + is_driver_worker: bool = False, + model_runner_cls: Optional[Type[NPUModelRunnerBase]] = None, + observability_config: Optional[ObservabilityConfig] = None, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.parallel_config.rank = rank + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.load_config = load_config + self.prompt_adapter_config = prompt_adapter_config + self.is_driver_worker = is_driver_worker + if parallel_config and is_driver_worker: + assert rank % parallel_config.tensor_parallel_size == 0, \ + "Driver worker should be rank 0 of tensor parallel group." + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + self.observability_config = observability_config + + # Return hidden states from target model if the draft model is an + # mlp_speculator + speculative_args = {} if speculative_config is None \ + or (speculative_config.draft_model_config.model == + model_config.model) \ + or (speculative_config.draft_model_config.hf_config.model_type + not in ["medusa", "mlp_speculator"]) \ + else {"return_hidden_states": True} + + ModelRunnerClass: Type[NPUModelRunnerBase] = NPUModelRunner + if model_runner_cls is not None: + ModelRunnerClass = model_runner_cls + elif self.model_config.embedding_mode: + ModelRunnerClass = EmbeddingModelRunner + mindie_model_config = { + "backend_type": "atb", + "model_id": model_config.model, + "rank": rank, + "local_rank": local_rank, + "world_size": parallel_config.world_size, + "npu_device_id": local_rank, + } + self.model_runner: NPUModelRunnerBase = ModelRunnerClass( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config=load_config, + lora_config=self.lora_config, + mindie_model_config=mindie_model_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, + observability_config=observability_config, + **speculative_args, + ) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[CacheEngine] + # Initialize npu_cache as embedding models don't initialize kv_caches + self.npu_cache: Optional[List[List[torch.Tensor]]] = None + + def init_device(self) -> None: + if self.device_config.device.type == "npu": + # # torch.distributed.all_reduce does not free the input tensor until + # # the synchronization point. This causes the memory usage to grow + # # as the number of all_reduce calls increases. This env var disables + # # this behavior. + # # Related issue: + # # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + # os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # # This env var set by Ray causes exceptions with graph building. + # os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device(f"npu:{self.local_rank}") + torch.npu.set_device(self.device) + + _check_if_npu_supports_dtype(self.model_config.dtype) + torch.npu.empty_cache() + self.init_npu_memory = torch.npu.mem_get_info()[0] + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + # TODO:HCCL 适配 + init_worker_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank) + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self.model_runner.save_sharded_state( + path, + pattern=pattern, + max_size=max_size, + ) + + def save_tensorized_model( + self, + tensorizer_config: TensorizerConfig, + ) -> None: + self.model_runner.save_tensorized_model( + tensorizer_config=tensorizer_config, ) + + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.npu.empty_cache() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.npu.synchronize() + free_npu_memory, total_npu_memory = torch.npu.mem_get_info() + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + peak_memory = self.init_npu_memory - free_npu_memory + assert peak_memory > 0, ( + "Error in memory profiling. " + f"Initial free memory {self.init_npu_memory}, current free memory" + f" {free_npu_memory}. This happens when the NPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + cache_block_size = self.get_cache_block_size_bytes() + num_npu_blocks = int( + (total_npu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) + num_npu_blocks = max(num_npu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() + gc.collect() + torch.npu.empty_cache() + return num_npu_blocks, num_cpu_blocks + + def initialize_cache(self, num_npu_blocks: int, + num_cpu_blocks: int) -> None: + """Allocate NPU and CPU KV cache with the specified number of blocks. + + This also warms up the model, which may record CANN graphs. + """ + raise_if_cache_size_invalid(num_npu_blocks, + self.cache_config.block_size, + self.model_config.max_model_len) + + self.cache_config.num_gpu_blocks = num_npu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self._init_cache_engine() + self._warm_up_model() + + def _init_cache_engine(self): + assert self.cache_config.num_gpu_blocks is not None + self.cache_engine = [ + CacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + self.npu_cache = [ + self.cache_engine[ve].gpu_cache + for ve in range(self.parallel_config.pipeline_parallel_size) + ] + + def _warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model(self.npu_cache) + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + @property + def do_metadata_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + + @property + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + return self.npu_cache + + @torch.inference_mode() + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + virtual_engine = execute_model_req.virtual_engine + num_seq_groups = len(execute_model_req.seq_group_metadata_list) + # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. + # they contain parameters to launch aclrtmemcpyasync. + blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, + device="cpu", + dtype=torch.int64).view(-1, 2) + blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, + device="cpu", + dtype=torch.int64).view(-1, 2) + # `blocks_to_copy` is a npu tensor. The src and tgt of + # blocks to copy are in the same device, and `blocks_to_copy` + # can be used directly within npus. + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device=self.device, + dtype=torch.int64).view(-1, 2) + + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, + ) + + @torch.inference_mode() + def execute_worker(self, worker_input: WorkerInput) -> None: + virtual_engine = worker_input.virtual_engine + # Issue cache operations. + if (worker_input.blocks_to_swap_in is not None + and worker_input.blocks_to_swap_in.numel() > 0): + self.cache_engine[virtual_engine].swap_in( + worker_input.blocks_to_swap_in) + if (worker_input.blocks_to_swap_out is not None + and worker_input.blocks_to_swap_out.numel() > 0): + self.cache_engine[virtual_engine].swap_out( + worker_input.blocks_to_swap_out) + if (worker_input.blocks_to_copy is not None + and worker_input.blocks_to_copy.numel() > 0): + self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.model_runner.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_runner.remove_lora(prompt_adapter_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_runner.pin_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + return self.model_runner.list_prompt_adapters() + + @property + def max_model_len(self) -> int: + return self.model_config.max_model_len + + @property + def vocab_size(self) -> int: + return self.model_runner.vocab_size + + def get_cache_block_size_bytes(self) -> int: + """Get the size of the KV cache block size in bytes. + """ + return CacheEngine.get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + + +def init_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, + backend: str = "hccl" +) -> None: + """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + + init_distributed_environment(parallel_config.world_size, rank, + distributed_init_method, local_rank, backend) + + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + +def _check_if_npu_supports_dtype(torch_dtype: torch.dtype): + # Check if the NPU supports the dtype. + # if torch_dtype == torch.bfloat16: + # compute_capability = current_platform.get_device_capability() + # if compute_capability[0] < 8: + # gpu_name = torch.cuda.get_device_name() + # raise ValueError( + # "Bfloat16 is only supported on GPUs with compute capability " + # f"of at least 8.0. Your {gpu_name} GPU has compute capability " + # f"{compute_capability[0]}.{compute_capability[1]}. " + # "You can use float16 instead by explicitly setting the" + # "`dtype` flag in CLI, for example: --dtype=half.") + #TODO + pass + + +def raise_if_cache_size_invalid(num_npu_blocks, block_size, + max_model_len) -> None: + if num_npu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + max_seq_len = block_size * num_npu_blocks + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index f9037625d4af9..99e13051b615d 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -21,7 +21,7 @@ MultiModalInputs, MultiModalRegistry) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad +from vllm.utils import DeviceMemoryProfiler, make_tensor_with_pad from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, @@ -149,99 +149,6 @@ def build(self) -> ModelInputForXPU: ) def _prepare_prompt( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - BatchedTensorInputs]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - seq_lens: List[int] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] - - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - seq_data = seq_group_metadata.seq_data[seq_id] - prompt_tokens = seq_data.get_token_ids() - computed_len = seq_data.get_num_computed_tokens() - seq_len = len(prompt_tokens) - - seq_lens.append(seq_len) # Prompt token num - input_tokens.extend(prompt_tokens) # Token ids - - # Token position ids - # NOTE(woosuk): Here we assume that the first token in the prompt - # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, seq_len))) - - if seq_group_metadata.block_tables is None: - # During memory profiling, the block tables are not initialized - # yet. In this case, we just use a dummy slot mapping. - slot_mapping.extend([_PAD_SLOT_ID] * seq_len) - continue - - # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] - # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, seq_len - sliding_window). - # For example, if the prompt len is 10, sliding window is 8, and - # block size is 4, the first two tokens are masked and the slot - # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - start_idx = max(0, seq_len - self.sliding_window) - - for i in range(computed_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - block_number = block_table[i // - self.block_size] # type: ignore - block_offset = i % self.block_size # type: ignore - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - num_prompt_tokens = len(input_tokens) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) # type: ignore - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) # type: ignore - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) # type: ignore - - max_seqlen = max(seq_lens) - tmp = [0] - tmp.extend(seq_lens) - seqlen = torch.tensor(tmp) - seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device) - - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - slot_mapping=slot_mapping, - seq_lens=seq_lens, - seqlen_q=seqlen_q, - max_seqlen=max_seqlen, - seq_lens_tensor=torch.tensor([]), - max_decode_seq_len=0, - num_prefills=len(seq_lens), - num_prefill_tokens=num_prompt_tokens, - num_decode_tokens=0, - block_tables=torch.tensor([], device=self.device, dtype=torch.int), - ) - - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) - - return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs) def _prepare_decode( @@ -391,7 +298,7 @@ def __init__( self.model: nn.Module # Set after init_Model def load_model(self) -> None: - with CudaMemoryProfiler() as m: + with DeviceMemoryProfiler() as m: self.model = get_model( model_config=self.model_config, device_config=self.device_config, @@ -586,4 +493,4 @@ def execute_model( # the communication time as well. output.model_forward_time = model_forward_time - return [output] + return [output] \ No newline at end of file