File tree 1 file changed +7
-6
lines changed
1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -118,14 +118,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
118
118
xm .wait_device_ops ()
119
119
120
120
m = xm .get_memory_info (self .device )
121
- program_size = 1024 * 1024 * 1024 # 1GB
122
- free_bytes = max (m ["bytes_limit" ] - m ["bytes_used" ] - program_size , 0 )
123
- kv_cache_bytes = int (free_bytes *
124
- self .cache_config .gpu_memory_utilization )
125
- kv_cache_dtype_btyes = get_dtype_size (self .cache_dtype )
121
+ total_memory_size = m ["bytes_limit" ]
122
+ usable_memory_size = int (total_memory_size *
123
+ self .cache_config .gpu_memory_utilization )
124
+ profiled = m ["bytes_used" ] # Weights + intermediate activations.
125
+ kv_cache_bytes = max (usable_memory_size - profiled , 0 )
126
+ dtype_btyes = get_dtype_size (self .cache_dtype )
126
127
block_size = self .cache_config .block_size
127
128
num_tpu_blocks = (kv_cache_bytes //
128
- (kv_cache_dtype_btyes * block_size * num_layers * 2 *
129
+ (dtype_btyes * block_size * num_layers * 2 *
129
130
head_size * num_kv_heads ))
130
131
num_tpu_blocks = (num_tpu_blocks // 8 ) * 8 # Round down to 8.
131
132
return num_tpu_blocks , 0
You can’t perform that action at this time.
0 commit comments