Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][1/N] Chunked Prefill #3106

Closed
wants to merge 38 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
06fe872
[1/n] Support efficient reshape caching.
rkooo567 Feb 28, 2024
9a0b6be
[2/n] support flash attention kernel
rkooo567 Feb 28, 2024
6947167
oss flash attention works
rkooo567 Feb 28, 2024
4769a26
in progress
rkooo567 Feb 28, 2024
963db44
flash attn enabled.
rkooo567 Feb 29, 2024
2b9c36b
ip
rkooo567 Feb 29, 2024
2c1bb6c
support every model
rkooo567 Feb 29, 2024
2bb5e62
Fixed broken tests.
rkooo567 Feb 29, 2024
78bb887
ip
rkooo567 Feb 29, 2024
74ac900
seems to work.
rkooo567 Mar 1, 2024
71bdada
.
rkooo567 Mar 1, 2024
d4c3b5d
ip?
rkooo567 Mar 1, 2024
baef7c6
block tables updated correctly
rkooo567 Mar 1, 2024
a12ec68
hopefully tests pass
rkooo567 Mar 1, 2024
0d8785f
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 3, 2024
08c8541
.
rkooo567 Mar 3, 2024
3bac9af
ip
rkooo567 Mar 3, 2024
31aa920
ip
rkooo567 Mar 4, 2024
2049b35
.
rkooo567 Mar 4, 2024
ef679d7
.
rkooo567 Mar 4, 2024
71bda97
.
rkooo567 Mar 4, 2024
4e00e7f
done?
rkooo567 Mar 4, 2024
7fd70f2
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 5, 2024
9177d54
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 6, 2024
c0384a4
Refactor 2d query to 1d query
rkooo567 Mar 6, 2024
6032edf
.,
rkooo567 Mar 6, 2024
c1ab0b0
done
rkooo567 Mar 6, 2024
f48dc72
Addressed code review.
rkooo567 Mar 7, 2024
769b2b4
working
rkooo567 Mar 7, 2024
4a20f4a
Merge branch 'main' into 1dquery
rkooo567 Mar 7, 2024
f7347b8
working
rkooo567 Mar 7, 2024
d931725
Merge branch 'main' into 1dquery
rkooo567 Mar 7, 2024
f91d73e
fix lora
rkooo567 Mar 8, 2024
f7d79da
fixed
rkooo567 Mar 8, 2024
851c018
Merge branch 'main' into 1dquery
rkooo567 Mar 8, 2024
406f1d4
fix
rkooo567 Mar 8, 2024
9442e8f
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 8, 2024
3da31eb
Merge branch '1dquery' into chunked-prefill-3
rkooo567 Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
done
  • Loading branch information
rkooo567 committed Mar 6, 2024
commit c1ab0b0bedf0e25f3d35c998f5216eb33b4275d1
2 changes: 1 addition & 1 deletion tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("enforce_eager", [False, True])
def test_models(
hf_runner,
vllm_runner,
Expand Down
1 change: 0 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,6 @@ def _process_model_outputs(
# Log stats.
if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs))
# breakpoint()
return request_outputs

def step(self) -> List[RequestOutput]:
Expand Down
53 changes: 19 additions & 34 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,29 +141,17 @@ def forward(
# If key_cache and value_cache are not provided, the new key and value
# vectors will not be cached. This happens during the initial memory
# profiling run.
num_valid_tokens = input_metadata.num_valid_tokens
if (num_valid_tokens > 0 and key_cache is not None
and value_cache is not None):
key_to_cache = key[:num_valid_tokens]
value_to_cache = value[:num_valid_tokens]
if (key_cache is not None and value_cache is not None):
cache_ops.reshape_and_cache(
key_to_cache,
value_to_cache,
key,
value,
key_cache,
value_cache,
input_metadata.slot_mapping.flatten(),
input_metadata.kv_cache_dtype,
)

num_prompt_tokens = input_metadata.num_prompt_tokens
num_generation_tokens = input_metadata.num_generation_tokens
print(num_generation_tokens)

if num_prompt_tokens > 0:
assert num_generation_tokens == 0
query = query[:num_prompt_tokens]
key = key[:num_prompt_tokens]
value = value[:num_prompt_tokens]
if input_metadata.is_prompt:
# normal attention
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
Expand Down Expand Up @@ -202,7 +190,7 @@ def forward(
input_metadata)

if self.use_ref_attention:
output[:num_prompt_tokens] = self.ref_masked_attention(
output = self.ref_masked_attention(
query,
key,
value,
Expand All @@ -222,18 +210,17 @@ def forward(
key = key.unflatten(0, (num_tokens))
value = value.unflatten(0, (num_tokens))

output[:
num_prompt_tokens] = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.
MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
).view_as(query)
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = out.view_as(query)
else:
# prefix-enabled attention
output = torch.empty_like(query)
Expand All @@ -252,13 +239,11 @@ def forward(
getattr(self, "alibi_slopes", None),
)

if num_generation_tokens > 0:
breakpoint()
assert num_prompt_tokens == 0
else:
# Decoding run.
output = _paged_attention(
output[num_prompt_tokens:num_valid_tokens],
query[num_prompt_tokens:num_valid_tokens],
output,
query,
key_cache,
value_cache,
input_metadata,
Expand Down
2 changes: 0 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,6 @@ def execute_model(
kv_caches=kv_caches,
input_metadata=input_metadata,
)
breakpoint()

# Sample the next token.
output = self.model.sample(
Expand Down Expand Up @@ -878,7 +877,6 @@ def forward(
non_blocking=True)
self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
non_blocking=True)
breakpoint()
# Run the graph.
self.graph.replay()

Expand Down
Loading