-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Speculative Decoding #1797
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some tiny comments.
@@ -408,11 +416,24 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, | |||
# We reuse the parent sequence here to reduce redundant memory | |||
# copies, especially when using non-beam search sampling methods. | |||
last_child_sample = child_samples[-1] | |||
parent.append_token_id(last_child_sample.output_token, | |||
last_child_sample.logprobs) | |||
if last_child_sample.accepted_tokens: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if FLAGS.ENABLE_SD:
?
self.propose_cnt = config.propose_cnt | ||
self.draft_model_config = config.draft_model_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: self.config = config
?
|
||
# propose draft tokens | ||
# the function will run the draft model and set draft_tokens and draft_token_probs of each seq | ||
def set_draft_tokens(self, seq_group_list: List[SequenceGroupMetadata], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
propose()
might be a better name
) | ||
if FLAGS.ENABLE_SD: | ||
output = _multi_query_cached_kv_attention( | ||
query, key, value, key_cache, value_cache, input_metadata) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why need to pass key
and value
? I think the two vars already have copied to key_cache
and value_cache
by cache_ops.reshape_and_cache()
. Maybe I am missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work, congrats!
for seq_group_metadata in seq_group_metadata_list: | ||
assert len( | ||
seq_group_metadata.seq_data | ||
) == 1, f"Speculative Decoding does nor beam search for now: {len(seq_group_metadata.seq_data)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a little typo in the assert message
@@ -573,6 +594,11 @@ def step(self) -> List[RequestOutput]: | |||
if scheduler_outputs.is_empty(): | |||
return ignored | |||
|
|||
# only enable speculative decoding for generation run | |||
if self.spec_dec_worker and (not scheduler_outputs.prompt_run): | |||
self.spec_dec_worker.set_draft_tokens(seq_group_metadata_list, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in multi GPU inference scenario, will this method be called by all the workers?
do you think it's a better idea to only run on rank 0, and broadcast the tokens to other ranks?
logger.setLevel("WARNING") | ||
|
||
|
||
class SpecDecWorker(Worker): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This worker is too tightly coupled with assisted decoding.
Do you think it's a good idea if we abstract an base class for SpD, and move these specific implementations to a concrete class like AssistedSpcDecWorker
?
But I believe we could refactor this later.
@@ -69,7 +69,7 @@ def __init__( | |||
revision: Optional[str] = None, | |||
tokenizer_revision: Optional[str] = None, | |||
seed: int = 0, | |||
gpu_memory_utilization: float = 0.9, | |||
gpu_memory_utilization: float = 0.8, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a little hacky to me. what if the sequence is long and could take more than 0.2 gpu memory?
do you think it's a better idea if we actual run the assisted model in profile_num_available_blocks
?
pass | ||
|
||
|
||
if triton.__version__ >= "2.1.0": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we assert the version should be greater or equal to 2.1.0?
offs_d[:, None] // x) * stride_k_cache_d + ( | ||
(start_n + offs_n[None, :]) % | ||
block_size) * stride_k_cache_bl + ( | ||
offs_d[:, None] % x) * stride_k_cache_x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good job! this would be faster than my version! 👍
block_mask = tl.where( | ||
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) | ||
|
||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder what's special for K and V of the draft tokens, why we need process these tokens separately?
self.scale, | ||
self.alibi_slopes, | ||
) | ||
if FLAGS.ENABLE_SD: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct me if I'm wrong, but for assisted decoding, usually the propse_cnt is small (maybe around 4?), which would cause first dimension of q
to be small, thus the q@k
gemm and qk@v
gemm are small. for such cases, does it really worth using Tensor Core for GEMM?
@@ -573,6 +594,11 @@ def step(self) -> List[RequestOutput]: | |||
if scheduler_outputs.is_empty(): | |||
return ignored | |||
|
|||
# only enable speculative decoding for generation run | |||
if self.spec_dec_worker and (not scheduler_outputs.prompt_run): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
report a bug here, when we start vllm api server with python3 -m vllm.entrypoints.api_server --model=/path/to/tgt_model/ --draft-model=/path/to/draft/model/ --propose-cnt=5
, the server errors out. looks like you forgot set_draft_tokens
and accept_tokens
in AsyncLLMEngine
This is an attempt to implement speculative decoding (paper) in vllm. It is not optimized, not tested (please avoid using it for now). The current design: