Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
wangshuai09 committed Oct 8, 2024
1 parent 23c45ae commit 1df3978
Show file tree
Hide file tree
Showing 15 changed files with 104 additions and 248 deletions.
13 changes: 3 additions & 10 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,9 @@ async def run_vllm_async(
return end - start


def run_hf(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
use_beam_search: bool,
max_batch_size: int,
trust_remote_code: bool,
device: str
) -> float:
def run_hf(requests: List[Tuple[str, int, int]], model: str,
tokenizer: PreTrainedTokenizerBase, n: int, use_beam_search: bool,
max_batch_size: int, trust_remote_code: bool, device: str) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
Expand Down
7 changes: 6 additions & 1 deletion examples/offline_inference_npu.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import gc

import torch

from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import destroy_model_parallel, destroy_distributed_environment
from vllm.distributed.parallel_state import (destroy_distributed_environment,
destroy_model_parallel)


def clean_up():
destroy_model_parallel()
destroy_distributed_environment()
gc.collect()
torch.npu.empty_cache()


# Sample prompts.
prompts = [
"Hello, my name is",
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def _is_openvino() -> bool:
def _is_xpu() -> bool:
return VLLM_TARGET_DEVICE == "xpu"


def _is_npu() -> bool:
return VLLM_TARGET_DEVICE == "npu"

Expand All @@ -294,8 +295,8 @@ def _build_custom_ops() -> bool:


def _build_core_ext() -> bool:
return not (_is_neuron() or _is_tpu() or _is_openvino() or _is_xpu() or
_is_npu())
return not (_is_neuron() or _is_tpu() or _is_openvino() or _is_xpu()
or _is_npu())


def get_hipcc_rocm_version():
Expand Down
4 changes: 2 additions & 2 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from ..utils import multi_gpu_test

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
"facebook/opt-125m", "/home/models/llama-2-7b/"
# "meta-llama/Llama-2-7b-hf",
]

TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
Expand Down
3 changes: 1 addition & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
identity, is_cpu)
from vllm.platforms import current_platform


logger = init_logger(__name__)

Expand Down
74 changes: 36 additions & 38 deletions vllm/attention/backends/ascend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from typing import Any, Dict, List, TYPE_CHECKING, Optional, Tuple, Type
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import torch
import torch_npu
Expand All @@ -11,11 +11,12 @@
CommonMetadataBuilder,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention, PagedAttentionMetadata
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)

if TYPE_CHECKING:
from vllm.worker.npu_model_runner import ModelInputForNPUBuilder


SHARE_MASK_TRIL_PREFIX_CACHE = None
SHARE_MASK_TRIL = None

Expand Down Expand Up @@ -55,10 +56,9 @@ def swap_blocks(
dst_indices = src_to_dst[:, 1]

dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
dst_key_cache.device)
dst_key_cache.device)
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
dst_key_cache.device)

dst_key_cache.device)

@staticmethod
def copy_blocks(
Expand Down Expand Up @@ -222,7 +222,7 @@ def prefill_metadata(self) -> Optional["AscendMetadata"]:
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
)
)
return self._cached_prefill_metadata

@property
Expand Down Expand Up @@ -260,7 +260,7 @@ def decode_metadata(self) -> Optional["AscendMetadata"]:
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
)
)
return self._cached_decode_metadata


Expand Down Expand Up @@ -308,19 +308,18 @@ def compute_npu_slot_indices(self, is_profile_run, slot_indices, seq_id,
slot_indices.extend([[PAD_SLOT_ID, 0]] * (max_query_len - numel))

def _add_seq_group(
self,
inter_data: "ModelInputForNPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool
):
self, inter_data: "ModelInputForNPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
max_query_len = max(max(data.query_lens)
for data in self.input_builder.inter_data_list)
max_query_len = max(
max(data.query_lens)
for data in self.input_builder.inter_data_list)

is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
Expand Down Expand Up @@ -401,7 +400,6 @@ def forward(
value: torch.Tensor,
kv_cache: List[torch.Tensor],
attn_metadata: AscendMetadata,
kv_scale: float = 1.0,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
Expand All @@ -413,29 +411,34 @@ def forward(
num_tokens = batch_size * seq_len
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache: shape = [2, num_blocks, block_size, num_kv_heads * head_size]
key_cache [num_blocks, block_size, num_kv_heads * head_size]
value_cache [num_blocks, block_size, num_kv_heads * head_size]
kv_cache: shape = [2, num_blocks, block_size,
num_kv_heads * head_size]
key_cache = [num_blocks, block_size,
num_kv_heads * head_size]
value_cache = [num_blocks, block_size,
num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len * num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl"
)
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
# view q k v to BSH
num_tokens = query.shape[0]

if kv_cache is not None:
if attn_metadata.num_prefills > 0:
slot_indices = attn_metadata.prefill_metadata.slot_mapping
slot_indices = (None
if attn_metadata.prefill_metadata is None else
attn_metadata.prefill_metadata.slot_mapping)
else:
slot_indices = attn_metadata.decode_metadata.slot_mapping
slot_indices = (None
if attn_metadata.decode_metadata is None else
attn_metadata.decode_metadata.slot_mapping)
key_cache, value_cache = kv_cache[0], kv_cache[1]
AscendPagedAttention.write_to_paged_cache(
key,
Expand All @@ -450,10 +453,8 @@ def forward(
if num_tokens > 16384:
attn_metadata.sparse_mode = 2
attention_mask = gen_input_mask(
attn_metadata.max_prefill_seq_len,
self.sliding_window,
num_tokens
)
attn_metadata.max_prefill_seq_len, self.sliding_window,
num_tokens)
attn_metadata.attn_mask = attention_mask

if (self.alibi_slopes is not None
Expand Down Expand Up @@ -491,8 +492,7 @@ def forward(
sparse_mode=attn_metadata.sparse_mode,
)
output = output.transpose(1, 2).reshape(
num_tokens, -1, self.num_heads * self.head_size
)
num_tokens, -1, self.num_heads * self.head_size)

elif attn_metadata.decode_metadata:
# FA for decoding phase
Expand Down Expand Up @@ -539,15 +539,13 @@ def gen_input_mask(seq_len, sliding_window, len):
global SHARE_MASK_TRIL
if SHARE_MASK_TRIL is None or SHARE_MASK_TRIL.shape[0] < seq_len:
SHARE_MASK_TRIL = ~torch.tril(
torch.ones(seq_len, seq_len, dtype=bool, device="npu")
)
torch.ones(seq_len, seq_len, dtype=bool, device="npu"))

attention_mask = SHARE_MASK_TRIL
if sliding_window is not None:
attention_mask = ~attention_mask
attention_mask = torch.triu(
attention_mask, diagonal=1 - sliding_window
)
attention_mask = torch.triu(attention_mask,
diagonal=1 - sliding_window)
attention_mask = ~attention_mask

return attention_mask
Expand Down
2 changes: 2 additions & 0 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.triton_utils.importing import HAS_TRITON

if HAS_TRITON:
from vllm.triton_utils import maybe_set_triton_cache_manager

from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
cuda_is_initialized, get_distributed_init_method,
get_open_port, get_vllm_instance_id, make_async,
Expand Down
10 changes: 6 additions & 4 deletions vllm/executor/multiproc_npu_executor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
import torch, torch_npu # noqa

import torch # noqa
import torch_npu # noqa

from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync)
from vllm.executor.npu_executor import NPUExecutor
from vllm.logger import init_logger
from vllm.utils import update_environment_variables
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync)

logger = init_logger(__name__)

Expand All @@ -21,7 +23,7 @@ def _check_executor_parameters(self):
if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ:
update_environment_variables({
"ASCEND_RT_VISIBLE_DEVICES":
(",".join(map(str, range(world_size))))
(",".join(map(str, range(world_size))))
})

npu_device_count = torch.npu.device_count()
Expand Down
2 changes: 1 addition & 1 deletion vllm/executor/npu_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Callable, Type, Tuple
from typing import Callable, List, Optional, Tuple, Type

from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
Expand Down
7 changes: 2 additions & 5 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,14 @@ def forward_npu(
import torch_npu

if residual is not None:
x, _, residual = torch_npu.npu_add_rms_norm(x,
residual,
self.weight,
self.variance_epsilon)
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
return x, residual

x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
return x


def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"
s += f", eps={self.variance_epsilon}"
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@

from .interfaces import SupportsLoRA


current_backend = "inductor"
if current_platform.is_npu():
current_backend = "npu"


@torch.compile(backend=current_backend)
def layer_norm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _prepare_seq_groups(
# If the current seq group is in decode stage, it is None.
seq_len: Optional[int] = None
query_len: Optional[int] = None
padding_len: Optional[int] = 0
padding_len: int = 0
prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
if cache is not None else [])
sample_indices: List[int] = (sample_obj.sample_indices
Expand Down
4 changes: 2 additions & 2 deletions vllm/platforms/ascend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Tuple
import os
from typing import Tuple

import torch

Expand Down Expand Up @@ -36,7 +36,7 @@ def inference_mode(cls):
return torch.inference_mode()

@classmethod
def set_device(cls, device: torch.device) -> torch.device:
def set_device(cls, device: torch.device):
torch.npu.set_device(device)

@classmethod
Expand Down
Loading

0 comments on commit 1df3978

Please sign in to comment.