Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FlashInfer support #2354

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add FlashInfer support
This change adds support for FlashInfer. FlashInfer can be enabled using
`FLASH_INFER=1` and is currently only implemented in `FlashCausalLM`.
Since this functionality is currently only for testing, FlashInfer is
not installed anywhere yet.

The FlashInfer API is quite different from FlashAttention/vLLM in that
it requires more global bookkeeping:

* A wrapper class needs to be contstructed (which we just call *state*).
  Since this is fairly expensive (due to pinned host memory allocation),
  we only do this once in a FlashCausalLM instance or for each CUDA
  Graph size.
* Each model forward call needs to be wrapped in `begin_forward` and
  `end_forward`. This sets up data structures that can be reused for all
  calls to attention for that forward call.

When calling attention, we need access to the state object. To avoid
passing an argument down the call chain (which would require changes to
all models), we use a context variable.

Each model forward call is wrapped using a context manager that does all
the bookkeeping for such a call:

* Set the context variable to the forward call's state.
* Call `begin_forward` on the state.
* Yield.
* Call `end_forward` on the state.
* Reset the context variable.

We cannot use a single shared global variable for this, since e.g. CUDA
Graphs of different sizes each have their own state.
  • Loading branch information
danieldk committed Aug 2, 2024
commit c9916107b73c87d2fec73162fd78025b43faf6fe
4 changes: 2 additions & 2 deletions server/text_generation_server/layers/attention/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import dataclass
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.models.globals import FLASH_DECODING, FLASH_INFER
import torch
from typing import Optional


if FLASH_DECODING:
if FLASH_DECODING or FLASH_INFER:

@dataclass
class Seqlen:
Expand Down
46 changes: 42 additions & 4 deletions server/text_generation_server/layers/attention/cuda.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
from text_generation_server.models.globals import (
FLASH_DECODING,
BLOCK_SIZE,
FLASH_INFER,
)
from text_generation_server.layers.attention import Seqlen
from typing import Optional

Expand All @@ -23,7 +27,7 @@ def reshape_and_cache(
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if FLASH_DECODING:
if FLASH_DECODING or FLASH_INFER:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
Expand Down Expand Up @@ -72,7 +76,16 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
if FLASH_DECODING:
if FLASH_INFER:
from text_generation_server.layers.attention.flash_infer import decode_state

return decode_state.get().forward(
query.contiguous(),
paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
)
elif FLASH_DECODING:
max_q = 1
max_k = max_s
import flash_attn_2_cuda
Expand Down Expand Up @@ -202,7 +215,32 @@ def paged_attention(

SUPPORTS_WINDOWING = V2

if V2:
if FLASH_INFER:

def attention(
q,
k,
v,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
from text_generation_server.layers.attention.flash_infer import prefill_state

return prefill_state.get().forward(
q,
k,
v,
causal=causal,
window_left=window_size_left,
logits_soft_cap=softcap,
sm_scale=softmax_scale,
)

elif V2:

def attention(
q,
Expand Down
164 changes: 164 additions & 0 deletions server/text_generation_server/layers/attention/flash_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from typing import Optional
from contextvars import ContextVar
from contextlib import contextmanager

import flashinfer
import torch

prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar(
"prefill_state"
)

decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
"decode_state"
)

workspace: Optional[torch.Tensor] = None


def get_workspace(device):
"""Get shared flashinfer workspace."""
global workspace
if workspace is None:
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
return workspace


def create_prefill_state(
*,
device: torch.device,
):
"""Create a prefill state."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
)


@contextmanager
def use_prefill_state(
*,
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
cu_seqlens: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
query_dtype: str = "float16",
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""

token = prefill_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
kv_indptr=cu_seqlens,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=query_dtype,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_state.reset(token)


def create_decode_state(
*,
device: torch.device,
num_heads: int,
num_kv_heads: int,
):
"""Create a decode state."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=False,
use_tensor_cores=num_heads // num_kv_heads > 4,
)


def create_decode_state_cuda_graphs(
*,
device: torch.device,
block_tables: torch.Tensor,
block_tables_ptr: torch.Tensor,
last_page_len: torch.Tensor,
num_heads: int,
num_kv_heads: int,
):
"""
Create a decode state for use with CUDA Graphs. `block_tables`,
`block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are
therefore stored as part of the state.
"""
workspace_buffer = get_workspace(device)
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=True,
paged_kv_indices_buffer=block_tables,
paged_kv_indptr_buffer=block_tables_ptr,
paged_kv_last_page_len_buffer=last_page_len,
use_tensor_cores=num_heads // num_kv_heads > 4,
)


@contextmanager
def use_decode_state(
*,
state: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
input_lengths: torch.Tensor,
block_tables: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
page_size: int,
query_dtype: str = "float16",
):
"""
Context manager to set the active flashinfer decoding state to the given
`state` and parameters. This state will be used by all calls to the
`paged_attention` function while the context manager is active.
"""
indptr = torch.zeros(
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)

# Get the lengths of the last page in a block.
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1

token = decode_state.set(state)

try:
state.begin_forward(
indptr=indptr,
indices=block_tables,
last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=page_size,
q_data_type=query_dtype,
)
yield
finally:
state.end_forward()
if token is not None:
decode_state.reset(token)
Loading
Loading