Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 3 additions & 28 deletions tests/compile/piecewise/test_multiple_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,26 @@
"""
import torch
from torch import nn
from torch.library import Library

from vllm.compilation.backends import set_model_tag
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import (ignore_torch_compile,
support_torch_compile)
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op

# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa
# This import automatically registers torch ops for testing (like silly.attention)
import tests.compile.testing_ops

BATCH_SIZE = 32
MLP_SIZE = 128
HIDDEN_SIZE = 1024
RANDOM_SEED = 0


def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v


def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return


direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)


@support_torch_compile
class ParentModel(nn.Module):

Expand Down Expand Up @@ -277,9 +256,5 @@ def test_multi_graph_piecewise_compile_outputs_equal():
outputs.append(
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))

# Generally don't expect outputs with and without inductor
# to be bitwise equivalent
assert torch.allclose(outputs[0], outputs[1])

# Expect bitwise equivalence using inductor w/ and w/o cudagraph
assert torch.equal(outputs[0], outputs[2])
43 changes: 8 additions & 35 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,17 @@
import pytest
import torch
from torch import nn
from torch.library import Library

from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op

global_counter = 0

# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa


def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
global global_counter
global_counter += 1
print(f"{global_counter=}")
out.copy_(q)
out[0] += 1


def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return


direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
# This import also automatically registers torch ops for testing (like silly.attention)
from tests.compile.testing_ops import (
get_global_counter, reset_global_counter
)


Expand All @@ -58,9 +33,8 @@ def __init__(self,

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overall effect:
x += 1
x[0] += 2
Overall effect with unified attention implementation:
input [0., 0.] -> final output [19., 19.]
global_counter += 2
"""
x = x + 1
Expand Down Expand Up @@ -121,13 +95,12 @@ def test_simple_piecewise_compile(use_inductor):
model(torch.randn(1).cuda())

input = torch.zeros(2).cuda()
global global_counter
global_counter = 0
reset_global_counter()
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )):
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
assert get_global_counter() == 2
assert torch.allclose(output.cpu(), torch.tensor([19., 19.]))
27 changes: 2 additions & 25 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,15 @@
import pytest
import torch
from torch import nn
from torch.library import Library

from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op

# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa


def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v


def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return


direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)
# This import automatically registers torch ops for testing (like silly.attention)
import tests.compile.testing_ops


@dataclass
Expand Down
28 changes: 2 additions & 26 deletions tests/compile/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,20 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import nn
from torch.library import Library

# This import automatically registers torch ops for testing (like silly.attention)
import tests.compile.testing_ops
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import (ignore_torch_compile,
support_torch_compile)
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
CUDAGraphMode, VllmConfig, set_current_vllm_config)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op

# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa

BATCH_SIZE = 32
MLP_SIZE = 128


def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v


def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return


direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)


@torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module,
cudagraph_runtime_mode: CUDAGraphMode):
Expand Down
62 changes: 62 additions & 0 deletions tests/compile/testing_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Shared PyTorch custom operations for compilation tests.

Centralizes custom operation definitions to avoid duplicate registrations.
"""

import torch
from torch.library import Library

from vllm.utils import direct_register_custom_op

# Shared library for all compilation test operations
# Using "silly" namespace to match existing test expectations
silly_lib = Library("silly", "FRAGMENT")


# Global counter that counts the number of times attention is invoked
_global_counter = 0


def get_global_counter():
"""Get the current global counter value"""
return _global_counter


def reset_global_counter():
"""Reset the global counter to 0"""
global _global_counter
_global_counter = 0


def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
"""
Unified attention implementation that depends on all inputs and affects the output.
Always increments a global counter that tests can use or ignore.
"""
global _global_counter

# Always increment the global counter
_global_counter += 1

# Unified implementation that depends on all inputs
out.copy_(q + k + v)


def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
"""Fake implementation for testing"""
return


# Register the unified attention operation
direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)