Skip to content

Print warnings/errors for large swap space #123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions cacheflow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
import torch
from transformers import AutoConfig, PretrainedConfig

from cacheflow.logger import init_logger
from cacheflow.utils import get_cpu_memory

logger = init_logger(__name__)

_GiB = 1 << 30


Expand Down Expand Up @@ -73,11 +78,37 @@ def __init__(
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GiB
self._verify_args()

# Will be set after profiling.
self.num_gpu_blocks = None
self.num_cpu_blocks = None

def _verify_args(self) -> None:
if self.gpu_memory_utilization > 1.0:
raise ValueError(
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")

def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
total_cpu_memory = get_cpu_memory()
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
# group are in the same node. However, the GPUs may span multiple nodes.
num_gpus_per_node = parallel_config.tensor_parallel_size
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

msg = (
f"{cpu_memory_usage / _GiB:.2f} GiB out of "
f"the {total_cpu_memory / _GiB:.2f} GiB total CPU memory is "
"allocated for the swap space.")
if cpu_memory_usage > 0.7 * total_cpu_memory:
raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory:
logger.warn("Possibly too large swap space. " + msg)


class ParallelConfig:

Expand Down
1 change: 1 addition & 0 deletions cacheflow/server/llm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(

def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)

def _init_cache(self) -> None:
# Get the maximum number of blocks that can be allocated on GPU and CPU.
Expand Down
2 changes: 2 additions & 0 deletions cacheflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ def reset(self) -> None:


def get_gpu_memory(gpu: int = 0) -> int:
"""Returns the total memory of the GPU in bytes."""
return torch.cuda.get_device_properties(gpu).total_memory


def get_cpu_memory() -> int:
"""Returns the total CPU memory of the node in bytes."""
return psutil.virtual_memory().total