|
26 | 26 | compute_slot_mapping_start_idx,
|
27 | 27 | is_block_tables_empty)
|
28 | 28 | from vllm.attention.ops.paged_attn import PagedAttention
|
| 29 | +from vllm.forward_context import get_forward_context |
29 | 30 | from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
30 | 31 | make_tensor_with_pad)
|
31 | 32 |
|
@@ -761,73 +762,132 @@ def forward(
|
761 | 762 | "encoder/decoder cross-attention "
|
762 | 763 | "are not implemented for "
|
763 | 764 | "FlashInferImpl")
|
764 |
| - num_tokens, hidden_size = query.shape |
765 |
| - query = query.view(-1, self.num_heads, self.head_size) |
766 |
| - key = key.view(-1, self.num_kv_heads, self.head_size) |
767 |
| - value = value.view(-1, self.num_kv_heads, self.head_size) |
768 | 765 |
|
769 |
| - if attn_metadata.num_prefill_tokens > 0: |
770 |
| - assert attn_metadata.num_decode_tokens == 0, ( |
771 |
| - "Chunked prefill is not supported with flashinfer yet.") |
772 |
| - if attn_metadata.num_decode_tokens > 0: |
773 |
| - assert attn_metadata.num_prefill_tokens == 0, ( |
774 |
| - "Chunked prefill is not supported with flashinfer yet.") |
775 |
| - if kv_cache.numel() > 0: |
776 |
| - # Use the same reshape and cache kernel as flash attention. |
777 |
| - ops.reshape_and_cache_flash( |
778 |
| - key, |
779 |
| - value, |
780 |
| - kv_cache[:, 0], |
781 |
| - kv_cache[:, 1], |
782 |
| - attn_metadata.slot_mapping.flatten(), |
783 |
| - self.kv_cache_dtype, |
784 |
| - k_scale, |
785 |
| - v_scale, |
| 766 | + return torch.ops.vllm.unified_flash_infer( |
| 767 | + query, |
| 768 | + key, |
| 769 | + value, |
| 770 | + self.num_heads, |
| 771 | + self.head_size, |
| 772 | + self.num_kv_heads, |
| 773 | + kv_cache, |
| 774 | + self.kv_cache_dtype, |
| 775 | + k_scale, |
| 776 | + v_scale, |
| 777 | + self.scale, |
| 778 | + self.sliding_window, |
| 779 | + self.alibi_slopes, |
| 780 | + self.logits_soft_cap, |
| 781 | + ) |
| 782 | + |
| 783 | + |
| 784 | +@torch.library.custom_op("vllm::unified_flash_infer", |
| 785 | + mutates_args=["kv_cache"]) |
| 786 | +def unified_flash_infer( |
| 787 | + query: torch.Tensor, |
| 788 | + key: torch.Tensor, |
| 789 | + value: torch.Tensor, |
| 790 | + num_heads: int, |
| 791 | + head_size: int, |
| 792 | + num_kv_heads: int, |
| 793 | + kv_cache: torch.Tensor, |
| 794 | + kv_cache_dtype: str, |
| 795 | + k_scale: float, |
| 796 | + v_scale: float, |
| 797 | + softmax_scale: float, |
| 798 | + window_size: Optional[List[int]] = None, |
| 799 | + alibi_slopes: Optional[torch.Tensor] = None, |
| 800 | + logits_soft_cap: Optional[float] = None, |
| 801 | +) -> torch.Tensor: |
| 802 | + |
| 803 | + current_metadata = get_forward_context() |
| 804 | + assert current_metadata is not None |
| 805 | + assert isinstance(current_metadata, FlashInferMetadata) |
| 806 | + attn_metadata: FlashInferMetadata = current_metadata |
| 807 | + |
| 808 | + num_tokens, hidden_size = query.shape |
| 809 | + query = query.view(-1, num_heads, head_size) |
| 810 | + key = key.view(-1, num_kv_heads, head_size) |
| 811 | + value = value.view(-1, num_kv_heads, head_size) |
| 812 | + |
| 813 | + if attn_metadata.num_prefill_tokens > 0: |
| 814 | + assert attn_metadata.num_decode_tokens == 0, ( |
| 815 | + "Chunked prefill is not supported with flashinfer yet.") |
| 816 | + if attn_metadata.num_decode_tokens > 0: |
| 817 | + assert attn_metadata.num_prefill_tokens == 0, ( |
| 818 | + "Chunked prefill is not supported with flashinfer yet.") |
| 819 | + if kv_cache.numel() > 0: |
| 820 | + # Use the same reshape and cache kernel as flash attention. |
| 821 | + ops.reshape_and_cache_flash( |
| 822 | + key, |
| 823 | + value, |
| 824 | + kv_cache[:, 0], |
| 825 | + kv_cache[:, 1], |
| 826 | + attn_metadata.slot_mapping.flatten(), |
| 827 | + kv_cache_dtype, |
| 828 | + k_scale, |
| 829 | + v_scale, |
| 830 | + ) |
| 831 | + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 |
| 832 | + # to process the cache when the kv_cache_dtype is fp8 |
| 833 | + if kv_cache_dtype.startswith("fp8"): |
| 834 | + torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( |
| 835 | + kv_cache_dtype) |
| 836 | + kv_cache = kv_cache.view(torch_dtype) |
| 837 | + |
| 838 | + query = query.contiguous() # Flashinfer requires query to be contiguous |
| 839 | + if prefill_meta := attn_metadata.prefill_metadata: |
| 840 | + # We will use flash attention for prefill |
| 841 | + # when kv_cache is not provided. |
| 842 | + # This happens when vllm runs the profiling to |
| 843 | + # determine the number of blocks. |
| 844 | + if kv_cache.numel() == 0: |
| 845 | + output = flash_attn_varlen_func( |
| 846 | + q=query, |
| 847 | + k=key, |
| 848 | + v=value, |
| 849 | + cu_seqlens_q=prefill_meta.seq_start_loc, |
| 850 | + cu_seqlens_k=prefill_meta.seq_start_loc, |
| 851 | + max_seqlen_q=prefill_meta.max_prefill_seq_len, |
| 852 | + max_seqlen_k=prefill_meta.max_prefill_seq_len, |
| 853 | + softmax_scale=softmax_scale, |
| 854 | + causal=True, |
| 855 | + window_size=window_size, |
| 856 | + alibi_slopes=alibi_slopes, |
786 | 857 | )
|
787 |
| - # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 |
788 |
| - # to process the cache when the kv_cache_dtype is fp8 |
789 |
| - if self.kv_cache_dtype.startswith("fp8"): |
790 |
| - torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( |
791 |
| - self.kv_cache_dtype) |
792 |
| - kv_cache = kv_cache.view(torch_dtype) |
793 |
| - |
794 |
| - query = query.contiguous( |
795 |
| - ) # Flashinfer requires query to be contiguous |
796 |
| - if prefill_meta := attn_metadata.prefill_metadata: |
797 |
| - # We will use flash attention for prefill |
798 |
| - # when kv_cache is not provided. |
799 |
| - # This happens when vllm runs the profiling to |
800 |
| - # determine the number of blocks. |
801 |
| - if kv_cache.numel() == 0: |
802 |
| - output = flash_attn_varlen_func( |
803 |
| - q=query, |
804 |
| - k=key, |
805 |
| - v=value, |
806 |
| - cu_seqlens_q=prefill_meta.seq_start_loc, |
807 |
| - cu_seqlens_k=prefill_meta.seq_start_loc, |
808 |
| - max_seqlen_q=prefill_meta.max_prefill_seq_len, |
809 |
| - max_seqlen_k=prefill_meta.max_prefill_seq_len, |
810 |
| - softmax_scale=self.scale, |
811 |
| - causal=True, |
812 |
| - window_size=self.sliding_window, |
813 |
| - alibi_slopes=self.alibi_slopes, |
814 |
| - ) |
815 |
| - else: |
816 |
| - assert prefill_meta is not None |
817 |
| - assert prefill_meta.prefill_wrapper is not None |
818 |
| - output = prefill_meta.prefill_wrapper.forward( |
819 |
| - query, |
820 |
| - kv_cache, |
821 |
| - logits_soft_cap=self.logits_soft_cap, |
822 |
| - causal=True) |
823 | 858 | else:
|
824 |
| - assert attn_metadata.decode_metadata is not None |
825 |
| - assert attn_metadata.decode_metadata.decode_wrapper is not None |
826 |
| - output = attn_metadata.decode_metadata.decode_wrapper.forward( |
827 |
| - query, |
828 |
| - kv_cache, |
829 |
| - sm_scale=self.scale, |
830 |
| - logits_soft_cap=self.logits_soft_cap, |
831 |
| - k_scale=k_scale, |
832 |
| - v_scale=v_scale) |
833 |
| - return output.view(num_tokens, hidden_size) |
| 859 | + assert prefill_meta is not None |
| 860 | + assert prefill_meta.prefill_wrapper is not None |
| 861 | + output = prefill_meta.prefill_wrapper.forward( |
| 862 | + query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True) |
| 863 | + else: |
| 864 | + assert attn_metadata.decode_metadata is not None |
| 865 | + assert attn_metadata.decode_metadata.decode_wrapper is not None |
| 866 | + output = attn_metadata.decode_metadata.decode_wrapper.forward( |
| 867 | + query, |
| 868 | + kv_cache, |
| 869 | + sm_scale=softmax_scale, |
| 870 | + logits_soft_cap=logits_soft_cap, |
| 871 | + k_scale=k_scale, |
| 872 | + v_scale=v_scale) |
| 873 | + return output.view(num_tokens, hidden_size) |
| 874 | + |
| 875 | + |
| 876 | +@unified_flash_infer.register_fake |
| 877 | +def _( |
| 878 | + query: torch.Tensor, |
| 879 | + key: torch.Tensor, |
| 880 | + value: torch.Tensor, |
| 881 | + num_heads: int, |
| 882 | + head_size: int, |
| 883 | + num_kv_heads: int, |
| 884 | + kv_cache: torch.Tensor, |
| 885 | + kv_cache_dtype: str, |
| 886 | + k_scale: float, |
| 887 | + v_scale: float, |
| 888 | + softmax_scale: float, |
| 889 | + window_size: Optional[List[int]] = None, |
| 890 | + alibi_slopes: Optional[torch.Tensor] = None, |
| 891 | + logits_soft_cap: Optional[float] = None, |
| 892 | +) -> torch.Tensor: |
| 893 | + return torch.empty_like(query).contiguous() |
0 commit comments