Skip to content

Automatically configure KV cache size #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## Installation

```bash
pip install cmake torch transformers
pip install psutil numpy torch transformers
pip install flash-attn # This may take up to 10 mins.
pip install -e .
```
Expand Down
6 changes: 3 additions & 3 deletions cacheflow/master/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from cacheflow.sequence import SequenceOutputs
from cacheflow.sequence import SequenceStatus

_MAX_NUM_BATCHED_TOKENS = 2048


class Scheduler:

Expand All @@ -21,12 +19,14 @@ def __init__(
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
max_num_batched_tokens: int,
) -> None:
self.frontend = frontend
self.controllers = controllers
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
self.max_num_batched_tokens = max_num_batched_tokens

# Create the block space manager.
self.block_manager = BlockSpaceManager(
Expand Down Expand Up @@ -164,7 +164,7 @@ def step(self) -> None:
num_prompt_tokens = seq_group.seqs[0].get_len()
if self.block_manager.can_allocate(seq_group):
if (num_batched_tokens + num_prompt_tokens
<= _MAX_NUM_BATCHED_TOKENS):
<= self.max_num_batched_tokens):
self._allocate(seq_group)
num_batched_tokens += num_prompt_tokens
continue
Expand Down
6 changes: 4 additions & 2 deletions cacheflow/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from cacheflow.models.input_metadata import InputMetadata
from cacheflow.models.model_utils import get_memory_analyzer
from cacheflow.models.model_utils import get_model
from cacheflow.models.model_utils import set_seed
from cacheflow.models.utils import set_seed


__all__ = [
'InputMetadata',
'get_memory_analyzer',
'get_model',
'set_seed'
'set_seed',
]
125 changes: 125 additions & 0 deletions cacheflow/models/memory_analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import torch
from transformers import AutoConfig

from cacheflow.models.utils import get_cpu_memory
from cacheflow.models.utils import get_dtype_size
from cacheflow.models.utils import get_gpu_memory

_GiB = 1 << 30


class CacheFlowMemoryAnalyzer:

def get_max_num_gpu_blocks(
self,
max_num_batched_tokens: int,
memory_utilization: float,
) -> int:
raise NotImplementedError()

def get_max_num_cpu_blocks(
self,
memory_utilization: float,
) -> int:
raise NotImplementedError()


class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):

def __init__(
self,
model_name: str,
block_size: int,
dtype: torch.dtype,
) -> None:
self.model_name = model_name
self.block_size = block_size
self.dtype = dtype

# TODO(woosuk): Support tensor parallelism.
config = AutoConfig.from_pretrained(model_name)
self.num_layers = config.num_hidden_layers
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = config.hidden_size // self.num_heads
self.ffn_size = config.ffn_dim
self.embedding_size = config.word_embed_proj_dim
self.vocab_size = config.vocab_size
self.max_position = config.max_position_embeddings

def _get_param_size(self) -> int:
# TODO(woosuk): Support tensor parallelism.
word_embedding = self.vocab_size * self.embedding_size
if self.embedding_size != self.vocab_size:
# Project in/out.
word_embedding += 2 * self.embedding_size * self.vocab_size
position_embedding = self.max_position * self.hidden_size

ln1 = 2 * self.hidden_size
q = self.hidden_size * self.hidden_size + self.hidden_size
k = self.hidden_size * self.hidden_size + self.hidden_size
v = self.hidden_size * self.hidden_size + self.hidden_size
out = self.hidden_size * self.hidden_size + self.hidden_size
mha = ln1 + q + k + v + out

ln2 = 2 * self.hidden_size
ffn1 = self.hidden_size * self.ffn_size + self.ffn_size
ffn2 = self.ffn_size * self.hidden_size + self.hidden_size
ffn = ln2 + ffn1 + ffn2

total = (word_embedding + position_embedding +
self.num_layers * (mha + ffn))
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total

def _get_max_act_size(
self,
max_num_batched_tokens: int,
) -> int:
# TODO(woosuk): Support tensor parallelism.
# NOTE: We approxmiately calculate the maximum activation size by
# 1) estimating the maximum activation tensor size during inference, and
# 2) multiplying it by 4.
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
qkv = 3 * (max_num_batched_tokens * self.hidden_size)
ffn = max_num_batched_tokens * self.ffn_size
max_act = 4 * max(qkv, ffn)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * max_act

def _get_workspace_size(self) -> int:
return 1 * _GiB

def _get_cache_block_size(self) -> int:
key_cache_block = self.block_size * self.num_heads * self.head_size
value_cache_block = self.block_size * self.num_heads * self.head_size
total = self.num_layers * (key_cache_block + value_cache_block)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total

def get_max_num_gpu_blocks(
self,
max_num_batched_tokens: int,
memory_utilization: float = 0.95,
) -> int:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
gpu_memory = get_gpu_memory()
usable_memory = int(memory_utilization * gpu_memory)

param_size = self._get_param_size()
act_size = self._get_max_act_size(max_num_batched_tokens)
workspace_size = self._get_workspace_size()

max_cache_size = usable_memory - (param_size + act_size + workspace_size)
max_num_blocks = max_cache_size // self._get_cache_block_size()
return max_num_blocks

def get_max_num_cpu_blocks(
self,
memory_utilization: float = 0.25,
) -> int:
cpu_memory = get_cpu_memory()
usable_memory = int(memory_utilization * cpu_memory)
max_num_blocks = usable_memory // self._get_cache_block_size()
return max_num_blocks
44 changes: 23 additions & 21 deletions cacheflow/models/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,44 @@
import random
from typing import Union

import numpy as np
import torch
import torch.nn as nn

from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
from cacheflow.models.opt import OPTForCausalLM
from cacheflow.models.utils import get_torch_dtype

MODEL_CLASSES = {

_MODELS = {
'opt': OPTForCausalLM,
}

STR_DTYPE_TO_TORCH_DTYPE = {
'half': torch.half,
'float': torch.float,
'float16': torch.float16,
'float32': torch.float32,
_MEMORY_ANALYZERS = {
'opt': OPTMemoryAnalyzer,
}


def get_model(
model_name: str,
dtype: Union[torch.dtype, str],
) -> nn.Module:
if isinstance(dtype, str):
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()]
else:
torch_dtype = dtype
for model_class, hf_model in MODEL_CLASSES.items():
torch_dtype = get_torch_dtype(dtype)
for model_class, hf_model in _MODELS.items():
if model_class in model_name:
model = hf_model.from_pretrained(model_name, torch_dtype=torch_dtype)
model = hf_model.from_pretrained(
model_name, torch_dtype=torch_dtype)
return model.eval()
raise ValueError(f'Invalid model name: {model_name}')
raise ValueError(f'Unsupported model name: {model_name}')


def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def get_memory_analyzer(
model_name: str,
block_size: int,
dtype: Union[torch.dtype, str],
) -> CacheFlowMemoryAnalyzer:
torch_dtype = get_torch_dtype(dtype)
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
if model_class in model_name:
return memory_analyzer(
model_name, block_size, torch_dtype)
raise ValueError(f'Unsupported model name: {model_name}')
43 changes: 43 additions & 0 deletions cacheflow/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Union

import random

import numpy as np
import psutil
import torch

_STR_DTYPE_TO_TORCH_DTYPE = {
'half': torch.half,
'float': torch.float,
'float16': torch.float16,
'float32': torch.float32,
}


def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
if isinstance(dtype, str):
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()]
else:
torch_dtype = dtype
return torch_dtype


def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
torch_dtype = get_torch_dtype(dtype)
return torch.tensor([], dtype=torch_dtype).element_size()


def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)


def get_gpu_memory(gpu: int = 0) -> int:
return torch.cuda.get_device_properties(gpu).total_memory


def get_cpu_memory() -> int:
return psutil.virtual_memory().total
24 changes: 17 additions & 7 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,33 @@

from cacheflow.master.frontend import Frontend
from cacheflow.master.scheduler import Scheduler
from cacheflow.models import get_memory_analyzer
from cacheflow.worker.controller import Controller

parser = argparse.ArgumentParser(description='CacheFlow server')
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes')
parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node')
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
# TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks.
parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks (per GPU)')
parser.add_argument('--num-cpu-blocks', type=int, default=32, help='number of CPU blocks (per GPU)')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--max-batch-size', type=int, default=2048, help='maximum number of batched tokens')
args = parser.parse_args()


def main():
memory_analyzer = get_memory_analyzer(
model_name=args.model,
block_size=args.block_size,
dtype=args.dtype,
)
num_gpu_blocks = memory_analyzer.get_max_num_gpu_blocks(
max_num_batched_tokens=args.max_batch_size)
num_cpu_blocks = memory_analyzer.get_max_num_cpu_blocks()
print(f'# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}')

# Create a controller for each node.
controllers: List[Controller] = []
for i in range(args.num_nodes):
Expand All @@ -29,8 +38,8 @@ def main():
num_workers=args.num_workers,
model_name=args.model,
block_size=args.block_size,
num_gpu_blocks=args.num_gpu_blocks,
num_cpu_blocks=args.num_cpu_blocks,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
dtype=args.dtype,
seed=args.seed,
)
Expand All @@ -47,8 +56,9 @@ def main():
frontend=frontend,
controllers=controllers,
block_size=args.block_size,
num_gpu_blocks=args.num_gpu_blocks,
num_cpu_blocks=args.num_cpu_blocks,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
max_num_batched_tokens=args.max_batch_size,
)
# Connect the controllers.
for i in range(len(controllers) - 1):
Expand Down