Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2/N] Chunked prefill data update #3538

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
127 commits
Select commit Hold shift + click to select a range
06fe872
[1/n] Support efficient reshape caching.
rkooo567 Feb 28, 2024
9a0b6be
[2/n] support flash attention kernel
rkooo567 Feb 28, 2024
6947167
oss flash attention works
rkooo567 Feb 28, 2024
4769a26
in progress
rkooo567 Feb 28, 2024
963db44
flash attn enabled.
rkooo567 Feb 29, 2024
2b9c36b
ip
rkooo567 Feb 29, 2024
2c1bb6c
support every model
rkooo567 Feb 29, 2024
2bb5e62
Fixed broken tests.
rkooo567 Feb 29, 2024
4d6a05f
[2/n] scheduler changes
rkooo567 Feb 29, 2024
0831f84
[2/n] ip
rkooo567 Feb 29, 2024
f31371f
[2/n]ip
rkooo567 Feb 29, 2024
78bb887
ip
rkooo567 Feb 29, 2024
b9d93c5
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Feb 29, 2024
42dd362
[2/n] ip
rkooo567 Mar 1, 2024
74ac900
seems to work.
rkooo567 Mar 1, 2024
e3afc25
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Mar 1, 2024
6141885
[2/n] ip
rkooo567 Mar 1, 2024
71bdada
.
rkooo567 Mar 1, 2024
d4c3b5d
ip?
rkooo567 Mar 1, 2024
baef7c6
block tables updated correctly
rkooo567 Mar 1, 2024
d503a22
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Mar 1, 2024
a12ec68
hopefully tests pass
rkooo567 Mar 1, 2024
85760db
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Mar 3, 2024
e40bc45
[2/n] update sequence data
rkooo567 Mar 3, 2024
d85670f
[2/n] add prefill range apis
rkooo567 Mar 3, 2024
0d8785f
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 3, 2024
08c8541
.
rkooo567 Mar 3, 2024
3bac9af
ip
rkooo567 Mar 3, 2024
0ca1284
add data.
rkooo567 Mar 3, 2024
2487bda
ip
rkooo567 Mar 3, 2024
81151e8
ip
rkooo567 Mar 3, 2024
31aa920
ip
rkooo567 Mar 4, 2024
2049b35
.
rkooo567 Mar 4, 2024
ef679d7
.
rkooo567 Mar 4, 2024
71bda97
.
rkooo567 Mar 4, 2024
4e00e7f
done?
rkooo567 Mar 4, 2024
c5f3a0d
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Mar 4, 2024
7fd70f2
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 5, 2024
9bbb04e
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler-data-…
rkooo567 Mar 5, 2024
9177d54
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 6, 2024
5e47c1e
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler-data-…
rkooo567 Mar 6, 2024
c0384a4
Refactor 2d query to 1d query
rkooo567 Mar 6, 2024
6032edf
.,
rkooo567 Mar 6, 2024
c1ab0b0
done
rkooo567 Mar 6, 2024
f48dc72
Addressed code review.
rkooo567 Mar 7, 2024
769b2b4
working
rkooo567 Mar 7, 2024
4a20f4a
Merge branch 'main' into 1dquery
rkooo567 Mar 7, 2024
f7347b8
working
rkooo567 Mar 7, 2024
d931725
Merge branch 'main' into 1dquery
rkooo567 Mar 7, 2024
f91d73e
fix lora
rkooo567 Mar 8, 2024
f7d79da
fixed
rkooo567 Mar 8, 2024
851c018
Merge branch 'main' into 1dquery
rkooo567 Mar 8, 2024
406f1d4
fix
rkooo567 Mar 8, 2024
c66ec36
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 11, 2024
c067a4c
working.
rkooo567 Mar 11, 2024
e1f244a
clean up.
rkooo567 Mar 11, 2024
d09eaf5
.
rkooo567 Mar 11, 2024
4a8ab3c
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 11, 2024
a08e65e
Merge branch 'main' into 1dquery
rkooo567 Mar 11, 2024
d9532f8
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 11, 2024
93a7b90
.
rkooo567 Mar 12, 2024
b4b94c6
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 12, 2024
647d8cc
.
rkooo567 Mar 12, 2024
65ac6ce
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 12, 2024
b2f4b3e
ip
rkooo567 Mar 12, 2024
cc8419f
.
rkooo567 Mar 12, 2024
76e7ca8
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 12, 2024
d3d0336
Merge branch 'main' into 1dquery
rkooo567 Mar 15, 2024
11ec167
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 15, 2024
3cb8093
ip addressing comments.
rkooo567 Mar 16, 2024
5391129
Alibi slopes working now.
rkooo567 Mar 18, 2024
6b04443
Merge branch 'main' into 1dquery
rkooo567 Mar 18, 2024
fe344f6
add new fieflds
rkooo567 Mar 18, 2024
e619c4e
Flash attn works now
rkooo567 Mar 18, 2024
9c86aa3
Linting
rkooo567 Mar 18, 2024
5b4aa09
temporary
rkooo567 Mar 18, 2024
03dd155
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 18, 2024
4cced78
fix tests
rkooo567 Mar 18, 2024
cdb7a2c
Fixed
rkooo567 Mar 18, 2024
276be06
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 18, 2024
d87b651
Pass unit tests.
rkooo567 Mar 18, 2024
2c18896
experiment
rkooo567 Mar 18, 2024
b46f902
.
rkooo567 Mar 18, 2024
07b22f8
.
rkooo567 Mar 18, 2024
9bd7ea1
.
rkooo567 Mar 18, 2024
c55402f
trial
rkooo567 Mar 18, 2024
a13cf7e
remove --fork
rkooo567 Mar 18, 2024
c5c5581
Merge branch 'main' into 1dquery
rkooo567 Mar 18, 2024
ec91304
fixed
rkooo567 Mar 19, 2024
4977e53
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 19, 2024
4a54688
Merge branch 'main' into 1dquery
rkooo567 Mar 19, 2024
2e6e919
Addressed code review.
rkooo567 Mar 19, 2024
1f6f6b0
Merge branch 'main' into 1dquery
rkooo567 Mar 19, 2024
ac7828c
revert removing forked
rkooo567 Mar 19, 2024
3d7f1a1
done
rkooo567 Mar 19, 2024
bcdd74a
Merge branch 'main' into 1dquery
rkooo567 Mar 20, 2024
fa3ce4e
final code review.
rkooo567 Mar 20, 2024
a83b235
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 20, 2024
7205ef9
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 21, 2024
8bc0af5
.
rkooo567 Mar 21, 2024
97bcb6f
ip
rkooo567 Mar 21, 2024
df34350
working except tests.
rkooo567 Mar 21, 2024
e70e03d
.
rkooo567 Mar 21, 2024
f89f428
ip
rkooo567 Mar 21, 2024
bf02f8e
done
rkooo567 Mar 21, 2024
ad43095
done
rkooo567 Mar 21, 2024
16b6196
Addressed code review.
rkooo567 Mar 22, 2024
916abc8
merge conflict fixed
rkooo567 Mar 25, 2024
5002e61
update
rkooo567 Mar 25, 2024
80f51ea
test fix
rkooo567 Mar 25, 2024
3cc5e99
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 25, 2024
fa7ba35
lint
rkooo567 Mar 25, 2024
51cf7f2
fix broken tests.
rkooo567 Mar 25, 2024
cdee1c6
.
rkooo567 Mar 26, 2024
16e3a7d
done
rkooo567 Mar 26, 2024
e0d301c
remove num chunked prefill from seq group metadata
rkooo567 Mar 27, 2024
5e0f87e
change apis
rkooo567 Mar 27, 2024
6e72648
cleaned
rkooo567 Mar 27, 2024
4f869be
now working
rkooo567 Mar 27, 2024
4f63c57
update with new apis
rkooo567 Mar 27, 2024
5c3abf4
working!
rkooo567 Mar 27, 2024
66f3fcf
fixed
rkooo567 Mar 27, 2024
9c12d8e
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 27, 2024
9d4b65c
Addressed code review.
rkooo567 Mar 28, 2024
54a58b2
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 28, 2024
9bdb9dc
fix tests.
rkooo567 Mar 28, 2024
88126a9
fixed a bug
rkooo567 Mar 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def generate(
A list of `RequestOutput` objects containing the generated
completions in the same order as the input prompts.
"""
print("SANG-TODO generate: ", prompts, prompt_token_ids)
# print("SANG-TODO generate: ", prompts, prompt_token_ids)
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
Expand Down
38 changes: 8 additions & 30 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def forward(
shape = [batch_size, seq_len, num_heads * head_size]
"""
batch_size, seq_len, hidden_size = query.shape
print("SANG-TODO query size: ", query.size())
# print("SANG-TODO query size: ", query.size())
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
Expand All @@ -153,12 +153,12 @@ def forward(
# profiling run.
if key_cache is not None and value_cache is not None:
if input_metadata.flash_style:
print("SANG-TODO reshape cache flash.")
# print("SANG-TODO reshape cache flash.")
cache_ops.reshape_and_cache_flash(
key, value, key_cache, value_cache,
input_metadata.slot_mapping.flatten())
else:
print("SANG-TODO reshape cache.")
# print("SANG-TODO reshape cache.")
cache_ops.reshape_and_cache(
key,
value,
Expand All @@ -173,33 +173,11 @@ def forward(
if (key_cache is None or value_cache is None
# or input_metadata.block_tables.numel() == 0):
or not input_metadata.prefix_enabled):
print("SANG-TODO flash attn is used.")
print(
"SANG-TODO query size: ",
query.view(batch_size, seq_len, self.num_heads,
self.head_size).size())
# if key_cache is not None and value_cache is not None:
# output2 = flash_attn_with_kvcache_paged(
# query.view(batch_size, seq_len, self.num_heads,
# self.head_size),
# key_cache,
# value_cache,
# self.scale,
# input_metadata.block_tables,
# input_metadata.context_lens + seq_len,
# self.alibi_slopes,
# )
# from flash_attn import flash_attn_func
# breakpoint()
# output3 = flash_attn_func(
# q=query.view(batch_size, seq_len, self.num_heads,
# self.head_size),
# k=key.view(batch_size, seq_len, self.num_kv_heads, self.head_size),
# v=value.view(batch_size, seq_len, self.num_kv_heads, self.head_size),
# softmax_scale=self.scale,
# causal=True,
# alibi_slopes=self.alibi_slopes,
# )
# print("SANG-TODO flash attn is used.")
# print(
# "SANG-TODO query size: ",
# query.view(batch_size, seq_len, self.num_heads,
# self.head_size).size())
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
Expand Down
10 changes: 5 additions & 5 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def _prepare_prompt(
context_lens: List[int] = []
subquery_lens: List[int] = []
prefix_block_tables: List[List[int]] = []
print("SANG-TODO # of requests (seq_group_metadata_list): ",
len(seq_group_metadata_list))
# print("SANG-TODO # of requests (seq_group_metadata_list): ",
# len(seq_group_metadata_list))
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
Expand All @@ -152,7 +152,7 @@ def _prepare_prompt(
prompt_lens.append(prompt_len)
prefix_len = 0
prefix = seq_group_metadata.prefix
print("SANG-TODO prefix, ", prefix)
# print("SANG-TODO prefix, ", prefix)
if prefix is not None and prefix.computed:
prefix_len = prefix.get_length()
prompt_tokens = prompt_tokens[prefix_len:]
Expand Down Expand Up @@ -500,12 +500,12 @@ def prepare_input_tensors(
# SANG-TODO set num prompt tokens and generations?
# Prepare input tensors.
if is_prompt:
print("SANG-TODO execute model prompt.")
# print("SANG-TODO execute model prompt.")
(input_tokens, input_positions, input_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_prompt(seq_group_metadata_list)
else:
print("SANG-TODO execute model decode.")
# print("SANG-TODO execute model decode.")
(input_tokens, input_positions, input_metadata,
lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_decode(seq_group_metadata_list)
Expand Down
6 changes: 3 additions & 3 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def profile_num_available_blocks(
gpu_memory_utilization: The fraction of the total GPU memory to use.
cpu_swap_space: The size of the CPU swap space in bytes.
"""
print("SANG-TODO profile_num_available_blocks")
# print("SANG-TODO profile_num_available_blocks")
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
Expand Down Expand Up @@ -154,7 +154,7 @@ def profile_num_available_blocks(
MAX_INT_32 // cache_block_size)
num_cpu_blocks = min(num_cpu_blocks,
MAX_INT_32 // cache_block_size)
print("SANG-TODO profile_num_available_blocks done")
# print("SANG-TODO profile_num_available_blocks done")

return num_gpu_blocks, num_cpu_blocks

Expand Down Expand Up @@ -207,7 +207,7 @@ def execute_model(
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> Optional[SamplerOutput]:
print("SANG-TODO execute model.")
# print("SANG-TODO execute model.")
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list)
Expand Down