-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Refactor Worker and ModelRunner to consolidate control plane communication #5408
Merged
simon-mo
merged 64 commits into
vllm-project:main
from
stephanie-wang:control-refactor-2
Jun 26, 2024
Merged
Changes from 1 commit
Commits
Show all changes
64 commits
Select commit
Hold shift + click to select a range
725b0b2
tmp
stephanie-wang b74eb10
fix
stephanie-wang 38b0ddf
ray and mp backends work
stephanie-wang 0d11e92
embedding model runner works
stephanie-wang 2cdc218
GPU executor works
stephanie-wang c728512
remove comment
stephanie-wang 2bf752b
use the right ModelInput class
stephanie-wang f35a23f
CPU worker
stephanie-wang 11133fe
remove commented
stephanie-wang 174bdb1
lint
stephanie-wang c0e98ca
Worker.execute_model vs execute_model_local
stephanie-wang dccec95
lint
stephanie-wang dad94ba
neuron model runner
stephanie-wang fca606e
disallow distributed comms
stephanie-wang 6ed3c2a
disable communication
stephanie-wang 1803e33
Update worker.py
stephanie-wang dde799e
fix tests
stephanie-wang 0398631
update
stephanie-wang 5c41cc6
Merge branch 'control-refactor-2' of github.com:stephanie-wang/vllm i…
stephanie-wang 72f0383
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang eef6623
merge
stephanie-wang 3004ceb
update
8d852e9
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang 9380ed8
fix
stephanie-wang 3c4de6d
fix
stephanie-wang 5053f30
fix
stephanie-wang db38556
x
stephanie-wang 456185d
rm
stephanie-wang e860652
lint
stephanie-wang 3d4f242
add missing
stephanie-wang 11304cb
revert
stephanie-wang 99f532e
refactor
stephanie-wang 797a7cf
doc
stephanie-wang 6ad2513
revert spec decode and doc
stephanie-wang 97ec303
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang e10bace
typing
stephanie-wang ce087ae
fix
stephanie-wang f851b00
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang 0e2acc4
XPU worker and rename
stephanie-wang d318ec8
lint
stephanie-wang b48f783
lint
stephanie-wang c93afc1
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang 30ac400
fix
stephanie-wang 01688d5
x
stephanie-wang 7dbb646
fix
stephanie-wang d2e4c41
fix
stephanie-wang 3e46253
lint
stephanie-wang 0a2890a
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang 36dfce1
merge
stephanie-wang ea5412e
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang dc2f103
x
stephanie-wang fbf074d
x
stephanie-wang 660a8d5
rename ModelInput -> ModelInputBase, override as_broadcastable_tensor…
stephanie-wang 8cca634
fixes
stephanie-wang 0a25c19
rename
stephanie-wang 0b26877
fix
stephanie-wang e7052d5
do not filter Nones
stephanie-wang df5551f
dupe
stephanie-wang 6745b3b
update
stephanie-wang ebae970
lint
stephanie-wang 5763621
revert
stephanie-wang 46d5b18
Merge branch 'main' into control-refactor-2
stephanie-wang d16d5fe
rm
stephanie-wang f6c6234
fix
stephanie-wang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
update
Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu>
- Loading branch information
commit 039863143c23b5a7be881d0fcfa2b307e763287b
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import dataclasses | ||
from typing import List, Tuple, Type | ||
|
||
import torch | ||
|
||
from vllm.attention import AttentionMetadata | ||
from vllm.attention.backends.abstract import AttentionBackend | ||
from vllm.model_executor import SamplingMetadata | ||
from vllm.model_input import GPUModelInputWithSamplingMetadata | ||
|
||
|
||
class MockAttentionBackend(AttentionBackend): | ||
|
||
@staticmethod | ||
def get_name() -> str: | ||
pass | ||
|
||
@staticmethod | ||
def get_impl_cls(): | ||
pass | ||
|
||
@staticmethod | ||
def get_metadata_cls() -> Type["AttentionMetadata"]: | ||
return AttentionMetadata | ||
|
||
@staticmethod | ||
def get_kv_cache_shape( | ||
num_blocks: int, | ||
block_size: int, | ||
num_kv_heads: int, | ||
head_size: int, | ||
) -> Tuple[int, ...]: | ||
pass | ||
|
||
@staticmethod | ||
def swap_blocks( | ||
src_kv_cache: torch.Tensor, | ||
dst_kv_cache: torch.Tensor, | ||
src_to_dst: torch.Tensor, | ||
) -> None: | ||
pass | ||
|
||
@staticmethod | ||
def copy_blocks( | ||
kv_caches: List[torch.Tensor], | ||
src_to_dists: torch.Tensor, | ||
) -> None: | ||
pass | ||
|
||
|
||
def test_gpu_model_input(): | ||
sampling_metadata = SamplingMetadata( | ||
["seq_group"], | ||
"selected_token_indices", | ||
"categorized_sample_indices", | ||
"num_prompts", | ||
) | ||
attn_metadata = AttentionMetadata( | ||
num_prefills=1, | ||
num_prefill_tokens=2, | ||
num_decode_tokens=3, | ||
slot_mapping=torch.zeros(1), | ||
) | ||
model_input = GPUModelInputWithSamplingMetadata.new( | ||
num_seq_groups=10, | ||
sampling_metadata=sampling_metadata, | ||
attn_metadata=attn_metadata) | ||
|
||
assert isinstance(model_input, GPUModelInputWithSamplingMetadata) | ||
|
||
# Test round trip serialization. | ||
tensor_dict = model_input.as_broadcastable_tensor_dict() | ||
attn_backend = MockAttentionBackend() | ||
received_model_input = GPUModelInputWithSamplingMetadata.new( | ||
attn_backend=attn_backend, **tensor_dict) | ||
assert isinstance(received_model_input, GPUModelInputWithSamplingMetadata) | ||
|
||
# Broadcast should not contain empty values. | ||
for field in dataclasses.fields(model_input): | ||
if getattr(model_input, field.name) is None: | ||
assert field.name not in tensor_dict | ||
# Broadcast should contain all non-empty fields defined by the developer | ||
# for this input type. | ||
for field in GPUModelInputWithSamplingMetadata.BROADCASTABLE_FIELDS: | ||
if getattr(model_input, field) is not None: | ||
assert field in tensor_dict | ||
|
||
# Check that received copy has correct values. | ||
for field in dataclasses.fields(AttentionMetadata): | ||
assert getattr(received_model_input.attn_metadata, field.name, | ||
None) == getattr(attn_metadata, field.name, None) | ||
# For sampling metadata, only selected_token_indices is copied. | ||
assert (received_model_input.sampling_metadata.selected_token_indices == | ||
sampling_metadata.selected_token_indices) | ||
assert received_model_input.sampling_metadata.seq_groups is None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,276 @@ | ||
"""Worker-local model inputs. These define the inputs to different model | ||
runners.""" | ||
import dataclasses | ||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union | ||
|
||
import torch | ||
|
||
from vllm.lora.request import LoRARequest | ||
|
||
if TYPE_CHECKING: | ||
from vllm.attention import AttentionMetadata | ||
from vllm.attention.backends.abstract import AttentionBackend | ||
from vllm.lora.layers import LoRAMapping | ||
from vllm.model_executor import SamplingMetadata | ||
from vllm.model_executor.pooling_metadata import PoolingMetadata | ||
|
||
|
||
def _init_attn_metadata_from_kwargs( | ||
attn_backend: Optional["AttentionBackend"] = None, | ||
attn_metadata: Optional["AttentionMetadata"] = None, | ||
**kwargs) -> Dict[str, Any]: | ||
if attn_metadata is None and attn_backend is not None: | ||
# Extract the fields used to create AttentionMetadata. | ||
valid_attn_kwargs = {} | ||
for field in dataclasses.fields(attn_backend.get_metadata_cls()): | ||
val = kwargs.pop(field.name, None) | ||
if val is not None: | ||
valid_attn_kwargs[field.name] = val | ||
|
||
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) | ||
if attn_metadata is not None: | ||
kwargs["attn_metadata"] = attn_metadata | ||
return kwargs | ||
|
||
|
||
def _add_attn_metadata_broadcastable_dict( | ||
tensor_dict: Dict[str, Union[int, torch.Tensor]], | ||
attn_metadata: Optional["AttentionMetadata"]) -> None: | ||
if attn_metadata is not None: | ||
tensor_dict.update(attn_metadata.asdict_zerocopy()) | ||
|
||
|
||
def _init_sampling_metadata_from_kwargs( # type: ignore | ||
selected_token_indices: Optional[torch.Tensor] = None, | ||
sampling_metadata: Optional["SamplingMetadata"] = None, | ||
**kwargs) -> Dict[str, Any]: | ||
if sampling_metadata is None and selected_token_indices is not None: | ||
from vllm.model_executor import SamplingMetadata | ||
|
||
# An empty SamplingMetadata to signal that the worker should skip | ||
# sampling. | ||
sampling_metadata = SamplingMetadata( | ||
seq_groups=None, | ||
selected_token_indices=selected_token_indices, | ||
categorized_sample_indices=None, | ||
num_prompts=0, | ||
) | ||
if sampling_metadata is not None: | ||
kwargs["sampling_metadata"] = sampling_metadata | ||
return kwargs | ||
|
||
|
||
def _add_sampling_metadata_broadcastable_dict( | ||
tensor_dict: Dict[str, Union[int, torch.Tensor]], | ||
sampling_metadata: Optional["SamplingMetadata"]) -> None: | ||
if sampling_metadata is not None: | ||
tensor_dict["selected_token_indices"] = ( | ||
sampling_metadata.selected_token_indices) | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class ModelInput: | ||
"""Local inputs to each worker's model runner. May contain | ||
device-specific data. Different worker backends may have different methods | ||
of converting from the global ExecuteModelRequest produced by the LLM | ||
engine to the worker-local ModelInput objects. | ||
|
||
Model runners should inherit from this class and add their required fields. | ||
For distributed executors, any fields that should be sent during a | ||
broadcast op should also be added to the BROADCASTABLE_FIELDS. During | ||
execution, these fields will be extracted from the source copy and | ||
broadcasted to all workers using broadcast_tensor_dict. | ||
|
||
Some fields may have values that cannot be broadcasted with this method | ||
because they require some special serialization/deserialization, e.g., a | ||
Python class like SamplingMetadata. For these fields, override | ||
as_broadcastable_tensor_dict to return the custom serialized values and | ||
override _get_init_kwargs to perform the custom deserialization ( | ||
GPUModelInput for an example). | ||
""" | ||
# Fields to broadcast to all workers from driver. The value must be | ||
# broadcastable using broadcast_tensor_dict (i.e. either a tensor, or a | ||
# Python primitive like int). During the broadcast, the listed fields will | ||
# be extracted from the source copy and then passed to `new()` to create a | ||
# copy on the destination(s). | ||
BROADCASTABLE_FIELDS: Tuple[str, ...] = () | ||
|
||
@classmethod | ||
def _get_init_kwargs(cls, **kwargs) -> Dict[str, Any]: | ||
""" | ||
Helper method to extract all dataclass fields from the given kwargs. | ||
Override for fields that require some custom deserialization. | ||
""" | ||
return kwargs | ||
|
||
@classmethod | ||
def new(cls, | ||
clone: Optional["ModelInput"] = None, | ||
**kwargs) -> "ModelInput": | ||
""" | ||
Create a new instance of this class. Copy fields from `clone` if | ||
provided. Populate the new instance with the given kwargs. | ||
""" | ||
clone_kwargs = {} | ||
if clone is not None: | ||
for field in dataclasses.fields(clone): | ||
val = getattr(clone, field.name) | ||
if val is not None: | ||
clone_kwargs[field.name] = val | ||
clone_kwargs = cls._get_init_kwargs(**clone_kwargs) | ||
|
||
kwargs = cls._get_init_kwargs(**kwargs) | ||
return cls(**clone_kwargs, **kwargs) | ||
|
||
def replace(self, **kwargs) -> "ModelInput": | ||
""" | ||
Replace current fields with fields in kwargs. | ||
""" | ||
valid_kwargs = self.__class__._get_init_kwargs(**kwargs) | ||
return dataclasses.replace(self, **valid_kwargs) | ||
|
||
def as_broadcastable_tensor_dict( | ||
self) -> Dict[str, Union[int, torch.Tensor]]: | ||
""" | ||
Extract broadcastable fields. Override for fields that require some | ||
custom deserialization. | ||
""" | ||
tensor_dict: Dict[str, Union[int, torch.Tensor]] = {} | ||
for field in self.BROADCASTABLE_FIELDS: | ||
val = getattr(self, field, None) | ||
if val is not None: | ||
tensor_dict[field] = val | ||
|
||
return tensor_dict | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class CPUModelInput(ModelInput): | ||
""" | ||
Used by the CPUModelRunner. | ||
""" | ||
num_seq_groups: Optional[int] = None | ||
blocks_to_copy: Optional[torch.Tensor] = None | ||
|
||
input_tokens: Optional[torch.Tensor] = None | ||
input_positions: Optional[torch.Tensor] = None | ||
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None | ||
|
||
attn_metadata: Optional["AttentionMetadata"] = None | ||
sampling_metadata: Optional["SamplingMetadata"] = None | ||
|
||
BROADCASTABLE_FIELDS: Tuple[str, ...] = ( | ||
"num_seq_groups", | ||
"blocks_to_copy", | ||
"input_tokens", | ||
"input_positions", | ||
"multi_modal_kwargs", | ||
) | ||
|
||
@classmethod | ||
def _get_init_kwargs( # type: ignore | ||
cls, **kwargs) -> Dict[str, Any]: | ||
kwargs = _init_attn_metadata_from_kwargs(**kwargs) | ||
kwargs = _init_sampling_metadata_from_kwargs(**kwargs) | ||
return super()._get_init_kwargs(**kwargs) | ||
|
||
def as_broadcastable_tensor_dict( | ||
self) -> Dict[str, Union[int, torch.Tensor]]: | ||
tensor_dict = super().as_broadcastable_tensor_dict() | ||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) | ||
_add_sampling_metadata_broadcastable_dict(tensor_dict, | ||
self.sampling_metadata) | ||
return tensor_dict | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class GPUModelInput(ModelInput): | ||
""" | ||
This base class contains metadata needed for the base model forward pass | ||
but not metadata for possible additional steps, e.g., sampling. Model | ||
runners that run additional steps should subclass this method to add | ||
additional fields. | ||
""" | ||
num_seq_groups: Optional[int] = None | ||
blocks_to_swap_in: Optional[torch.Tensor] = None | ||
blocks_to_swap_out: Optional[torch.Tensor] = None | ||
blocks_to_copy: Optional[torch.Tensor] = None | ||
|
||
input_tokens: Optional[torch.Tensor] = None | ||
input_positions: Optional[torch.Tensor] = None | ||
seq_lens: Optional[List[int]] = None | ||
query_lens: Optional[List[int]] = None | ||
lora_mapping: Optional["LoRAMapping"] = None | ||
lora_requests: Optional[Set[LoRARequest]] = None | ||
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None | ||
|
||
attn_metadata: Optional["AttentionMetadata"] = None | ||
|
||
BROADCASTABLE_FIELDS: Tuple[str, ...] = ( | ||
"num_seq_groups", | ||
"blocks_to_swap_in", | ||
"blocks_to_swap_out", | ||
"blocks_to_copy", | ||
"input_tokens", | ||
"input_positions", | ||
"lora_requests", | ||
"lora_mapping", | ||
"multi_modal_kwargs", | ||
) | ||
|
||
@classmethod | ||
def _get_init_kwargs( # type: ignore | ||
cls, **kwargs) -> Dict[str, Any]: | ||
kwargs = _init_attn_metadata_from_kwargs(**kwargs) | ||
return super()._get_init_kwargs(**kwargs) | ||
|
||
def as_broadcastable_tensor_dict( | ||
self) -> Dict[str, Union[int, torch.Tensor]]: | ||
tensor_dict = super().as_broadcastable_tensor_dict() | ||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) | ||
return tensor_dict | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class GPUModelInputWithPoolingMetadata(GPUModelInput): | ||
""" | ||
Used by the EmbeddingModelRunner. | ||
""" | ||
pooling_metadata: Optional["PoolingMetadata"] = None | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class GPUModelInputWithSamplingMetadata(GPUModelInput): | ||
""" | ||
Used by the ModelRunner. | ||
""" | ||
sampling_metadata: Optional["SamplingMetadata"] = None | ||
|
||
@classmethod | ||
def _get_init_kwargs( # type: ignore | ||
cls, **kwargs) -> Dict[str, Any]: | ||
kwargs = _init_sampling_metadata_from_kwargs(**kwargs) | ||
return super()._get_init_kwargs(**kwargs) | ||
|
||
def as_broadcastable_tensor_dict( | ||
self) -> Dict[str, Union[int, torch.Tensor]]: | ||
tensor_dict = super().as_broadcastable_tensor_dict() | ||
_add_sampling_metadata_broadcastable_dict(tensor_dict, | ||
self.sampling_metadata) | ||
return tensor_dict | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class ModelInputForNeuron(ModelInput): | ||
""" | ||
Used by the NeuronModelRunner. | ||
""" | ||
input_tokens: Optional[torch.Tensor] = None | ||
input_positions: Optional[torch.Tensor] = None | ||
input_block_ids: Optional[torch.Tensor] = None | ||
seq_lens: Optional[List[int]] = None | ||
sampling_metadata: Optional["SamplingMetadata"] = None | ||
|
||
def as_broadcastable_tensor_dict( | ||
self) -> Dict[str, Union[int, torch.Tensor]]: | ||
raise NotImplementedError("ModelInputForNeuron cannot be broadcast.") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this file should live in
worker/
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, will move.