Skip to content

[CPU] Refine default config for the CPU backend #19539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions .buildkite/scripts/hardware_ci/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,22 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu .

# Run the image, setting --shm-size=4g for tensor parallel.
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2

function cpu_tests() {
set -e
export NUMA_NODE=$2

# list packages
docker exec cpu-test-"$NUMA_NODE"-avx2 bash -c "
set -e
pip list"

docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e
pip list"

# offline inference
docker exec cpu-test-"$NUMA_NODE"-avx2 bash -c "
set -e
Expand Down Expand Up @@ -72,7 +81,7 @@ function cpu_tests() {
set -e
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py \
VLLM_CPU_CI_ENV=0 python3 benchmarks/benchmark_serving.py \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Setting VLLM_CPU_CI_ENV=0 here seems to override the environment variable set in the docker run command. Is this intentional?

Suggested change
VLLM_CPU_CI_ENV=0 python3 benchmarks/benchmark_serving.py \
python3 benchmarks/benchmark_serving.py \
--backend vllm \

--backend vllm \
--dataset-name random \
--model facebook/opt-125m \
Expand Down
26 changes: 22 additions & 4 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,14 +1562,20 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None:
UsageContext.LLM_CLASS: 16384,
UsageContext.OPENAI_API_SERVER: 8192,
}
default_max_num_seqs = 1024
default_max_num_seqs = {
UsageContext.LLM_CLASS: 1024,
UsageContext.OPENAI_API_SERVER: 1024,
}
else:
# TODO(woosuk): Tune the default values for other hardware.
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 8192,
UsageContext.OPENAI_API_SERVER: 2048,
}
default_max_num_seqs = 256
default_max_num_seqs = {
UsageContext.LLM_CLASS: 256,
UsageContext.OPENAI_API_SERVER: 256,
}

# tpu specific default values.
if current_platform.is_tpu():
Expand All @@ -1586,6 +1592,17 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None:
}
}

# cpu specific default values.
if current_platform.is_cpu():
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 4096,
UsageContext.OPENAI_API_SERVER: 2048,
}
default_max_num_seqs = {
UsageContext.LLM_CLASS: 128,
UsageContext.OPENAI_API_SERVER: 32,
}

use_context_value = usage_context.value if usage_context else None
if (self.max_num_batched_tokens is None
and usage_context in default_max_num_batched_tokens):
Expand All @@ -1606,8 +1623,9 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None:
"Setting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens, use_context_value)

if self.max_num_seqs is None:
self.max_num_seqs = default_max_num_seqs
if (self.max_num_seqs is None
and usage_context in default_max_num_seqs):
self.max_num_seqs = default_max_num_seqs[usage_context]

logger.debug("Setting max_num_seqs to %d for %s usage context.",
self.max_num_seqs, use_context_value)
Expand Down
20 changes: 14 additions & 6 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
import vllm.envs as envs
from vllm.utils import GiB_bytes
model_config = vllm_config.model_config
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
if not model_config.enforce_eager:
model_config.enforce_eager = True

model_config.disable_cascade_attn = True

Expand Down Expand Up @@ -171,9 +167,21 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
compilation_config = vllm_config.compilation_config
if (envs.VLLM_USE_V1 and vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE):

# Note: vLLM V1 is using PIECEWISE level compilation, which will
# take time to compile kernels just-in-time with the inductor
# backend. For CPU CI tests, most of them are executed fast and
# compilations consume too much time, even with torch compile
# cache. So use VLLM_CPU_CI_ENV to indicate the CI environment,
# and just execute model with dynamo + eager mode to save time.
# VLLM_CPU_CI_ENV is only used as an internal variable.
if os.environ.get("VLLM_CPU_CI_ENV", "0") != "0":
backend = "eager"
else:
backend = "inductor"
Comment on lines +178 to +181
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The line compilation_config.custom_ops += ["none"] was removed. Clarify the expected behavior of custom ops with this change, for both "eager" and "inductor" backends under CompilationLevel.DYNAMO_ONCE. Ensure custom ops are correctly handled.


compilation_config.level = CompilationLevel.DYNAMO_ONCE
compilation_config.backend = "eager"
compilation_config.custom_ops += ["none"]
compilation_config.backend = backend
compilation_config.inductor_compile_config.update({
"dce":
True,
Expand Down
24 changes: 12 additions & 12 deletions vllm/v1/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def load_model(self) -> None:
def warming_up_model(self) -> None:
logger.info("Warming up model for the compilation...")
# Only generate graph for the generic shape
self._dummy_run(max(16, self.max_num_reqs))
with _set_global_compilation_settings(self.vllm_config):
self._dummy_run(max(16, self.max_num_reqs))
logger.info("Warming up done.")

def _init_device_properties(self) -> None:
Expand All @@ -71,16 +72,15 @@ def _sync_device(self) -> None:


@contextmanager
def _set_global_compilation_settings():
def _set_global_compilation_settings(config: VllmConfig):
import torch._inductor.config

# Note: The CPPGEMM backend requires freezing parameters.
freezing_value = torch._inductor.config.freezing
torch._inductor.config.freezing = True
# Note: workaround for "ValueError: fast mode: can't pickle cyclic objects
# including object type dict"
force_disable_caches = torch._inductor.config.force_disable_caches
torch._inductor.config.force_disable_caches = True
yield
torch._inductor.config.freezing = freezing_value
torch._inductor.config.force_disable_caches = force_disable_caches
inductor_config = config.compilation_config.inductor_compile_config
try:
# Note: The MKLDNN and CPPGEMM backend requires freezing parameters.
freezing_value = torch._inductor.config.freezing
if inductor_config.get("max_autotune", False):
torch._inductor.config.freezing = True
yield
finally:
torch._inductor.config.freezing = freezing_value
Comment on lines +75 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for setting torch._inductor.config.freezing has changed. It's now conditional on inductor_config.get("max_autotune", False). Verify if freezing=False is safe and intended for the default CPU Inductor path, or if freezing=True should be set by default for CPU, independently of max_autotune if required by the underlying Inductor backends.