Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Refactor Worker and ModelRunner to consolidate control plane communication #5408

Merged
merged 64 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
725b0b2
tmp
stephanie-wang Jun 11, 2024
b74eb10
fix
stephanie-wang Jun 11, 2024
38b0ddf
ray and mp backends work
stephanie-wang Jun 11, 2024
0d11e92
embedding model runner works
stephanie-wang Jun 11, 2024
2cdc218
GPU executor works
stephanie-wang Jun 11, 2024
c728512
remove comment
stephanie-wang Jun 11, 2024
2bf752b
use the right ModelInput class
stephanie-wang Jun 11, 2024
f35a23f
CPU worker
stephanie-wang Jun 11, 2024
11133fe
remove commented
stephanie-wang Jun 11, 2024
174bdb1
lint
stephanie-wang Jun 11, 2024
c0e98ca
Worker.execute_model vs execute_model_local
stephanie-wang Jun 11, 2024
dccec95
lint
stephanie-wang Jun 11, 2024
dad94ba
neuron model runner
stephanie-wang Jun 11, 2024
fca606e
disallow distributed comms
stephanie-wang Jun 11, 2024
6ed3c2a
disable communication
stephanie-wang Jun 12, 2024
1803e33
Update worker.py
stephanie-wang Jun 12, 2024
dde799e
fix tests
stephanie-wang Jun 12, 2024
0398631
update
stephanie-wang Jun 12, 2024
5c41cc6
Merge branch 'control-refactor-2' of github.com:stephanie-wang/vllm i…
stephanie-wang Jun 12, 2024
72f0383
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang Jun 12, 2024
eef6623
merge
stephanie-wang Jun 12, 2024
3004ceb
update
Jun 13, 2024
8d852e9
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang Jun 13, 2024
9380ed8
fix
stephanie-wang Jun 13, 2024
3c4de6d
fix
stephanie-wang Jun 13, 2024
5053f30
fix
stephanie-wang Jun 13, 2024
db38556
x
stephanie-wang Jun 14, 2024
456185d
rm
stephanie-wang Jun 14, 2024
e860652
lint
stephanie-wang Jun 14, 2024
3d4f242
add missing
stephanie-wang Jun 14, 2024
11304cb
revert
stephanie-wang Jun 14, 2024
99f532e
refactor
stephanie-wang Jun 15, 2024
797a7cf
doc
stephanie-wang Jun 15, 2024
6ad2513
revert spec decode and doc
stephanie-wang Jun 15, 2024
97ec303
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang Jun 15, 2024
e10bace
typing
stephanie-wang Jun 15, 2024
ce087ae
fix
stephanie-wang Jun 18, 2024
f851b00
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang Jun 18, 2024
0e2acc4
XPU worker and rename
stephanie-wang Jun 18, 2024
d318ec8
lint
stephanie-wang Jun 18, 2024
b48f783
lint
stephanie-wang Jun 18, 2024
c93afc1
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang Jun 18, 2024
30ac400
fix
stephanie-wang Jun 18, 2024
01688d5
x
stephanie-wang Jun 18, 2024
7dbb646
fix
stephanie-wang Jun 18, 2024
d2e4c41
fix
stephanie-wang Jun 19, 2024
3e46253
lint
stephanie-wang Jun 19, 2024
0a2890a
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang Jun 21, 2024
36dfce1
merge
stephanie-wang Jun 21, 2024
ea5412e
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang Jun 21, 2024
dc2f103
x
stephanie-wang Jun 21, 2024
fbf074d
x
stephanie-wang Jun 22, 2024
660a8d5
rename ModelInput -> ModelInputBase, override as_broadcastable_tensor…
stephanie-wang Jun 23, 2024
8cca634
fixes
stephanie-wang Jun 23, 2024
0a25c19
rename
stephanie-wang Jun 23, 2024
0b26877
fix
stephanie-wang Jun 24, 2024
e7052d5
do not filter Nones
stephanie-wang Jun 24, 2024
df5551f
dupe
stephanie-wang Jun 24, 2024
6745b3b
update
stephanie-wang Jun 25, 2024
ebae970
lint
stephanie-wang Jun 25, 2024
5763621
revert
stephanie-wang Jun 25, 2024
46d5b18
Merge branch 'main' into control-refactor-2
stephanie-wang Jun 25, 2024
d16d5fe
rm
stephanie-wang Jun 25, 2024
f6c6234
fix
stephanie-wang Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu>
  • Loading branch information
stephanie-wang committed Jun 12, 2024
commit 039863143c23b5a7be881d0fcfa2b307e763287b
95 changes: 95 additions & 0 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import dataclasses
from typing import List, Tuple, Type

import torch

from vllm.attention import AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata
from vllm.model_input import GPUModelInputWithSamplingMetadata


class MockAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
pass

@staticmethod
def get_impl_cls():
pass

@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return AttentionMetadata

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
pass

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
pass

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
pass


def test_gpu_model_input():
sampling_metadata = SamplingMetadata(
["seq_group"],
"selected_token_indices",
"categorized_sample_indices",
"num_prompts",
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
)
model_input = GPUModelInputWithSamplingMetadata.new(
num_seq_groups=10,
sampling_metadata=sampling_metadata,
attn_metadata=attn_metadata)

assert isinstance(model_input, GPUModelInputWithSamplingMetadata)

# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = GPUModelInputWithSamplingMetadata.new(
attn_backend=attn_backend, **tensor_dict)
assert isinstance(received_model_input, GPUModelInputWithSamplingMetadata)

# Broadcast should not contain empty values.
for field in dataclasses.fields(model_input):
if getattr(model_input, field.name) is None:
assert field.name not in tensor_dict
# Broadcast should contain all non-empty fields defined by the developer
# for this input type.
for field in GPUModelInputWithSamplingMetadata.BROADCASTABLE_FIELDS:
if getattr(model_input, field) is not None:
assert field in tensor_dict

# Check that received copy has correct values.
for field in dataclasses.fields(AttentionMetadata):
assert getattr(received_model_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# For sampling metadata, only selected_token_indices is copied.
assert (received_model_input.sampling_metadata.selected_token_indices ==
sampling_metadata.selected_token_indices)
assert received_model_input.sampling_metadata.seq_groups is None
276 changes: 276 additions & 0 deletions vllm/model_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
"""Worker-local model inputs. These define the inputs to different model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this file should live in worker/?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, will move.

runners."""
import dataclasses
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union

import torch

from vllm.lora.request import LoRARequest

if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend
from vllm.lora.layers import LoRAMapping
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata


def _init_attn_metadata_from_kwargs(
attn_backend: Optional["AttentionBackend"] = None,
attn_metadata: Optional["AttentionMetadata"] = None,
**kwargs) -> Dict[str, Any]:
if attn_metadata is None and attn_backend is not None:
# Extract the fields used to create AttentionMetadata.
valid_attn_kwargs = {}
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
val = kwargs.pop(field.name, None)
if val is not None:
valid_attn_kwargs[field.name] = val

attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
if attn_metadata is not None:
kwargs["attn_metadata"] = attn_metadata
return kwargs


def _add_attn_metadata_broadcastable_dict(
tensor_dict: Dict[str, Union[int, torch.Tensor]],
attn_metadata: Optional["AttentionMetadata"]) -> None:
if attn_metadata is not None:
tensor_dict.update(attn_metadata.asdict_zerocopy())


def _init_sampling_metadata_from_kwargs( # type: ignore
selected_token_indices: Optional[torch.Tensor] = None,
sampling_metadata: Optional["SamplingMetadata"] = None,
**kwargs) -> Dict[str, Any]:
if sampling_metadata is None and selected_token_indices is not None:
from vllm.model_executor import SamplingMetadata

# An empty SamplingMetadata to signal that the worker should skip
# sampling.
sampling_metadata = SamplingMetadata(
seq_groups=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
num_prompts=0,
)
if sampling_metadata is not None:
kwargs["sampling_metadata"] = sampling_metadata
return kwargs


def _add_sampling_metadata_broadcastable_dict(
tensor_dict: Dict[str, Union[int, torch.Tensor]],
sampling_metadata: Optional["SamplingMetadata"]) -> None:
if sampling_metadata is not None:
tensor_dict["selected_token_indices"] = (
sampling_metadata.selected_token_indices)


@dataclasses.dataclass(frozen=True)
class ModelInput:
"""Local inputs to each worker's model runner. May contain
device-specific data. Different worker backends may have different methods
of converting from the global ExecuteModelRequest produced by the LLM
engine to the worker-local ModelInput objects.

Model runners should inherit from this class and add their required fields.
For distributed executors, any fields that should be sent during a
broadcast op should also be added to the BROADCASTABLE_FIELDS. During
execution, these fields will be extracted from the source copy and
broadcasted to all workers using broadcast_tensor_dict.

Some fields may have values that cannot be broadcasted with this method
because they require some special serialization/deserialization, e.g., a
Python class like SamplingMetadata. For these fields, override
as_broadcastable_tensor_dict to return the custom serialized values and
override _get_init_kwargs to perform the custom deserialization (
GPUModelInput for an example).
"""
# Fields to broadcast to all workers from driver. The value must be
# broadcastable using broadcast_tensor_dict (i.e. either a tensor, or a
# Python primitive like int). During the broadcast, the listed fields will
# be extracted from the source copy and then passed to `new()` to create a
# copy on the destination(s).
BROADCASTABLE_FIELDS: Tuple[str, ...] = ()

@classmethod
def _get_init_kwargs(cls, **kwargs) -> Dict[str, Any]:
"""
Helper method to extract all dataclass fields from the given kwargs.
Override for fields that require some custom deserialization.
"""
return kwargs

@classmethod
def new(cls,
clone: Optional["ModelInput"] = None,
**kwargs) -> "ModelInput":
"""
Create a new instance of this class. Copy fields from `clone` if
provided. Populate the new instance with the given kwargs.
"""
clone_kwargs = {}
if clone is not None:
for field in dataclasses.fields(clone):
val = getattr(clone, field.name)
if val is not None:
clone_kwargs[field.name] = val
clone_kwargs = cls._get_init_kwargs(**clone_kwargs)

kwargs = cls._get_init_kwargs(**kwargs)
return cls(**clone_kwargs, **kwargs)

def replace(self, **kwargs) -> "ModelInput":
"""
Replace current fields with fields in kwargs.
"""
valid_kwargs = self.__class__._get_init_kwargs(**kwargs)
return dataclasses.replace(self, **valid_kwargs)

def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
"""
Extract broadcastable fields. Override for fields that require some
custom deserialization.
"""
tensor_dict: Dict[str, Union[int, torch.Tensor]] = {}
for field in self.BROADCASTABLE_FIELDS:
val = getattr(self, field, None)
if val is not None:
tensor_dict[field] = val

return tensor_dict


@dataclasses.dataclass(frozen=True)
class CPUModelInput(ModelInput):
"""
Used by the CPUModelRunner.
"""
num_seq_groups: Optional[int] = None
blocks_to_copy: Optional[torch.Tensor] = None

input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None

attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None

BROADCASTABLE_FIELDS: Tuple[str, ...] = (
"num_seq_groups",
"blocks_to_copy",
"input_tokens",
"input_positions",
"multi_modal_kwargs",
)

@classmethod
def _get_init_kwargs( # type: ignore
cls, **kwargs) -> Dict[str, Any]:
kwargs = _init_attn_metadata_from_kwargs(**kwargs)
kwargs = _init_sampling_metadata_from_kwargs(**kwargs)
return super()._get_init_kwargs(**kwargs)

def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = super().as_broadcastable_tensor_dict()
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict


@dataclasses.dataclass(frozen=True)
class GPUModelInput(ModelInput):
"""
This base class contains metadata needed for the base model forward pass
but not metadata for possible additional steps, e.g., sampling. Model
runners that run additional steps should subclass this method to add
additional fields.
"""
num_seq_groups: Optional[int] = None
blocks_to_swap_in: Optional[torch.Tensor] = None
blocks_to_swap_out: Optional[torch.Tensor] = None
blocks_to_copy: Optional[torch.Tensor] = None

input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None

attn_metadata: Optional["AttentionMetadata"] = None

BROADCASTABLE_FIELDS: Tuple[str, ...] = (
"num_seq_groups",
"blocks_to_swap_in",
"blocks_to_swap_out",
"blocks_to_copy",
"input_tokens",
"input_positions",
"lora_requests",
"lora_mapping",
"multi_modal_kwargs",
)

@classmethod
def _get_init_kwargs( # type: ignore
cls, **kwargs) -> Dict[str, Any]:
kwargs = _init_attn_metadata_from_kwargs(**kwargs)
return super()._get_init_kwargs(**kwargs)

def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = super().as_broadcastable_tensor_dict()
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict


@dataclasses.dataclass(frozen=True)
class GPUModelInputWithPoolingMetadata(GPUModelInput):
"""
Used by the EmbeddingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None


@dataclasses.dataclass(frozen=True)
class GPUModelInputWithSamplingMetadata(GPUModelInput):
"""
Used by the ModelRunner.
"""
sampling_metadata: Optional["SamplingMetadata"] = None

@classmethod
def _get_init_kwargs( # type: ignore
cls, **kwargs) -> Dict[str, Any]:
kwargs = _init_sampling_metadata_from_kwargs(**kwargs)
return super()._get_init_kwargs(**kwargs)

def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = super().as_broadcastable_tensor_dict()
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict


@dataclasses.dataclass(frozen=True)
class ModelInputForNeuron(ModelInput):
"""
Used by the NeuronModelRunner.
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
input_block_ids: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None
sampling_metadata: Optional["SamplingMetadata"] = None

def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
raise NotImplementedError("ModelInputForNeuron cannot be broadcast.")
Loading