Skip to content

[Bug]: use_inductor_partition + splitting_ops results in AssertionError #26678

@angelayi

Description

@angelayi

Your current environment

vllm main, torch 2.9 RC, B200

🐛 Describe the bug

With #25845, we can now use splitting_ops to specify what things to split on if use_inductor_partition=True. However, the following repro runs into the following error:

Repro:

import os
from typing import Optional

import torch
import torch.nn as nn
from torch._dynamo.test_case import TestCase, run_tests
from torch._subclasses.fake_tensor import FakeTensorMode

from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode
from tests.compile.backend import TestBackend
from vllm.config import CompilationConfig, PassConfig, VllmConfig, CompilationLevel


os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1"
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "1"
os.environ["VLLM_USE_V1"] = "1"
os.environ["VLLM_LOGGING_LEVEL"] = "DEBUG"
os.environ["VLLM_USE_STANDALONE_COMPILE"] = "1"

config = CompilationConfig(
    level=CompilationLevel.PIECEWISE,
    cudagraph_mode=CUDAGraphMode.FULL,
    # splitting_ops=[],
    custom_ops=['+quant_fp8'],
    use_inductor_graph_partition=True,
)

llm = LLM(
    model="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
    gpu_memory_utilization=0.6,
    max_model_len=3000,
    compilation_config=config,
    tensor_parallel_size=2,
    enforce_eager=False,
)

outputs = llm.generate(["Hello, my name is"], SamplingParams(temperature=0))

# Print the outputs.
print("-" * 50)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt:    {prompt!r}")
    print(f"Output:    {generated_text!r}")
    print("-" * 60)

Error:

Traceback (most recent call last):
  File "/home/ProExpertProg/git/vllm/vllm/v1/engine/core.py", line 781, in run_engine_core
    engine_core = EngineCoreProc(*args, **kwargs)
  File "/home/ProExpertProg/git/vllm/vllm/v1/engine/core.py", line 553, in __init__
    super().__init__(
  File "/home/ProExpertProg/git/vllm/vllm/v1/engine/core.py", line 110, in __init__
    num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
  File "/home/ProExpertProg/git/vllm/vllm/v1/engine/core.py", line 221, in _initialize_kv_caches
    available_gpu_memory = self.model_executor.determine_available_memory()
  File "/home/ProExpertProg/git/vllm/vllm/v1/executor/abstract.py", line 88, in determine_available_memory
    return self.collective_rpc("determine_available_memory")
  File "/home/ProExpertProg/git/vllm/vllm/executor/uniproc_executor.py", line 74, in collective_rpc
    return [run_method(self.driver_worker, method, args, kwargs)]
  File "/home/ProExpertProg/git/vllm/vllm/utils/__init__.py", line 2977, in run_method
    return func(*args, **kwargs)
  File "/home/ProExpertProg/git/vllm/.venv29/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/home/ProExpertProg/git/vllm/vllm/v1/worker/gpu_worker.py", line 280, in determine_available_memory
    self.model_runner.profile_run()
  File "/home/ProExpertProg/git/vllm/vllm/v1/worker/gpu_model_runner.py", line 3703, in profile_run
    hidden_states, last_hidden_states = self._dummy_run(
  File "/home/ProExpertProg/git/vllm/.venv29/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/home/ProExpertProg/git/vllm/vllm/v1/worker/gpu_model_runner.py", line 3456, in _dummy_run
    outputs = self.model(
  File "/home/ProExpertProg/git/vllm/vllm/compilation/cuda_graph.py", line 126, in __call__
    return self.runnable(*args, **kwargs)
  File "/home/ProExpertProg/git/vllm/.venv29/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ProExpertProg/git/vllm/.venv29/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ProExpertProg/git/vllm/vllm/model_executor/models/llama.py", line 631, in forward
    model_output = self.model(
  File "/home/ProExpertProg/git/vllm/vllm/compilation/decorators.py", line 407, in __call__
    output = self.compiled_callable(*args, **kwargs)
  File "/home/ProExpertProg/git/vllm/.venv29/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 845, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/home/ProExpertProg/git/vllm/.venv29/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 990, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/home/ProExpertProg/git/vllm/.venv29/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 974, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/home/ProExpertProg/git/vllm/.venv29/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1695, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/home/ProExpertProg/git/vllm/.venv29/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1505, in codegen_and_compile
    compiled_module = graph.compile_to_module()
  File "/home/ProExpertProg/git/vllm/.venv29/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2319, in compile_to_module
    return self._compile_to_module()
  File "/home/ProExpertProg/git/vllm/.venv29/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2325, in _compile_to_module
    self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  File "/home/ProExpertProg/git/vllm/.venv29/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2264, in codegen
    self.scheduler.codegen()
  File "/home/ProExpertProg/git/vllm/.venv29/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 5205, in codegen
    self._codegen_partitions()
  File "/home/ProExpertProg/git/vllm/.venv29/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 5332, in _codegen_partitions
    partitions, signatures = self.graph_partition()
  File "/home/ProExpertProg/git/vllm/.venv29/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 5182, in graph_partition
    should_partition = self.should_partition(node, should_log=True)
  File "/home/ProExpertProg/git/vllm/.venv29/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 4663, in should_partition
    assert fx_node is not None#, f"{ir_node=}, {operator=}"
torch._inductor.exc.InductorError: AssertionError: 

This occur with the operator vllm.unified_attention_with_output, which has an inplace-mutation, which results in this code not properly setting origin_node. Here's a smaller repro for the pytorch side:

        @torch._inductor.config.patch("graph_partition", True)
        @torch._inductor.config.patch("implicit_fallbacks", True)
        def test_graph_partition_custom_rule_inplace(self):
            def get_num_partitions(code):
                code = "".join(code)
                found = re.search(r"partitions=\[(.*)\]", code)
                assert found is not None
                partitions = found.group(1)
                num_partitions = len([p for p in partitions.split(",") if p])
                return num_partitions

            x = torch.randn(2, device="cuda")
            
            @torch.library.custom_op("mylib::baz", mutates_args=("output",))
            def baz(x: torch.Tensor, flag: int, output: torch.Tensor) -> None:
                return output.copy_(x)

            @baz.register_fake
            def _(x, flag, output):
                return None

            def should_partition(x, flag):
                return flag

            torch._inductor.scheduler.register_should_partition_rule(
                torch.ops.mylib.baz.default, should_partition
            )

            def f(x, flag):
                x = x + 1
                baz(x, flag, x)
                x = x + 1
                return x

            f_compiled = torch.compile(f, mode="reduce-overhead", fullgraph=True)
            _, code = run_and_get_code(f_compiled, x, True)
            num_partitions = get_num_partitions(code)
            self.assertEqual(num_partitions, 2)

cc @ProExpertProg @zou3519 @BoyuanFeng

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions