Skip to content

[CI/BUILD] enable intel queue for longer CPU tests #4113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,15 @@ remove_docker_container() { docker rm -f cpu-test || true; }
trap remove_docker_container EXIT
remove_docker_container

# Run the image and launch offline inference
docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 vllm/examples/offline_inference.py
# Run the image
docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test

# offline inference
docker exec cpu-test bash -c "python3 examples/offline_inference.py"

# Run basic model test
docker exec cpu-test bash -c "cd tests;
pip install pytest Pillow protobuf
bash ../.buildkite/download-images.sh
cd ../
pytest -v -s tests/models --ignore=tests/models/test_llava.py --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py"
2 changes: 2 additions & 0 deletions .buildkite/test-template.j2
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ steps:

- label: "Intel Test"
depends_on: ~
agents:
queue: intel
command: bash .buildkite/run-cpu-test.sh

{% for step in steps %}
Expand Down
6 changes: 5 additions & 1 deletion Dockerfile.cpu
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.

FROM ubuntu:22.04
FROM ubuntu:22.04 AS cpu-test-1

RUN apt-get update -y \
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
Expand All @@ -9,6 +9,8 @@ RUN apt-get update -y \
RUN pip install --upgrade pip \
&& pip install wheel packaging ninja setuptools>=49.4.0 numpy

FROM cpu-test-1 AS build

COPY ./ /workspace/vllm

WORKDIR /workspace/vllm
Expand All @@ -19,4 +21,6 @@ RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install

WORKDIR /workspace/

RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CMD ["/bin/bash"]
101 changes: 51 additions & 50 deletions csrc/cpu/pos_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,73 +21,74 @@ void rotary_embedding_impl(
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();

const int embed_dim = rot_dim / 2;
TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0);
bool flag = (embed_dim % VEC_ELEM_NUM == 0);
const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM;

#pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr,
scalar_t* qk) {
int j = 0;
for (; j < loop_upper; j += VEC_ELEM_NUM) {
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;

for (int i = 0; i < num_heads; ++i) {
const int head_idx = i;
const int64_t token_head =
token_idx * query_stride + head_idx * head_size;
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;

const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);

const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);
const scalar_vec_t q_x(qk + out_x);
const scalar_vec_t q_y(qk + out_y);

const scalar_vec_t q_x(query + out_x);
const scalar_vec_t q_y(query + out_y);
vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);

vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);
vec_op::FP32Vec8 fp32_q_x(q_x);
vec_op::FP32Vec8 fp32_q_y(q_y);

vec_op::FP32Vec8 fp32_q_x(q_x);
vec_op::FP32Vec8 fp32_q_y(q_y);
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
scalar_vec_t(out1).save(qk + out_x);

auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
scalar_vec_t(out1).save(query + out_x);

auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
scalar_vec_t(out2).save(query + out_y);
}
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
scalar_vec_t(out2).save(qk + out_y);
}

for (int i = 0; i < num_kv_heads; ++i) {
const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
if (!flag) {
for (; j < embed_dim; ++j) {
const int x_index = j;
const int y_index = embed_dim + j;

const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;

const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);
const float fp32_cos = cache_ptr[x_index];
const float fp32_sin = cache_ptr[y_index];

const scalar_vec_t k_x(key + out_x);
const scalar_vec_t k_y(key + out_y);
const float fp32_q_x = qk[out_x];
const float fp32_q_y = qk[out_y];

vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);
qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
}
}
};

vec_op::FP32Vec8 fp32_k_x(k_x);
vec_op::FP32Vec8 fp32_k_y(k_y);
#pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;

auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin;
scalar_vec_t(out1).save(key + out_x);
auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin;
scalar_vec_t(out2).save(key + out_y);
}
for (int i = 0; i < num_heads; ++i) {
const int head_idx = i;
const int64_t token_head =
token_idx * query_stride + head_idx * head_size;
compute_loop(token_head, cache_ptr, query);
}

for (int i = 0; i < num_kv_heads; ++i) {
const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
compute_loop(token_head, cache_ptr, key);
}
}
}
Expand Down
36 changes: 23 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.multimodal import MultiModalData
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu

logger = init_logger(__name__)

Expand Down Expand Up @@ -58,7 +59,8 @@ def cleanup():
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
torch.cuda.empty_cache()
if not is_cpu():
torch.cuda.empty_cache()


@pytest.fixture()
Expand Down Expand Up @@ -151,6 +153,12 @@ def example_long_prompts() -> List[str]:

class HfRunner:

def wrap_device(self, input: any):
if not is_cpu():
return input.to("cuda")
else:
return input.to("cpu")

def __init__(
self,
model_name: str,
Expand All @@ -164,16 +172,18 @@ def __init__(
if model_name in _EMBEDDING_MODELS:
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(
model_name,
device="cpu",
).to(dtype=torch_dtype).cuda()
self.model = self.wrap_device(
SentenceTransformer(
model_name,
device="cpu",
).to(dtype=torch_dtype))
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
self.model = self.wrap_device(
AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
))

self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
Expand Down Expand Up @@ -214,7 +224,7 @@ def generate(
inputs = self.processor(**processor_kwargs)

output_ids = self.model.generate(
**inputs.to("cuda"),
**self.wrap_device(inputs),
use_cache=True,
**kwargs,
)
Expand Down Expand Up @@ -271,7 +281,7 @@ def generate_greedy_logprobs(
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
output = self.model.generate(
input_ids.cuda(),
self.wrap_device(input_ids),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
Expand Down Expand Up @@ -306,7 +316,7 @@ def generate_greedy_logprobs_limit(
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
output = self.model.generate(
input_ids.cuda(),
self.wrap_device(input_ids),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
Expand Down
11 changes: 7 additions & 4 deletions tests/models/test_aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@

from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
aqlm_not_supported = (capability <
QUANTIZATION_METHODS["aqlm"].get_min_capability())
aqlm_not_supported = True

if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
aqlm_not_supported = (capability <
QUANTIZATION_METHODS["aqlm"].get_min_capability())

# In this test we hardcode prompts and generations for the model so we don't
# need to require the AQLM package as a dependency
Expand Down
10 changes: 8 additions & 2 deletions tests/models/test_big_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Run `pytest tests/models/test_big_models.py`.
"""
import pytest
import torch

MODELS = [
"meta-llama/Llama-2-7b-hf",
Expand All @@ -16,9 +17,14 @@
# "Qwen/Qwen1.5-0.5B" # Broken,
]

#TODO: remove this after CPU float16 support ready
target_dtype = "float"
if torch.cuda.is_available():
target_dtype = "half"


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [32])
def test_models(
hf_runner,
Expand Down Expand Up @@ -46,7 +52,7 @@ def test_models(


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("dtype", [target_dtype])
def test_model_print(
vllm_runner,
model: str,
Expand Down
11 changes: 7 additions & 4 deletions tests/models/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@
},
}

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
fp8_not_supported = (capability <
QUANTIZATION_METHODS["fp8"].get_min_capability())
fp8_not_supported = True

if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
fp8_not_supported = (capability <
QUANTIZATION_METHODS["fp8"].get_min_capability())


@pytest.mark.skipif(fp8_not_supported,
Expand Down
11 changes: 7 additions & 4 deletions tests/models/test_gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@

MAX_MODEL_LEN = 1024

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
gptq_marlin_not_supported = (
capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability())
gptq_marlin_not_supported = True

if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
gptq_marlin_not_supported = (
capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability())

MODELS = [
# act_order==False, group_size=channelwise
Expand Down
11 changes: 7 additions & 4 deletions tests/models/test_gptq_marlin_24.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
marlin_not_supported = (capability <
QUANTIZATION_METHODS["marlin"].get_min_capability())
marlin_not_supported = True

if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
marlin_not_supported = (
capability < QUANTIZATION_METHODS["marlin"].get_min_capability())


@dataclass
Expand Down
11 changes: 7 additions & 4 deletions tests/models/test_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@

from .utils import check_logprobs_close

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
marlin_not_supported = (capability <
QUANTIZATION_METHODS["marlin"].get_min_capability())
marlin_not_supported = True

if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
marlin_not_supported = (
capability < QUANTIZATION_METHODS["marlin"].get_min_capability())


@dataclass
Expand Down
Loading