Skip to content

[torch.compile] reorganize the cache directory to support compiling multiple models #19064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 56 additions & 9 deletions vllm/compilation/backends.py
Copy link
Contributor

@luyuzhe111 luyuzhe111 Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @youkaichao, I wonder if we can simply the def configure_post_pass(self) method here? I had to make some edits to make things work here but maybe they are not necessary anymore? Thanks!

Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pprint
import time
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any, Callable, Optional

import torch
Expand Down Expand Up @@ -65,7 +66,25 @@ def __init__(self, compilation_config: CompilationConfig):
def compute_hash(self, vllm_config: VllmConfig) -> str:
return self.compiler.compute_hash(vllm_config)

def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This is technically a subdirectory (or a suffix to the path), not a prefix. I was prototyping something like this locally and I called this the "model_component", but up to you

"""
Initialize the cache directory for the compiler.

The organization of the cache directory is as follows:
cache_dir=/path/to/hash_str/rank_i_j/prefix/
inside cache_dir, there will be:
- vllm_compile_cache.py
- computation_graph.py
- transformed_code.py

for multiple prefixes, they can share the same
base cache dir of /path/to/hash_str/rank_i_j/ ,
to store some common compilation artifacts.
"""

self.disable_cache = disable_cache
self.cache_dir = cache_dir
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
Expand All @@ -79,7 +98,8 @@ def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
self.cache = ast.literal_eval(f.read())

self.compiler.initialize_cache(cache_dir=cache_dir,
disable_cache=disable_cache)
disable_cache=disable_cache,
prefix=prefix)

def save_to_file(self):
if self.disable_cache or not self.is_cache_updated:
Expand Down Expand Up @@ -309,6 +329,25 @@ def call_module(self, target: torch.fx.node.Target,
return output


# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"


@contextmanager
def set_model_tag(tag: str):
"""Context manager to set the model tag."""
global model_tag
assert tag != model_tag, \
f"Model tag {tag} is the same as the current tag {model_tag}."
old_tag = model_tag
model_tag = tag
try:
yield
finally:
model_tag = old_tag


class VllmBackend:
"""The compilation backend for `torch.compile` with vLLM.
It is used for compilation level of `CompilationLevel.PIECEWISE`,
Expand Down Expand Up @@ -340,7 +379,17 @@ class VllmBackend:
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):

# if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix,
# e.g. launguage_model, vision_model, etc.
# when multiple parts are initialized as independent
# models, we need to use the model_tag to distinguish
# them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag

global global_graph_pool
if global_graph_pool is None:
global_graph_pool = current_platform.graph_pool_handle()
Expand Down Expand Up @@ -440,16 +489,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
)
self.compilation_config.cache_dir = cache_dir

if compilation_counter.num_graphs_seen > 0:
cache_dir = self.compilation_config.cache_dir + \
f'-{compilation_counter.num_graphs_seen}'
else:
cache_dir = self.compilation_config.cache_dir
cache_dir = self.compilation_config.cache_dir
os.makedirs(cache_dir, exist_ok=True)
self.compilation_config.cache_dir = cache_dir
rank = vllm_config.parallel_config.rank
dp_rank = vllm_config.parallel_config.data_parallel_rank
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}",
self.prefix)
os.makedirs(local_cache_dir, exist_ok=True)
self.compilation_config.local_cache_dir = local_cache_dir

Expand All @@ -461,7 +507,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
logger.info("Using cache directory: %s for vLLM's torch.compile",
local_cache_dir)

self.compiler_manager.initialize_cache(local_cache_dir, disable_cache)
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache,
self.prefix)

# when dynamo calls the backend, it means the bytecode
# transform and analysis are done
Expand Down
38 changes: 29 additions & 9 deletions vllm/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,22 @@ class CompilerInterface:
# This is a class-level attribute.
name: str

def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
"""
when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
e.g. by re-directing its own cache directory to a sub-directory.

prefix can be used in combination with cache_dir to figure out the base
cache directory, e.g. there're multiple parts of model being compiled,
but we want to share the same cache directory for all of them.

e.g.
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
"""
pass

Expand Down Expand Up @@ -165,7 +176,10 @@ def compute_hash(self, vllm_config: VllmConfig) -> str:
usedforsecurity=False).hexdigest()[:10]
return hash_str

def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
self.cache_dir = cache_dir

def compile(
Expand Down Expand Up @@ -241,18 +255,23 @@ def compute_hash(self, vllm_config: VllmConfig) -> str:
usedforsecurity=False).hexdigest()[:10]
return hash_str

def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since prefix is only used to caculate the base_cache_dir, why not use pass in the base_cache_dir instead of passing in prefix?

self.cache_dir = cache_dir
self.prefix = prefix
self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir
if disable_cache:
return
# redirect the cache directory to a sub-directory
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
inductor_cache = os.path.join(cache_dir, "inductor_cache")
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
os.makedirs(inductor_cache, exist_ok=True)
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
triton_cache = os.path.join(cache_dir, "triton_cache")
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
os.makedirs(triton_cache, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = triton_cache

Expand Down Expand Up @@ -297,14 +316,14 @@ def hijack_load(*args, **kwargs):
nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable
file_path = compiled_fn.__code__.co_filename # noqa
if not file_path.startswith(self.cache_dir):
if not file_path.startswith(self.base_cache_dir):
# hooked in the align_inputs_from_check_idxs function
# in torch/_inductor/utils.py
for cell in compiled_fn.__closure__:
if not callable(cell.cell_contents):
continue
if cell.cell_contents.__code__.co_filename.startswith(
self.cache_dir):
self.base_cache_dir):
# this is the real file path compiled from Inductor
file_path = cell.cell_contents.__code__.co_filename
break
Expand All @@ -324,14 +343,15 @@ def hijacked_compile_fx_inner(*args, **kwargs):
nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable
file_path = compiled_fn.__code__.co_filename # noqa
if not file_path.startswith(self.cache_dir):
if not file_path.startswith(self.base_cache_dir):
# hooked in the align_inputs_from_check_idxs function
# in torch/_inductor/utils.py
for cell in compiled_fn.__closure__:
if not callable(cell.cell_contents):
continue
code = cell.cell_contents.__code__
if code.co_filename.startswith(self.cache_dir):
if code.co_filename.startswith(
self.base_cache_dir):
# this is the real file path
# compiled from Inductor
file_path = code.co_filename
Expand Down
19 changes: 17 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4602,23 +4602,28 @@ def __str__(self):


_current_vllm_config: Optional[VllmConfig] = None
_current_prefix: Optional[str] = None


@contextmanager
def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
def set_current_vllm_config(vllm_config: VllmConfig,
check_compile=False,
prefix: Optional[str] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a bit comment to explain the prefix meaning?

"""
Temporarily set the current vLLM config.
Used during model initialization.
We save the current vLLM config in a global variable,
so that all modules can access it, e.g. custom ops
can access the vLLM config to determine how to dispatch.
"""
global _current_vllm_config
global _current_vllm_config, _current_prefix
old_vllm_config = _current_vllm_config
old_prefix = _current_prefix
from vllm.compilation.counter import compilation_counter
num_models_seen = compilation_counter.num_models_seen
try:
_current_vllm_config = vllm_config
_current_prefix = prefix
yield
except Exception:
raise
Expand All @@ -4642,6 +4647,7 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
vllm_config.model_config.model)
finally:
_current_vllm_config = old_vllm_config
_current_prefix = old_prefix


def get_current_vllm_config() -> VllmConfig:
Expand All @@ -4655,6 +4661,15 @@ def get_current_vllm_config() -> VllmConfig:
return _current_vllm_config


def get_current_model_prefix() -> str:
"""
Get the prefix of the model that's currently being initialized.
"""
assert _current_prefix is not None, \
"Current model prefix is not set. "
return _current_prefix


def contains_object_print(text):
"""
Check if the text looks like a printed Python object, e.g.
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def initialize_model(
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
with set_current_vllm_config(vllm_config, check_compile=True):
with set_current_vllm_config(vllm_config,
check_compile=True,
prefix=prefix):
return model_class(vllm_config=vllm_config, prefix=prefix)

msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
Expand Down Expand Up @@ -85,7 +87,9 @@ def initialize_model(
kwargs["lora_config"] = vllm_config.lora_config
if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config
with set_current_vllm_config(vllm_config, check_compile=True):
with set_current_vllm_config(vllm_config,
check_compile=True,
prefix=prefix):
return model_class(**kwargs)


Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,10 @@ def load_model(self, target_model: nn.Module) -> None:
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys())

self.model = get_model(vllm_config=self.vllm_config,
model_config=draft_model_config)
from vllm.compilation.backends import set_model_tag
with set_model_tag("eagle_head"):
Copy link
Collaborator

@zou3519 zou3519 Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we name this something like "set_compile_region" or "set_model_component" (see the other comment)? That would make it clearer that this is 1:1 with a fullgraph torch.compile region

self.model = get_model(vllm_config=self.vllm_config,
model_config=draft_model_config)

draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
Expand Down
8 changes: 5 additions & 3 deletions vllm/v1/spec_decode/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ def propose(
return [list(row) for row in zip(*draft_tokens)]

def load_model(self, target_model: nn.Module) -> None:
self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config.
speculative_config.draft_model_config)
from vllm.compilation.backends import set_model_tag
with set_model_tag("medusa_head"):
self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config.
speculative_config.draft_model_config)

@torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None:
Expand Down