@@ -44,7 +44,7 @@ class KVCacheManager:
44
44
config(InferenceConfig): The All-in-one inference configuration.
45
45
"""
46
46
47
- def __init__ (self , config : InferenceConfig ) -> None :
47
+ def __init__ (self , config : InferenceConfig , verbose : bool = False ) -> None :
48
48
self .logger = get_dist_logger (__name__ )
49
49
self .device = get_current_device ()
50
50
@@ -73,9 +73,9 @@ def __init__(self, config: InferenceConfig) -> None:
73
73
self .num_blocks = self .max_blocks_per_sequence * self .max_batch_size * self .beam_width
74
74
75
75
# Physical cache allocation
76
- alloc_shape = ( self . num_blocks , self . block_size , self . head_num , self . head_size )
77
- self . logger . info ( f"Allocating KV cache with shape: { alloc_shape } consisting of { self .num_blocks } blocks." )
78
- # self._kv_caches = self._init_device_caches( alloc_shape)
76
+ if verbose :
77
+ alloc_shape = ( self .num_blocks , self . head_num , self . head_size , self . block_size )
78
+ self .logger . info ( f"Allocating KV cache with shape: { alloc_shape } consisting of { self . num_blocks } blocks." )
79
79
self ._kv_caches = self ._init_device_caches ()
80
80
self .total_physical_cache_size_in_bytes = (
81
81
self .elem_size_in_bytes
@@ -177,7 +177,7 @@ def free_cache_blocks(self, block_table: torch.Tensor) -> None:
177
177
for i in range (block_table .numel ()):
178
178
global_block_id = block_table [i ].item ()
179
179
block : CacheBlock = self ._cache_blocks [global_block_id ]
180
- block .remove_ref () # not going to clear the block thoroughly
180
+ block .remove_ref ()
181
181
if not block .has_ref ():
182
182
block .allocated_size = 0
183
183
self ._free_blocks .append (block )
@@ -236,11 +236,11 @@ def _init_device_caches(self) -> Tuple[torch.Tensor, torch.Tensor]:
236
236
"""Initialize the physical cache on the device.
237
237
238
238
For each layer of the model, we allocate two tensors for key and value respectively,
239
- with shape of [num_blocks, block_size, num_head, head_size ]
239
+ with shape of [num_blocks, num_kv_heads, head_size, block_size ]
240
240
"""
241
- alloc_shape = (self .num_blocks , self .block_size , self .head_num , self .head_size )
241
+ alloc_shape = (self .num_blocks , self .head_num , self .head_size , self .block_size )
242
242
# TODO: Explore the performance when using difference shapes with kernel-related optimizations
243
- # e.g. [num_blocks, block_size, num_head // x, head_size, x]
243
+ # e.g. [num_blocks, num_kv_heads // x, head_size, block_size , x]
244
244
k_cache : List [torch .Tensor ] = []
245
245
v_cache : List [torch .Tensor ] = []
246
246
for _ in range (self .num_layers ):
0 commit comments