Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] directly register custom op #9896

Merged
merged 25 commits into from
Nov 1, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
flash attn
Signed-off-by: youkaichao <youkaichao@gmail.com>
  • Loading branch information
youkaichao committed Oct 31, 2024
commit eca0f28a05b45794724237ef9263692ecdcd4d36
15 changes: 8 additions & 7 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import torch
from torch.library import Library

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
Expand All @@ -15,7 +14,8 @@
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.forward_context import get_forward_context
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
make_tensor_with_pad)

if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
Expand Down Expand Up @@ -773,9 +773,10 @@ def unified_flash_attention_fake(
return torch.empty_like(query)


my_lib = Library("vllm", "FRAGMENT")
my_lib.define(
"unified_flash_attention(Tensor query, Tensor key, Tensor value, SymInt num_heads, SymInt head_size, SymInt num_kv_heads, Tensor(a6!) kv_cache, str kv_cache_dtype, float k_scale, float v_scale, float softmax_scale, SymInt[]? window_size=None, Tensor? alibi_slopes=None, float? logits_soft_cap=None) -> Tensor" # noqa
direct_register_custom_op(
library_name="vllm",
op_name="unified_flash_attention",
op_func=unified_flash_attention,
mutates_args=["kv_cache"],
fake_impl=unified_flash_attention_fake,
)
my_lib.impl("unified_flash_attention", unified_flash_attention, "CUDA")
my_lib._register_fake("unified_flash_attention", unified_flash_attention_fake)