Skip to content

Commit 7ab7621

Browse files
youkaichaoLeiWang1999
authored andcommitted
[core] use forward context for flash infer (vllm-project#9097)
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent e6054f9 commit 7ab7621

File tree

1 file changed

+127
-67
lines changed

1 file changed

+127
-67
lines changed

vllm/attention/backends/flashinfer.py

Lines changed: 127 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
compute_slot_mapping_start_idx,
2727
is_block_tables_empty)
2828
from vllm.attention.ops.paged_attn import PagedAttention
29+
from vllm.forward_context import get_forward_context
2930
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
3031
make_tensor_with_pad)
3132

@@ -761,73 +762,132 @@ def forward(
761762
"encoder/decoder cross-attention "
762763
"are not implemented for "
763764
"FlashInferImpl")
764-
num_tokens, hidden_size = query.shape
765-
query = query.view(-1, self.num_heads, self.head_size)
766-
key = key.view(-1, self.num_kv_heads, self.head_size)
767-
value = value.view(-1, self.num_kv_heads, self.head_size)
768765

769-
if attn_metadata.num_prefill_tokens > 0:
770-
assert attn_metadata.num_decode_tokens == 0, (
771-
"Chunked prefill is not supported with flashinfer yet.")
772-
if attn_metadata.num_decode_tokens > 0:
773-
assert attn_metadata.num_prefill_tokens == 0, (
774-
"Chunked prefill is not supported with flashinfer yet.")
775-
if kv_cache.numel() > 0:
776-
# Use the same reshape and cache kernel as flash attention.
777-
ops.reshape_and_cache_flash(
778-
key,
779-
value,
780-
kv_cache[:, 0],
781-
kv_cache[:, 1],
782-
attn_metadata.slot_mapping.flatten(),
783-
self.kv_cache_dtype,
784-
k_scale,
785-
v_scale,
766+
return torch.ops.vllm.unified_flash_infer(
767+
query,
768+
key,
769+
value,
770+
self.num_heads,
771+
self.head_size,
772+
self.num_kv_heads,
773+
kv_cache,
774+
self.kv_cache_dtype,
775+
k_scale,
776+
v_scale,
777+
self.scale,
778+
self.sliding_window,
779+
self.alibi_slopes,
780+
self.logits_soft_cap,
781+
)
782+
783+
784+
@torch.library.custom_op("vllm::unified_flash_infer",
785+
mutates_args=["kv_cache"])
786+
def unified_flash_infer(
787+
query: torch.Tensor,
788+
key: torch.Tensor,
789+
value: torch.Tensor,
790+
num_heads: int,
791+
head_size: int,
792+
num_kv_heads: int,
793+
kv_cache: torch.Tensor,
794+
kv_cache_dtype: str,
795+
k_scale: float,
796+
v_scale: float,
797+
softmax_scale: float,
798+
window_size: Optional[List[int]] = None,
799+
alibi_slopes: Optional[torch.Tensor] = None,
800+
logits_soft_cap: Optional[float] = None,
801+
) -> torch.Tensor:
802+
803+
current_metadata = get_forward_context()
804+
assert current_metadata is not None
805+
assert isinstance(current_metadata, FlashInferMetadata)
806+
attn_metadata: FlashInferMetadata = current_metadata
807+
808+
num_tokens, hidden_size = query.shape
809+
query = query.view(-1, num_heads, head_size)
810+
key = key.view(-1, num_kv_heads, head_size)
811+
value = value.view(-1, num_kv_heads, head_size)
812+
813+
if attn_metadata.num_prefill_tokens > 0:
814+
assert attn_metadata.num_decode_tokens == 0, (
815+
"Chunked prefill is not supported with flashinfer yet.")
816+
if attn_metadata.num_decode_tokens > 0:
817+
assert attn_metadata.num_prefill_tokens == 0, (
818+
"Chunked prefill is not supported with flashinfer yet.")
819+
if kv_cache.numel() > 0:
820+
# Use the same reshape and cache kernel as flash attention.
821+
ops.reshape_and_cache_flash(
822+
key,
823+
value,
824+
kv_cache[:, 0],
825+
kv_cache[:, 1],
826+
attn_metadata.slot_mapping.flatten(),
827+
kv_cache_dtype,
828+
k_scale,
829+
v_scale,
830+
)
831+
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
832+
# to process the cache when the kv_cache_dtype is fp8
833+
if kv_cache_dtype.startswith("fp8"):
834+
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
835+
kv_cache_dtype)
836+
kv_cache = kv_cache.view(torch_dtype)
837+
838+
query = query.contiguous() # Flashinfer requires query to be contiguous
839+
if prefill_meta := attn_metadata.prefill_metadata:
840+
# We will use flash attention for prefill
841+
# when kv_cache is not provided.
842+
# This happens when vllm runs the profiling to
843+
# determine the number of blocks.
844+
if kv_cache.numel() == 0:
845+
output = flash_attn_varlen_func(
846+
q=query,
847+
k=key,
848+
v=value,
849+
cu_seqlens_q=prefill_meta.seq_start_loc,
850+
cu_seqlens_k=prefill_meta.seq_start_loc,
851+
max_seqlen_q=prefill_meta.max_prefill_seq_len,
852+
max_seqlen_k=prefill_meta.max_prefill_seq_len,
853+
softmax_scale=softmax_scale,
854+
causal=True,
855+
window_size=window_size,
856+
alibi_slopes=alibi_slopes,
786857
)
787-
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
788-
# to process the cache when the kv_cache_dtype is fp8
789-
if self.kv_cache_dtype.startswith("fp8"):
790-
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
791-
self.kv_cache_dtype)
792-
kv_cache = kv_cache.view(torch_dtype)
793-
794-
query = query.contiguous(
795-
) # Flashinfer requires query to be contiguous
796-
if prefill_meta := attn_metadata.prefill_metadata:
797-
# We will use flash attention for prefill
798-
# when kv_cache is not provided.
799-
# This happens when vllm runs the profiling to
800-
# determine the number of blocks.
801-
if kv_cache.numel() == 0:
802-
output = flash_attn_varlen_func(
803-
q=query,
804-
k=key,
805-
v=value,
806-
cu_seqlens_q=prefill_meta.seq_start_loc,
807-
cu_seqlens_k=prefill_meta.seq_start_loc,
808-
max_seqlen_q=prefill_meta.max_prefill_seq_len,
809-
max_seqlen_k=prefill_meta.max_prefill_seq_len,
810-
softmax_scale=self.scale,
811-
causal=True,
812-
window_size=self.sliding_window,
813-
alibi_slopes=self.alibi_slopes,
814-
)
815-
else:
816-
assert prefill_meta is not None
817-
assert prefill_meta.prefill_wrapper is not None
818-
output = prefill_meta.prefill_wrapper.forward(
819-
query,
820-
kv_cache,
821-
logits_soft_cap=self.logits_soft_cap,
822-
causal=True)
823858
else:
824-
assert attn_metadata.decode_metadata is not None
825-
assert attn_metadata.decode_metadata.decode_wrapper is not None
826-
output = attn_metadata.decode_metadata.decode_wrapper.forward(
827-
query,
828-
kv_cache,
829-
sm_scale=self.scale,
830-
logits_soft_cap=self.logits_soft_cap,
831-
k_scale=k_scale,
832-
v_scale=v_scale)
833-
return output.view(num_tokens, hidden_size)
859+
assert prefill_meta is not None
860+
assert prefill_meta.prefill_wrapper is not None
861+
output = prefill_meta.prefill_wrapper.forward(
862+
query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True)
863+
else:
864+
assert attn_metadata.decode_metadata is not None
865+
assert attn_metadata.decode_metadata.decode_wrapper is not None
866+
output = attn_metadata.decode_metadata.decode_wrapper.forward(
867+
query,
868+
kv_cache,
869+
sm_scale=softmax_scale,
870+
logits_soft_cap=logits_soft_cap,
871+
k_scale=k_scale,
872+
v_scale=v_scale)
873+
return output.view(num_tokens, hidden_size)
874+
875+
876+
@unified_flash_infer.register_fake
877+
def _(
878+
query: torch.Tensor,
879+
key: torch.Tensor,
880+
value: torch.Tensor,
881+
num_heads: int,
882+
head_size: int,
883+
num_kv_heads: int,
884+
kv_cache: torch.Tensor,
885+
kv_cache_dtype: str,
886+
k_scale: float,
887+
v_scale: float,
888+
softmax_scale: float,
889+
window_size: Optional[List[int]] = None,
890+
alibi_slopes: Optional[torch.Tensor] = None,
891+
logits_soft_cap: Optional[float] = None,
892+
) -> torch.Tensor:
893+
return torch.empty_like(query).contiguous()

0 commit comments

Comments
 (0)