Skip to content

Commit c1474e6

Browse files
committed
optimize shapes of tensors used as physical cache
1 parent 4158c3d commit c1474e6

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

colossalai/inference/kv_cache/kvcache_manager.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class KVCacheManager:
4444
config(InferenceConfig): The All-in-one inference configuration.
4545
"""
4646

47-
def __init__(self, config: InferenceConfig) -> None:
47+
def __init__(self, config: InferenceConfig, verbose: bool = False) -> None:
4848
self.logger = get_dist_logger(__name__)
4949
self.device = get_current_device()
5050

@@ -73,9 +73,9 @@ def __init__(self, config: InferenceConfig) -> None:
7373
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
7474

7575
# Physical cache allocation
76-
alloc_shape = (self.num_blocks, self.block_size, self.head_num, self.head_size)
77-
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
78-
# self._kv_caches = self._init_device_caches(alloc_shape)
76+
if verbose:
77+
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
78+
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
7979
self._kv_caches = self._init_device_caches()
8080
self.total_physical_cache_size_in_bytes = (
8181
self.elem_size_in_bytes
@@ -177,7 +177,7 @@ def free_cache_blocks(self, block_table: torch.Tensor) -> None:
177177
for i in range(block_table.numel()):
178178
global_block_id = block_table[i].item()
179179
block: CacheBlock = self._cache_blocks[global_block_id]
180-
block.remove_ref() # not going to clear the block thoroughly
180+
block.remove_ref()
181181
if not block.has_ref():
182182
block.allocated_size = 0
183183
self._free_blocks.append(block)
@@ -236,11 +236,11 @@ def _init_device_caches(self) -> Tuple[torch.Tensor, torch.Tensor]:
236236
"""Initialize the physical cache on the device.
237237
238238
For each layer of the model, we allocate two tensors for key and value respectively,
239-
with shape of [num_blocks, block_size, num_head, head_size]
239+
with shape of [num_blocks, num_kv_heads, head_size, block_size]
240240
"""
241-
alloc_shape = (self.num_blocks, self.block_size, self.head_num, self.head_size)
241+
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
242242
# TODO: Explore the performance when using difference shapes with kernel-related optimizations
243-
# e.g. [num_blocks, block_size, num_head // x, head_size, x]
243+
# e.g. [num_blocks, num_kv_heads // x, head_size, block_size, x]
244244
k_cache: List[torch.Tensor] = []
245245
v_cache: List[torch.Tensor] = []
246246
for _ in range(self.num_layers):

tests/test_infer/test_kvcache_manager.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class SampleConfig:
1919
max_output_length: int
2020
beam_width: int
2121
dtype: torch.dtype
22+
tp_size: int
2223

2324

2425
@parameterize(
@@ -60,6 +61,7 @@ def test_logical_blocks(test_config):
6061
"max_output_length": 32,
6162
"dtype": torch.float32,
6263
"beam_width": 1,
64+
"tp_size": 1,
6365
},
6466
{
6567
"num_attention_heads": 4,
@@ -71,6 +73,7 @@ def test_logical_blocks(test_config):
7173
"max_output_length": 32,
7274
"dtype": torch.float16,
7375
"beam_width": 3,
76+
"tp_size": 1,
7477
},
7578
],
7679
)
@@ -92,7 +95,7 @@ def test_cache_manager(test_config):
9295
assert len(cache_manager._allocated_blocks) == 0
9396
key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers
9497
assert len(key_caches) == test_config["num_layers"]
95-
expected_kv_shape = (num_blocks, block_size, num_heads, head_size)
98+
expected_kv_shape = (num_blocks, num_heads, head_size, block_size)
9699
assert key_caches[0].shape == expected_kv_shape
97100
k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0)
98101
expected_kv_block_shape = expected_kv_shape[1:]

0 commit comments

Comments
 (0)