Skip to content

Revert "[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass… #14317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions tests/compile/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,21 @@ class TestBackend:
This class provides a simple Inductor backend that can be used for testing.
It takes a list of custom passes and runs them after Inductor's passes.
It also saves the graph before and after the custom passes for inspection.

Inductor config can be modified directly by editing the inductor_config
property. This can be helpful for adding passes like the
'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'.
"""

def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
None]]):
self.custom_passes = list(passes)
from torch._inductor import config
self.inductor_config = config.shallow_copy_dict()
self.inductor_config['force_disable_caches'] = True
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
self.current_config = config.shallow_copy_dict()
self.current_config['force_disable_caches'] = True
self.current_config['post_grad_custom_post_pass'] = self.post_pass

def __call__(self, graph: fx.GraphModule, example_inputs):
self.graph_pre_compile = deepcopy(graph)
from torch._inductor.compile_fx import compile_fx
return compile_fx(graph,
example_inputs,
config_patches=self.inductor_config)
config_patches=self.current_config)

def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph)
Expand Down
8 changes: 4 additions & 4 deletions tests/compile/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.config import CompilationConfig

from .backend import TestBackend
Expand Down Expand Up @@ -50,11 +50,11 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
torch.set_default_device("cuda")

config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
enable_noop=True)
noop_pass = NoOpEliminationPass(config)
enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)

passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass]
func_pass = FixFunctionalizationPass(config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
Expand Down
127 changes: 58 additions & 69 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,23 @@
from compressed_tensors.quantization import FP8_DTYPE

import vllm.envs as envs
import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, QuantKey)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.config import CompilationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, apply_fp8_linear, maybe_create_device_identity)
apply_fp8_linear)

from .backend import TestBackend


class TestModel(torch.nn.Module):

def __init__(self, hidden_size: int, eps: float, static: bool,
cutlass_fp8_enabled: bool, *args, **kwargs):
def __init__(self, hidden_size: int, eps: float, static: bool, *args,
**kwargs):
super().__init__(*args, **kwargs)
self.cutlass_fp8_enabled = cutlass_fp8_enabled
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
if static:
Expand All @@ -43,17 +41,15 @@ def forward(self, x):
self.w[0],
self.wscale[0],
self.scale[0],
use_per_token_if_dynamic=True,
cutlass_fp8_supported=self.cutlass_fp8_enabled)
use_per_token_if_dynamic=True)
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)

x3 = apply_fp8_linear(y2,
self.w[1],
self.wscale[1],
self.scale[1],
use_per_token_if_dynamic=True,
cutlass_fp8_supported=self.cutlass_fp8_enabled)
use_per_token_if_dynamic=True)
y3, resid = self.norm[2](x3, resid) # use resid here
return y3

Expand All @@ -63,67 +59,60 @@ def forward(self, x):
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("cutlass_fp8_enabled",
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cutlass_fp8_enabled):
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths

vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
config = CompilationConfig.PassConfig(enable_fusion=True,
enable_noop=True)
noop_pass = NoOpEliminationPass(config)
fusion_pass = FusionPass.instance(config)

backend = TestBackend(noop_pass, fusion_pass)
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)

# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)

result = model(x)

model2 = torch.compile(model, backend=backend)
result2 = model2(x)

# Higher tol for dynamic, even higher for bfloat16
if static:
ATOL, RTOL = (1e-3, 1e-3)
elif dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)

torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)

# Check substitution worked
pre_nodes = backend.graph_pre_pass.nodes
post_nodes = backend.graph_post_pass.nodes

# static is per-tensor, dynamic is per-token
key = QuantKey(dtype=FP8_DTYPE,
static=static,
per_tensor=static,
symmetric=True)
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
fp8_quant = QUANT_OPS[key]

# In pre-nodes, fp8 quant should be there and fused kernels should not
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
find_auto_fn(pre_nodes, fp8_quant)

# In post-nodes, fused kernels should be there and fp8 quant should not
find_auto_fn(post_nodes, rms_quant)
find_auto_fn(post_nodes, add_rms_quant)
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
# Reshape pass is needed for the fusion pass to work
config = CompilationConfig.PassConfig(enable_fusion=True,
enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)

backend = TestBackend(reshape_pass, fusion_pass)
model = TestModel(hidden_size, eps, static)

# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)

result = model(x)

model2 = torch.compile(model, backend=backend)
result2 = model2(x)

# Higher tol for dynamic, even higher for bfloat16
if static:
ATOL, RTOL = (1e-3, 1e-3)
elif dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)

torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)

# Check substitution worked
pre_nodes = backend.graph_pre_pass.nodes
post_nodes = backend.graph_post_pass.nodes

# static is per-tensor, dynamic is per-token
key = QuantKey(dtype=FP8_DTYPE,
static=static,
per_tensor=static,
symmetric=True)
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
fp8_quant = QUANT_OPS[key]

# In pre-nodes, fp8 quant should be present and fused kernels should not
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
find_auto_fn(pre_nodes, fp8_quant)

# In post-nodes, fused kernels should be present and fp8 quant should not
find_auto_fn(post_nodes, rms_quant)
find_auto_fn(post_nodes, add_rms_quant)
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
135 changes: 0 additions & 135 deletions vllm/compilation/noop_elimination.py

This file was deleted.

8 changes: 4 additions & 4 deletions vllm/compilation/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .inductor_pass import InductorPass
from .noop_elimination import NoOpEliminationPass
from .reshapes import RedundantReshapesPass

logger = init_logger(__name__)

Expand All @@ -36,7 +36,7 @@ class PostGradPassManager(Parent):

The order of the post-grad post-passes is:
1. passes (constructor parameter)
2. default passes (NoopEliminationPass, FusionPass)
2. default passes (RedundantReshapesPass, FusionPass)
3. config["post_grad_custom_post_pass"] (if it exists)
4. fix_functionalization
This way, all passes operate on a functionalized graph.
Expand All @@ -54,8 +54,8 @@ def __call__(self, graph: fx.Graph):

def configure(self, pass_config: CompilationConfig.PassConfig):
self.pass_config = pass_config
if pass_config.enable_noop:
self.passes += [NoOpEliminationPass(pass_config)]
if pass_config.enable_reshape:
self.passes += [RedundantReshapesPass(pass_config)]

if pass_config.enable_fusion:
self.passes += [FusionPass.instance(pass_config)]
Expand Down
Loading