Skip to content

Commit a039da9

Browse files
committed
Merge remote-tracking branch 'AzureGIT/main' into llava_devel
2 parents a8b0dbc + 358c328 commit a039da9

32 files changed

+597
-333
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Easy, fast, and cheap LLM serving for everyone
2727
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
2828

2929
---
30-
30+
## About
3131
vLLM is a fast and easy-to-use library for LLM inference and serving.
3232

3333
vLLM is fast with:
@@ -54,6 +54,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
5454
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
5555
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
5656
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
57+
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
5758
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
5859
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
5960
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)

csrc/pos_encoding_kernels.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ __global__ void rotary_embedding_kernel(
4343
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
4444
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
4545
const int rot_dim,
46-
const int query_stride,
47-
const int key_stride,
46+
const int64_t query_stride,
47+
const int64_t key_stride,
4848
const int num_heads,
4949
const int num_kv_heads,
5050
const int head_size) {
@@ -60,7 +60,7 @@ __global__ void rotary_embedding_kernel(
6060
const int nq = num_heads * embed_dim;
6161
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
6262
const int head_idx = i / embed_dim;
63-
const int token_head = token_idx * query_stride + head_idx * head_size;
63+
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
6464
const int rot_offset = i % embed_dim;
6565
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
6666
sin_ptr, rot_offset, embed_dim);
@@ -69,7 +69,7 @@ __global__ void rotary_embedding_kernel(
6969
const int nk = num_kv_heads * embed_dim;
7070
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
7171
const int head_idx = i / embed_dim;
72-
const int token_head = token_idx * key_stride + head_idx * head_size;
72+
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
7373
const int rot_offset = i % embed_dim;
7474
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
7575
sin_ptr, rot_offset, embed_dim);
@@ -89,8 +89,8 @@ void rotary_embedding(
8989
int rot_dim = cos_sin_cache.size(1);
9090
int num_heads = query.size(-1) / head_size;
9191
int num_kv_heads = key.size(-1) / head_size;
92-
int query_stride = query.stride(-2);
93-
int key_stride = key.stride(-2);
92+
int64_t query_stride = query.stride(-2);
93+
int64_t key_stride = key.stride(-2);
9494

9595
dim3 grid(num_tokens);
9696
dim3 block(std::min(num_heads * rot_dim / 2, 512));

csrc/quantization/gptq/q_gemm.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace gptq {
2828
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
2929

3030
#if defined(USE_ROCM)
31+
#include <hipblas/hipblas.h>
3132
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
3233
hipblasOperation_t transA,
3334
hipblasOperation_t transB,
@@ -520,12 +521,21 @@ __global__ void gemm_half_q_half_alt_kernel(
520521
zeros_tmp[tmp_k] = zero;
521522
}
522523
for (int m = 0; m < b_end; m++) {
524+
#ifndef USE_ROCM
523525
res2 = {};
526+
#else
527+
res2.x = __half_as_ushort(__float2half(0));
528+
res2.y = __half_as_ushort(__float2half(0));
529+
#endif
524530
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
525531
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
526532
res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2);
527533
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2);
534+
#ifndef USE_ROCM
528535
res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
536+
#else
537+
res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
538+
#endif
529539
}
530540
i += width;
531541
k += 4;

docs/source/getting_started/amd-installation.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from
116116

117117
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
118118
- `Pytorch <https://pytorch.org/>`_
119+
- `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_
119120

120121
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
121122

docs/source/getting_started/installation.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ You can install vLLM using pip:
4242
$ pip uninstall torch -y
4343
$ pip install torch --upgrade --index-url https://download.pytorch.org/whl/cu118
4444
45+
$ # Re-install xFormers with CUDA 11.8.
46+
$ pip uninstall xformers -y
47+
$ pip install --upgrade xformers --index-url https://download.pytorch.org/whl/cu118
48+
4549
4650
.. _build_from_source:
4751

docs/source/models/engine_args.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,11 @@ Below, you can find an explanation of every engine argument for vLLM:
8989

9090
CPU swap space size (GiB) per GPU.
9191

92-
.. option:: --gpu-memory-utilization <percentage>
92+
.. option:: --gpu-memory-utilization <fraction>
9393

94-
The percentage of GPU memory to be used for the model executor.
94+
The fraction of GPU memory to be used for the model executor, which can range from 0 to 1.
95+
For example, a value of 0.5 would imply 50% GPU memory utilization.
96+
If unspecified, will use the default value of 0.9.
9597

9698
.. option:: --max-num-batched-tokens <tokens>
9799

docs/source/models/supported_models.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ Alongside each architecture, we include some popular models that use it.
2323
* - :code:`ChatGLMModel`
2424
- ChatGLM
2525
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
26+
* - :code:`DeciLMForCausalLM`
27+
- DeciLM
28+
- :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc.
2629
* - :code:`BloomForCausalLM`
2730
- BLOOM, BLOOMZ, BLOOMChat
2831
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
@@ -90,7 +93,7 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
9093
If vLLM successfully generates text, it indicates that your model is supported.
9194

9295
.. tip::
93-
To use models from `ModelScope <www.modelscope.cn>`_ instead of HuggingFace Hub, set an environment variable:
96+
To use models from `ModelScope <https://www.modelscope.cn>`_ instead of HuggingFace Hub, set an environment variable:
9497

9598
.. code-block:: shell
9699

docs/source/serving/serving_with_langchain.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ To run inference on a single or multiple GPUs, use ``VLLM`` class from ``langcha
2828
2929
print(llm("What is the capital of France ?"))
3030
31-
Please refer to this `Tutorial <https://github.com/langchain-ai/langchain/blob/master/docs/extras/integrations/llms/vllm.ipynb>`_ for more details.
31+
Please refer to this `Tutorial <https://github.com/langchain-ai/langchain/blob/master/docs/docs/integrations/llms/vllm.ipynb>`_ for more details.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,13 +219,13 @@ def get_torch_arch_list() -> Set[str]:
219219
"csrc/activation_kernels.cu",
220220
"csrc/layernorm_kernels.cu",
221221
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
222+
"csrc/quantization/gptq/q_gemm.cu",
222223
"csrc/cuda_utils_kernels.cu",
223224
"csrc/pybind.cpp",
224225
]
225226

226227
if _is_cuda():
227228
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
228-
vllm_extension_sources.append("csrc/quantization/gptq/q_gemm.cu")
229229

230230
vllm_extension = CUDAExtension(
231231
name="vllm._C",

tests/async_engine/test_api_server.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,14 @@ def test_api_server(api_server):
4444
"""
4545
with Pool(32) as pool:
4646
# Wait until the server is ready
47-
prompts = ["Hello world"] * 1
47+
prompts = ["warm up"] * 1
4848
result = None
4949
while not result:
5050
try:
51-
for _ in pool.map(_query_server, prompts):
51+
for r in pool.map(_query_server, prompts):
52+
result = r
5253
break
53-
except Exception:
54+
except requests.exceptions.ConnectionError:
5455
time.sleep(1)
5556

5657
# Actual tests start here
@@ -63,13 +64,14 @@ def test_api_server(api_server):
6364
assert num_aborted_requests == 0
6465

6566
# Try with 100 prompts
66-
prompts = ["Hello world"] * 100
67+
prompts = ["test prompt"] * 100
6768
for result in pool.map(_query_server, prompts):
6869
assert result
6970

7071
# Cancel requests
72+
prompts = ["canceled requests"] * 100
7173
pool.map_async(_query_server, prompts)
72-
time.sleep(0.01)
74+
time.sleep(0.001)
7375
pool.terminate()
7476
pool.join()
7577

@@ -81,6 +83,6 @@ def test_api_server(api_server):
8183
# check that server still runs after cancellations
8284
with Pool(32) as pool:
8385
# Try with 100 prompts
84-
prompts = ["Hello world"] * 100
86+
prompts = ["test prompt after canceled"] * 100
8587
for result in pool.map(_query_server, prompts):
8688
assert result

0 commit comments

Comments
 (0)