Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][1/N] Chunked Prefill #3106

Closed
wants to merge 38 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
06fe872
[1/n] Support efficient reshape caching.
rkooo567 Feb 28, 2024
9a0b6be
[2/n] support flash attention kernel
rkooo567 Feb 28, 2024
6947167
oss flash attention works
rkooo567 Feb 28, 2024
4769a26
in progress
rkooo567 Feb 28, 2024
963db44
flash attn enabled.
rkooo567 Feb 29, 2024
2b9c36b
ip
rkooo567 Feb 29, 2024
2c1bb6c
support every model
rkooo567 Feb 29, 2024
2bb5e62
Fixed broken tests.
rkooo567 Feb 29, 2024
78bb887
ip
rkooo567 Feb 29, 2024
74ac900
seems to work.
rkooo567 Mar 1, 2024
71bdada
.
rkooo567 Mar 1, 2024
d4c3b5d
ip?
rkooo567 Mar 1, 2024
baef7c6
block tables updated correctly
rkooo567 Mar 1, 2024
a12ec68
hopefully tests pass
rkooo567 Mar 1, 2024
0d8785f
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 3, 2024
08c8541
.
rkooo567 Mar 3, 2024
3bac9af
ip
rkooo567 Mar 3, 2024
31aa920
ip
rkooo567 Mar 4, 2024
2049b35
.
rkooo567 Mar 4, 2024
ef679d7
.
rkooo567 Mar 4, 2024
71bda97
.
rkooo567 Mar 4, 2024
4e00e7f
done?
rkooo567 Mar 4, 2024
7fd70f2
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 5, 2024
9177d54
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 6, 2024
c0384a4
Refactor 2d query to 1d query
rkooo567 Mar 6, 2024
6032edf
.,
rkooo567 Mar 6, 2024
c1ab0b0
done
rkooo567 Mar 6, 2024
f48dc72
Addressed code review.
rkooo567 Mar 7, 2024
769b2b4
working
rkooo567 Mar 7, 2024
4a20f4a
Merge branch 'main' into 1dquery
rkooo567 Mar 7, 2024
f7347b8
working
rkooo567 Mar 7, 2024
d931725
Merge branch 'main' into 1dquery
rkooo567 Mar 7, 2024
f91d73e
fix lora
rkooo567 Mar 8, 2024
f7d79da
fixed
rkooo567 Mar 8, 2024
851c018
Merge branch 'main' into 1dquery
rkooo567 Mar 8, 2024
406f1d4
fix
rkooo567 Mar 8, 2024
9442e8f
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 8, 2024
3da31eb
Merge branch '1dquery' into chunked-prefill-3
rkooo567 Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
working
  • Loading branch information
rkooo567 committed Mar 7, 2024
commit 769b2b491939f9e461486fecd2ac97e80e15eb0c
9 changes: 9 additions & 0 deletions vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def __init__(
self.start_loc = start_loc
self.max_context_len = max_context_len
self.slot_mapping = slot_mapping
# Index: The batched sequence's index.
# Value: The length of attention context.
# NOTE(sang): When it is prefill/decoding,
# the definition is different. For prefill,
# it means the the length of KV that are cached
# excluding the new KVs. In decoding, this
# includes a new KV.
self.context_lens = context_lens
self.block_tables = block_tables
self.use_cuda_graph = use_cuda_graph
Expand All @@ -53,6 +60,8 @@ def __init__(
# Set during the execution of the first attention op.
# FIXME(woosuk): This is a hack.
self.attn_bias = None
# Number of valid tokens. It includes paddings.
# See attention.py for precise definition.
self.num_valid_tokens = slot_mapping.shape[0]

def __repr__(self) -> str:
Expand Down
9 changes: 5 additions & 4 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def forward(
else:
# Decoding run.
output = _paged_attention(
output,
query,
key_cache,
value_cache,
Expand Down Expand Up @@ -289,15 +288,17 @@ def _make_alibi_bias(


def _paged_attention(
output: torch.Tensor, # [num_tokens, num_heads, head_size]
query: torch.Tensor, # [num_tokens, num_heads, head_size]
key_cache: torch.Tensor,
value_cache: torch.Tensor,
key_cache: torch.
Tensor, # [num_total_blocks, block_size, num_heads, head_size]
value_cache: torch.
Tensor, # [num_total_blocks, block_size, num_heads, head_size]
input_metadata: InputMetadata,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
output = torch.empty_like(query)
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (
Expand Down
14 changes: 6 additions & 8 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
# Note that cuda graph is only used for decoding because it speeds up
# the performance when num_tokens < 200. Batch here means a single token.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
Expand Down Expand Up @@ -215,6 +213,9 @@ def _prepare_prompt(

max_prompt_len = max(subquery_lens)
num_prompt_tokens = len(input_tokens)

# Pad tokens to better utilize tensor cores although
# cuda graph is not enabled.
input_tokens = _make_tensor_with_pad_for_alignment(input_tokens,
pad=0,
dtype=torch.long,
Expand Down Expand Up @@ -347,7 +348,8 @@ def _prepare_decode(
block_tables.append([])
batch_size = graph_batch_size

# Q: should we not pad when cuda graph is disabled?
# Pad tokens to better utilize tensor cores although
# cuda graph is not enabled.
input_tokens = _make_tensor_with_pad_for_alignment(input_tokens,
pad=0,
dtype=torch.long,
Expand Down Expand Up @@ -599,9 +601,6 @@ def execute_model(

# Execute the model.
if input_metadata.use_cuda_graph:
# NOTE: We use cuda graph only when there are only
# decoding requests, which means the number of batch
# size is equivalent to number of input tokens.
graph_batch_size = input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
Expand Down Expand Up @@ -719,7 +718,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None:
# deleted before the CUDA graphs.
self.cupy_nccl_backend = cupy_utils.get_nccl_backend()

# assert not self.model_config.enforce_eager
assert not self.model_config.enforce_eager
logger.info("Capturing the model for CUDA graphs. This may lead to "
"unexpected consequences if the model is not static. To "
"run the model in eager mode, set 'enforce_eager=True' or "
Expand Down Expand Up @@ -915,7 +914,6 @@ def _make_tensor_with_pad_for_alignment(
"""Create a tensor of a given list x with padding.
It adds paddings to align with graph batch size. See
_get_graph_batch_size for more details.
# NOTE: This API is only for decoding requests.
"""
batch_size = len(x)
batch_size = _get_graph_batch_size(batch_size)
Expand Down
Loading