8
8
from cacheflow .outputs import RequestOutput
9
9
from cacheflow .sampling_params import SamplingParams
10
10
from cacheflow .server .arg_utils import ServerArgs
11
- from cacheflow .server .ray_utils import ray , initialize_cluster
11
+ from cacheflow .server .ray_utils import DeviceID , initialize_cluster , ray
12
12
from cacheflow .server .tokenizer_utils import (get_tokenizer ,
13
13
detokenize_incrementally )
14
14
from cacheflow .sequence import Sequence , SequenceGroup , SequenceStatus
19
19
20
20
21
21
class LLMServer :
22
+ """An LLM server that receives requests and generates texts.
23
+
24
+ This is the main class for the CacheFlow LLM server. It receives requests
25
+ from clients and generates texts from the LLM. It includes a tokenizer, a
26
+ language model (possibly distributed across multiple GPUs), and GPU memory
27
+ space allocated for intermediate states (aka KV cache). This class utilizes
28
+ iteration-level scheduling and efficient memory management to maximize the
29
+ serving throughput.
30
+
31
+ The `LLM` class wraps this class for offline batched inference and the
32
+ `AsyncLLMServer` class wraps this class for online serving.
33
+
34
+ NOTE: The config arguments are derived from the `ServerArgs` class. For the
35
+ comprehensive list of arguments, see `ServerArgs`.
36
+
37
+ Args:
38
+ model_config: The configuration related to the LLM model.
39
+ cache_config: The configuration related to the KV cache memory
40
+ management.
41
+ parallel_config: The configuration related to distributed execution.
42
+ scheduler_config: The configuration related to the request scheduler.
43
+ distributed_init_method: The initialization method for distributed
44
+ execution. See `torch.distributed.init_process_group` for details.
45
+ stage_devices: The list of devices for each stage. Each stage is a list
46
+ of (rank, node_resource, device) tuples.
47
+ log_stats: Whether to log statistics.
48
+ """
22
49
23
50
def __init__ (
24
51
self ,
@@ -27,7 +54,7 @@ def __init__(
27
54
parallel_config : ParallelConfig ,
28
55
scheduler_config : SchedulerConfig ,
29
56
distributed_init_method : str ,
30
- stage_devices : List [List [Any ]],
57
+ stage_devices : List [List [DeviceID ]],
31
58
log_stats : bool ,
32
59
) -> None :
33
60
logger .info (
@@ -83,6 +110,7 @@ def _verify_args(self) -> None:
83
110
self .cache_config .verify_with_parallel_config (self .parallel_config )
84
111
85
112
def _init_cache (self ) -> None :
113
+ """Profiles the memory usage and initializes the KV cache."""
86
114
# Get the maximum number of blocks that can be allocated on GPU and CPU.
87
115
num_blocks = self ._run_workers (
88
116
"profile_num_available_blocks" ,
@@ -108,6 +136,7 @@ def _init_cache(self) -> None:
108
136
109
137
@classmethod
110
138
def from_server_args (cls , server_args : ServerArgs ) -> "LLMServer" :
139
+ """Creates an LLM server from the server arguments."""
111
140
# Create the server configs.
112
141
server_configs = server_args .create_server_configs ()
113
142
parallel_config = server_configs [2 ]
@@ -126,6 +155,22 @@ def add_request(
126
155
prompt_token_ids : Optional [List [int ]] = None ,
127
156
arrival_time : Optional [float ] = None ,
128
157
) -> None :
158
+ """Add a request to the server's request pool.
159
+
160
+ The request is added to the request pool and will be processed by the
161
+ scheduler as `server.step()` is called. The exact scheduling policy is
162
+ determined by the scheduler.
163
+
164
+ Args:
165
+ request_id: The unique ID of the request.
166
+ prompt: The prompt string. Can be None if prompt_token_ids is
167
+ provided.
168
+ sampling_params: The sampling parameters for text generation.
169
+ prompt_token_ids: The token IDs of the prompt. If None, we
170
+ use the tokenizer to convert the prompts to token IDs.
171
+ arrival_time: The arrival time of the request. If None, we use
172
+ the current time.
173
+ """
129
174
if arrival_time is None :
130
175
arrival_time = time .time ()
131
176
if prompt_token_ids is None :
@@ -148,15 +193,30 @@ def add_request(
148
193
self .scheduler .add_seq_group (seq_group )
149
194
150
195
def abort_request (self , request_id : str ) -> None :
196
+ """Aborts a request with the given ID.
197
+
198
+ Args:
199
+ request_id: The ID of the request to abort.
200
+ """
151
201
self .scheduler .abort_seq_group (request_id )
152
202
153
203
def get_num_unfinished_requests (self ) -> int :
204
+ """Gets the number of unfinished requests."""
154
205
return self .scheduler .get_num_unfinished_seq_groups ()
155
206
156
207
def has_unfinished_requests (self ) -> bool :
208
+ """Returns True if there are unfinished requests."""
157
209
return self .scheduler .has_unfinished_seqs ()
158
210
159
211
def step (self ) -> List [RequestOutput ]:
212
+ """Performs one decoding iteration and returns newly generated results.
213
+
214
+ This function performs one decoding iteration for the server. It first
215
+ schedules the sequences to be executed in the next iteration and the
216
+ token blocks to be swapped in/out/copy. Then, it executes the model
217
+ and updates the scheduler with the model outputs. Finally, it decodes
218
+ the sequences and returns the newly generated results.
219
+ """
160
220
seq_group_metadata_list , scheduler_outputs = self .scheduler .schedule ()
161
221
if (not seq_group_metadata_list ) and scheduler_outputs .is_empty ():
162
222
# Nothing to do.
@@ -188,7 +248,7 @@ def step(self) -> List[RequestOutput]:
188
248
return request_outputs
189
249
190
250
def _decode_sequences (self , seq_groups : List [SequenceGroup ]) -> None :
191
- # Decode the sequence outputs.
251
+ """Decodes the sequence outputs."""
192
252
for seq_group in seq_groups :
193
253
for seq in seq_group .get_seqs (status = SequenceStatus .RUNNING ):
194
254
new_token , new_output_text = detokenize_incrementally (
@@ -201,7 +261,7 @@ def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
201
261
seq .output_text = new_output_text
202
262
203
263
def _stop_sequences (self , seq_groups : List [SequenceGroup ]) -> None :
204
- # Stop the sequences.
264
+ """ Stop the finished sequences."""
205
265
for seq_group in seq_groups :
206
266
sampling_params = seq_group .sampling_params
207
267
for seq in seq_group .get_seqs (status = SequenceStatus .RUNNING ):
@@ -238,6 +298,7 @@ def _run_workers(
238
298
* args ,
239
299
** kwargs ,
240
300
) -> Any :
301
+ """Runs the given method on all workers."""
241
302
all_outputs = []
242
303
for worker in self .workers :
243
304
executor = getattr (worker , method )
0 commit comments