Skip to content

[Fix][torch.compile] Enable custom ops by default when Inductor off #20102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 26 additions & 19 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,42 +28,49 @@ class Relu3(ReLUSquaredActivation):


@pytest.mark.parametrize(
"env, torch_level, ops_enabled, default_on",
"env, torch_level, use_inductor, ops_enabled, default_on",
[
# Default values based on compile level
("", 0, [True] * 4, True),
("", 1, [True] * 4, True),
("", 2, [True] * 4, True), # All by default
("", 3, [False] * 4, False),
("", 4, [False] * 4, False), # None by default
# - All by default (no Inductor compilation)
("", 0, False, [True] * 4, True),
("", 1, True, [True] * 4, True),
("", 2, False, [True] * 4, True),
# - None by default (with Inductor)
("", 3, True, [False] * 4, False),
("", 4, True, [False] * 4, False),
# - All by default (without Inductor)
("", 3, False, [True] * 4, True),
("", 4, False, [True] * 4, True),
# Explicitly enabling/disabling
#
# Default: all
#
# All but SiluAndMul
("+rms_norm,-silu_and_mul", 0, [1, 0, 1, 1], True),
("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True),
# Only ReLU3
("none,-rms_norm,+relu3", 0, [0, 0, 0, 1], False),
("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False),
# All but SiluAndMul
("all,-silu_and_mul", 1, [1, 0, 1, 1], True),
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
# All but ReLU3 (even if ReLU2 is on)
("-relu3,relu2", 1, [1, 1, 1, 0], True),
# GeluAndMul and SiluAndMul
("none,-relu3,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False),
("-relu3,relu2", 3, False, [1, 1, 1, 0], True),
# RMSNorm and SiluAndMul
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
# All but RMSNorm
("-rms_norm", 2, [0, 1, 1, 1], True),
("-rms_norm", 3, False, [0, 1, 1, 1], True),
#
# Default: none
#
# Only ReLU3
("-silu_and_mul,+relu3", 3, [0, 0, 0, 1], False),
("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False),
# All but RMSNorm
("all,-rms_norm", 4, [0, 1, 1, 1], True),
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
])
def test_enabled_ops(env: str, torch_level: int, ops_enabled: list[int],
default_on: bool):
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=torch_level, custom_ops=env.split(",")))
def test_enabled_ops(env: str, torch_level: int, use_inductor: bool,
ops_enabled: list[int], default_on: bool):
vllm_config = VllmConfig(
compilation_config=CompilationConfig(use_inductor=bool(use_inductor),
level=torch_level,
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on

Expand Down
27 changes: 9 additions & 18 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3929,7 +3929,8 @@ class CompilationConfig:
- 'none,+op1,+op2' to enable only op1 and op2

By default, all custom ops are enabled when running without Inductor and
disabled when running with Inductor (compile_level >= Inductor)."""
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
Inductor generates (fused) Triton kernels for disabled custom ops."""
splitting_ops: list[str] = field(default_factory=list)
"""A list of ops to split the full graph into subgraphs, used in piecewise
compilation."""
Expand All @@ -3938,10 +3939,13 @@ class CompilationConfig:
use_inductor: bool = True
"""Whether to use inductor compilation:

- False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for compile_sizes,
using configurations in inductor_compile_config."""
- False: inductor compilation is not used. graph runs in eager
(custom_ops enabled by default).
- True: inductor compilation is used (custom_ops disabled by default).
One graph for symbolic shape and one graph per size in compile_sizes
are compiled using configurations in inductor_compile_config.

This setting is ignored if level<PIECEWISE."""
compile_sizes: Optional[list[Union[int, str]]] = None
"""Sizes to compile for inductor. In addition
to integers, it also supports "cudagraph_capture_sizes" to
Expand Down Expand Up @@ -4469,19 +4473,6 @@ def __post_init__(self):
self.compilation_config.level = CompilationLevel.PIECEWISE
self.compilation_config.set_splitting_ops_for_v1()

# The behavior of custom ops with inductor depends on the config:
# - If use_inductor=True and custom_ops is empty:
# Inductor generates Triton kernels for all registered custom ops
# (default behavior)
# - If use_inductor=True and custom_ops is non-empty:
# Custom CUDA kernels are used for specified ops while inductor
# generates Triton kernels for remaining ops, including misc torch
# ops in the model.
if (not self.compilation_config.custom_ops
and self.compilation_config.use_inductor):
# Let inductor generate Triton kernels for the custom ops.
self.compilation_config.custom_ops = ["none"]

self._set_cudagraph_sizes()

if self.cache_config.cpu_offload_gb > 0 and \
Expand Down
12 changes: 6 additions & 6 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,16 @@ def enabled(cls) -> bool:
@staticmethod
def default_on() -> bool:
"""
On by default if level < CompilationLevel.PIECEWISE
On by default if PyTorch Inductor is not used.
Specifying 'all' or 'none' in custom_op takes precedence.
"""
from vllm.config import CompilationLevel
compilation_config = get_current_vllm_config().compilation_config
custom_ops = compilation_config.custom_ops
count_none = custom_ops.count("none")
count_all = custom_ops.count("all")
return compilation_config.level < CompilationLevel.PIECEWISE and \
not count_none > 0 or count_all > 0
default_on = (compilation_config.level < CompilationLevel.PIECEWISE
or not compilation_config.use_inductor)
count_none = compilation_config.custom_ops.count("none")
count_all = compilation_config.custom_ops.count("all")
return default_on and not count_none > 0 or count_all > 0

# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
Expand Down