Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] directly register custom op #9896

Merged
merged 25 commits into from
Nov 1, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
parallel state
Signed-off-by: youkaichao <youkaichao@gmail.com>
  • Loading branch information
youkaichao committed Oct 31, 2024
commit dc23b25b5174d7b06b6d10b5530d2233ccd29cc6
25 changes: 13 additions & 12 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,11 @@
import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup
from torch.library import Library

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import supports_custom_op
from vllm.utils import supports_custom_op, direct_register_custom_op


@dataclass
Expand Down Expand Up @@ -110,12 +109,13 @@ def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None:
return

my_lib = Library("vllm", "FRAGMENT")
my_lib.define(
"inplace_all_reduce(Tensor(a0!) tensor, str group_name) -> ()" # noqa
direct_register_custom_op(
library_name="vllm",
op_name="inplace_all_reduce",
op_func=inplace_all_reduce,
mutates_args=["tensor"],
fake_impl=inplace_all_reduce_fake,
)
my_lib.impl("inplace_all_reduce", inplace_all_reduce, "CUDA")
my_lib._register_fake("inplace_all_reduce", inplace_all_reduce_fake)

def outplace_all_reduce(tensor: torch.Tensor,
group_name: str) -> torch.Tensor:
Expand All @@ -129,12 +129,13 @@ def outplace_all_reduce_fake(tensor: torch.Tensor,
group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)

my_lib = Library("vllm", "FRAGMENT")
my_lib.define(
"outplace_all_reduce(Tensor tensor, str group_name) -> Tensor" # noqa
direct_register_custom_op(
library_name="vllm",
op_name="outplace_all_reduce",
op_func=outplace_all_reduce,
mutates_args=[],
fake_impl=outplace_all_reduce_fake,
)
my_lib.impl("outplace_all_reduce", outplace_all_reduce, "CUDA")
my_lib._register_fake("outplace_all_reduce", outplace_all_reduce_fake)


class GroupCoordinator:
Expand Down