-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[torch.compile] reorganize the cache directory to support compiling multiple models #19064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
import pprint | ||
import time | ||
from collections.abc import Sequence | ||
from contextlib import contextmanager | ||
from typing import Any, Callable, Optional | ||
|
||
import torch | ||
|
@@ -65,7 +66,25 @@ def __init__(self, compilation_config: CompilationConfig): | |
def compute_hash(self, vllm_config: VllmConfig) -> str: | ||
return self.compiler.compute_hash(vllm_config) | ||
|
||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False): | ||
def initialize_cache(self, | ||
cache_dir: str, | ||
disable_cache: bool = False, | ||
prefix: str = ""): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: This is technically a subdirectory (or a suffix to the path), not a prefix. I was prototyping something like this locally and I called this the "model_component", but up to you |
||
""" | ||
Initialize the cache directory for the compiler. | ||
|
||
The organization of the cache directory is as follows: | ||
cache_dir=/path/to/hash_str/rank_i_j/prefix/ | ||
inside cache_dir, there will be: | ||
- vllm_compile_cache.py | ||
- computation_graph.py | ||
- transformed_code.py | ||
|
||
for multiple prefixes, they can share the same | ||
base cache dir of /path/to/hash_str/rank_i_j/ , | ||
to store some common compilation artifacts. | ||
""" | ||
|
||
self.disable_cache = disable_cache | ||
self.cache_dir = cache_dir | ||
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py") | ||
|
@@ -79,7 +98,8 @@ def initialize_cache(self, cache_dir: str, disable_cache: bool = False): | |
self.cache = ast.literal_eval(f.read()) | ||
|
||
self.compiler.initialize_cache(cache_dir=cache_dir, | ||
disable_cache=disable_cache) | ||
disable_cache=disable_cache, | ||
prefix=prefix) | ||
|
||
def save_to_file(self): | ||
if self.disable_cache or not self.is_cache_updated: | ||
|
@@ -309,6 +329,25 @@ def call_module(self, target: torch.fx.node.Target, | |
return output | ||
|
||
|
||
# the tag for the part of model being compiled, | ||
# e.g. backbone/eagle_head | ||
model_tag: str = "backbone" | ||
|
||
|
||
@contextmanager | ||
def set_model_tag(tag: str): | ||
"""Context manager to set the model tag.""" | ||
global model_tag | ||
assert tag != model_tag, \ | ||
f"Model tag {tag} is the same as the current tag {model_tag}." | ||
old_tag = model_tag | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model_tag = tag | ||
try: | ||
yield | ||
finally: | ||
model_tag = old_tag | ||
|
||
|
||
class VllmBackend: | ||
"""The compilation backend for `torch.compile` with vLLM. | ||
It is used for compilation level of `CompilationLevel.PIECEWISE`, | ||
|
@@ -340,7 +379,17 @@ class VllmBackend: | |
def __init__( | ||
self, | ||
vllm_config: VllmConfig, | ||
prefix: str = "", | ||
): | ||
|
||
# if the model is initialized with a non-empty prefix, | ||
# then usually it's enough to use that prefix, | ||
# e.g. launguage_model, vision_model, etc. | ||
# when multiple parts are initialized as independent | ||
# models, we need to use the model_tag to distinguish | ||
# them, e.g. backbone (default), eagle_head, etc. | ||
self.prefix = prefix or model_tag | ||
|
||
global global_graph_pool | ||
if global_graph_pool is None: | ||
global_graph_pool = current_platform.graph_pool_handle() | ||
|
@@ -440,16 +489,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: | |
) | ||
self.compilation_config.cache_dir = cache_dir | ||
|
||
if compilation_counter.num_graphs_seen > 0: | ||
cache_dir = self.compilation_config.cache_dir + \ | ||
f'-{compilation_counter.num_graphs_seen}' | ||
else: | ||
cache_dir = self.compilation_config.cache_dir | ||
cache_dir = self.compilation_config.cache_dir | ||
os.makedirs(cache_dir, exist_ok=True) | ||
self.compilation_config.cache_dir = cache_dir | ||
rank = vllm_config.parallel_config.rank | ||
dp_rank = vllm_config.parallel_config.data_parallel_rank | ||
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") | ||
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", | ||
self.prefix) | ||
os.makedirs(local_cache_dir, exist_ok=True) | ||
self.compilation_config.local_cache_dir = local_cache_dir | ||
|
||
|
@@ -461,7 +507,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: | |
logger.info("Using cache directory: %s for vLLM's torch.compile", | ||
local_cache_dir) | ||
|
||
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache) | ||
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache, | ||
self.prefix) | ||
|
||
# when dynamo calls the backend, it means the bytecode | ||
# transform and analysis are done | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,11 +27,22 @@ class CompilerInterface: | |
# This is a class-level attribute. | ||
name: str | ||
|
||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False): | ||
def initialize_cache(self, | ||
cache_dir: str, | ||
disable_cache: bool = False, | ||
prefix: str = ""): | ||
""" | ||
when the vLLM process uses `cache_dir` as the cache directory, | ||
the compiler should initialize itself with the cache directory, | ||
e.g. by re-directing its own cache directory to a sub-directory. | ||
|
||
prefix can be used in combination with cache_dir to figure out the base | ||
cache directory, e.g. there're multiple parts of model being compiled, | ||
but we want to share the same cache directory for all of them. | ||
|
||
e.g. | ||
cache_dir = "/path/to/dir/backbone", prefix = "backbone" | ||
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head" | ||
""" | ||
pass | ||
|
||
|
@@ -165,7 +176,10 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: | |
usedforsecurity=False).hexdigest()[:10] | ||
return hash_str | ||
|
||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False): | ||
def initialize_cache(self, | ||
cache_dir: str, | ||
disable_cache: bool = False, | ||
prefix: str = ""): | ||
self.cache_dir = cache_dir | ||
|
||
def compile( | ||
|
@@ -241,18 +255,23 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: | |
usedforsecurity=False).hexdigest()[:10] | ||
return hash_str | ||
|
||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False): | ||
def initialize_cache(self, | ||
cache_dir: str, | ||
disable_cache: bool = False, | ||
prefix: str = ""): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since prefix is only used to caculate the base_cache_dir, why not use pass in the base_cache_dir instead of passing in prefix? |
||
self.cache_dir = cache_dir | ||
self.prefix = prefix | ||
self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir | ||
if disable_cache: | ||
return | ||
# redirect the cache directory to a sub-directory | ||
# set flags so that Inductor and Triton store their cache | ||
# in the cache_dir, then users only need to copy the cache_dir | ||
# to another machine to reuse the cache. | ||
inductor_cache = os.path.join(cache_dir, "inductor_cache") | ||
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache") | ||
os.makedirs(inductor_cache, exist_ok=True) | ||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache | ||
triton_cache = os.path.join(cache_dir, "triton_cache") | ||
triton_cache = os.path.join(self.base_cache_dir, "triton_cache") | ||
os.makedirs(triton_cache, exist_ok=True) | ||
os.environ["TRITON_CACHE_DIR"] = triton_cache | ||
|
||
|
@@ -297,14 +316,14 @@ def hijack_load(*args, **kwargs): | |
nonlocal file_path | ||
compiled_fn = inductor_compiled_graph.current_callable | ||
file_path = compiled_fn.__code__.co_filename # noqa | ||
if not file_path.startswith(self.cache_dir): | ||
if not file_path.startswith(self.base_cache_dir): | ||
# hooked in the align_inputs_from_check_idxs function | ||
# in torch/_inductor/utils.py | ||
for cell in compiled_fn.__closure__: | ||
if not callable(cell.cell_contents): | ||
continue | ||
if cell.cell_contents.__code__.co_filename.startswith( | ||
self.cache_dir): | ||
self.base_cache_dir): | ||
# this is the real file path compiled from Inductor | ||
file_path = cell.cell_contents.__code__.co_filename | ||
break | ||
|
@@ -324,14 +343,15 @@ def hijacked_compile_fx_inner(*args, **kwargs): | |
nonlocal file_path | ||
compiled_fn = inductor_compiled_graph.current_callable | ||
file_path = compiled_fn.__code__.co_filename # noqa | ||
if not file_path.startswith(self.cache_dir): | ||
if not file_path.startswith(self.base_cache_dir): | ||
# hooked in the align_inputs_from_check_idxs function | ||
# in torch/_inductor/utils.py | ||
for cell in compiled_fn.__closure__: | ||
if not callable(cell.cell_contents): | ||
continue | ||
code = cell.cell_contents.__code__ | ||
if code.co_filename.startswith(self.cache_dir): | ||
if code.co_filename.startswith( | ||
self.base_cache_dir): | ||
# this is the real file path | ||
# compiled from Inductor | ||
file_path = code.co_filename | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4602,23 +4602,28 @@ def __str__(self): | |
|
||
|
||
_current_vllm_config: Optional[VllmConfig] = None | ||
_current_prefix: Optional[str] = None | ||
|
||
|
||
@contextmanager | ||
def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): | ||
def set_current_vllm_config(vllm_config: VllmConfig, | ||
check_compile=False, | ||
prefix: Optional[str] = None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a bit comment to explain the prefix meaning? |
||
""" | ||
Temporarily set the current vLLM config. | ||
Used during model initialization. | ||
We save the current vLLM config in a global variable, | ||
so that all modules can access it, e.g. custom ops | ||
can access the vLLM config to determine how to dispatch. | ||
""" | ||
global _current_vllm_config | ||
global _current_vllm_config, _current_prefix | ||
old_vllm_config = _current_vllm_config | ||
old_prefix = _current_prefix | ||
from vllm.compilation.counter import compilation_counter | ||
num_models_seen = compilation_counter.num_models_seen | ||
try: | ||
_current_vllm_config = vllm_config | ||
_current_prefix = prefix | ||
yield | ||
except Exception: | ||
raise | ||
|
@@ -4642,6 +4647,7 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): | |
vllm_config.model_config.model) | ||
finally: | ||
_current_vllm_config = old_vllm_config | ||
_current_prefix = old_prefix | ||
|
||
|
||
def get_current_vllm_config() -> VllmConfig: | ||
|
@@ -4655,6 +4661,15 @@ def get_current_vllm_config() -> VllmConfig: | |
return _current_vllm_config | ||
|
||
|
||
def get_current_model_prefix() -> str: | ||
""" | ||
Get the prefix of the model that's currently being initialized. | ||
""" | ||
assert _current_prefix is not None, \ | ||
"Current model prefix is not set. " | ||
return _current_prefix | ||
|
||
|
||
def contains_object_print(text): | ||
""" | ||
Check if the text looks like a printed Python object, e.g. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -319,8 +319,10 @@ def load_model(self, target_model: nn.Module) -> None: | |
target_attn_layer_names = set( | ||
get_layers_from_vllm_config(self.vllm_config, Attention).keys()) | ||
|
||
self.model = get_model(vllm_config=self.vllm_config, | ||
model_config=draft_model_config) | ||
from vllm.compilation.backends import set_model_tag | ||
with set_model_tag("eagle_head"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: could we name this something like "set_compile_region" or "set_model_component" (see the other comment)? That would make it clearer that this is 1:1 with a fullgraph torch.compile region |
||
self.model = get_model(vllm_config=self.vllm_config, | ||
model_config=draft_model_config) | ||
|
||
draft_attn_layer_names = ( | ||
get_layers_from_vllm_config(self.vllm_config, Attention).keys() - | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @youkaichao, I wonder if we can simply the
def configure_post_pass(self)
method here? I had to make some edits to make things work here but maybe they are not necessary anymore? Thanks!