Skip to content

Commit 3439c5a

Browse files
authored
[Bugfix][TPU] Fix KV cache size calculation (#5860)
1 parent 6806998 commit 3439c5a

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

vllm/worker/tpu_worker.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
118118
xm.wait_device_ops()
119119

120120
m = xm.get_memory_info(self.device)
121-
program_size = 1024 * 1024 * 1024 # 1GB
122-
free_bytes = max(m["bytes_limit"] - m["bytes_used"] - program_size, 0)
123-
kv_cache_bytes = int(free_bytes *
124-
self.cache_config.gpu_memory_utilization)
125-
kv_cache_dtype_btyes = get_dtype_size(self.cache_dtype)
121+
total_memory_size = m["bytes_limit"]
122+
usable_memory_size = int(total_memory_size *
123+
self.cache_config.gpu_memory_utilization)
124+
profiled = m["bytes_used"] # Weights + intermediate activations.
125+
kv_cache_bytes = max(usable_memory_size - profiled, 0)
126+
dtype_btyes = get_dtype_size(self.cache_dtype)
126127
block_size = self.cache_config.block_size
127128
num_tpu_blocks = (kv_cache_bytes //
128-
(kv_cache_dtype_btyes * block_size * num_layers * 2 *
129+
(dtype_btyes * block_size * num_layers * 2 *
129130
head_size * num_kv_heads))
130131
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
131132
return num_tpu_blocks, 0

0 commit comments

Comments
 (0)