Skip to content

Commit 0d75dc0

Browse files
authored
Add docstrings for LLMServer and related classes and examples (vllm-project#142)
1 parent 742af99 commit 0d75dc0

File tree

10 files changed

+216
-22
lines changed

10 files changed

+216
-22
lines changed

cacheflow/config.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,20 @@
1212

1313

1414
class ModelConfig:
15+
"""Configuration for the model.
16+
17+
Args:
18+
model: Name or path of the huggingface model to use.
19+
download_dir: Directory to download and load the weights, default to the
20+
default cache directory of huggingface.
21+
use_np_weights: Save a numpy copy of model weights for faster loading.
22+
This can increase the disk usage by up to 2x.
23+
use_dummy_weights: Use dummy values for model weights (for profiling).
24+
dtype: Data type for model weights and activations. The "auto" option
25+
will use FP16 precision for FP32 and FP16 models, and BF16 precision
26+
for BF16 models.
27+
seed: Random seed for reproducibility.
28+
"""
1529

1630
def __init__(
1731
self,
@@ -68,7 +82,14 @@ def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
6882

6983

7084
class CacheConfig:
71-
85+
"""Configuration for the KV cache.
86+
87+
Args:
88+
block_size: Size of a cache block in number of tokens.
89+
gpu_memory_utilization: Fraction of GPU memory to use for the
90+
CacheFlow execution.
91+
swap_space: Size of the CPU swap space per GPU (in GiB).
92+
"""
7293
def __init__(
7394
self,
7495
block_size: int,
@@ -111,7 +132,15 @@ def verify_with_parallel_config(
111132

112133

113134
class ParallelConfig:
114-
135+
"""Configuration for the distributed execution.
136+
137+
Args:
138+
pipeline_parallel_size: Number of pipeline parallel groups.
139+
tensor_parallel_size: Number of tensor parallel groups.
140+
worker_use_ray: Whether to use Ray for model workers. Will be set to
141+
True if either pipeline_parallel_size or tensor_parallel_size is
142+
greater than 1.
143+
"""
115144
def __init__(
116145
self,
117146
pipeline_parallel_size: int,
@@ -134,7 +163,14 @@ def _verify_args(self) -> None:
134163

135164

136165
class SchedulerConfig:
137-
166+
"""Scheduler configuration.
167+
168+
Args:
169+
max_num_batched_tokens: Maximum number of tokens to be processed in
170+
a single iteration.
171+
max_num_seqs: Maximum number of sequences to be processed in a single
172+
iteration.
173+
"""
138174
def __init__(
139175
self,
140176
max_num_batched_tokens: int,

cacheflow/entrypoints/openai/openai_frontend.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,18 @@ def create_logprobs(token_ids: List[int],
9696

9797
@app.post("/v1/completions")
9898
async def create_completion(raw_request: Request):
99+
"""Completion API similar to OpenAI's API.
100+
101+
See https://platform.openai.com/docs/api-reference/completions/create
102+
for the API specification. This API mimics the OpenAI Completion API.
103+
104+
NOTE: Currently we do not support the following features:
105+
- echo (since the cacheflow server does not currently support
106+
getting the logprobs of prompt tokens)
107+
- suffix (the language models we currently support do not support
108+
suffix)
109+
- logit_bias (to be supported in cacheflow server)
110+
"""
99111
request = CompletionRequest(**await raw_request.json())
100112
logger.info(f"Received completion request: {request}")
101113

cacheflow/entrypoints/simple_fastapi_frontend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818

1919
@app.post("/generate")
2020
async def generate_stream(request: Request) -> StreamingResponse:
21+
""" Stream the results of the generation request.
22+
23+
The request should be a JSON object with the following fields:
24+
- prompt: the prompt to use for the generation.
25+
- other fields: the sampling parameters (See `SamplingParams` for details).
26+
"""
2127
request_dict = await request.json()
2228
prompt = request_dict.pop("prompt")
2329
sampling_params = SamplingParams(**request_dict)

cacheflow/server/arg_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
@dataclass
1111
class ServerArgs:
12+
"""Arguments for CacheFlow servers."""
1213
model: str
1314
download_dir: Optional[str] = None
1415
use_np_weights: bool = False
@@ -117,6 +118,7 @@ def create_server_configs(
117118

118119
@dataclass
119120
class AsyncServerArgs(ServerArgs):
121+
"""Arguments for asynchronous CacheFlow servers."""
120122
server_use_ray: bool = False
121123

122124
@staticmethod

cacheflow/server/async_llm_server.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import time
3-
from typing import Dict, Optional
3+
from typing import Dict, List, Optional
44

55
from cacheflow.logger import init_logger
66
from cacheflow.outputs import RequestOutput
@@ -15,7 +15,25 @@
1515

1616

1717
class AsyncLLMServer:
18-
18+
"""An asynchronous wrapper for LLMServer.
19+
20+
This class is used to wrap the LLMServer class to make it asynchronous. It
21+
uses asyncio to create a background loop that keeps processing incoming
22+
requests. The LLMServer is kicked by the generate method when there
23+
are requests in the waiting queue. The generate method yields the outputs
24+
from the LLMServer to the caller.
25+
26+
NOTE: For the comprehensive list of arguments, see `LLMServer`.
27+
28+
Args:
29+
worker_use_ray: Whether to use Ray for model workers. Required for
30+
distributed execution. Should be the same as
31+
`parallel_config.worker_use_ray`.
32+
server_use_ray: Whether to make LLMServer a Ray actor. If so, the
33+
async frontend will be executed in a separate process as the
34+
model workers.
35+
*args, *kwargs: Arguments for LLMServer.
36+
"""
1937
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
2038
*args, **kwargs) -> None:
2139
self.worker_use_ray = worker_use_ray
@@ -35,6 +53,7 @@ def __init__(self, worker_use_ray: bool, server_use_ray: bool,
3553
self.kicking_request_id: Optional[str] = None
3654

3755
async def server_step(self, kicking_request_id: Optional[str] = None):
56+
"""Kick the server to process the waiting requests."""
3857
self.is_server_running = True
3958
self.kicking_request_id = kicking_request_id
4059
if self.server_use_ray:
@@ -54,8 +73,31 @@ async def server_step(self, kicking_request_id: Optional[str] = None):
5473
self.request_outputs[request_id] = request_output
5574
self.request_events[request_id].set()
5675

57-
async def generate(self, prompt: str, sampling_params: SamplingParams,
58-
request_id: str) -> RequestOutput:
76+
async def generate(
77+
self,
78+
prompt: Optional[str],
79+
sampling_params: SamplingParams,
80+
request_id: str,
81+
prompt_token_ids: Optional[List[int]] = None
82+
) -> RequestOutput:
83+
"""Generate outputs for a request.
84+
85+
Generate outputs for a request. This method is a coroutine. It adds the
86+
request into the waiting queue of the LLMServer and streams the outputs
87+
from the LLMServer to the caller.
88+
89+
Args:
90+
prompt: The prompt string. Can be None if prompt_token_ids is
91+
provided.
92+
sampling_params: The sampling parameters of the request.
93+
request_id: The unique id of the request.
94+
prompt_token_ids: The token IDs of the prompt. If None, we
95+
use the tokenizer to convert the prompts to token IDs.
96+
97+
Yields:
98+
The output `RequestOutput` objects from the LLMServer for the
99+
request.
100+
"""
59101
# Preprocess the request.
60102
arrival_time = time.time()
61103

@@ -66,20 +108,29 @@ async def generate(self, prompt: str, sampling_params: SamplingParams,
66108

67109
logger.info(f"Received request {request_id}: "
68110
f"prompt: {prompt!r}, "
69-
f"sampling params: {sampling_params}.")
111+
f"sampling params: {sampling_params}, "
112+
f"prompt token ids: {prompt_token_ids}.")
70113

71114
# Add the request into the cacheflow server's waiting queue.
72115
if self.server_use_ray:
73116
await self.server.add_request.remote(
74-
request_id, prompt, sampling_params, arrival_time=arrival_time)
117+
request_id, prompt, sampling_params,
118+
prompt_token_ids=prompt_token_ids,
119+
arrival_time=arrival_time)
75120
else:
76121
self.server.add_request(
77-
request_id, prompt, sampling_params, arrival_time=arrival_time)
122+
request_id, prompt, sampling_params,
123+
prompt_token_ids=prompt_token_ids,
124+
arrival_time=arrival_time)
78125

79126
# The cacheflow server does not have a background loop that keeps
80127
# processing incoming requests. Therefore, we need to keep kicking
81128
# the server to process the requests.
82129
while True:
130+
if request_id not in self.request_events:
131+
# The request has been aborted.
132+
return
133+
83134
# Kick the server if the server is not running.
84135
if not self.is_server_running:
85136
await self.server_step(request_id)
@@ -113,6 +164,14 @@ async def generate(self, prompt: str, sampling_params: SamplingParams,
113164
break
114165

115166
async def abort(self, request_id: str) -> None:
167+
"""Abort a request.
168+
169+
Abort a submitted request. If the request is finished or not found,
170+
this method will be a no-op.
171+
172+
Args:
173+
request_id: The unique id of the request.
174+
"""
116175
if request_id not in self.request_events:
117176
# The request has already finished or been aborted.
118177
return
@@ -137,6 +196,7 @@ async def abort(self, request_id: str) -> None:
137196

138197
@classmethod
139198
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer":
199+
"""Creates an async LLM server from the server arguments."""
140200
# Create the server configs.
141201
server_configs = server_args.create_server_configs()
142202
parallel_config = server_configs[2]

cacheflow/server/llm_server.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from cacheflow.outputs import RequestOutput
99
from cacheflow.sampling_params import SamplingParams
1010
from cacheflow.server.arg_utils import ServerArgs
11-
from cacheflow.server.ray_utils import ray, initialize_cluster
11+
from cacheflow.server.ray_utils import DeviceID, initialize_cluster, ray
1212
from cacheflow.server.tokenizer_utils import (get_tokenizer,
1313
detokenize_incrementally)
1414
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
@@ -19,6 +19,33 @@
1919

2020

2121
class LLMServer:
22+
"""An LLM server that receives requests and generates texts.
23+
24+
This is the main class for the CacheFlow LLM server. It receives requests
25+
from clients and generates texts from the LLM. It includes a tokenizer, a
26+
language model (possibly distributed across multiple GPUs), and GPU memory
27+
space allocated for intermediate states (aka KV cache). This class utilizes
28+
iteration-level scheduling and efficient memory management to maximize the
29+
serving throughput.
30+
31+
The `LLM` class wraps this class for offline batched inference and the
32+
`AsyncLLMServer` class wraps this class for online serving.
33+
34+
NOTE: The config arguments are derived from the `ServerArgs` class. For the
35+
comprehensive list of arguments, see `ServerArgs`.
36+
37+
Args:
38+
model_config: The configuration related to the LLM model.
39+
cache_config: The configuration related to the KV cache memory
40+
management.
41+
parallel_config: The configuration related to distributed execution.
42+
scheduler_config: The configuration related to the request scheduler.
43+
distributed_init_method: The initialization method for distributed
44+
execution. See `torch.distributed.init_process_group` for details.
45+
stage_devices: The list of devices for each stage. Each stage is a list
46+
of (rank, node_resource, device) tuples.
47+
log_stats: Whether to log statistics.
48+
"""
2249

2350
def __init__(
2451
self,
@@ -27,7 +54,7 @@ def __init__(
2754
parallel_config: ParallelConfig,
2855
scheduler_config: SchedulerConfig,
2956
distributed_init_method: str,
30-
stage_devices: List[List[Any]],
57+
stage_devices: List[List[DeviceID]],
3158
log_stats: bool,
3259
) -> None:
3360
logger.info(
@@ -83,6 +110,7 @@ def _verify_args(self) -> None:
83110
self.cache_config.verify_with_parallel_config(self.parallel_config)
84111

85112
def _init_cache(self) -> None:
113+
"""Profiles the memory usage and initializes the KV cache."""
86114
# Get the maximum number of blocks that can be allocated on GPU and CPU.
87115
num_blocks = self._run_workers(
88116
"profile_num_available_blocks",
@@ -108,6 +136,7 @@ def _init_cache(self) -> None:
108136

109137
@classmethod
110138
def from_server_args(cls, server_args: ServerArgs) -> "LLMServer":
139+
"""Creates an LLM server from the server arguments."""
111140
# Create the server configs.
112141
server_configs = server_args.create_server_configs()
113142
parallel_config = server_configs[2]
@@ -126,6 +155,22 @@ def add_request(
126155
prompt_token_ids: Optional[List[int]] = None,
127156
arrival_time: Optional[float] = None,
128157
) -> None:
158+
"""Add a request to the server's request pool.
159+
160+
The request is added to the request pool and will be processed by the
161+
scheduler as `server.step()` is called. The exact scheduling policy is
162+
determined by the scheduler.
163+
164+
Args:
165+
request_id: The unique ID of the request.
166+
prompt: The prompt string. Can be None if prompt_token_ids is
167+
provided.
168+
sampling_params: The sampling parameters for text generation.
169+
prompt_token_ids: The token IDs of the prompt. If None, we
170+
use the tokenizer to convert the prompts to token IDs.
171+
arrival_time: The arrival time of the request. If None, we use
172+
the current time.
173+
"""
129174
if arrival_time is None:
130175
arrival_time = time.time()
131176
if prompt_token_ids is None:
@@ -148,15 +193,30 @@ def add_request(
148193
self.scheduler.add_seq_group(seq_group)
149194

150195
def abort_request(self, request_id: str) -> None:
196+
"""Aborts a request with the given ID.
197+
198+
Args:
199+
request_id: The ID of the request to abort.
200+
"""
151201
self.scheduler.abort_seq_group(request_id)
152202

153203
def get_num_unfinished_requests(self) -> int:
204+
"""Gets the number of unfinished requests."""
154205
return self.scheduler.get_num_unfinished_seq_groups()
155206

156207
def has_unfinished_requests(self) -> bool:
208+
"""Returns True if there are unfinished requests."""
157209
return self.scheduler.has_unfinished_seqs()
158210

159211
def step(self) -> List[RequestOutput]:
212+
"""Performs one decoding iteration and returns newly generated results.
213+
214+
This function performs one decoding iteration for the server. It first
215+
schedules the sequences to be executed in the next iteration and the
216+
token blocks to be swapped in/out/copy. Then, it executes the model
217+
and updates the scheduler with the model outputs. Finally, it decodes
218+
the sequences and returns the newly generated results.
219+
"""
160220
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
161221
if (not seq_group_metadata_list) and scheduler_outputs.is_empty():
162222
# Nothing to do.
@@ -188,7 +248,7 @@ def step(self) -> List[RequestOutput]:
188248
return request_outputs
189249

190250
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
191-
# Decode the sequence outputs.
251+
"""Decodes the sequence outputs."""
192252
for seq_group in seq_groups:
193253
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
194254
new_token, new_output_text = detokenize_incrementally(
@@ -201,7 +261,7 @@ def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
201261
seq.output_text = new_output_text
202262

203263
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
204-
# Stop the sequences.
264+
"""Stop the finished sequences."""
205265
for seq_group in seq_groups:
206266
sampling_params = seq_group.sampling_params
207267
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
@@ -238,6 +298,7 @@ def _run_workers(
238298
*args,
239299
**kwargs,
240300
) -> Any:
301+
"""Runs the given method on all workers."""
241302
all_outputs = []
242303
for worker in self.workers:
243304
executor = getattr(worker, method)

0 commit comments

Comments
 (0)