Skip to content

Commit 6f1222c

Browse files
committed
PR vllm-project#26956: Squashed commit of the following:
commit ad717d4 Author: Richard Zou <zou3519@gmail.com> Date: Wed Oct 15 16:29:49 2025 -0700 [BugFix] Work around graph partition x torch.compile cache issue In PyTorch 2.9, torch.compile has a bug where the graph partition is not taken into account during caching. Because vLLM's Mode.VLLM_COMPILE is the only mode that uses Inductor graph partition, and VLLM_COMPILE implies there is a PostGradPassManager, we put the list of operators to graph partition into the PostGradPassManager's uuid (which then gets incorporated into Inductor's FX graph cache key). Remove this hack whenever torch.compile fixes it. Signed-off-by: Richard Zou <zou3519@gmail.com> Signed-off-by: ProExpertProg <lgovedic@redhat.com>
1 parent 3c5789f commit 6f1222c

File tree

2 files changed

+31
-12
lines changed

2 files changed

+31
-12
lines changed

tests/compile/piecewise/test_toy_llama.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,8 @@ def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
337337
def test_toy_llama(
338338
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
339339
):
340-
# We disable the vLLM compile cache into a new tmp dir for 2 reasons:
340+
# We disable the vLLM compile cache into a new tmp dir for 1 reason:
341341
# 1. To make sure we can properly track the number of Inductor compilations.
342-
# 2. Inductor partitioning does not play nicely with Autograd cache (below)
343342
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
344343

345344
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
@@ -369,15 +368,6 @@ def test_toy_llama(
369368
cudagraph_capture_sizes=[1, 2],
370369
)
371370

372-
# FIXME(luka/boyuan): the graph from the previous test case
373-
# (no inductor partition) gets cached by AotAutograd so then the
374-
# compilation with inductor partitioning incorrectly loads an unpartitioned
375-
# graph and never partitions. I think this is a bug with custom inductor
376-
# partitioning but does not affect vLLM more generally as vLLM uses its own
377-
# cache (which takes inductor partitioning into account).
378-
if use_inductor_graph_partition:
379-
compile_config_no_split.inductor_compile_config["force_disable_caches"] = True
380-
381371
compile_config_split = deepcopy(compile_config_no_split)
382372
compile_config_split.splitting_ops = ["silly::attention"]
383373

vllm/compilation/pass_manager.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,27 @@ def configure(self, config: VllmConfig):
110110
self.post_cleanup = PostCleanupPass(config)
111111
self.fix_functionalization = FixFunctionalizationPass(config)
112112

113+
# [HACK: Bug with Inductor graph partition and torch.compile cache]
114+
# In PyTorch 2.9, torch.compile has a bug where the graph
115+
# partition is not taken into account during caching.
116+
# Because vLLM's Mode.VLLM_COMPILE is the only mode that uses
117+
# Inductor graph partition, and VLLM_COMPILE implies there
118+
# is a PostGradPassManager, we put the list of operators to graph
119+
# partition into the PostGradPassManager's uuid (which
120+
# then gets incorporated into Inductor's FX graph cache key).
121+
# Remove this hack whenever torch.compile fixes it.
122+
123+
# This is the list of operators that vLLM asks Inductor to split.
124+
self.inductor_splitting_ops = []
125+
if (
126+
config.compilation_config.use_inductor_graph_partition
127+
and config.compilation_config.splitting_ops is not None
128+
):
129+
# Sort them so we're not dependent on the ordering.
130+
self.inductor_splitting_ops = sorted(
131+
config.compilation_config.splitting_ops
132+
)
133+
113134
def add(self, pass_: InductorPass):
114135
assert isinstance(pass_, InductorPass)
115136
self.passes.append(pass_)
@@ -120,8 +141,16 @@ def uuid(self):
120141
affects compilation caching. Its uuid depends on the UUIDs of all
121142
dependent passes and the pass config. See InductorPass for more info.
122143
"""
123-
state = {"pass_config": self.pass_config.uuid(), "passes": []}
144+
state = {
145+
"pass_config": self.pass_config.uuid(),
146+
"passes": [],
147+
"inductor_splitting_ops": [],
148+
}
124149
for pass_ in self.passes:
125150
state["passes"].append(pass_.uuid())
126151
state["passes"].append(self.fix_functionalization.uuid())
152+
153+
# See [HACK: Bug with Inductor graph partition and torch.compile cache]
154+
state["inductor_splitting_ops"].extend(self.inductor_splitting_ops)
155+
127156
return InductorPass.hash_dict(state)

0 commit comments

Comments
 (0)