Skip to content
7 changes: 6 additions & 1 deletion docs/design/debug_vllm_compile.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,12 @@ can be compiled once and then reused after they have been compiled. This
is a layer on top of [torch.compile's compiler cache](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html).

While torch.compile's compiler cache is rock-stable, vLLM's compiler cache is unfortunately
not always correct. You can disable it via setting `VLLM_DISABLE_COMPILE_CACHE=1`.
not always correct. You can disable it by either:

- Setting the config flag: `--compilation-config '{"disable_compile_cache": true}'`
Copy link
Collaborator

@zou3519 zou3519 Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, positive flags are generally better than negative flags. Something like enable_compile_cache. Because double negation gets confusing.

I know the envvar is already negative, but we should (in the future) add a positive envvar and then deprecate the negative envvar. In that sense, I'd prefer that we have the config flag be positive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 I agree that positive flags are clearer than negative ones.

In this PR I mirrored the existing env var naming to keep behavior and terminology aligned with what’s already shipped. The original issue was specifically “add a config flag for the existing env var,” so I followed the same pattern here.

I’d prefer to handle the rename as a follow-up: introduce a positive config flag and env var, then deprecate the negative env var in a controlled way. Otherwise this PR ends up mixing “add config flag” with a behavioral/UX change.

If you’d still like the positive flag in this PR, I can update it, but my inclination is to track that as a separate issue and keep this change scoped to exposing the current env var via config.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renames require deprecation etc, so if we're planning to rename something we should just name it right the first time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg Just to clarify your suggestion: are you proposing that the config flag be renamed to enable_compile_cache instead? I may be misunderstanding the intent of your comment - could you elaborate a bit on what you’d like the flag to be called? My concern is that env var (which will be present still) will be VLLM_DISABLE_COMPILE_CACHE and config flag will be enable_compile_cache. Am I missing something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to introduce a negative flag and then have to deprecate the negative flag. This PR can introduce both the positive flag and a positive envvar, if your worry is consistency. Then we can deprecate the negative envvar.

- Setting the environment variable: `VLLM_DISABLE_COMPILE_CACHE=1`

The environment variable takes precedence over the config flag if both are set.

You can also manually remove this cache.

Expand Down
7 changes: 6 additions & 1 deletion docs/design/torch_compile.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ The factors considered include:
- PyTorch configs (see the `compute_hash` functions in the [compiler_interface.py](../../vllm/compilation/compiler_interface.py))
- The model's forward function and the relevant functions called by the forward function (see below)

With all these factors taken into consideration, usually we can guarantee that the cache is safe to use, and will not cause any unexpected behavior. Therefore, the cache is enabled by default. If you want to debug the compilation process, or if you suspect the cache is causing some issues, you can disable it by setting the environment variable `VLLM_DISABLE_COMPILE_CACHE=1`.
With all these factors taken into consideration, usually we can guarantee that the cache is safe to use, and will not cause any unexpected behavior. Therefore, the cache is enabled by default. If you want to debug the compilation process, or if you suspect the cache is causing some issues, you can disable it by either:

- Setting the config flag: `--compilation-config '{"disable_compile_cache": true}'`
- Setting the environment variable: `VLLM_DISABLE_COMPILE_CACHE=1`

The environment variable takes precedence over the config flag if both are set.

A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all the compilation finishes before we serve any requests. No requests will trigger new compilations. Otherwise, the engine would be blocked on that request, and the response time will have unexpected spikes.

Expand Down
69 changes: 69 additions & 0 deletions tests/compile/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,75 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_disable_compile_cache_config(vllm_runner, monkeypatch):
"""Test that disable_compile_cache config option works."""
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
# Ensure env var is not set so we test the config option
monkeypatch.delenv("VLLM_DISABLE_COMPILE_CACHE", raising=False)

compilation_config = {
"cudagraph_mode": CUDAGraphMode.NONE, # speed things up a bit
"disable_compile_cache": True,
}
with (
compilation_counter.expect(
num_cache_entries_updated=0, num_compiled_artifacts_saved=0
),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
"facebook/opt-125m",
compilation_config=compilation_config,
gpu_memory_utilization=0.4,
) as _,
):
pass


@pytest.mark.parametrize(
"config_value,env_var_value,expected",
[
# Env var overrides config when config is False
(False, "1", True),
# Both config and env var disable cache
(True, "1", True),
# Config disables cache, env var not set
(True, None, True),
# Config enables cache, env var not set
(False, None, False),
# env var is set to 0, should be treated as not set, so config used
(True, "0", True),
(False, "0", False),
],
ids=[
"env_var_overrides_config",
"both_disable_cache",
"config_disables_without_env_var",
"config_enables_without_env_var",
],
)
def test_disable_compile_cache_config_and_env_var(
monkeypatch, config_value, env_var_value, expected
):
"""Test disable_compile_cache config and env var interaction."""
import vllm.envs as envs

# Set or unset the env var using monkeypatch
if env_var_value is not None:
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
monkeypatch.setattr(envs, "VLLM_DISABLE_COMPILE_CACHE", env_var_value)
else:
monkeypatch.delenv("VLLM_DISABLE_COMPILE_CACHE", raising=False)
monkeypatch.setattr(envs, "VLLM_DISABLE_COMPILE_CACHE", False)

config = VllmConfig(
compilation_config=CompilationConfig(disable_compile_cache=config_value)
)
assert config.compilation_config.disable_compile_cache is expected


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
@pytest.mark.parametrize(
Expand Down
13 changes: 11 additions & 2 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,13 @@ def compile(
assert compiled_graph is not None, "Failed to compile the graph"

# store the artifact in the cache
if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
if (
is_compile_cache_enabled(
additional_inductor_config,
compilation_config.disable_compile_cache,
)
and handle is not None
):
self.cache[(compile_range, graph_index, self.compiler.name)] = handle
compilation_counter.num_cache_entries_updated += 1
self.is_cache_updated = True
Expand Down Expand Up @@ -608,7 +614,10 @@ def __call__(
self.compilation_config.local_cache_dir = local_cache_dir

# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
disable_cache = not is_compile_cache_enabled(self.inductor_config)
disable_cache = not is_compile_cache_enabled(
self.inductor_config,
self.compilation_config.disable_compile_cache,
)

if disable_cache:
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
Expand Down
14 changes: 13 additions & 1 deletion vllm/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def get_inductor_factors() -> list[Any]:

def is_compile_cache_enabled(
vllm_additional_inductor_config: dict[str, Any],
disable_compile_cache: bool = False,
) -> bool:
vllm_inductor_config_disable_cache = vllm_additional_inductor_config.get(
"force_disable_caches", False
Expand All @@ -175,7 +176,7 @@ def is_compile_cache_enabled(
# with torch.compiler.config.force_disable_caches when minimum PyTorch
# version reaches 2.10
return (
not envs.VLLM_DISABLE_COMPILE_CACHE
not disable_compile_cache
and not torch._inductor.config.force_disable_caches
and not vllm_inductor_config_disable_cache
)
Expand Down Expand Up @@ -220,6 +221,12 @@ def compile(
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)

# Remove vllm-specific keys that are not valid inductor config options
# before passing to standalone_compile. These keys are used internally
# by vLLM for cache control but would cause AttributeError in torch._inductor.
current_config.pop("vllm_disable_compile_cache", None)

set_inductor_config(current_config, compile_range)
set_functorch_config()

Expand Down Expand Up @@ -325,6 +332,11 @@ def compile(
if compiler_config is not None:
current_config.update(compiler_config)

# Remove vllm-specific keys that are not valid inductor config options
# before passing to compile_fx. These keys are used internally by vLLM
# for cache control but would cause AttributeError in torch._inductor.
current_config.pop("vllm_disable_compile_cache", None)

# disable remote cache
current_config["fx_graph_cache"] = True
current_config["fx_graph_remote_cache"] = False
Expand Down
7 changes: 7 additions & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,12 @@ class CompilationConfig:
"""The directory to store the compiled graph, to accelerate Inductor
compilation. By default, it will use model-related information to generate
a cache directory."""
disable_compile_cache: bool = False
"""If True, disable vLLM's torch.compile cache. The compile cache stores
compiled artifacts to disk to speed up subsequent runs. Disabling it is
useful for debugging compilation issues or when running in environments
where caching is not desired. Can also be set via VLLM_DISABLE_COMPILE_CACHE
environment variable, which takes precedence over this config option."""
compile_cache_save_format: Literal["binary", "unpacked"] = field(
default_factory=lambda: envs.VLLM_COMPILE_CACHE_SAVE_FORMAT
)
Expand Down Expand Up @@ -708,6 +714,7 @@ def compute_hash(self) -> str:
# Paths/dirs and runtime/metrics that don’t affect compiled graph
"debug_dump_path",
"cache_dir",
"disable_compile_cache",
"local_cache_dir",
"bs_to_padded_graph_size",
"traced_files",
Expand Down
9 changes: 9 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,15 @@ def has_blocked_weights():
)
self.compilation_config.debug_dump_path = env_path

# Handle VLLM_DISABLE_COMPILE_CACHE env var override
if envs.VLLM_DISABLE_COMPILE_CACHE:
if not self.compilation_config.disable_compile_cache:
logger.info(
"Config-specified disable_compile_cache=False is overridden"
" by VLLM_DISABLE_COMPILE_CACHE=1"
)
self.compilation_config.disable_compile_cache = True

def has_blocked_weights():
if self.quant_config is not None:
if hasattr(self.quant_config, "weight_block_size"):
Expand Down