-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
221 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
from typing import List, Optional | ||
|
||
import torch | ||
|
||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, | ||
ModelConfig, ParallelConfig, SchedulerConfig, | ||
VisionLanguageConfig) | ||
from vllm.logger import init_logger | ||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata | ||
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, | ||
ModelRunner) | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class DraftModelRunner(ModelRunner): | ||
|
||
def __init__( | ||
self, | ||
model_config: ModelConfig, | ||
parallel_config: ParallelConfig, | ||
scheduler_config: SchedulerConfig, | ||
device_config: DeviceConfig, | ||
cache_config: CacheConfig, | ||
load_config: LoadConfig, | ||
lora_config: Optional[LoRAConfig], | ||
kv_cache_dtype: Optional[str] = "auto", | ||
is_driver_worker: bool = False, | ||
vision_language_config: Optional[VisionLanguageConfig] = None, | ||
return_hidden_states: bool = False, | ||
): | ||
if return_hidden_states: | ||
raise ValueError( | ||
"return_hidden_states is not supported for DraftModelRunner.") | ||
|
||
super().__init__( | ||
model_config=model_config, | ||
parallel_config=parallel_config, | ||
scheduler_config=scheduler_config, | ||
device_config=device_config, | ||
cache_config=cache_config, | ||
load_config=load_config, | ||
lora_config=lora_config, | ||
kv_cache_dtype=kv_cache_dtype, | ||
is_driver_worker=is_driver_worker, | ||
vision_language_config=vision_language_config, | ||
return_hidden_states=return_hidden_states, | ||
) | ||
|
||
# TODO: Remove this cache when we are able to update model_input | ||
# directly in advance_step. | ||
self.cached_seq_group_metadata_list: Optional[ | ||
List[SequenceGroupMetadata]] = None | ||
|
||
def prepare_model_input( | ||
self, | ||
seq_group_metadata_list: List[SequenceGroupMetadata], | ||
) -> ModelInputForGPUWithSamplingMetadata: | ||
"""A temporary solution that caches the seq_group_metadata_list | ||
for multi-step execution. | ||
TODO: In-place update model_input and remove this function. | ||
""" | ||
self.cached_seq_group_metadata_list = seq_group_metadata_list | ||
return super().prepare_model_input(seq_group_metadata_list) | ||
|
||
def advance_step( | ||
self, model_input: ModelInputForGPUWithSamplingMetadata, | ||
last_output: SamplerOutput | ||
) -> ModelInputForGPUWithSamplingMetadata: | ||
"""Prepare the model inputs for the next step. | ||
TODO: In-place update model_input instead of calling | ||
prepare_model_input. | ||
""" | ||
|
||
# Append the output token to the sequence data. | ||
assert self.cached_seq_group_metadata_list is not None | ||
for seq_group_metadata, sequence_group_outputs in zip( | ||
self.cached_seq_group_metadata_list, last_output.outputs): | ||
seq_group_metadata.is_prompt = False | ||
|
||
for seq_output in sequence_group_outputs.samples: | ||
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] | ||
|
||
token_id = seq_output.output_token | ||
token_logprob = seq_output.logprobs[token_id] | ||
|
||
seq.append_token_id(token_id, token_logprob.logprob) | ||
seq.update_num_computed_tokens(1) | ||
|
||
return self.prepare_model_input(self.cached_seq_group_metadata_list) | ||
|
||
@torch.inference_mode() | ||
def execute_model( | ||
self, | ||
model_input: ModelInputForGPUWithSamplingMetadata, | ||
kv_caches: List[torch.Tensor], | ||
num_steps: int = 1, | ||
) -> Optional[List[SamplerOutput]]: | ||
# Since we do not broadcast data inside execute_model anymore, | ||
# we need to figure out the best way to support TP > 1 in this | ||
# case, because we will at least need to broadcast the sampled | ||
# tokens to all workers. | ||
if not self.is_driver_worker: | ||
raise ValueError("DraftModelRunner only supports TP=1 for now.") | ||
|
||
if self.lora_config: | ||
assert model_input.lora_requests is not None | ||
assert model_input.lora_mapping is not None | ||
self.set_active_loras(model_input.lora_requests, | ||
model_input.lora_mapping) | ||
|
||
outputs: List[SamplerOutput] = [] | ||
for step in range(num_steps): | ||
# Currently cuda graph is only supported by the decode phase. | ||
assert model_input.attn_metadata is not None | ||
prefill_meta = model_input.attn_metadata.prefill_metadata | ||
decode_meta = model_input.attn_metadata.decode_metadata | ||
if prefill_meta is None and decode_meta.use_cuda_graph: | ||
assert model_input.input_tokens is not None | ||
graph_batch_size = model_input.input_tokens.shape[0] | ||
model_executable = self.graph_runners[graph_batch_size] | ||
else: | ||
model_executable = self.model | ||
|
||
multi_modal_kwargs = model_input.multi_modal_kwargs or {} | ||
hidden_states = model_executable( | ||
input_ids=model_input.input_tokens, | ||
positions=model_input.input_positions, | ||
kv_caches=kv_caches, | ||
attn_metadata=model_input.attn_metadata, | ||
**multi_modal_kwargs, | ||
) | ||
|
||
# Compute the logits. | ||
logits = self.model.compute_logits(hidden_states, | ||
model_input.sampling_metadata) | ||
|
||
# Sample the next token. | ||
outputs.append( | ||
self.model.sample( | ||
logits=logits, | ||
sampling_metadata=model_input.sampling_metadata, | ||
)) | ||
|
||
# Prepare the inputs for the next step. | ||
if step != num_steps - 1: | ||
model_input = self.advance_step(model_input, outputs[-1]) | ||
|
||
return outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.