Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Major Change][Undecided yet] Move to FlashDecoding instead of PagedAttention kernel. #1940

Merged
merged 16 commits into from
Jul 1, 2024
Merged
Prev Previous commit
Next Next commit
Factoring cu_seqlen_qk for better abstracting over every model.
  • Loading branch information
Narsil committed Jul 1, 2024
commit 4b1364da9204f948fdfb4ad68fbb40a17c68e345
2 changes: 2 additions & 0 deletions server/text_generation_server/layers/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from text_generation_server.utils.import_utils import SYSTEM
import os

from .common import Seqlen

if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
Expand Down
31 changes: 31 additions & 0 deletions server/text_generation_server/layers/attention/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dataclasses import dataclass
from text_generation_server.models.globals import FLASH_DECODING
import torch
from typing import Optional


@dataclass
class Seqlen:
input_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]

def __init__(self, input_lengths):
self.input_lengths = input_lengths
if FLASH_DECODING:
device = self.input_lengths.device
shape = self.input_lengths.shape
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
cu_seqlen_k = torch.empty(shape[-1] + 1, device=device, dtype=torch.int32)
cu_seqlen_k[0] = 0
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])

self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
else:
self.cu_seqlen_q = None
self.cu_seqlen_k = None
10 changes: 5 additions & 5 deletions server/text_generation_server/layers/attention/cuda.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen

major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
Expand Down Expand Up @@ -40,8 +41,7 @@ def paged_attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
cu_seqlen_q: torch.Tensor,
cu_seqlen_k: torch.Tensor,
seqlen: Seqlen,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
Expand All @@ -66,7 +66,6 @@ def paged_attention(
block_size = BLOCK_SIZE
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = cu_seqlen_k

# NOTE(woosuk): We use a simple heuristic to decide whether to use
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: this comment should move down to the paged attention version condition.

# PagedAttention V1 or V2. If the number of partitions is 1, we use
Expand All @@ -88,8 +87,8 @@ def paged_attention(
key_cache,
value_cache,
None,
cu_seqlen_q,
cu_seqlen_k,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
block_tables,
None,
Expand All @@ -106,6 +105,7 @@ def paged_attention(
)
return out2[0]
else:
input_lengths = seqlen.input_lengths
from vllm._C import ops

use_v1 = max_s <= 8192 and (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,7 @@ def forward(
cu_seqlen_prefill,
kv_cache,
block_tables,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
slots,
max_s,
):
Expand Down Expand Up @@ -314,8 +313,7 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
max_s,
)

Expand Down Expand Up @@ -389,8 +387,7 @@ def forward(
cu_seqlen_prefill,
kv_cache,
block_tables,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
slots,
max_s,
):
Expand All @@ -404,8 +401,7 @@ def forward(
cu_seqlen_prefill,
kv_cache,
block_tables,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
slots,
max_s,
)
Expand Down Expand Up @@ -469,23 +465,6 @@ def forward(
)

residual = None
if cu_seqlen_prefill is None and FLASH_DECODING:
cu_seqlen_q = torch.arange(
input_lengths.shape[0] + 1,
device=input_ids.device,
dtype=torch.int32,
)
cu_seqlen_k = torch.cat(
[
torch.zeros(
(1,), device=input_lengths.device, dtype=input_lengths.dtype
),
input_lengths.cumsum(dim=-1),
]
).to(dtype=torch.int32)
else:
cu_seqlen_q = None
cu_seqlen_k = input_lengths

for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
Expand All @@ -496,8 +475,7 @@ def forward(
cu_seqlen_prefill,
kv_cache[i],
block_tables,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
slots,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ def forward(
kv_cache,
block_tables,
slots,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
max_s,
adapter_data,
):
Expand Down Expand Up @@ -218,8 +217,7 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
max_s,
)

Expand Down Expand Up @@ -356,8 +354,7 @@ def forward(
kv_cache,
block_tables,
slots,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
max_s,
adapter_data,
):
Expand All @@ -372,8 +369,7 @@ def forward(
kv_cache,
block_tables,
slots,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
max_s,
adapter_data,
)
Expand Down Expand Up @@ -443,23 +439,6 @@ def forward(
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
if cu_seqlen_prefill is None and FLASH_DECODING:
cu_seqlen_q = torch.arange(
input_lengths.shape[0] + 1,
device=inputs_embeds.device,
dtype=torch.int32,
)
cu_seqlen_k = torch.cat(
[
torch.zeros(
(1,), device=input_lengths.device, dtype=input_lengths.dtype
),
input_lengths.cumsum(dim=-1),
]
).to(dtype=torch.int32)
else:
cu_seqlen_q = None
cu_seqlen_k = input_lengths

residual = None
for i, layer in enumerate(self.layers):
Expand All @@ -472,8 +451,7 @@ def forward(
kv_cache[i],
block_tables,
slots,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
max_s,
adapter_data,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down Expand Up @@ -349,7 +348,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
10 changes: 6 additions & 4 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
from text_generation_server.models.globals import (
MEM_POOL,
FLASH_DECODING,
BLOCK_SIZE,
CUDA_GRAPHS,
get_adapter_to_index,
MODEL_ID,
)
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
Expand All @@ -47,9 +49,6 @@

tracer = trace.get_tracer(__name__)

BLOCK_SIZE: int = (
256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16
)

# Will be set in init
SLIDING_WINDOW: Optional[int] = None
Expand Down Expand Up @@ -927,6 +926,7 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
"slots": slots,
"input_lengths": input_lengths,
}
input_lengths = Seqlen(input_lengths=input_lengths)
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph

Expand Down Expand Up @@ -1086,6 +1086,7 @@ def tunableop_warmup(self, seqlen: int):

# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
seqlen = Seqlen(input_lengths=input_lengths)

# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
Expand All @@ -1096,7 +1097,7 @@ def tunableop_warmup(self, seqlen: int):
),
kv_cache=self.kv_cache,
block_tables=None,
input_lengths=input_lengths,
seqlen=seqlen,
slots=slots,
max_s=seqlen,
lm_head_indices=None,
Expand Down Expand Up @@ -1172,6 +1173,7 @@ def forward(
cuda_graph = None

if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
Expand Down