Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] directly register custom op #9896

Merged
merged 25 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
hack fix library
Signed-off-by: youkaichao <youkaichao@gmail.com>
  • Loading branch information
youkaichao committed Nov 1, 2024
commit d515d619a4f2f97b0bdd17d1b8652a4d82629fc4
2 changes: 1 addition & 1 deletion tests/compile/piecewise/piecewise_compilation_config.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"use_cudagraph": true,
"non_cudagraph_ops": ["silly.attention"]
"non_cudagraph_ops": ["vllm.toy_attention"]
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
}
8 changes: 4 additions & 4 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,


direct_register_custom_op(
library_name="silly",
op_name="attention",
library_name="vllm",
op_name="toy_attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
Expand All @@ -57,12 +57,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + 1
x = x + 2
out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out)
torch.ops.vllm.toy_attention(x, x, x, out)
x = out
x = x - 2
x = x - 1
out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out)
torch.ops.vllm.toy_attention(x, x, x, out)
x = out
x = x + 1
return x
Expand Down
8 changes: 4 additions & 4 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,


direct_register_custom_op(
library_name="silly",
op_name="attention",
library_name="vllm",
op_name="toy_attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
Expand Down Expand Up @@ -103,7 +103,7 @@ def forward(
k = k + positions.unsqueeze(1)

attn_output = torch.empty_like(q)
torch.ops.silly.attention(q, k, v, attn_output)
torch.ops.vllm.toy_attention(q, k, v, attn_output)

output = self.output_projection(attn_output)
return output
Expand Down Expand Up @@ -179,7 +179,7 @@ def run_model(llama_config,
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
non_cudagraph_ops=["vllm.toy_attention"],
))
else:
set_compilation_config(CompilationConfig(use_cudagraph=True, ))
Expand Down
7 changes: 6 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,6 +1515,9 @@ def weak_ref_tensors(
raise ValueError("Invalid type for tensors")


vllm_lib = Library("vllm", "FRAGMENT")


def direct_register_custom_op(
library_name: str,
op_name: str,
Expand All @@ -1530,7 +1533,9 @@ def direct_register_custom_op(
for more details.
"""
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
my_lib = Library(library_name, "FRAGMENT")
# FIXME after https://github.com/pytorch/pytorch/issues/139444 is resolved
assert library_name == "vllm"
my_lib = vllm_lib
my_lib.define(op_name + schema_str)
my_lib.impl(op_name, op_func, "CUDA")
if fake_impl is not None:
Expand Down