Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] directly register custom op #9896

Merged
merged 25 commits into from
Nov 1, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fused moe
Signed-off-by: youkaichao <youkaichao@gmail.com>
  • Loading branch information
youkaichao committed Oct 31, 2024
commit 3c5e7dd6aa6fa571d8059a4c153c8b34ab537342
24 changes: 13 additions & 11 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import torch
import triton
import triton.language as tl
from torch.library import Library

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op

logger = init_logger(__name__)

Expand Down Expand Up @@ -498,12 +498,13 @@ def inplace_fused_experts_fake(
pass


my_lib = Library("vllm", "FRAGMENT")
my_lib.define(
"inplace_fused_experts(Tensor(a0!) hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool use_fp8_w8a8=False, bool use_int8_w8a16=False, Tensor? w1_scale=None, Tensor? w2_scale=None, Tensor? a1_scale=None, Tensor? a2_scale=None) -> ()" # noqa
direct_register_custom_op(
library_name="vllm",
op_name="inplace_fused_experts",
op_func=inplace_fused_experts,
mutates_args=["hidden_states"],
fake_impl=inplace_fused_experts_fake,
)
my_lib.impl("inplace_fused_experts", inplace_fused_experts, "CUDA")
my_lib._register_fake("inplace_fused_experts", inplace_fused_experts_fake)


def outplace_fused_experts(
Expand Down Expand Up @@ -538,12 +539,13 @@ def outplace_fused_experts_fake(
return torch.empty_like(hidden_states)


my_lib = Library("vllm", "FRAGMENT")
my_lib.define(
"outplace_fused_experts(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool use_fp8_w8a8=False, bool use_int8_w8a16=False, Tensor? w1_scale=None, Tensor? w2_scale=None, Tensor? a1_scale=None, Tensor? a2_scale=None) -> Tensor" # noqa
direct_register_custom_op(
library_name="vllm",
op_name="outplace_fused_experts",
op_func=outplace_fused_experts,
mutates_args=[],
fake_impl=outplace_fused_experts_fake,
)
my_lib.impl("outplace_fused_experts", outplace_fused_experts, "CUDA")
my_lib._register_fake("outplace_fused_experts", outplace_fused_experts_fake)


def fused_experts(hidden_states: torch.Tensor,
Expand Down