This repository has been archived by the owner on Oct 11, 2024. It is now read-only.
forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathmodel_runner_base.py
162 lines (137 loc) · 5.33 KB
/
model_runner_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import dataclasses
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar)
import torch
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata
T = TypeVar('T', bound="ModelRunnerInputBase")
def _add_attn_metadata_broadcastable_dict(
tensor_dict: Dict[str, Any],
attn_metadata: Optional["AttentionMetadata"]) -> None:
"""
Helper method to update tensor_dict with broadcastable
AttentionMetadata fields.
"""
if attn_metadata is not None:
tensor_dict.update(attn_metadata.asdict_zerocopy())
def _init_attn_metadata_from_tensor_dict(
attn_backend: "AttentionBackend",
tensor_dict: Dict[str, Any],
) -> Dict[str, Any]:
"""
Helper method to initialize AttentionMetadata based on an
AttentionBackend and broadcastable AttentionMetadata fields.
"""
# Extract the fields used to create AttentionMetadata.
valid_attn_kwargs = {}
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
val = tensor_dict.pop(field.name, None)
if val is not None:
valid_attn_kwargs[field.name] = val
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
tensor_dict["attn_metadata"] = attn_metadata
return tensor_dict
def _init_sampling_metadata_from_tensor_dict( # type: ignore
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Helper method to initialize SamplingMetadata based on broadcastable
SamplingMetadata fields.
"""
from vllm.model_executor import SamplingMetadata
selected_token_indices = tensor_dict.pop("selected_token_indices", None)
# An empty SamplingMetadata to signal that the worker should skip
# sampling.
if selected_token_indices is not None:
tensor_dict["sampling_metadata"] = SamplingMetadata(
seq_groups=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
num_prompts=0,
)
return tensor_dict
def _add_sampling_metadata_broadcastable_dict(
tensor_dict: Dict[str, Any],
sampling_metadata: Optional["SamplingMetadata"]) -> None:
"""
Helper method to update tensor_dict with broadcastable
SamplingMetadata fields.
"""
if sampling_metadata is not None:
tensor_dict["selected_token_indices"] = (
sampling_metadata.selected_token_indices)
@dataclasses.dataclass(frozen=True)
class ModelRunnerInputBase(ABC):
"""Local inputs to each worker's model runner. May contain
device-specific data. Different worker backends may have different methods
of converting from the global ExecuteModelRequest produced by the LLM
engine to the worker-local ModelRunnerInputBase objects.
Model runners that support multi-GPU execution should define a
ModelRunnerInputBase subclass, add their required fields, and specify how to
serialize/deserialize a ModelInput for broadcast between workers.
"""
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
"""
Extract broadcastable fields. Override for fields that require some
custom deserialization.
"""
raise NotImplementedError
@classmethod
@abstractmethod
def from_broadcasted_tensor_dict(
cls: Type[T],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> T:
"""
Pop fields from the given tensor_dict and populate a new instance of
ModelRunnerInputBase.
"""
raise NotImplementedError
class ModelRunnerBase(ABC, Generic[T]):
"""
Model runner interface that abstracts a particular hardware and/or type of
model. Model execution may communicate data with model runners in other
processes, but it should not include control plane metadata communication.
Each ModelRunnerBase subclass should define a corresponding
ModelRunnerInputBase subclass.
"""
@abstractmethod
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> T:
"""
Make an instance of a ModelRunnerInputBase from the broadcasted tensor
dict.
"""
raise NotImplementedError
@abstractmethod
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
) -> T:
"""
Prepare the inputs to ModelRunnerBase.execute_model from an execution
request. This method may move data to the worker's local device. It is
not allowed to communicate with other workers or devices.
"""
raise NotImplementedError
@torch.inference_mode()
def execute_model(
self,
model_input: T,
kv_caches: Optional[List[torch.Tensor]],
intermediate_tensors: Optional[IntermediateTensors],
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
"""
Execute the model on the given input.
"""
raise NotImplementedError