Skip to content

Commit 53ae506

Browse files
BoyuanFenghmellor
authored andcommitted
Allow users to specify kv cache memory size (vllm-project#21489)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
1 parent 0d40089 commit 53ae506

File tree

10 files changed

+236
-47
lines changed

10 files changed

+236
-47
lines changed

vllm/config/cache.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,15 @@ class CacheConfig:
113113
necessary for implementing this optimization in some models (e.g. Gemma3n)
114114
"""
115115

116+
kv_cache_memory_bytes: Optional[int] = None
117+
"""Size of KV Cache per GPU in bytes. By default, this is set to None
118+
and vllm can automatically infer the kv cache size based on
119+
gpu_memory_utilization. However, users may want to manually specify
120+
the kv cache memory size. kv_cache_memory_bytes allows more fine-grain
121+
control of how much memory gets used when compared with using
122+
gpu_memory_memory_utilization. Note that kv_cache_memory_bytes
123+
(when not-None) ignores gpu_memory_utilization"""
124+
116125
def compute_hash(self) -> str:
117126
"""
118127
WARNING: Whenever a new field is added to this config,

vllm/engine/arg_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,14 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
227227
elif contains_type(type_hints, int):
228228
kwargs[name]["type"] = int
229229
# Special case for large integers
230-
if name in {"max_model_len", "max_num_batched_tokens"}:
230+
human_readable_ints = {
231+
"max_model_len",
232+
"max_num_batched_tokens",
233+
"kv_cache_memory_bytes",
234+
}
235+
if name in human_readable_ints:
231236
kwargs[name]["type"] = human_readable_int
237+
kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
232238
elif contains_type(type_hints, float):
233239
kwargs[name]["type"] = float
234240
elif (contains_type(type_hints, dict)
@@ -335,6 +341,7 @@ class EngineArgs:
335341
swap_space: float = CacheConfig.swap_space
336342
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
337343
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
344+
kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes
338345
max_num_batched_tokens: Optional[
339346
int] = SchedulerConfig.max_num_batched_tokens
340347
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
@@ -734,6 +741,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
734741
cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
735742
cache_group.add_argument("--gpu-memory-utilization",
736743
**cache_kwargs["gpu_memory_utilization"])
744+
cache_group.add_argument("--kv-cache-memory-bytes",
745+
**cache_kwargs["kv_cache_memory_bytes"])
737746
cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
738747
cache_group.add_argument("--kv-cache-dtype",
739748
**cache_kwargs["cache_dtype"])
@@ -1174,6 +1183,7 @@ def create_engine_config(
11741183
cache_config = CacheConfig(
11751184
block_size=self.block_size,
11761185
gpu_memory_utilization=self.gpu_memory_utilization,
1186+
kv_cache_memory_bytes=self.kv_cache_memory_bytes,
11771187
swap_space=self.swap_space,
11781188
cache_dtype=self.kv_cache_dtype,
11791189
is_attention_free=model_config.is_attention_free,

vllm/engine/llm_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
278278
self.cache_config.block_size,
279279
"gpu_memory_utilization":
280280
self.cache_config.gpu_memory_utilization,
281-
281+
"kv_cache_memory_bytes":
282+
self.cache_config.kv_cache_memory_bytes,
282283
# Quantization
283284
"quantization":
284285
self.model_config.quantization,

vllm/entrypoints/llm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,14 @@ class LLM:
110110
values will increase the KV cache size and thus improve the model's
111111
throughput. However, if the value is too high, it may cause out-of-
112112
memory (OOM) errors.
113+
kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
114+
this is set to None and vllm can automatically infer the kv cache
115+
size based on gpu_memory_utilization. However, users may want to
116+
manually specify the kv cache memory size. kv_cache_memory_bytes
117+
allows more fine-grain control of how much memory gets used when
118+
compared with using gpu_memory_memory_utilization. Note that
119+
kv_cache_memory_bytes (when not-None) ignores
120+
gpu_memory_utilization
113121
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
114122
This can be used for temporarily storing the states of the requests
115123
when their `best_of` sampling parameters are larger than 1. If all
@@ -184,6 +192,7 @@ def __init__(
184192
hf_overrides: Optional[HfOverrides] = None,
185193
mm_processor_kwargs: Optional[dict[str, Any]] = None,
186194
override_pooler_config: Optional[PoolerConfig] = None,
195+
kv_cache_memory_bytes: Optional[int] = None,
187196
compilation_config: Optional[Union[int, dict[str, Any],
188197
CompilationConfig]] = None,
189198
logits_processors: Optional[list[Union[str,
@@ -251,6 +260,7 @@ def __init__(
251260
tokenizer_revision=tokenizer_revision,
252261
seed=seed,
253262
gpu_memory_utilization=gpu_memory_utilization,
263+
kv_cache_memory_bytes=kv_cache_memory_bytes,
254264
swap_space=swap_space,
255265
cpu_offload_gb=cpu_offload_gb,
256266
enforce_eager=enforce_eager,

vllm/utils/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2791,7 +2791,10 @@ def memory_profiling(
27912791
result.torch_peak_increase = diff_profile.torch_peak
27922792
result.non_torch_increase = diff_from_create.non_torch_memory
27932793
result.profile_time = diff_profile.timestamp
2794-
result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa
2794+
2795+
non_torch_memory = result.non_torch_increase
2796+
peak_activation_memory = result.torch_peak_increase
2797+
result.non_kv_cache_memory = non_torch_memory + peak_activation_memory + result.weights_memory # noqa
27952798

27962799

27972800
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501

vllm/v1/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,8 @@ def report_usage_stats(
355355
vllm_config.cache_config.block_size,
356356
"gpu_memory_utilization":
357357
vllm_config.cache_config.gpu_memory_utilization,
358-
358+
"kv_cache_memory_bytes":
359+
vllm_config.cache_config.kv_cache_memory_bytes,
359360
# Quantization
360361
"quantization":
361362
vllm_config.model_config.quantization,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3041,12 +3041,12 @@ def profile_run(self) -> None:
30413041
self.encoder_cache.clear()
30423042
gc.collect()
30433043

3044-
def capture_model(self) -> None:
3044+
def capture_model(self) -> int:
30453045
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
30463046
logger.warning(
30473047
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
30483048
"ensure `cudagraph_mode` was not manually set to `NONE`")
3049-
return
3049+
return 0
30503050
else:
30513051
self.initialize_cudagraph_capture()
30523052

@@ -3117,6 +3117,7 @@ def freeze_gc():
31173117
# This usually takes 5~20 seconds.
31183118
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
31193119
elapsed_time, cuda_graph_size / (1 << 30))
3120+
return cuda_graph_size
31203121

31213122
def _capture_cudagraphs(self, compilation_cases: list[int],
31223123
cudagraph_runtime_mode: CUDAGraphMode,

vllm/v1/worker/gpu_worker.py

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -231,18 +231,40 @@ def determine_available_memory(self) -> int:
231231
You may limit the usage of GPU memory
232232
by adjusting the `gpu_memory_utilization` parameter.
233233
"""
234+
GiB = lambda b: b / GiB_bytes
235+
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
236+
# still need a profile run which compiles the model for
237+
# max_num_batched_tokens
238+
self.model_runner.profile_run()
239+
240+
msg = (
241+
f"Initial free memory {GiB(self.init_snapshot.free_memory)} "
242+
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
243+
"KV Cache as specified by kv_cache_memory_bytes config and "
244+
"skipped memory profiling. This does does not respect the "
245+
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
246+
"config when you want manual control of KV cache memory "
247+
"size. If OOM'ed, check the difference of initial free "
248+
"memory between the current run and the previous run "
249+
"where kv_cache_memory_bytes is suggested and update it "
250+
"correspondingly.")
251+
logger.info(msg)
252+
return kv_cache_memory_bytes
253+
234254
torch.cuda.empty_cache()
235255
torch.cuda.reset_peak_memory_stats()
236-
GiB = lambda b: b / GiB_bytes
237256

238257
# Execute a forward pass with dummy inputs to profile the memory usage
239258
# of the model.
240259
with memory_profiling(
241260
self.init_snapshot,
242-
weights_memory=int(
243-
self.model_runner.model_memory_usage)) as profile_result:
261+
weights_memory=int(self.model_runner.model_memory_usage),
262+
) as profile_result:
244263
self.model_runner.profile_run()
245264

265+
self.non_torch_memory = profile_result.non_torch_increase
266+
self.peak_activation_memory = profile_result.torch_peak_increase
267+
246268
free_gpu_memory = profile_result.after_profile.free_memory
247269
# NOTE(woosuk): Here we assume that the other processes using the same
248270
# GPU did not change their memory usage during the profiling.
@@ -254,7 +276,7 @@ def determine_available_memory(self) -> int:
254276
"release GPU memory while vLLM is profiling during initialization. "
255277
"To fix this, ensure consistent GPU memory allocation or "
256278
"isolate vLLM in its own container.")
257-
available_kv_cache_memory = self.requested_memory \
279+
self.available_kv_cache_memory_bytes = self.requested_memory \
258280
- profile_result.non_kv_cache_memory
259281

260282
unrequested_memory = self.init_snapshot.free_memory \
@@ -274,10 +296,10 @@ def determine_available_memory(self) -> int:
274296
)
275297
logger.debug(profile_result)
276298
logger.info("Available KV cache memory: %.2f GiB",
277-
GiB(available_kv_cache_memory))
299+
GiB(self.available_kv_cache_memory_bytes))
278300
gc.collect()
279301

280-
return int(available_kv_cache_memory)
302+
return int(self.available_kv_cache_memory_bytes)
281303

282304
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
283305
return self.model_runner.get_kv_cache_spec()
@@ -317,8 +339,56 @@ def compile_or_warm_up_model(self) -> None:
317339
# cuda graph capture.
318340
kernel_warmup(self)
319341

342+
cuda_graph_memory_bytes = 0
320343
if not self.model_config.enforce_eager:
321-
self.model_runner.capture_model()
344+
cuda_graph_memory_bytes = self.model_runner.capture_model()
345+
346+
if (self.cache_config.kv_cache_memory_bytes is None
347+
and hasattr(self, "peak_activation_memory")):
348+
# Suggests optimal kv cache memory size if we rely on
349+
# memory_profiling to guess the kv cache memory size which
350+
# provides peak_activation_memory and a few other memory
351+
# consumption. `memory_profiling` does not consider
352+
# CUDAGraph memory size and may not utilize all gpu memory.
353+
# Users may want fine-grained control to specify kv cache
354+
# memory size.
355+
GiB = lambda b: round(b / GiB_bytes, 2)
356+
357+
# empirically observed that the memory profiling may
358+
# slightly underestimate the memory consumption.
359+
# So leave a small buffer (=150MiB) to avoid OOM.
360+
redundancy_buffer_memory = 150 * (1 << 20)
361+
non_kv_cache_memory = (self.model_runner.model_memory_usage +
362+
self.peak_activation_memory +
363+
self.non_torch_memory +
364+
cuda_graph_memory_bytes)
365+
kv_cache_memory_bytes_to_gpu_limit = (
366+
self.init_snapshot.free_memory - non_kv_cache_memory -
367+
redundancy_buffer_memory)
368+
kv_cache_memory_bytes_to_requested_limit = (
369+
int(self.requested_memory) - non_kv_cache_memory -
370+
redundancy_buffer_memory)
371+
372+
msg = (
373+
f"Free memory on device "
374+
f"({GiB(self.init_snapshot.free_memory)}/"
375+
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
376+
f"Desired GPU memory utilization is "
377+
f"({self.cache_config.gpu_memory_utilization}, "
378+
f"{GiB(self.requested_memory)} GiB). "
379+
f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
380+
f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
381+
f"for peak activation, {GiB(self.non_torch_memory)} GiB "
382+
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
383+
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
384+
f"config with `--kv-cache-memory="
385+
f"{kv_cache_memory_bytes_to_requested_limit}` to fit into "
386+
f"requested memory, or `--kv-cache-memory="
387+
f"{kv_cache_memory_bytes_to_gpu_limit}` to fully "
388+
f"utilize gpu memory. Current kv cache memory in use is "
389+
f"{int(self.available_kv_cache_memory_bytes)} bytes.")
390+
391+
logger.info(msg)
322392

323393
# Warm up sampler and preallocate memory buffer for logits and other
324394
# sampling related tensors of max possible shape to avoid memory

vllm/worker/model_runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,8 +1337,9 @@ def list_loras(self) -> Set[int]:
13371337
return self.lora_manager.list_adapters()
13381338

13391339
@torch.inference_mode()
1340-
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
1341-
"""Cuda graph capture a model.
1340+
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> int:
1341+
"""Cuda graph capture a model and return cudagraph memory
1342+
consumption in bytes.
13421343
13431344
Note that CUDA graph's performance gain is negligible if number
13441345
of batched tokens are larger than 200. And since CUDA graph
@@ -1505,6 +1506,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
15051506
# This usually takes < 10 seconds.
15061507
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
15071508
elapsed_time, cuda_graph_size / GiB_bytes)
1509+
return cuda_graph_size
15081510

15091511
def _update_inputs_to_capture_for_enc_dec_model(self,
15101512
capture_inputs: Dict[str,

0 commit comments

Comments
 (0)