Skip to content

Add docstrings for LLMServer and related classes and examples #142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions cacheflow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@


class ModelConfig:
"""Configuration for the model.

Args:
model: Name or path of the huggingface model to use.
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
use_np_weights: Save a numpy copy of model weights for faster loading.
This can increase the disk usage by up to 2x.
use_dummy_weights: Use dummy values for model weights (for profiling).
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
seed: Random seed for reproducibility.
"""

def __init__(
self,
Expand Down Expand Up @@ -68,7 +82,14 @@ def get_num_layers(self, parallel_config: "ParallelConfig") -> int:


class CacheConfig:

"""Configuration for the KV cache.

Args:
block_size: Size of a cache block in number of tokens.
gpu_memory_utilization: Fraction of GPU memory to use for the
CacheFlow execution.
swap_space: Size of the CPU swap space per GPU (in GiB).
"""
def __init__(
self,
block_size: int,
Expand Down Expand Up @@ -111,7 +132,15 @@ def verify_with_parallel_config(


class ParallelConfig:

"""Configuration for the distributed execution.

Args:
pipeline_parallel_size: Number of pipeline parallel groups.
tensor_parallel_size: Number of tensor parallel groups.
worker_use_ray: Whether to use Ray for model workers. Will be set to
True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1.
"""
def __init__(
self,
pipeline_parallel_size: int,
Expand All @@ -134,7 +163,14 @@ def _verify_args(self) -> None:


class SchedulerConfig:

"""Scheduler configuration.

Args:
max_num_batched_tokens: Maximum number of tokens to be processed in
a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single
iteration.
"""
def __init__(
self,
max_num_batched_tokens: int,
Expand Down
12 changes: 12 additions & 0 deletions cacheflow/entrypoints/openai/openai_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ def create_logprobs(token_ids: List[int],

@app.post("/v1/completions")
async def create_completion(raw_request: Request):
"""Completion API similar to OpenAI's API.

See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.

NOTE: Currently we do not support the following features:
- echo (since the cacheflow server does not currently support
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported in cacheflow server)
"""
request = CompletionRequest(**await raw_request.json())
logger.info(f"Received completion request: {request}")

Expand Down
6 changes: 6 additions & 0 deletions cacheflow/entrypoints/simple_fastapi_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@

@app.post("/generate")
async def generate_stream(request: Request) -> StreamingResponse:
""" Stream the results of the generation request.

The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
sampling_params = SamplingParams(**request_dict)
Expand Down
2 changes: 2 additions & 0 deletions cacheflow/server/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

@dataclass
class ServerArgs:
"""Arguments for CacheFlow servers."""
model: str
download_dir: Optional[str] = None
use_np_weights: bool = False
Expand Down Expand Up @@ -117,6 +118,7 @@ def create_server_configs(

@dataclass
class AsyncServerArgs(ServerArgs):
"""Arguments for asynchronous CacheFlow servers."""
server_use_ray: bool = False

@staticmethod
Expand Down
74 changes: 67 additions & 7 deletions cacheflow/server/async_llm_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import time
from typing import Dict, Optional
from typing import Dict, List, Optional

from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
Expand All @@ -15,7 +15,25 @@


class AsyncLLMServer:

"""An asynchronous wrapper for LLMServer.

This class is used to wrap the LLMServer class to make it asynchronous. It
uses asyncio to create a background loop that keeps processing incoming
requests. The LLMServer is kicked by the generate method when there
are requests in the waiting queue. The generate method yields the outputs
from the LLMServer to the caller.

NOTE: For the comprehensive list of arguments, see `LLMServer`.

Args:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
server_use_ray: Whether to make LLMServer a Ray actor. If so, the
async frontend will be executed in a separate process as the
model workers.
*args, *kwargs: Arguments for LLMServer.
"""
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
*args, **kwargs) -> None:
self.worker_use_ray = worker_use_ray
Expand All @@ -35,6 +53,7 @@ def __init__(self, worker_use_ray: bool, server_use_ray: bool,
self.kicking_request_id: Optional[str] = None

async def server_step(self, kicking_request_id: Optional[str] = None):
"""Kick the server to process the waiting requests."""
self.is_server_running = True
self.kicking_request_id = kicking_request_id
if self.server_use_ray:
Expand All @@ -54,8 +73,31 @@ async def server_step(self, kicking_request_id: Optional[str] = None):
self.request_outputs[request_id] = request_output
self.request_events[request_id].set()

async def generate(self, prompt: str, sampling_params: SamplingParams,
request_id: str) -> RequestOutput:
async def generate(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None
) -> RequestOutput:
"""Generate outputs for a request.

Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMServer and streams the outputs
from the LLMServer to the caller.

Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.

Yields:
The output `RequestOutput` objects from the LLMServer for the
request.
"""
# Preprocess the request.
arrival_time = time.time()

Expand All @@ -66,20 +108,29 @@ async def generate(self, prompt: str, sampling_params: SamplingParams,

logger.info(f"Received request {request_id}: "
f"prompt: {prompt!r}, "
f"sampling params: {sampling_params}.")
f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.")

# Add the request into the cacheflow server's waiting queue.
if self.server_use_ray:
await self.server.add_request.remote(
request_id, prompt, sampling_params, arrival_time=arrival_time)
request_id, prompt, sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
else:
self.server.add_request(
request_id, prompt, sampling_params, arrival_time=arrival_time)
request_id, prompt, sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)

# The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# the server to process the requests.
while True:
if request_id not in self.request_events:
# The request has been aborted.
return

# Kick the server if the server is not running.
if not self.is_server_running:
await self.server_step(request_id)
Expand Down Expand Up @@ -113,6 +164,14 @@ async def generate(self, prompt: str, sampling_params: SamplingParams,
break

async def abort(self, request_id: str) -> None:
"""Abort a request.

Abort a submitted request. If the request is finished or not found,
this method will be a no-op.

Args:
request_id: The unique id of the request.
"""
if request_id not in self.request_events:
# The request has already finished or been aborted.
return
Expand All @@ -137,6 +196,7 @@ async def abort(self, request_id: str) -> None:

@classmethod
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer":
"""Creates an async LLM server from the server arguments."""
# Create the server configs.
server_configs = server_args.create_server_configs()
parallel_config = server_configs[2]
Expand Down
69 changes: 65 additions & 4 deletions cacheflow/server/llm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.ray_utils import ray, initialize_cluster
from cacheflow.server.ray_utils import DeviceID, initialize_cluster, ray
from cacheflow.server.tokenizer_utils import (get_tokenizer,
detokenize_incrementally)
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
Expand All @@ -19,6 +19,33 @@


class LLMServer:
"""An LLM server that receives requests and generates texts.

This is the main class for the CacheFlow LLM server. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes
iteration-level scheduling and efficient memory management to maximize the
serving throughput.

The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMServer` class wraps this class for online serving.

NOTE: The config arguments are derived from the `ServerArgs` class. For the
comprehensive list of arguments, see `ServerArgs`.

Args:
model_config: The configuration related to the LLM model.
cache_config: The configuration related to the KV cache memory
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
distributed_init_method: The initialization method for distributed
execution. See `torch.distributed.init_process_group` for details.
stage_devices: The list of devices for each stage. Each stage is a list
of (rank, node_resource, device) tuples.
log_stats: Whether to log statistics.
"""

def __init__(
self,
Expand All @@ -27,7 +54,7 @@ def __init__(
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
distributed_init_method: str,
stage_devices: List[List[Any]],
stage_devices: List[List[DeviceID]],
log_stats: bool,
) -> None:
logger.info(
Expand Down Expand Up @@ -83,6 +110,7 @@ def _verify_args(self) -> None:
self.cache_config.verify_with_parallel_config(self.parallel_config)

def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache."""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers(
"profile_num_available_blocks",
Expand All @@ -108,6 +136,7 @@ def _init_cache(self) -> None:

@classmethod
def from_server_args(cls, server_args: ServerArgs) -> "LLMServer":
"""Creates an LLM server from the server arguments."""
# Create the server configs.
server_configs = server_args.create_server_configs()
parallel_config = server_configs[2]
Expand All @@ -126,6 +155,22 @@ def add_request(
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
) -> None:
"""Add a request to the server's request pool.

The request is added to the request pool and will be processed by the
scheduler as `server.step()` is called. The exact scheduling policy is
determined by the scheduler.

Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters for text generation.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current time.
"""
if arrival_time is None:
arrival_time = time.time()
if prompt_token_ids is None:
Expand All @@ -148,15 +193,30 @@ def add_request(
self.scheduler.add_seq_group(seq_group)

def abort_request(self, request_id: str) -> None:
"""Aborts a request with the given ID.

Args:
request_id: The ID of the request to abort.
"""
self.scheduler.abort_seq_group(request_id)

def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups()

def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs()

def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.

This function performs one decoding iteration for the server. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if (not seq_group_metadata_list) and scheduler_outputs.is_empty():
# Nothing to do.
Expand Down Expand Up @@ -188,7 +248,7 @@ def step(self) -> List[RequestOutput]:
return request_outputs

def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Decode the sequence outputs.
"""Decodes the sequence outputs."""
for seq_group in seq_groups:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
new_token, new_output_text = detokenize_incrementally(
Expand All @@ -201,7 +261,7 @@ def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
seq.output_text = new_output_text

def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Stop the sequences.
"""Stop the finished sequences."""
for seq_group in seq_groups:
sampling_params = seq_group.sampling_params
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
Expand Down Expand Up @@ -238,6 +298,7 @@ def _run_workers(
*args,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
all_outputs = []
for worker in self.workers:
executor = getattr(worker, method)
Expand Down
Loading