Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][2/N] Model runner refactoring part 2. Combine prepare prefill / decode to a single API #4681

Merged
merged 49 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
c9fcb26
first checkpoint done
rkooo567 May 8, 2024
de61dbb
refactoring subquery
rkooo567 May 8, 2024
159ea2f
.
rkooo567 May 8, 2024
5833cbb
ip
rkooo567 May 8, 2024
7614ce0
working
rkooo567 May 8, 2024
6744eff
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 9, 2024
7de7f63
.
rkooo567 May 9, 2024
e8a4ea3
working with flash attn
rkooo567 May 9, 2024
64e8fd4
rocm and sdpa
rkooo567 May 9, 2024
ceec66d
working with flash infer
rkooo567 May 9, 2024
1851d59
add flash infer to pipeline
rkooo567 May 9, 2024
5cf1d3e
.
rkooo567 May 9, 2024
21a612a
working.
rkooo567 May 9, 2024
61dec37
fix spec decoding
rkooo567 May 9, 2024
6cad7bc
Fixed model runner test
rkooo567 May 9, 2024
e20a29e
fixed
rkooo567 May 9, 2024
94964ab
fix intel test
rkooo567 May 9, 2024
ff99251
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 10, 2024
1c77e2d
.
rkooo567 May 10, 2024
f929edd
done
rkooo567 May 10, 2024
74683a1
.
rkooo567 May 10, 2024
546735a
fix circular reference.
rkooo567 May 10, 2024
89e5df2
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 10, 2024
d7b2743
working
rkooo567 May 10, 2024
0ed4160
Merge branch 'circular-dep' into model-runner-refactoring-coelsce
rkooo567 May 10, 2024
f5af730
fixed spec decoding
rkooo567 May 10, 2024
7e39882
working
rkooo567 May 10, 2024
e02bc5d
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 13, 2024
dd48c00
fix embedding meta
rkooo567 May 13, 2024
bba70f1
ip
rkooo567 May 13, 2024
f76b9ea
improve assert
rkooo567 May 13, 2024
cf1dbbb
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 13, 2024
a281d97
done
rkooo567 May 13, 2024
f6afb05
lint
rkooo567 May 13, 2024
35f64ac
done
rkooo567 May 14, 2024
cc5df57
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 14, 2024
ccf937c
.
rkooo567 May 14, 2024
0951715
works except spec decoding
rkooo567 May 14, 2024
2b2423d
.
rkooo567 May 14, 2024
f1c12f3
.,
rkooo567 May 14, 2024
8a01746
.
rkooo567 May 14, 2024
bf7959a
.
rkooo567 May 14, 2024
b42b43d
.
rkooo567 May 14, 2024
e9a973e
ip
rkooo567 May 14, 2024
4e733e2
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 14, 2024
237e939
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 14, 2024
35e98a0
Merge branch 'main' into model-runner-refactoring-coelsce
rkooo567 May 15, 2024
426d99a
.
rkooo567 May 15, 2024
1271556
done
rkooo567 May 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
rocm and sdpa
  • Loading branch information
rkooo567 committed May 9, 2024
commit 64e8fd4eedde8962a5de6c2beea2be3fe6727de0
10 changes: 10 additions & 0 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ def asdict_zerocopy(self,
skip_fields.add('decode_wrapper')
return super().asdict_zerocopy(skip_fields)

@property
def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
# Currently chunked prefill is not supported
return self

@property
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
# Currently chunked prefill is not supported
return self


class FlashInferImpl(AttentionImpl):

Expand Down
76 changes: 70 additions & 6 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
Expand All @@ -80,10 +77,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
# |-------------------- seq_len ----------------------|
# |-- query_len ---|

# Maximum query length in the batch.
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# Maximum sequence length among prefill batch.
max_prefill_seq_len: Optional[int]
# Maximum sequence length among decode batch.
max_decode_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
Expand All @@ -100,6 +99,71 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None

@property
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
if self.num_prefills == 0:
return None

if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata

assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.block_tables is not None
assert self.seq_start_loc is not None

self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
kv_cache_dtype=self.kv_cache_dtype,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=None,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
)
return self._cached_prefill_metadata

@property
def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None

if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.block_tables is not None
assert self.seq_lens_tensor is not None

self._cached_decode_metadata = ROCmFlashAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
kv_cache_dtype=self.kv_cache_dtype,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_prefill_seq_len=None,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata


class ROCmFlashAttentionImpl(AttentionImpl):
Expand Down
13 changes: 11 additions & 2 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def copy_blocks(


@dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
AttentionMetadata):
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
Expand All @@ -71,6 +70,16 @@ def __post_init__(self):
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[torch.Tensor]] = None

@property
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
# Currently chunked prefill is not supported
return self

@property
def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
# Currently chunked prefill is not supported
return self


class TorchSDPABackendImpl(AttentionImpl):

Expand Down
Loading