@@ -231,18 +231,40 @@ def determine_available_memory(self) -> int:
231231            You may limit the usage of GPU memory 
232232            by adjusting the `gpu_memory_utilization` parameter. 
233233        """ 
234+         GiB  =  lambda  b : b  /  GiB_bytes 
235+         if  kv_cache_memory_bytes  :=  self .cache_config .kv_cache_memory_bytes :
236+             # still need a profile run which compiles the model for 
237+             # max_num_batched_tokens 
238+             self .model_runner .profile_run ()
239+ 
240+             msg  =  (
241+                 f"Initial free memory { GiB (self .init_snapshot .free_memory )}  
242+                 f"GiB, reserved { GiB (kv_cache_memory_bytes ):.2f}  
243+                 "KV Cache as specified by kv_cache_memory_bytes config and " 
244+                 "skipped memory profiling. This does does not respect the " 
245+                 "gpu_memory_utilization config. Only use kv_cache_memory_bytes " 
246+                 "config when you want manual control of KV cache memory " 
247+                 "size. If OOM'ed, check the difference of initial free " 
248+                 "memory between the current run and the previous run " 
249+                 "where kv_cache_memory_bytes is suggested and update it " 
250+                 "correspondingly." )
251+             logger .info (msg )
252+             return  kv_cache_memory_bytes 
253+ 
234254        torch .cuda .empty_cache ()
235255        torch .cuda .reset_peak_memory_stats ()
236-         GiB  =  lambda  b : b  /  GiB_bytes 
237256
238257        # Execute a forward pass with dummy inputs to profile the memory usage 
239258        # of the model. 
240259        with  memory_profiling (
241260                self .init_snapshot ,
242-                 weights_memory = int (
243-                      self . model_runner . model_memory_usage ) ) as  profile_result :
261+                 weights_memory = int (self . model_runner . model_memory_usage ), 
262+         ) as  profile_result :
244263            self .model_runner .profile_run ()
245264
265+         self .non_torch_memory  =  profile_result .non_torch_increase 
266+         self .peak_activation_memory  =  profile_result .torch_peak_increase 
267+ 
246268        free_gpu_memory  =  profile_result .after_profile .free_memory 
247269        # NOTE(woosuk): Here we assume that the other processes using the same 
248270        # GPU did not change their memory usage during the profiling. 
@@ -254,7 +276,7 @@ def determine_available_memory(self) -> int:
254276            "release GPU memory while vLLM is profiling during initialization. " 
255277            "To fix this, ensure consistent GPU memory allocation or " 
256278            "isolate vLLM in its own container." )
257-         available_kv_cache_memory  =  self .requested_memory  \
279+         self . available_kv_cache_memory_bytes  =  self .requested_memory  \
258280            -  profile_result .non_kv_cache_memory 
259281
260282        unrequested_memory  =  self .init_snapshot .free_memory  \
@@ -274,10 +296,10 @@ def determine_available_memory(self) -> int:
274296        )
275297        logger .debug (profile_result )
276298        logger .info ("Available KV cache memory: %.2f GiB" ,
277-                     GiB (available_kv_cache_memory ))
299+                     GiB (self . available_kv_cache_memory_bytes ))
278300        gc .collect ()
279301
280-         return  int (available_kv_cache_memory )
302+         return  int (self . available_kv_cache_memory_bytes )
281303
282304    def  get_kv_cache_spec (self ) ->  dict [str , KVCacheSpec ]:
283305        return  self .model_runner .get_kv_cache_spec ()
@@ -317,8 +339,56 @@ def compile_or_warm_up_model(self) -> None:
317339        # cuda graph capture. 
318340        kernel_warmup (self )
319341
342+         cuda_graph_memory_bytes  =  0 
320343        if  not  self .model_config .enforce_eager :
321-             self .model_runner .capture_model ()
344+             cuda_graph_memory_bytes  =  self .model_runner .capture_model ()
345+ 
346+         if  (self .cache_config .kv_cache_memory_bytes  is  None 
347+                 and  hasattr (self , "peak_activation_memory" )):
348+             # Suggests optimal kv cache memory size if we rely on 
349+             # memory_profiling to guess the kv cache memory size which 
350+             # provides peak_activation_memory and a few other memory 
351+             # consumption. `memory_profiling` does not consider 
352+             # CUDAGraph memory size and may not utilize all gpu memory. 
353+             # Users may want fine-grained control to specify kv cache 
354+             # memory size. 
355+             GiB  =  lambda  b : round (b  /  GiB_bytes , 2 )
356+ 
357+             # empirically observed that the memory profiling may 
358+             # slightly underestimate the memory consumption. 
359+             # So leave a small buffer (=150MiB) to avoid OOM. 
360+             redundancy_buffer_memory  =  150  *  (1  <<  20 )
361+             non_kv_cache_memory  =  (self .model_runner .model_memory_usage  + 
362+                                    self .peak_activation_memory  + 
363+                                    self .non_torch_memory  + 
364+                                    cuda_graph_memory_bytes )
365+             kv_cache_memory_bytes_to_gpu_limit  =  (
366+                 self .init_snapshot .free_memory  -  non_kv_cache_memory  - 
367+                 redundancy_buffer_memory )
368+             kv_cache_memory_bytes_to_requested_limit  =  (
369+                 int (self .requested_memory ) -  non_kv_cache_memory  - 
370+                 redundancy_buffer_memory )
371+ 
372+             msg  =  (
373+                 f"Free memory on device " 
374+                 f"({ GiB (self .init_snapshot .free_memory )}  
375+                 f"{ GiB (self .init_snapshot .total_memory )}  
376+                 f"Desired GPU memory utilization is " 
377+                 f"({ self .cache_config .gpu_memory_utilization }  
378+                 f"{ GiB (self .requested_memory )}  
379+                 f"Actual usage is { GiB (self .model_runner .model_memory_usage )}  
380+                 f"GiB for weight, { GiB (self .peak_activation_memory )}  
381+                 f"for peak activation, { GiB (self .non_torch_memory )}  
382+                 f"for non-torch memory, and { GiB (cuda_graph_memory_bytes )}  
383+                 f"GiB for CUDAGraph memory. Replace gpu_memory_utilization " 
384+                 f"config with `--kv-cache-memory=" 
385+                 f"{ kv_cache_memory_bytes_to_requested_limit }  
386+                 f"requested memory, or `--kv-cache-memory=" 
387+                 f"{ kv_cache_memory_bytes_to_gpu_limit }  
388+                 f"utilize gpu memory. Current kv cache memory in use is " 
389+                 f"{ int (self .available_kv_cache_memory_bytes )}  )
390+ 
391+             logger .info (msg )
322392
323393        # Warm up sampler and preallocate memory buffer for logits and other 
324394        # sampling related tensors of max possible shape to avoid memory 
0 commit comments