Skip to content

Commit f44b619

Browse files
committed
fix tests
Signed-off-by: jiang.li <jiang1.li@intel.com>
1 parent 413ef08 commit f44b619

File tree

5 files changed

+35
-13
lines changed

5 files changed

+35
-13
lines changed

.buildkite/scripts/hardware_ci/run-cpu-test.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@ function cpu_tests() {
3535
# offline inference
3636
docker exec cpu-test-"$NUMA_NODE"-avx2 bash -c "
3737
set -e
38+
export VLLM_USE_V1=1
3839
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
3940

4041
# Run basic model test
4142
docker exec cpu-test-"$NUMA_NODE" bash -c "
4243
set -e
44+
export VLLM_USE_V1=1
4345
pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model
4446
pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model
4547
pytest -v -s tests/models/language/generation -m cpu_model
@@ -49,6 +51,7 @@ function cpu_tests() {
4951
# Run compressed-tensor test
5052
docker exec cpu-test-"$NUMA_NODE" bash -c "
5153
set -e
54+
export VLLM_USE_V1=1
5255
pytest -s -v \
5356
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
5457
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"
@@ -62,6 +65,7 @@ function cpu_tests() {
6265
# Run chunked-prefill and prefix-cache test
6366
docker exec cpu-test-"$NUMA_NODE" bash -c "
6467
set -e
68+
export VLLM_USE_V1=1
6569
pytest -s -v -k cpu_model \
6670
tests/basic_correctness/test_chunked_prefill.py"
6771

@@ -70,6 +74,7 @@ function cpu_tests() {
7074
set -e
7175
export VLLM_CPU_KVCACHE_SPACE=10
7276
export VLLM_CPU_OMP_THREADS_BIND=$1
77+
export VLLM_USE_V1=1
7378
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
7479
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
7580
python3 benchmarks/benchmark_serving.py \

tests/kernels/attention/test_attention_selector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,10 @@ def test_env(
8484
CpuPlatform()):
8585
backend = get_attn_backend(16, torch.float16, torch.float16,
8686
block_size, False)
87-
assert backend.get_name() == "TORCH_SDPA"
87+
if use_v1:
88+
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
89+
else:
90+
assert backend.get_name() == "TORCH_SDPA"
8891

8992
elif device == "hip":
9093
with patch("vllm.attention.selector.current_platform",

vllm/compilation/wrapper.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,16 @@ def __init__(self,
4040
# compiling the forward method
4141

4242
backend = vllm_config.compilation_config.init_backend(vllm_config)
43+
options = None
44+
if isinstance(backend, str) and backend == "inductor":
45+
options = get_current_vllm_config(
46+
).compilation_config.inductor_compile_config
4347

4448
compiled_callable = torch.compile(
4549
self.forward,
4650
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
47-
backend=backend)
51+
backend=backend,
52+
options=options)
4853

4954
self.compiled_callable = compiled_callable
5055
self.original_code_object = self.__class__.forward.__code__

vllm/platforms/cpu.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,24 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
155155
# Note: workaround for v1 gpu_model_runner
156156
from vllm.config import CompilationLevel
157157
vllm_config.compilation_config.cudagraph_capture_sizes = []
158-
vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION
159-
vllm_config.compilation_config.custom_ops = []
158+
159+
compilation_config = vllm_config.compilation_config
160+
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE:
161+
compilation_config.level = CompilationLevel.DYNAMO_ONCE
162+
compilation_config.backend = "inductor"
163+
compilation_config.custom_ops += ["none"]
164+
compilation_config.inductor_compile_config.update({
165+
"dce":
166+
True,
167+
"size_asserts":
168+
False,
169+
"nan_asserts":
170+
False,
171+
"memory_planning":
172+
True,
173+
"epilogue_fusion":
174+
True,
175+
})
160176

161177
assert vllm_config.device_config.device_type == "cpu"
162178

@@ -192,13 +208,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
192208
# To hint IPEX uses shared memory based AllReduce
193209
os.environ["LOCAL_WORLD_SIZE"] = str(
194210
vllm_config.parallel_config.tensor_parallel_size)
195-
if sys.platform == "darwin" and \
196-
envs.VLLM_WORKER_MULTIPROC_METHOD == "fork":
197-
if os.environ.get('VLLM_WORKER_MULTIPROC_METHOD', None) is None:
198-
logger.warning(
199-
"Default to spawn method on MacOS. If this is not desired,"
200-
" set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.")
201-
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
202211

203212
if vllm_config.model_config and vllm_config.model_config.use_mla:
204213
logger.info(

vllm/v1/worker/cpu_model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
1818
super().__init__(vllm_config, device)
1919

2020
assert device == torch.device("cpu")
21-
assert not self.use_spec_decode, "spec decode is not supported."
21+
assert self.speculative_config is None, "spec decode is not supported."
2222
assert not self.model_config.uses_mrope, "mrope is not supported."
2323
assert self.lora_config is None, "lora is not supported."
2424

@@ -58,7 +58,7 @@ def warming_up_model(self) -> None:
5858
logger.info("Warming up model for the compilation...")
5959
# Only generate graph for the generic shape
6060
with _set_global_compilation_settings():
61-
self._dummy_run(self.max_num_tokens)
61+
self._dummy_run(max(16, self.max_num_reqs))
6262
logger.info("Warming up done.")
6363

6464
def _init_device_properties(self) -> None:

0 commit comments

Comments
 (0)