Skip to content

Commit

Permalink
Merge remote-tracking branch 'AzureGIT/main' into llava_devel
Browse files Browse the repository at this point in the history
  • Loading branch information
AzureSilent committed Dec 29, 2023
2 parents a8b0dbc + 358c328 commit a039da9
Show file tree
Hide file tree
Showing 32 changed files with 597 additions and 333 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Easy, fast, and cheap LLM serving for everyone
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).

---

## About
vLLM is a fast and easy-to-use library for LLM inference and serving.

vLLM is fast with:
Expand All @@ -54,6 +54,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
Expand Down
12 changes: 6 additions & 6 deletions csrc/pos_encoding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ __global__ void rotary_embedding_kernel(
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int query_stride,
const int key_stride,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
Expand All @@ -60,7 +60,7 @@ __global__ void rotary_embedding_kernel(
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * query_stride + head_idx * head_size;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
Expand All @@ -69,7 +69,7 @@ __global__ void rotary_embedding_kernel(
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * key_stride + head_idx * head_size;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
Expand All @@ -89,8 +89,8 @@ void rotary_embedding(
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
int query_stride = query.stride(-2);
int key_stride = key.stride(-2);
int64_t query_stride = query.stride(-2);
int64_t key_stride = key.stride(-2);

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
Expand Down
10 changes: 10 additions & 0 deletions csrc/quantization/gptq/q_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace gptq {
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))

#if defined(USE_ROCM)
#include <hipblas/hipblas.h>
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
Expand Down Expand Up @@ -520,12 +521,21 @@ __global__ void gemm_half_q_half_alt_kernel(
zeros_tmp[tmp_k] = zero;
}
for (int m = 0; m < b_end; m++) {
#ifndef USE_ROCM
res2 = {};
#else
res2.x = __half_as_ushort(__float2half(0));
res2.y = __half_as_ushort(__float2half(0));
#endif
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2);
#ifndef USE_ROCM
res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
#else
res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
#endif
}
i += width;
k += 4;
Expand Down
1 change: 1 addition & 0 deletions docs/source/getting_started/amd-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from

- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
- `Pytorch <https://pytorch.org/>`_
- `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_

1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_

Expand Down
4 changes: 4 additions & 0 deletions docs/source/getting_started/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ You can install vLLM using pip:
$ pip uninstall torch -y
$ pip install torch --upgrade --index-url https://download.pytorch.org/whl/cu118
$ # Re-install xFormers with CUDA 11.8.
$ pip uninstall xformers -y
$ pip install --upgrade xformers --index-url https://download.pytorch.org/whl/cu118
.. _build_from_source:

Expand Down
6 changes: 4 additions & 2 deletions docs/source/models/engine_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ Below, you can find an explanation of every engine argument for vLLM:

CPU swap space size (GiB) per GPU.

.. option:: --gpu-memory-utilization <percentage>
.. option:: --gpu-memory-utilization <fraction>

The percentage of GPU memory to be used for the model executor.
The fraction of GPU memory to be used for the model executor, which can range from 0 to 1.
For example, a value of 0.5 would imply 50% GPU memory utilization.
If unspecified, will use the default value of 0.9.

.. option:: --max-num-batched-tokens <tokens>

Expand Down
5 changes: 4 additions & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`ChatGLMModel`
- ChatGLM
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
* - :code:`DeciLMForCausalLM`
- DeciLM
- :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc.
* - :code:`BloomForCausalLM`
- BLOOM, BLOOMZ, BLOOMChat
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
Expand Down Expand Up @@ -90,7 +93,7 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
If vLLM successfully generates text, it indicates that your model is supported.

.. tip::
To use models from `ModelScope <www.modelscope.cn>`_ instead of HuggingFace Hub, set an environment variable:
To use models from `ModelScope <https://www.modelscope.cn>`_ instead of HuggingFace Hub, set an environment variable:

.. code-block:: shell
Expand Down
2 changes: 1 addition & 1 deletion docs/source/serving/serving_with_langchain.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ To run inference on a single or multiple GPUs, use ``VLLM`` class from ``langcha
print(llm("What is the capital of France ?"))
Please refer to this `Tutorial <https://github.com/langchain-ai/langchain/blob/master/docs/extras/integrations/llms/vllm.ipynb>`_ for more details.
Please refer to this `Tutorial <https://github.com/langchain-ai/langchain/blob/master/docs/docs/integrations/llms/vllm.ipynb>`_ for more details.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,13 @@ def get_torch_arch_list() -> Set[str]:
"csrc/activation_kernels.cu",
"csrc/layernorm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/quantization/gptq/q_gemm.cu",
"csrc/cuda_utils_kernels.cu",
"csrc/pybind.cpp",
]

if _is_cuda():
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
vllm_extension_sources.append("csrc/quantization/gptq/q_gemm.cu")

vllm_extension = CUDAExtension(
name="vllm._C",
Expand Down
14 changes: 8 additions & 6 deletions tests/async_engine/test_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@ def test_api_server(api_server):
"""
with Pool(32) as pool:
# Wait until the server is ready
prompts = ["Hello world"] * 1
prompts = ["warm up"] * 1
result = None
while not result:
try:
for _ in pool.map(_query_server, prompts):
for r in pool.map(_query_server, prompts):
result = r
break
except Exception:
except requests.exceptions.ConnectionError:
time.sleep(1)

# Actual tests start here
Expand All @@ -63,13 +64,14 @@ def test_api_server(api_server):
assert num_aborted_requests == 0

# Try with 100 prompts
prompts = ["Hello world"] * 100
prompts = ["test prompt"] * 100
for result in pool.map(_query_server, prompts):
assert result

# Cancel requests
prompts = ["canceled requests"] * 100
pool.map_async(_query_server, prompts)
time.sleep(0.01)
time.sleep(0.001)
pool.terminate()
pool.join()

Expand All @@ -81,6 +83,6 @@ def test_api_server(api_server):
# check that server still runs after cancellations
with Pool(32) as pool:
# Try with 100 prompts
prompts = ["Hello world"] * 100
prompts = ["test prompt after canceled"] * 100
for result in pool.map(_query_server, prompts):
assert result
9 changes: 5 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer

_TEST_PROMPTS = ["prompts/example.txt"]
_LONG_PROMPTS = ["prompts/summary.txt"]
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]


def _read_prompts(filename: str) -> str:
Expand All @@ -24,15 +25,15 @@ def _read_prompts(filename: str) -> str:
def example_prompts() -> List[str]:
prompts = []
for filename in _TEST_PROMPTS:
prompts += _read_prompts(os.path.join("tests", filename))
prompts += _read_prompts(filename)
return prompts


@pytest.fixture
def example_long_prompts() -> List[str]:
prompts = []
for filename in _LONG_PROMPTS:
prompts += _read_prompts(os.path.join("tests", filename))
prompts += _read_prompts(filename)
return prompts


Expand Down
2 changes: 1 addition & 1 deletion tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from vllm.config import ParallelConfig
from vllm.engine.ray_utils import get_open_port
from vllm.utils import get_open_port
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce,
tensor_model_parallel_all_gather,
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
"mistralai/Mistral-7B-v0.1",
"Deci/DeciLM-7b",
"tiiuae/falcon-7b",
"gpt2",
"bigcode/tiny_starcoder_py",
Expand Down
2 changes: 1 addition & 1 deletion vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams

__version__ = "0.2.5"
__version__ = "0.2.6"

__all__ = [
"LLM",
Expand Down
33 changes: 15 additions & 18 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,24 +112,20 @@ def _verify_load_format(self) -> None:
supported_load_format = [
"auto", "pt", "safetensors", "npcache", "dummy"
]
rocm_not_supported_load_format = ["safetensors"]
rocm_not_supported_load_format = []
if load_format not in supported_load_format:
raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
if is_hip():
if load_format in ["safetensors"]:
rocm_supported_load_format = [
f for f in supported_load_format
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format \'{load_format}\' is not supported in ROCm. "
f"Supported load format are "
f"{rocm_supported_load_format}")
# Force ROCm to load from pt weights if nothing specific is set
if load_format == "auto":
load_format = "pt"
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f for f in supported_load_format
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format \'{load_format}\' is not supported in ROCm. "
f"Supported load format are "
f"{rocm_supported_load_format}")

# TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures = getattr(self.hf_config, "architectures", [])
Expand All @@ -149,7 +145,7 @@ def _verify_tokenizer_mode(self) -> None:

def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm"]
rocm_not_supported_quantization = ["awq", "gptq"]
rocm_not_supported_quantization = ["awq"]
if self.quantization is not None:
self.quantization = self.quantization.lower()

Expand Down Expand Up @@ -185,10 +181,11 @@ def _verify_cuda_graph(self) -> None:
self.max_context_len_to_capture = self.max_model_len
self.max_context_len_to_capture = min(self.max_context_len_to_capture,
self.max_model_len)
if self.quantization == "gptq" and not self.enforce_eager:
if (self.quantization in ["gptq", "squeezellm"]
and not self.enforce_eager):
# Related issue: https://github.com/vllm-project/vllm/issues/2147
logger.warning("GPTQ does not support CUDA graph yet. Disabling "
"CUDA graph.")
logger.warning(f"{self.quantization} does not support CUDA graph "
"yet. Disabling CUDA graph.")
self.enforce_eager = True

def verify_with_parallel_config(
Expand Down
6 changes: 3 additions & 3 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs()[0]
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = len(seq.logical_token_blocks)
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
Expand All @@ -122,7 +122,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same
# prompt.
seq = seq_group.get_seqs()[0]
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]

# Allocate new physical token blocks that will store the prompt tokens.
block_table: BlockTable = []
Expand All @@ -137,7 +137,7 @@ def allocate(self, seq_group: SequenceGroup) -> None:
block_table.append(block)

# Assign the block table for each sequence.
for seq in seq_group.get_seqs():
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy()

def can_append_slot(self, seq_group: SequenceGroup) -> bool:
Expand Down
12 changes: 7 additions & 5 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,17 @@ def _schedule(self) -> SchedulerOutputs:
while self.waiting:
seq_group = self.waiting[0]

assert seq_group.num_seqs() == 1, (
waiting_seqs = seq_group.get_seqs(
status=SequenceStatus.WAITING)
assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
num_prompt_tokens = waiting_seqs[0].get_len()
if num_prompt_tokens > self.prompt_limit:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds limit of {self.prompt_limit}")
for seq in seq_group.get_seqs():
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
Expand All @@ -161,7 +163,7 @@ def _schedule(self) -> SchedulerOutputs:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds the capacity of block_manager")
for seq in seq_group.get_seqs():
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
Expand Down Expand Up @@ -317,7 +319,7 @@ def free_finished_seq_groups(self) -> None:

def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.get_seqs():
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
seq.status = SequenceStatus.RUNNING

def _append_slot(
Expand Down
Loading

0 comments on commit a039da9

Please sign in to comment.