Skip to content

[Bug]: V1 engine peak memory usage calculations incorrect #16141

Open
@vqvu

Description

@vqvu

Your current environment

The collect_env.py script doesn't work because I don't have vllm installed in my environment. This bug is reproducible using the docker image, so I don't think this matters.

Affected VLLM version is v0.8.3.

🐛 Describe the bug

The peak memory usage calculations for VLLM is buggy. It seems to think that the memory usage of the other processes on the GPU contribute to the minimum required. This happens with v0.8.3.

This is a problem when running multiple instances of VLLM on the same GPU.

Repro steps

This is easy to reproduce with the docker image. Here is the nvidia-smi output before running VLLM. No memory usage.

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX A6000               Off |   00000000:81:00.0 Off |                    0 |
| 30%   39C    P2             N/A /  300W |       1MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

Then I run one instance

docker run --rm --gpus=all --runtime nvidia --ipc=host \
    -v ~/.cache:/root/.cache \
    vllm/vllm-openai:v0.8.3 \
    --model Qwen/Qwen2.5-0.5B-Instruct-AWQ \
    --gpu-memory-utilization 0.1 \
    --max-num-seqs 1 \
    --max-model-len 512
Complete logs
INFO 04-06 18:42:42 [__init__.py:239] Automatically detected platform cuda.
INFO 04-06 18:42:45 [api_server.py:1034] vLLM API server version 0.8.3
INFO 04-06 18:42:45 [api_server.py:1035] args: Namespace(host=None, port=8000, uvicorn_log_level='info', disable_uvicorn_access_log=False, allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format='auto', response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, enable_ssl_refresh=False, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_request_id_headers=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='Qwen/Qwen2.5-0.5B-Instruct-AWQ', task='auto', tokenizer=None, hf_config_path=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, allowed_local_media_path=None, download_dir=None, load_format='auto', config_format=<ConfigFormat.AUTO: 'auto'>, dtype='auto', kv_cache_dtype='auto', max_model_len=512, guided_decoding_backend='xgrammar', logits_processor_pattern=None, model_impl='auto', distributed_executor_backend=None, pipeline_parallel_size=1, tensor_parallel_size=1, data_parallel_size=1, enable_expert_parallel=False, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=None, enable_prefix_caching=None, prefix_caching_hash_algo='builtin', disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=None, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.1, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_partial_prefills=1, max_long_partial_prefills=1, long_prefill_token_threshold=0, max_num_seqs=1, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_overrides=None, enforce_eager=False, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, disable_mm_preprocessor_cache=False, enable_lora=False, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, use_tqdm_on_load=True, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_config=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy='fcfs', scheduler_cls='vllm.core.scheduler.Scheduler', override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls='auto', worker_extension_cls='', generation_config='auto', override_generation_config=None, enable_sleep_mode=False, calculate_kv_scales=False, additional_config=None, enable_reasoning=False, reasoning_parser=None, disable_cascade_attn=False, disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False, enable_server_load_tracking=False)
INFO 04-06 18:42:53 [config.py:600] This model supports multiple tasks: {'embed', 'score', 'classify', 'generate', 'reward'}. Defaulting to 'generate'.
INFO 04-06 18:42:54 [awq_marlin.py:114] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 04-06 18:42:54 [config.py:1780] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 04-06 18:42:59 [__init__.py:239] Automatically detected platform cuda.
INFO 04-06 18:43:01 [core.py:61] Initializing a V1 LLM engine (v0.8.3) with config: model='Qwen/Qwen2.5-0.5B-Instruct-AWQ', speculative_config=None, tokenizer='Qwen/Qwen2.5-0.5B-Instruct-AWQ', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=512, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=awq_marlin, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=Qwen/Qwen2.5-0.5B-Instruct-AWQ, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"level":3,"custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":512}
WARNING 04-06 18:43:02 [utils.py:2413] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7f6d3b5030e0>
INFO 04-06 18:43:02 [parallel_state.py:957] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 04-06 18:43:02 [cuda.py:221] Using Flash Attention backend on V1 engine.
INFO 04-06 18:43:02 [gpu_model_runner.py:1258] Starting to load model Qwen/Qwen2.5-0.5B-Instruct-AWQ...
INFO 04-06 18:43:03 [topk_topp_sampler.py:59] Using FlashInfer for top-p & top-k sampling.
INFO 04-06 18:43:03 [weight_utils.py:265] Using model weights format ['*.safetensors']
INFO 04-06 18:43:03 [weight_utils.py:315] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  6.51it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  6.50it/s]

INFO 04-06 18:43:03 [loader.py:447] Loading weights took 0.18 seconds
INFO 04-06 18:43:04 [gpu_model_runner.py:1273] Model loading took 0.4315 GiB and 0.935828 seconds
INFO 04-06 18:43:10 [backends.py:416] Using cache directory: /root/.cache/vllm/torch_compile_cache/e05df7ec0c/rank_0_0 for vLLM's torch.compile
INFO 04-06 18:43:10 [backends.py:426] Dynamo bytecode transform time: 6.97 s
INFO 04-06 18:43:11 [backends.py:115] Directly load the compiled graph for shape None from the cache
INFO 04-06 18:43:17 [monitor.py:33] torch.compile takes 6.97 s in total
INFO 04-06 18:43:17 [kv_cache_utils.py:578] GPU KV cache size: 314,272 tokens
INFO 04-06 18:43:17 [kv_cache_utils.py:581] Maximum concurrency for 512 tokens per request: 613.81x
INFO 04-06 18:43:35 [gpu_model_runner.py:1608] Graph capturing finished in 18 secs, took 0.41 GiB
INFO 04-06 18:43:35 [core.py:162] init engine (profile, create kv cache, warmup model) took 31.70 seconds
WARNING 04-06 18:43:35 [config.py:1088] Default sampling parameters have been overridden by the model's Hugging Face generation config recommended from the model creator. If this is not intended, please relaunch vLLM instance with `--generation-config vllm`.
INFO 04-06 18:43:35 [serving_chat.py:117] Using default chat sampling params from model: {'repetition_penalty': 1.1, 'temperature': 0.7, 'top_k': 20, 'top_p': 0.8}
INFO 04-06 18:43:35 [serving_completion.py:61] Using default completion sampling params from model: {'repetition_penalty': 1.1, 'temperature': 0.7, 'top_k': 20, 'top_p': 0.8}
INFO 04-06 18:43:35 [api_server.py:1081] Starting vLLM API server on http://0.0.0.0:8000
INFO 04-06 18:43:35 [launcher.py:26] Available routes are:
INFO 04-06 18:43:35 [launcher.py:34] Route: /openapi.json, Methods: GET, HEAD
INFO 04-06 18:43:35 [launcher.py:34] Route: /docs, Methods: GET, HEAD
INFO 04-06 18:43:35 [launcher.py:34] Route: /docs/oauth2-redirect, Methods: GET, HEAD
INFO 04-06 18:43:35 [launcher.py:34] Route: /redoc, Methods: GET, HEAD
INFO 04-06 18:43:35 [launcher.py:34] Route: /health, Methods: GET
INFO 04-06 18:43:35 [launcher.py:34] Route: /load, Methods: GET
INFO 04-06 18:43:35 [launcher.py:34] Route: /ping, Methods: GET, POST
INFO 04-06 18:43:35 [launcher.py:34] Route: /tokenize, Methods: POST
INFO 04-06 18:43:35 [launcher.py:34] Route: /detokenize, Methods: POST
INFO 04-06 18:43:35 [launcher.py:34] Route: /v1/models, Methods: GET
INFO 04-06 18:43:35 [launcher.py:34] Route: /version, Methods: GET
INFO 04-06 18:43:35 [launcher.py:34] Route: /v1/chat/completions, Methods: POST
INFO 04-06 18:43:35 [launcher.py:34] Route: /v1/completions, Methods: POST
INFO 04-06 18:43:35 [launcher.py:34] Route: /v1/embeddings, Methods: POST
INFO 04-06 18:43:35 [launcher.py:34] Route: /pooling, Methods: POST
INFO 04-06 18:43:35 [launcher.py:34] Route: /score, Methods: POST
INFO 04-06 18:43:35 [launcher.py:34] Route: /v1/score, Methods: POST
INFO 04-06 18:43:35 [launcher.py:34] Route: /v1/audio/transcriptions, Methods: POST
INFO 04-06 18:43:35 [launcher.py:34] Route: /rerank, Methods: POST
INFO 04-06 18:43:35 [launcher.py:34] Route: /v1/rerank, Methods: POST
INFO 04-06 18:43:35 [launcher.py:34] Route: /v2/rerank, Methods: POST
INFO 04-06 18:43:35 [launcher.py:34] Route: /invocations, Methods: POST
INFO:     Started server process [1]
INFO:     Waiting for application startup.
INFO:     Application startup complete.

It starts successfully and consumes 10% of GPU memory. As expected.

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX A6000               Off |   00000000:81:00.0 Off |                    0 |
| 30%   52C    P2             84W /  300W |    4910MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

If I start a second instance (same docker command as before) while the first is still running, it fails with the error No available memory for the cache blocks. Try increasing `gpu_memory_utilization` when initializing the engine., despite the GPU having plenty of memory left.

Complete logs
INFO 04-06 18:43:47 [__init__.py:239] Automatically detected platform cuda.
INFO 04-06 18:43:50 [api_server.py:1034] vLLM API server version 0.8.3
INFO 04-06 18:43:50 [api_server.py:1035] args: Namespace(host=None, port=8000, uvicorn_log_level='info', disable_uvicorn_access_log=False, allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format='auto', response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, enable_ssl_refresh=False, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_request_id_headers=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='Qwen/Qwen2.5-0.5B-Instruct-AWQ', task='auto', tokenizer=None, hf_config_path=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, allowed_local_media_path=None, download_dir=None, load_format='auto', config_format=<ConfigFormat.AUTO: 'auto'>, dtype='auto', kv_cache_dtype='auto', max_model_len=512, guided_decoding_backend='xgrammar', logits_processor_pattern=None, model_impl='auto', distributed_executor_backend=None, pipeline_parallel_size=1, tensor_parallel_size=1, data_parallel_size=1, enable_expert_parallel=False, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=None, enable_prefix_caching=None, prefix_caching_hash_algo='builtin', disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=None, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.1, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_partial_prefills=1, max_long_partial_prefills=1, long_prefill_token_threshold=0, max_num_seqs=1, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_overrides=None, enforce_eager=False, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, disable_mm_preprocessor_cache=False, enable_lora=False, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, use_tqdm_on_load=True, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_config=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy='fcfs', scheduler_cls='vllm.core.scheduler.Scheduler', override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls='auto', worker_extension_cls='', generation_config='auto', override_generation_config=None, enable_sleep_mode=False, calculate_kv_scales=False, additional_config=None, enable_reasoning=False, reasoning_parser=None, disable_cascade_attn=False, disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False, enable_server_load_tracking=False)
INFO 04-06 18:43:58 [config.py:600] This model supports multiple tasks: {'classify', 'reward', 'generate', 'score', 'embed'}. Defaulting to 'generate'.
INFO 04-06 18:43:59 [awq_marlin.py:114] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 04-06 18:43:59 [config.py:1780] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 04-06 18:44:04 [__init__.py:239] Automatically detected platform cuda.
INFO 04-06 18:44:06 [core.py:61] Initializing a V1 LLM engine (v0.8.3) with config: model='Qwen/Qwen2.5-0.5B-Instruct-AWQ', speculative_config=None, tokenizer='Qwen/Qwen2.5-0.5B-Instruct-AWQ', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=512, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=awq_marlin, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=Qwen/Qwen2.5-0.5B-Instruct-AWQ, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"level":3,"custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":512}
WARNING 04-06 18:44:07 [utils.py:2413] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7faf5a91ff20>
INFO 04-06 18:44:07 [parallel_state.py:957] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 04-06 18:44:07 [cuda.py:221] Using Flash Attention backend on V1 engine.
INFO 04-06 18:44:07 [gpu_model_runner.py:1258] Starting to load model Qwen/Qwen2.5-0.5B-Instruct-AWQ...
INFO 04-06 18:44:07 [topk_topp_sampler.py:59] Using FlashInfer for top-p & top-k sampling.
INFO 04-06 18:44:08 [weight_utils.py:265] Using model weights format ['*.safetensors']
INFO 04-06 18:44:08 [weight_utils.py:315] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  6.43it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  6.42it/s]

INFO 04-06 18:44:08 [loader.py:447] Loading weights took 0.18 seconds
INFO 04-06 18:44:08 [gpu_model_runner.py:1273] Model loading took 0.4315 GiB and 0.825929 seconds
INFO 04-06 18:44:15 [backends.py:416] Using cache directory: /root/.cache/vllm/torch_compile_cache/e05df7ec0c/rank_0_0 for vLLM's torch.compile
INFO 04-06 18:44:15 [backends.py:426] Dynamo bytecode transform time: 6.93 s
INFO 04-06 18:44:16 [backends.py:115] Directly load the compiled graph for shape None from the cache
INFO 04-06 18:44:21 [monitor.py:33] torch.compile takes 6.93 s in total
ERROR 04-06 18:44:22 [core.py:390] EngineCore hit an exception: Traceback (most recent call last):
ERROR 04-06 18:44:22 [core.py:390]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 378, in run_engine_core
ERROR 04-06 18:44:22 [core.py:390]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 04-06 18:44:22 [core.py:390]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-06 18:44:22 [core.py:390]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 319, in __init__
ERROR 04-06 18:44:22 [core.py:390]     super().__init__(vllm_config, executor_class, log_stats)
ERROR 04-06 18:44:22 [core.py:390]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 71, in __init__
ERROR 04-06 18:44:22 [core.py:390]     self._initialize_kv_caches(vllm_config)
ERROR 04-06 18:44:22 [core.py:390]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 137, in _initialize_kv_caches
ERROR 04-06 18:44:22 [core.py:390]     get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
ERROR 04-06 18:44:22 [core.py:390]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/core/kv_cache_utils.py", line 643, in get_kv_cache_config
ERROR 04-06 18:44:22 [core.py:390]     check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
ERROR 04-06 18:44:22 [core.py:390]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/core/kv_cache_utils.py", line 480, in check_enough_kv_cache_memory
ERROR 04-06 18:44:22 [core.py:390]     raise ValueError("No available memory for the cache blocks. "
ERROR 04-06 18:44:22 [core.py:390] ValueError: No available memory for the cache blocks. Try increasing `gpu_memory_utilization` when initializing the engine.
ERROR 04-06 18:44:22 [core.py:390]
CRITICAL 04-06 18:44:22 [core_client.py:361] Got fatal signal from worker processes, shutting down. See stack trace above for root cause issue.
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 1121, in <module>
    uvloop.run(run_server(args))
  File "/usr/local/lib/python3.12/dist-packages/uvloop/__init__.py", line 109, in run
    return __asyncio.run(
           ^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
  File "/usr/local/lib/python3.12/dist-packages/uvloop/__init__.py", line 61, in wrapper
    return await main
           ^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 1069, in run_server
    async with build_async_engine_client(args) as engine_client:
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 146, in build_async_engine_client
    async with build_async_engine_client_from_engine_args(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 178, in build_async_engine_client_from_engine_args
    async_llm = AsyncLLM.from_vllm_config(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/async_llm.py", line 136, in from_vllm_config
    return cls(
           ^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/async_llm.py", line 102, in __init__
    self.engine_core = EngineCoreClient.make_client(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 69, in make_client
    return AsyncMPClient(vllm_config, executor_class, log_stats)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 570, in __init__
    super().__init__(
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 401, in __init__
    engine.proc_handle.wait_for_startup()
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/utils.py", line 127, in wait_for_startup
    if self.reader.recv()["status"] != "READY":
       ^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 250, in recv
    buf = self._recv_bytes()
          ^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 430, in _recv_bytes
    buf = self._recv(4)
          ^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 399, in _recv
    raise EOFError
EOFError

This error doesn't happen with the V0 engine.

docker run --rm --gpus=all --runtime nvidia --ipc=host \
    -e VLLM_USE_V1=0 \
    -v ~/.cache:/root/.cache \
    vllm/vllm-openai:v0.8.3 \
    --model Qwen/Qwen2.5-0.5B-Instruct-AWQ \
    --gpu-memory-utilization 0.1 \
    --max-num-seqs 1 \
    --max-model-len 512
Complete logs
INFO 04-06 18:44:54 [__init__.py:239] Automatically detected platform cuda.
INFO 04-06 18:44:57 [api_server.py:1034] vLLM API server version 0.8.3
INFO 04-06 18:44:57 [api_server.py:1035] args: Namespace(host=None, port=8000, uvicorn_log_level='info', disable_uvicorn_access_log=False, allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format='auto', response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, enable_ssl_refresh=False, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_request_id_headers=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='Qwen/Qwen2.5-0.5B-Instruct-AWQ', task='auto', tokenizer=None, hf_config_path=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, allowed_local_media_path=None, download_dir=None, load_format='auto', config_format=<ConfigFormat.AUTO: 'auto'>, dtype='auto', kv_cache_dtype='auto', max_model_len=512, guided_decoding_backend='xgrammar', logits_processor_pattern=None, model_impl='auto', distributed_executor_backend=None, pipeline_parallel_size=1, tensor_parallel_size=1, data_parallel_size=1, enable_expert_parallel=False, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=None, enable_prefix_caching=None, prefix_caching_hash_algo='builtin', disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=None, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.1, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_partial_prefills=1, max_long_partial_prefills=1, long_prefill_token_threshold=0, max_num_seqs=1, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_overrides=None, enforce_eager=False, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, disable_mm_preprocessor_cache=False, enable_lora=False, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, use_tqdm_on_load=True, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_config=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy='fcfs', scheduler_cls='vllm.core.scheduler.Scheduler', override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls='auto', worker_extension_cls='', generation_config='auto', override_generation_config=None, enable_sleep_mode=False, calculate_kv_scales=False, additional_config=None, enable_reasoning=False, reasoning_parser=None, disable_cascade_attn=False, disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False, enable_server_load_tracking=False)
INFO 04-06 18:45:05 [config.py:600] This model supports multiple tasks: {'classify', 'embed', 'score', 'generate', 'reward'}. Defaulting to 'generate'.
INFO 04-06 18:45:07 [awq_marlin.py:114] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 04-06 18:45:07 [api_server.py:246] Started engine process with PID 93
INFO 04-06 18:45:10 [__init__.py:239] Automatically detected platform cuda.
INFO 04-06 18:45:12 [llm_engine.py:242] Initializing a V0 LLM engine (v0.8.3) with config: model='Qwen/Qwen2.5-0.5B-Instruct-AWQ', speculative_config=None, tokenizer='Qwen/Qwen2.5-0.5B-Instruct-AWQ', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=512, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=awq_marlin, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=Qwen/Qwen2.5-0.5B-Instruct-AWQ, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=None, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[1],"max_capture_size":1}, use_cached_outputs=True,
INFO 04-06 18:45:13 [cuda.py:292] Using Flash Attention backend.
INFO 04-06 18:45:13 [parallel_state.py:957] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 04-06 18:45:13 [model_runner.py:1110] Starting to load model Qwen/Qwen2.5-0.5B-Instruct-AWQ...
INFO 04-06 18:45:14 [weight_utils.py:265] Using model weights format ['*.safetensors']
INFO 04-06 18:45:14 [weight_utils.py:315] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  6.44it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  6.43it/s]

INFO 04-06 18:45:14 [loader.py:447] Loading weights took 0.18 seconds
INFO 04-06 18:45:14 [model_runner.py:1146] Model loading took 0.4315 GiB and 0.745171 seconds
INFO 04-06 18:45:15 [worker.py:267] Memory profiling takes 0.47 seconds
INFO 04-06 18:45:15 [worker.py:267] the current vLLM instance can use total_gpu_memory (44.55GiB) x gpu_memory_utilization (0.10) = 4.45GiB
INFO 04-06 18:45:15 [worker.py:267] model weights take 0.43GiB; non_torch_memory takes 0.06GiB; PyTorch activation peak memory takes 0.09GiB; the rest of the memory reserved for KV Cache is 3.88GiB.
INFO 04-06 18:45:15 [executor_base.py:112] # cuda blocks: 21186, # CPU blocks: 21845
INFO 04-06 18:45:15 [executor_base.py:117] Maximum concurrency for 512 tokens per request: 662.06x
INFO 04-06 18:45:20 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
Capturing CUDA graph shapes: 100%|██████████| 1/1 [00:00<00:00,  2.47it/s]
INFO 04-06 18:45:20 [model_runner.py:1598] Graph capturing finished in 0 secs, took 0.03 GiB
INFO 04-06 18:45:20 [llm_engine.py:448] init engine (profile, create kv cache, warmup model) took 5.76 seconds
WARNING 04-06 18:45:21 [config.py:1088] Default sampling parameters have been overridden by the model's Hugging Face generation config recommended from the model creator. If this is not intended, please relaunch vLLM instance with `--generation-config vllm`.
INFO 04-06 18:45:21 [serving_chat.py:117] Using default chat sampling params from model: {'repetition_penalty': 1.1, 'temperature': 0.7, 'top_k': 20, 'top_p': 0.8}
INFO 04-06 18:45:21 [serving_completion.py:61] Using default completion sampling params from model: {'repetition_penalty': 1.1, 'temperature': 0.7, 'top_k': 20, 'top_p': 0.8}
INFO 04-06 18:45:21 [api_server.py:1081] Starting vLLM API server on http://0.0.0.0:8000
INFO 04-06 18:45:21 [launcher.py:26] Available routes are:
INFO 04-06 18:45:21 [launcher.py:34] Route: /openapi.json, Methods: HEAD, GET
INFO 04-06 18:45:21 [launcher.py:34] Route: /docs, Methods: HEAD, GET
INFO 04-06 18:45:21 [launcher.py:34] Route: /docs/oauth2-redirect, Methods: HEAD, GET
INFO 04-06 18:45:21 [launcher.py:34] Route: /redoc, Methods: HEAD, GET
INFO 04-06 18:45:21 [launcher.py:34] Route: /health, Methods: GET
INFO 04-06 18:45:21 [launcher.py:34] Route: /load, Methods: GET
INFO 04-06 18:45:21 [launcher.py:34] Route: /ping, Methods: POST, GET
INFO 04-06 18:45:21 [launcher.py:34] Route: /tokenize, Methods: POST
INFO 04-06 18:45:21 [launcher.py:34] Route: /detokenize, Methods: POST
INFO 04-06 18:45:21 [launcher.py:34] Route: /v1/models, Methods: GET
INFO 04-06 18:45:21 [launcher.py:34] Route: /version, Methods: GET
INFO 04-06 18:45:21 [launcher.py:34] Route: /v1/chat/completions, Methods: POST
INFO 04-06 18:45:21 [launcher.py:34] Route: /v1/completions, Methods: POST
INFO 04-06 18:45:21 [launcher.py:34] Route: /v1/embeddings, Methods: POST
INFO 04-06 18:45:21 [launcher.py:34] Route: /pooling, Methods: POST
INFO 04-06 18:45:21 [launcher.py:34] Route: /score, Methods: POST
INFO 04-06 18:45:21 [launcher.py:34] Route: /v1/score, Methods: POST
INFO 04-06 18:45:21 [launcher.py:34] Route: /v1/audio/transcriptions, Methods: POST
INFO 04-06 18:45:21 [launcher.py:34] Route: /rerank, Methods: POST
INFO 04-06 18:45:21 [launcher.py:34] Route: /v1/rerank, Methods: POST
INFO 04-06 18:45:21 [launcher.py:34] Route: /v2/rerank, Methods: POST
INFO 04-06 18:45:21 [launcher.py:34] Route: /invocations, Methods: POST
INFO 04-06 18:45:21 [launcher.py:34] Route: /metrics, Methods: GET
INFO:     Started server process [1]
INFO:     Waiting for application startup.
INFO:     Application startup complete.

Looking at the determine_available_memory function in the V1 gpu_worker, my guess is that the culprit is the non_torch_allocations calculations. It doesn't understand that some of the GPU memory outside of torch is used by other processes and not related to this instance of VLLM.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions