88from cacheflow .outputs import RequestOutput
99from cacheflow .sampling_params import SamplingParams
1010from cacheflow .server .arg_utils import ServerArgs
11- from cacheflow .server .ray_utils import ray , initialize_cluster
11+ from cacheflow .server .ray_utils import DeviceID , initialize_cluster , ray
1212from cacheflow .server .tokenizer_utils import (get_tokenizer ,
1313 detokenize_incrementally )
1414from cacheflow .sequence import Sequence , SequenceGroup , SequenceStatus
1919
2020
2121class LLMServer :
22+ """An LLM server that receives requests and generates texts.
23+
24+ This is the main class for the CacheFlow LLM server. It receives requests
25+ from clients and generates texts from the LLM. It includes a tokenizer, a
26+ language model (possibly distributed across multiple GPUs), and GPU memory
27+ space allocated for intermediate states (aka KV cache). This class utilizes
28+ iteration-level scheduling and efficient memory management to maximize the
29+ serving throughput.
30+
31+ The `LLM` class wraps this class for offline batched inference and the
32+ `AsyncLLMServer` class wraps this class for online serving.
33+
34+ NOTE: The config arguments are derived from the `ServerArgs` class. For the
35+ comprehensive list of arguments, see `ServerArgs`.
36+
37+ Args:
38+ model_config: The configuration related to the LLM model.
39+ cache_config: The configuration related to the KV cache memory
40+ management.
41+ parallel_config: The configuration related to distributed execution.
42+ scheduler_config: The configuration related to the request scheduler.
43+ distributed_init_method: The initialization method for distributed
44+ execution. See `torch.distributed.init_process_group` for details.
45+ stage_devices: The list of devices for each stage. Each stage is a list
46+ of (rank, node_resource, device) tuples.
47+ log_stats: Whether to log statistics.
48+ """
2249
2350 def __init__ (
2451 self ,
@@ -27,7 +54,7 @@ def __init__(
2754 parallel_config : ParallelConfig ,
2855 scheduler_config : SchedulerConfig ,
2956 distributed_init_method : str ,
30- stage_devices : List [List [Any ]],
57+ stage_devices : List [List [DeviceID ]],
3158 log_stats : bool ,
3259 ) -> None :
3360 logger .info (
@@ -83,6 +110,7 @@ def _verify_args(self) -> None:
83110 self .cache_config .verify_with_parallel_config (self .parallel_config )
84111
85112 def _init_cache (self ) -> None :
113+ """Profiles the memory usage and initializes the KV cache."""
86114 # Get the maximum number of blocks that can be allocated on GPU and CPU.
87115 num_blocks = self ._run_workers (
88116 "profile_num_available_blocks" ,
@@ -108,6 +136,7 @@ def _init_cache(self) -> None:
108136
109137 @classmethod
110138 def from_server_args (cls , server_args : ServerArgs ) -> "LLMServer" :
139+ """Creates an LLM server from the server arguments."""
111140 # Create the server configs.
112141 server_configs = server_args .create_server_configs ()
113142 parallel_config = server_configs [2 ]
@@ -126,6 +155,22 @@ def add_request(
126155 prompt_token_ids : Optional [List [int ]] = None ,
127156 arrival_time : Optional [float ] = None ,
128157 ) -> None :
158+ """Add a request to the server's request pool.
159+
160+ The request is added to the request pool and will be processed by the
161+ scheduler as `server.step()` is called. The exact scheduling policy is
162+ determined by the scheduler.
163+
164+ Args:
165+ request_id: The unique ID of the request.
166+ prompt: The prompt string. Can be None if prompt_token_ids is
167+ provided.
168+ sampling_params: The sampling parameters for text generation.
169+ prompt_token_ids: The token IDs of the prompt. If None, we
170+ use the tokenizer to convert the prompts to token IDs.
171+ arrival_time: The arrival time of the request. If None, we use
172+ the current time.
173+ """
129174 if arrival_time is None :
130175 arrival_time = time .time ()
131176 if prompt_token_ids is None :
@@ -148,15 +193,30 @@ def add_request(
148193 self .scheduler .add_seq_group (seq_group )
149194
150195 def abort_request (self , request_id : str ) -> None :
196+ """Aborts a request with the given ID.
197+
198+ Args:
199+ request_id: The ID of the request to abort.
200+ """
151201 self .scheduler .abort_seq_group (request_id )
152202
153203 def get_num_unfinished_requests (self ) -> int :
204+ """Gets the number of unfinished requests."""
154205 return self .scheduler .get_num_unfinished_seq_groups ()
155206
156207 def has_unfinished_requests (self ) -> bool :
208+ """Returns True if there are unfinished requests."""
157209 return self .scheduler .has_unfinished_seqs ()
158210
159211 def step (self ) -> List [RequestOutput ]:
212+ """Performs one decoding iteration and returns newly generated results.
213+
214+ This function performs one decoding iteration for the server. It first
215+ schedules the sequences to be executed in the next iteration and the
216+ token blocks to be swapped in/out/copy. Then, it executes the model
217+ and updates the scheduler with the model outputs. Finally, it decodes
218+ the sequences and returns the newly generated results.
219+ """
160220 seq_group_metadata_list , scheduler_outputs = self .scheduler .schedule ()
161221 if (not seq_group_metadata_list ) and scheduler_outputs .is_empty ():
162222 # Nothing to do.
@@ -188,7 +248,7 @@ def step(self) -> List[RequestOutput]:
188248 return request_outputs
189249
190250 def _decode_sequences (self , seq_groups : List [SequenceGroup ]) -> None :
191- # Decode the sequence outputs.
251+ """Decodes the sequence outputs."""
192252 for seq_group in seq_groups :
193253 for seq in seq_group .get_seqs (status = SequenceStatus .RUNNING ):
194254 new_token , new_output_text = detokenize_incrementally (
@@ -201,7 +261,7 @@ def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
201261 seq .output_text = new_output_text
202262
203263 def _stop_sequences (self , seq_groups : List [SequenceGroup ]) -> None :
204- # Stop the sequences.
264+ """ Stop the finished sequences."""
205265 for seq_group in seq_groups :
206266 sampling_params = seq_group .sampling_params
207267 for seq in seq_group .get_seqs (status = SequenceStatus .RUNNING ):
@@ -238,6 +298,7 @@ def _run_workers(
238298 * args ,
239299 ** kwargs ,
240300 ) -> Any :
301+ """Runs the given method on all workers."""
241302 all_outputs = []
242303 for worker in self .workers :
243304 executor = getattr (worker , method )
0 commit comments