Skip to content

Refactor system architecture #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 80 commits into from
May 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
33ef394
Use runtime profiling to replace manual memory analyzers
zhuohan123 May 7, 2023
f07bc4a
Merge branch 'main' into dynamic-memory-profiler
zhuohan123 May 8, 2023
72f5b9a
Fix merge error
zhuohan123 May 9, 2023
38771dc
Add argument for cache memory utilization
zhuohan123 May 9, 2023
ff18742
Add comments
zhuohan123 May 9, 2023
78c6f30
Resolve part of review comments
zhuohan123 May 11, 2023
3881670
Fix comments on GPU memory percentage
zhuohan123 May 13, 2023
c96ff33
Merge branch 'main' into dynamic-memory-profiler
zhuohan123 May 13, 2023
dcca6f4
Fix merging errors
zhuohan123 May 13, 2023
511cc61
fix logging
zhuohan123 May 13, 2023
51de2cb
fix a bug in sampler
zhuohan123 May 13, 2023
5d46bec
fix fastapi frontend
zhuohan123 May 13, 2023
56f14bf
Minor
WoosukKwon May 14, 2023
768cdf4
Minor
WoosukKwon May 14, 2023
7386dc8
[WIP] Remove ambiguous parts of controllers
WoosukKwon May 14, 2023
fb4aa64
Merge branch 'main' into refactor
WoosukKwon May 18, 2023
52b342f
Minor typo in sampling params
WoosukKwon May 18, 2023
03ec645
Fix random seed
zhuohan123 May 18, 2023
83c46e8
Fix the placement for get_cache_block_size
zhuohan123 May 18, 2023
99cb539
Merge branch 'main' into dynamic-memory-profiler
zhuohan123 May 18, 2023
b9a8da9
Profile memory with max_num_sequences sequences
zhuohan123 May 18, 2023
9d5e531
tmp
WoosukKwon May 19, 2023
7dcaa40
Merge branch 'dynamic-memory-profiler' into refactor
WoosukKwon May 19, 2023
026fd69
Remove controller
WoosukKwon May 19, 2023
6224546
Add config
WoosukKwon May 19, 2023
6daf8cd
Fix model loader
WoosukKwon May 19, 2023
4858e60
Fix cache engine
WoosukKwon May 19, 2023
c76c5e2
Minor fix for attention
WoosukKwon May 19, 2023
fc37406
Fix worker
WoosukKwon May 19, 2023
80b0d60
Move to tokenizer_utils
WoosukKwon May 19, 2023
8b1dfd1
Delete simple_frontend
WoosukKwon May 19, 2023
d2545f4
Add ray_utils
WoosukKwon May 19, 2023
7acde9d
Minor fix in config
WoosukKwon May 19, 2023
a9a7acf
Add LLMServe
WoosukKwon May 19, 2023
71686c9
Remove server
WoosukKwon May 19, 2023
fabcfbd
Merge branch 'main' into refactor
WoosukKwon May 20, 2023
a8accf2
Minor
WoosukKwon May 20, 2023
6ea5a60
Minor fix
WoosukKwon May 20, 2023
a9e0871
Minor fix for ray_utils
WoosukKwon May 20, 2023
2c6f252
Move get_cache_block_size to CacheEngine
WoosukKwon May 20, 2023
12e0057
Fix scheduler
WoosukKwon May 20, 2023
dc5548b
Fix worker
WoosukKwon May 20, 2023
6697d9b
Fix LLM server
WoosukKwon May 20, 2023
9e5df49
Add arg utils
WoosukKwon May 20, 2023
685af3d
Fix simple server
WoosukKwon May 20, 2023
7273734
Minor
WoosukKwon May 20, 2023
4e68a31
Move fastapi server
WoosukKwon May 20, 2023
f635c85
Minor
WoosukKwon May 20, 2023
d529769
Add sampling_params in SequenceGroup
WoosukKwon May 20, 2023
40e8edb
Add stream to sampling params
WoosukKwon May 20, 2023
3d5fa0f
Output seq groups
WoosukKwon May 20, 2023
84fdabb
Add outputs
WoosukKwon May 20, 2023
544a54e
Return outputs from step
WoosukKwon May 20, 2023
13ddb4e
Fix logprobs output
WoosukKwon May 20, 2023
64b1240
Simplify simple server
WoosukKwon May 20, 2023
d6237bf
Set seed before model initialization
WoosukKwon May 20, 2023
6f51383
Add functions and classes to cacheflow.__init__
WoosukKwon May 20, 2023
674ee75
group id -> request id
WoosukKwon May 20, 2023
0cb5fea
Minor
WoosukKwon May 20, 2023
52d59ca
Minor fix for backward compatibility
WoosukKwon May 20, 2023
776930c
Minor
WoosukKwon May 20, 2023
9532ef1
[WIP] Fix FastAPI server
WoosukKwon May 20, 2023
6d39f82
Move simple_server to examples
WoosukKwon May 20, 2023
4b87959
Remove stream in sampling params
WoosukKwon May 20, 2023
e9219af
Remove stream outputs
WoosukKwon May 20, 2023
e6fbd85
Minor fix
WoosukKwon May 20, 2023
e1f1eb4
Return every seq group at every step
WoosukKwon May 20, 2023
fdbe518
Fix simple_server
WoosukKwon May 20, 2023
2c4b0f4
Minor fix
WoosukKwon May 20, 2023
10c4776
Minor fix
WoosukKwon May 20, 2023
6ea6677
Rename output varialbes
WoosukKwon May 20, 2023
ae863bc
Minor bugfix
WoosukKwon May 20, 2023
8886fa5
Fix FastAPI server
WoosukKwon May 20, 2023
11d5f78
Minor fix
WoosukKwon May 20, 2023
7dca549
Move fastapi server
WoosukKwon May 20, 2023
8d0e44e
Fix README
WoosukKwon May 20, 2023
d4807a7
Fix README
WoosukKwon May 20, 2023
c94dfe1
[Minor] Rename entrypoint -> entrypoints
WoosukKwon May 20, 2023
6929b6e
[Minor] Rename entrypoint -> entrypoints
WoosukKwon May 20, 2023
470b812
Merge branch 'main' into refactor
WoosukKwon May 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,30 @@ pip install -e . # This may take several minutes.
## Test simple server

```bash
# Single-GPU inference.
python examples/simple_server.py # --model <your_model>

# Multi-GPU inference (e.g., 2 GPUs).
ray start --head
python simple_server.py
python examples/simple_server.py -tp 2 # --model <your_model>
```

The detailed arguments for `simple_server.py` can be found by:
```bash
python simple_server.py --help
python examples/simple_server.py --help
```

## FastAPI server

To start the server:
```bash
ray start --head
python -m cacheflow.http_frontend.fastapi_frontend
python -m cacheflow.entrypoints.fastapi_server # --model <your_model>
```

To test the server:
```bash
python -m cacheflow.http_frontend.test_cli_client
python test_cli_client.py
```

## Gradio web server
Expand All @@ -55,7 +59,6 @@ Since LLaMA weight is not fully public, we cannot directly download the LLaMA we
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path/llama-7b
```
Please make sure that `llama` is included in the output directory name.
2. For all the commands above, specify the model with `--model /output/path/llama-7b` to load the model. For example:
```bash
python simple_server.py --model /output/path/llama-7b
Expand Down
19 changes: 19 additions & 0 deletions cacheflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import (
add_server_arguments,
create_server_configs_from_args,
initialize_server_from_args,
)
from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster

__all__ = [
"RequestOutput",
"SamplingParams",
"LLMServer",
"add_server_arguments",
"create_server_configs_from_args",
"initialize_server_from_args",
"initialize_cluster",
]
165 changes: 165 additions & 0 deletions cacheflow/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from typing import Optional

import torch
from transformers import AutoConfig, PretrainedConfig


class ModelConfig:

def __init__(
self,
model: str,
download_dir: Optional[str],
use_np_weights: bool,
use_dummy_weights: bool,
dtype: str,
seed: int,
) -> None:
self.model = model
self.download_dir = download_dir
self.use_np_weights = use_np_weights
self.use_dummy_weights = use_dummy_weights
self.seed = seed

self.hf_config: PretrainedConfig = AutoConfig.from_pretrained(model)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)

def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
total_num_attention_heads = self.hf_config.num_attention_heads
tensor_parallel_size = parallel_config.tensor_parallel_size
if total_num_attention_heads % tensor_parallel_size != 0:
raise ValueError(
f"Total number of attention heads ({total_num_attention_heads})"
" must be divisible by tensor parallel size "
f"({tensor_parallel_size}).")

total_num_hidden_layers = self.hf_config.num_hidden_layers
pipeline_parallel_size = parallel_config.pipeline_parallel_size
if total_num_hidden_layers % pipeline_parallel_size != 0:
raise ValueError(
f"Total number of hidden layers ({total_num_hidden_layers}) "
"must be divisible by pipeline parallel size "
f"({pipeline_parallel_size}).")

def get_hidden_size(self) -> int:
return self.hf_config.hidden_size

def get_head_size(self) -> int:
# FIXME(woosuk): This may not be true for all models.
return self.hf_config.hidden_size // self.hf_config.num_attention_heads

def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size


class CacheConfig:

def __init__(
self,
block_size: int,
gpu_memory_utilization: float,
swap_space: int,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space = swap_space

# Will be set after profiling.
self.num_gpu_blocks = None
self.num_cpu_blocks = None


class ParallelConfig:

def __init__(
self,
pipeline_parallel_size: int,
tensor_parallel_size: int,
use_ray: bool,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.use_ray = use_ray

self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1:
self.use_ray = True
self._verify_args()

def _verify_args(self) -> None:
if self.pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is not supported yet.")


class SchedulerConfig:

def __init__(
self,
max_num_batched_tokens: int,
max_num_seqs: int,
) -> None:
self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_seqs = max_num_seqs


_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
"float": torch.float32,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}


def _get_and_verify_dtype(
config: PretrainedConfig,
dtype: str,
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32

dtype = dtype.lower()
if dtype == "default":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]

# Verify the dtype.
if torch_dtype != config_dtype:
if torch_dtype == torch.float32:
# Upcasting to float32 is allowed.
pass
elif config_dtype == torch.float32:
# Downcasting from float32 to float16 or bfloat16 is allowed.
pass
else:
# Casting between float16 and bfloat16 is not allowed.
raise ValueError(
f"Cannot use {torch_dtype} for {config_dtype} model.")

# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:
compute_capability = torch.cuda.get_device_capability()
if compute_capability[0] < 8:
gpu_name = torch.cuda.get_device_name()
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
f"{compute_capability[0]}.{compute_capability[1]}.")
return torch_dtype
Loading