Skip to content

Use runtime profiling to replace manual memory analyzers #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions cacheflow/core/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
import ray
except ImportError:
ray = None
import numpy as np
import torch

from cacheflow.core.scheduler import Scheduler
from cacheflow.frontend.simple_frontend import SimpleFrontend
from cacheflow.logger import init_logger
from cacheflow.model_executor import get_memory_analyzer
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceGroup
from cacheflow.utils import get_gpu_memory, get_cpu_memory
from cacheflow.worker.controller import Controller, DeviceID

logger = init_logger(__name__)
Expand All @@ -34,14 +33,13 @@ def __init__(
dtype: str,
seed: int,
swap_space: int,
gpu_memory_utilization: float,
max_num_batched_tokens: int,
max_num_sequences: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
all_stage_devices: List[List[DeviceID]],
gpu_memory: int,
cpu_memory: int,
use_ray: bool,
log_stats: bool,
):
Expand All @@ -63,19 +61,6 @@ def __init__(
assert self.world_size == 1, (
"Only support single GPU without Ray.")

self.memory_analyzer = get_memory_analyzer(
model_name=model,
block_size=block_size,
dtype=dtype,
gpu_memory=gpu_memory,
cpu_memory=cpu_memory,
tensor_parallel_size=tensor_parallel_size,
)
self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks(
max_num_batched_tokens=max_num_batched_tokens)
self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks(
swap_space_gib=swap_space)

# Create a controller for each pipeline stage.
self.controllers: List[Controller] = []
for i in range(pipeline_parallel_size):
Expand All @@ -87,19 +72,35 @@ def __init__(
tensor_parallel_size=tensor_parallel_size,
distributed_init_method=distributed_init_method,
model_name=model,
block_size=block_size,
num_gpu_blocks=self.num_gpu_blocks,
num_cpu_blocks=self.num_cpu_blocks,
dtype=dtype,
seed=seed,
cache_dir=cache_dir,
use_dummy_weights=use_dummy_weights,
use_np_cache=use_np_cache,
max_num_batched_tokens=max_num_batched_tokens,
max_num_sequences=max_num_sequences,
use_ray=use_ray,
)
self.controllers.append(controller)

# Initialize cache engine.
all_worker_num_available_blocks = []
for controller in self.controllers:
all_worker_num_available_blocks.extend(
controller.get_num_available_blocks(
block_size, swap_space, gpu_memory_utilization)
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
self.num_gpu_blocks = np.min([b[0] for b in all_worker_num_available_blocks])
self.num_cpu_blocks = np.min([b[1] for b in all_worker_num_available_blocks])
logger.info(f'# GPU blocks: {self.num_gpu_blocks}, '
f'# CPU blocks: {self.num_cpu_blocks}')
for controller in self.controllers:
controller.init_cache_engine(block_size, self.num_gpu_blocks,
self.num_cpu_blocks)

# Create a scheduler.
self.scheduler = Scheduler(
controllers=self.controllers,
Expand Down Expand Up @@ -214,7 +215,11 @@ def initialize_cluster(
all_stage_devices)


_GiB = 1 << 30


def add_server_arguments(parser: argparse.ArgumentParser):
"""Shared arguments for CacheFlow servers."""
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
parser.add_argument('--cache-dir', type=str, default=None,
Expand All @@ -238,15 +243,19 @@ def add_server_arguments(parser: argparse.ArgumentParser):
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor')
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
parser.add_argument('--log-stats', action='store_true', help='log system statistics')
return parser


def process_server_arguments(args: argparse.Namespace):
"""Post process the parsed arguments."""
if args.pipeline_parallel_size * args.tensor_parallel_size > 1:
args.use_ray = True
args.swap_space = args.swap_space * _GiB
args.max_num_sequences = min(args.max_num_sequences, args.max_num_batched_tokens)
return args


Expand Down Expand Up @@ -274,14 +283,13 @@ def init_local_server_and_frontend_with_arguments(args: argparse.Namespace):
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
gpu_memory_utilization=args.gpu_memory_utilization,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=args.use_ray,
log_stats=args.log_stats,
)
Expand Down
10 changes: 7 additions & 3 deletions cacheflow/frontend/fastapi_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from cacheflow.frontend.utils import get_tokenizer
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
from cacheflow.utils import Counter
from cacheflow.worker.controller import DeviceID

TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
Expand All @@ -34,13 +34,15 @@ def __init__(
dtype: str,
seed: int,
swap_space: int,
gpu_memory_utilization: float,
max_num_batched_tokens: int,
max_num_sequences: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
all_stage_devices: List[List[DeviceID]],
server_use_ray: bool,
log_stats: bool,
):
self.block_size = block_size

Expand All @@ -62,15 +64,15 @@ def __init__(
dtype=dtype,
seed=seed,
swap_space=swap_space,
gpu_memory_utilization=gpu_memory_utilization,
max_num_batched_tokens=max_num_batched_tokens,
max_num_sequences=max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=server_use_ray,
log_stats=log_stats,
)

self.running_seq_groups: Dict[int, SequenceGroup] = {}
Expand Down Expand Up @@ -182,13 +184,15 @@ async def generate_stream(request: Request):
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
gpu_memory_utilization=args.gpu_memory_utilization,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
server_use_ray=args.use_ray,
log_stats=args.log_stats,
)

uvicorn.run(app, host=args.host, port=args.port, log_level="info")
7 changes: 4 additions & 3 deletions cacheflow/model_executor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.model_executor.model_loader import get_model, get_memory_analyzer
from cacheflow.model_executor.utils import set_random_seed
from cacheflow.model_executor.model_loader import get_model
from cacheflow.model_executor.utils import (set_random_seed,
get_cache_block_size)


__all__ = [
"InputMetadata",
"get_cache_block_size",
"get_model",
"get_memory_analyzer",
"set_random_seed",
]
50 changes: 29 additions & 21 deletions cacheflow/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from cacheflow.model_executor.input_metadata import InputMetadata


_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256]

class GPTCacheFlowAttention(nn.Module):
"""GPT-style multi-head attention.

Expand Down Expand Up @@ -39,11 +41,19 @@ class GPTCacheFlowAttention(nn.Module):
5. Output a flattened 1D tensor.
"""

def __init__(self, scale: float) -> None:
def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp()

if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f'head_size ({self.head_size}) is not supported by '
'the single_query_cached_kv_attention kernel. '
'Use one of the following head sizes: '
f'{_SUPPORTED_HEAD_SIZES}.')

def multi_query_kv_attention(
self,
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
Expand Down Expand Up @@ -74,14 +84,6 @@ def single_query_cached_kv_attention(
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
) -> None:
head_size = value_cache.shape[2]
supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
if head_size not in supported_head_sizes:
raise ValueError(f'head_size ({head_size}) is not supported by '
'the single_query_cached_kv_attention kernel. '
'Use one of the following head sizes: '
f'{supported_head_sizes}.')

block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention(
output,
Expand All @@ -100,20 +102,18 @@ def forward(
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
key_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# NOTE: The query, key, and value tensors must be sliced from a qkv
# tensor of shape [num_tokens, 3 * num_heads * head_size].

# Reshape the query, key, and value tensors.
num_heads = value_cache.shape[1]
head_size = value_cache.shape[2]
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_heads, head_size)
value = value.view(-1, num_heads, head_size)
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)

# Pre-allocate the output tensor.
output = torch.empty_like(query)
Expand All @@ -134,8 +134,11 @@ def forward(
cache_event.wait()

# Reshape the keys and values and store them in the cache.
# When key_cache and value_cache are not provided, the new key
# and value vectors will not be cached.
num_valid_tokens = input_metadata.num_valid_tokens
if num_valid_tokens > 0:
if (num_valid_tokens > 0 and key_cache is not None
and value_cache is not None):
# The stride is 3 because the key and value are sliced from qkv.
cache_ops.reshape_and_cache(
key[:num_valid_tokens],
Expand All @@ -146,6 +149,10 @@ def forward(
)

if input_metadata.num_generation_tokens > 0:
assert key_cache is not None and value_cache is not None, (
"key_cache and value_cache must be provided when "
"generating tokens."
)
# Compute the attention op for generation tokens.
self.single_query_cached_kv_attention(
output[num_prompt_tokens:num_valid_tokens],
Expand All @@ -156,20 +163,22 @@ def forward(

# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
return output.view(-1, num_heads * head_size)
return output.view(-1, self.num_heads * self.head_size)


class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
"""Attention with GPT-NeoX style rotary embedding."""

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
rotary_dim: int,
max_position: int = 8192,
base: int = 10000,
) -> None:
super().__init__(scale)
super().__init__(num_heads, head_size, scale)

# Create the cos and sin cache.
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
Expand Down Expand Up @@ -199,12 +208,11 @@ def forward(
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
head_size = value_cache.shape[2]
pos_encoding_ops.rotary_embedding_neox(
positions,
query,
key,
head_size,
self.head_size,
self.cos_sin_cache,
)
return super().forward(
Expand Down
2 changes: 1 addition & 1 deletion cacheflow/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def forward(
# Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == probs.shape[0]
if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks):
if any(p < 1.0 for p in top_ps) or any(k != self.vocab_size for k in top_ks):
probs = _apply_top_p_top_k(probs, top_ps, top_ks)

# Sample the next tokens.
Expand Down
Loading