Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 48e8e3f

Browse files
zhouyuanRobert Shaw
authored and
Robert Shaw
committed
[CI/BUILD] enable intel queue for longer CPU tests (vllm-project#4113)
1 parent 1ebb772 commit 48e8e3f

11 files changed

+138
-89
lines changed

.buildkite/run-cpu-test.sh

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,15 @@ remove_docker_container() { docker rm -f cpu-test || true; }
1010
trap remove_docker_container EXIT
1111
remove_docker_container
1212

13-
# Run the image and launch offline inference
14-
docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 vllm/examples/offline_inference.py
13+
# Run the image
14+
docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test
15+
16+
# offline inference
17+
docker exec cpu-test bash -c "python3 examples/offline_inference.py"
18+
19+
# Run basic model test
20+
docker exec cpu-test bash -c "cd tests;
21+
pip install pytest Pillow protobuf
22+
bash ../.buildkite/download-images.sh
23+
cd ../
24+
pytest -v -s tests/models --ignore=tests/models/test_llava.py --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py"

.buildkite/test-template.j2

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ steps:
4040

4141
- label: "Intel Test"
4242
depends_on: ~
43+
agents:
44+
queue: intel
4345
command: bash .buildkite/run-cpu-test.sh
4446

4547
{% for step in steps %}

Dockerfile.cpu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
22

3-
FROM ubuntu:22.04
3+
FROM ubuntu:22.04 AS cpu-test-1
44

55
RUN apt-get update -y \
66
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
@@ -9,6 +9,8 @@ RUN apt-get update -y \
99
RUN pip install --upgrade pip \
1010
&& pip install wheel packaging ninja setuptools>=49.4.0 numpy
1111

12+
FROM cpu-test-1 AS build
13+
1214
COPY ./ /workspace/vllm
1315

1416
WORKDIR /workspace/vllm
@@ -19,4 +21,6 @@ RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
1921

2022
WORKDIR /workspace/
2123

24+
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
25+
2226
CMD ["/bin/bash"]

csrc/cpu/pos_encoding.cpp

Lines changed: 51 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -21,73 +21,74 @@ void rotary_embedding_impl(
2121
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
2222

2323
const int embed_dim = rot_dim / 2;
24-
TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0);
24+
bool flag = (embed_dim % VEC_ELEM_NUM == 0);
25+
const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM;
2526

26-
#pragma omp parallel for
27-
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
28-
int64_t pos = positions[token_idx];
29-
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
27+
auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr,
28+
scalar_t* qk) {
29+
int j = 0;
30+
for (; j < loop_upper; j += VEC_ELEM_NUM) {
31+
const int rot_offset = j;
32+
const int x_index = rot_offset;
33+
const int y_index = embed_dim + rot_offset;
3034

31-
for (int i = 0; i < num_heads; ++i) {
32-
const int head_idx = i;
33-
const int64_t token_head =
34-
token_idx * query_stride + head_idx * head_size;
35-
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
36-
const int rot_offset = j;
37-
const int x_index = rot_offset;
38-
const int y_index = embed_dim + rot_offset;
35+
const int64_t out_x = token_head + x_index;
36+
const int64_t out_y = token_head + y_index;
3937

40-
const int64_t out_x = token_head + x_index;
41-
const int64_t out_y = token_head + y_index;
38+
const scalar_vec_t cos(cache_ptr + x_index);
39+
const scalar_vec_t sin(cache_ptr + y_index);
4240

43-
const scalar_vec_t cos(cache_ptr + x_index);
44-
const scalar_vec_t sin(cache_ptr + y_index);
41+
const scalar_vec_t q_x(qk + out_x);
42+
const scalar_vec_t q_y(qk + out_y);
4543

46-
const scalar_vec_t q_x(query + out_x);
47-
const scalar_vec_t q_y(query + out_y);
44+
vec_op::FP32Vec8 fp32_cos(cos);
45+
vec_op::FP32Vec8 fp32_sin(sin);
4846

49-
vec_op::FP32Vec8 fp32_cos(cos);
50-
vec_op::FP32Vec8 fp32_sin(sin);
47+
vec_op::FP32Vec8 fp32_q_x(q_x);
48+
vec_op::FP32Vec8 fp32_q_y(q_y);
5149

52-
vec_op::FP32Vec8 fp32_q_x(q_x);
53-
vec_op::FP32Vec8 fp32_q_y(q_y);
50+
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
51+
scalar_vec_t(out1).save(qk + out_x);
5452

55-
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
56-
scalar_vec_t(out1).save(query + out_x);
57-
58-
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
59-
scalar_vec_t(out2).save(query + out_y);
60-
}
53+
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
54+
scalar_vec_t(out2).save(qk + out_y);
6155
}
62-
63-
for (int i = 0; i < num_kv_heads; ++i) {
64-
const int head_idx = i;
65-
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
66-
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
67-
const int rot_offset = j;
68-
const int x_index = rot_offset;
69-
const int y_index = embed_dim + rot_offset;
56+
if (!flag) {
57+
for (; j < embed_dim; ++j) {
58+
const int x_index = j;
59+
const int y_index = embed_dim + j;
7060

7161
const int64_t out_x = token_head + x_index;
7262
const int64_t out_y = token_head + y_index;
7363

74-
const scalar_vec_t cos(cache_ptr + x_index);
75-
const scalar_vec_t sin(cache_ptr + y_index);
64+
const float fp32_cos = cache_ptr[x_index];
65+
const float fp32_sin = cache_ptr[y_index];
7666

77-
const scalar_vec_t k_x(key + out_x);
78-
const scalar_vec_t k_y(key + out_y);
67+
const float fp32_q_x = qk[out_x];
68+
const float fp32_q_y = qk[out_y];
7969

80-
vec_op::FP32Vec8 fp32_cos(cos);
81-
vec_op::FP32Vec8 fp32_sin(sin);
70+
qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
71+
qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
72+
}
73+
}
74+
};
8275

83-
vec_op::FP32Vec8 fp32_k_x(k_x);
84-
vec_op::FP32Vec8 fp32_k_y(k_y);
76+
#pragma omp parallel for
77+
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
78+
int64_t pos = positions[token_idx];
79+
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
8580

86-
auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin;
87-
scalar_vec_t(out1).save(key + out_x);
88-
auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin;
89-
scalar_vec_t(out2).save(key + out_y);
90-
}
81+
for (int i = 0; i < num_heads; ++i) {
82+
const int head_idx = i;
83+
const int64_t token_head =
84+
token_idx * query_stride + head_idx * head_size;
85+
compute_loop(token_head, cache_ptr, query);
86+
}
87+
88+
for (int i = 0; i < num_kv_heads; ++i) {
89+
const int head_idx = i;
90+
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
91+
compute_loop(token_head, cache_ptr, key);
9192
}
9293
}
9394
}

tests/conftest.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from vllm.multimodal import MultiModalData
2121
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
2222
from vllm.sequence import SampleLogprobs
23+
from vllm.utils import is_cpu
2324

2425
logger = init_logger(__name__)
2526

@@ -60,7 +61,8 @@ def cleanup():
6061
with contextlib.suppress(AssertionError):
6162
torch.distributed.destroy_process_group()
6263
gc.collect()
63-
torch.cuda.empty_cache()
64+
if not is_cpu():
65+
torch.cuda.empty_cache()
6466

6567

6668
@pytest.fixture()
@@ -153,6 +155,12 @@ def example_long_prompts() -> List[str]:
153155

154156
class HfRunner:
155157

158+
def wrap_device(self, input: any):
159+
if not is_cpu():
160+
return input.to("cuda")
161+
else:
162+
return input.to("cpu")
163+
156164
def __init__(
157165
self,
158166
model_name: str,
@@ -167,17 +175,18 @@ def __init__(
167175
if model_name in _EMBEDDING_MODELS:
168176
# Lazy init required for AMD CI
169177
from sentence_transformers import SentenceTransformer
170-
self.model = SentenceTransformer(
171-
model_name,
172-
device="cpu",
173-
).to(dtype=torch_dtype).cuda()
178+
self.model = self.wrap_device(
179+
SentenceTransformer(
180+
model_name,
181+
device="cpu",
182+
).to(dtype=torch_dtype))
174183
else:
175-
self.model = AutoModelForCausalLM.from_pretrained(
176-
model_name,
177-
torch_dtype=torch_dtype,
178-
trust_remote_code=True,
179-
token=access_token,
180-
).cuda()
184+
self.model = self.wrap_device(
185+
AutoModelForCausalLM.from_pretrained(
186+
model_name,
187+
torch_dtype=torch_dtype,
188+
trust_remote_code=True,
189+
))
181190

182191
self.tokenizer = AutoTokenizer.from_pretrained(
183192
model_name,
@@ -218,7 +227,7 @@ def generate(
218227
inputs = self.processor(**processor_kwargs)
219228

220229
output_ids = self.model.generate(
221-
**inputs.to("cuda"),
230+
**self.wrap_device(inputs),
222231
use_cache=True,
223232
**kwargs,
224233
)
@@ -275,7 +284,7 @@ def generate_greedy_logprobs(
275284
for prompt in prompts:
276285
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
277286
output = self.model.generate(
278-
input_ids.cuda(),
287+
self.wrap_device(input_ids),
279288
use_cache=True,
280289
do_sample=False,
281290
max_new_tokens=max_tokens,
@@ -310,7 +319,7 @@ def generate_greedy_logprobs_limit(
310319
for prompt in prompts:
311320
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
312321
output = self.model.generate(
313-
input_ids.cuda(),
322+
self.wrap_device(input_ids),
314323
use_cache=True,
315324
do_sample=False,
316325
max_new_tokens=max_tokens,

tests/models/test_aqlm.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88

99
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1010

11-
capability = torch.cuda.get_device_capability()
12-
capability = capability[0] * 10 + capability[1]
13-
aqlm_not_supported = (capability <
14-
QUANTIZATION_METHODS["aqlm"].get_min_capability())
11+
aqlm_not_supported = True
12+
13+
if torch.cuda.is_available():
14+
capability = torch.cuda.get_device_capability()
15+
capability = capability[0] * 10 + capability[1]
16+
aqlm_not_supported = (capability <
17+
QUANTIZATION_METHODS["aqlm"].get_min_capability())
1518

1619
# In this test we hardcode prompts and generations for the model so we don't
1720
# need to require the AQLM package as a dependency

tests/models/test_big_models.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import sys
99

1010
import pytest
11+
import torch
1112

1213
MODELS = [
1314
"meta-llama/Llama-2-7b-hf",
@@ -36,9 +37,14 @@
3637
"mosaicml/mpt-7b",
3738
]
3839

40+
#TODO: remove this after CPU float16 support ready
41+
target_dtype = "float"
42+
if torch.cuda.is_available():
43+
target_dtype = "half"
44+
3945

4046
@pytest.mark.parametrize("model", MODELS)
41-
@pytest.mark.parametrize("dtype", ["half"])
47+
@pytest.mark.parametrize("dtype", [target_dtype])
4248
@pytest.mark.parametrize("max_tokens", [32])
4349
def test_models(
4450
hf_runner,
@@ -78,7 +84,7 @@ def test_models(
7884

7985
@pytest.mark.skip("Slow and not useful (just prints model).")
8086
@pytest.mark.parametrize("model", MODELS)
81-
@pytest.mark.parametrize("dtype", ["half"])
87+
@pytest.mark.parametrize("dtype", [target_dtype])
8288
def test_model_print(
8389
vllm_runner,
8490
model: str,

tests/models/test_fp8.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,13 @@
6767
},
6868
}
6969

70-
capability = torch.cuda.get_device_capability()
71-
capability = capability[0] * 10 + capability[1]
72-
fp8_not_supported = (capability <
73-
QUANTIZATION_METHODS["fp8"].get_min_capability())
70+
fp8_not_supported = True
71+
72+
if torch.cuda.is_available():
73+
capability = torch.cuda.get_device_capability()
74+
capability = capability[0] * 10 + capability[1]
75+
fp8_not_supported = (capability <
76+
QUANTIZATION_METHODS["fp8"].get_min_capability())
7477

7578

7679
@pytest.mark.skipif(fp8_not_supported,

tests/models/test_gptq_marlin.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@
2121

2222
MAX_MODEL_LEN = 1024
2323

24-
capability = torch.cuda.get_device_capability()
25-
capability = capability[0] * 10 + capability[1]
26-
gptq_marlin_not_supported = (
27-
capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability())
24+
gptq_marlin_not_supported = True
25+
26+
if torch.cuda.is_available():
27+
capability = torch.cuda.get_device_capability()
28+
capability = capability[0] * 10 + capability[1]
29+
gptq_marlin_not_supported = (
30+
capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability())
2831

2932
MODELS = [
3033
# act_order==False, group_size=channelwise

tests/models/test_gptq_marlin_24.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
from tests.models.utils import check_logprobs_close
1515
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1616

17-
capability = torch.cuda.get_device_capability()
18-
capability = capability[0] * 10 + capability[1]
19-
marlin_not_supported = (capability <
20-
QUANTIZATION_METHODS["marlin"].get_min_capability())
17+
marlin_not_supported = True
18+
19+
if torch.cuda.is_available():
20+
capability = torch.cuda.get_device_capability()
21+
capability = capability[0] * 10 + capability[1]
22+
marlin_not_supported = (
23+
capability < QUANTIZATION_METHODS["marlin"].get_min_capability())
2124

2225

2326
@dataclass

tests/models/test_marlin.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,15 @@
2323
from tests.models.utils import check_logprobs_close
2424
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
2525

26-
capability = torch.cuda.get_device_capability()
27-
capability = capability[0] * 10 + capability[1]
28-
marlin_not_supported = (capability <
29-
QUANTIZATION_METHODS["marlin"].get_min_capability())
26+
from .utils import check_logprobs_close
27+
28+
marlin_not_supported = True
29+
30+
if torch.cuda.is_available():
31+
capability = torch.cuda.get_device_capability()
32+
capability = capability[0] * 10 + capability[1]
33+
marlin_not_supported = (
34+
capability < QUANTIZATION_METHODS["marlin"].get_min_capability())
3035

3136

3237
@dataclass

0 commit comments

Comments
 (0)