Skip to content

[Infer] Optimize Blocked KVCache And Kernels Using It #5325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
11 changes: 4 additions & 7 deletions colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width

# Physical cache allocation
alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size)
if verbose:
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches()
self._kv_caches = self._init_device_caches(alloc_shape)
self.total_physical_cache_size_in_bytes = (
self.elem_size_in_bytes
* self.num_layers
Expand Down Expand Up @@ -297,15 +297,12 @@ def _init_logical_caches(self):
blocks.append(cache_block)
return blocks

def _init_device_caches(self) -> Tuple[torch.Tensor, torch.Tensor]:
def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Initialize the physical cache on the device.

For each layer of the model, we allocate two tensors for key and value respectively,
with shape of [num_blocks, num_kv_heads, head_size, block_size]
with shape of [num_blocks, num_kv_heads, block_size, head_size]
"""
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
# TODO: Explore the performance when using difference shapes with kernel-related optimizations
# e.g. [num_blocks, num_kv_heads // x, head_size, block_size, x]
k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
Expand Down
28 changes: 14 additions & 14 deletions colossalai/inference/modeling/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
lengths: key/value lengths
block_tables
"""
num_blocks, num_heads, head_size, block_size = cache.shape
num_blocks, num_heads, block_size, head_size = cache.shape
bsz, max_blocks_per_seq = block_tables.shape
needed_blocks = (lengths + block_size - 1) // block_size

Expand All @@ -26,17 +26,17 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
block_num = needed_blocks[i]
token_id = 0
for block_idx in range(block_num - 1):
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 2, 0)
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 0, 2)
token_id += block_size
cache[block_tables[i][block_num - 1], :, :, : seq_len - token_id] = source[i][token_id:seq_len].permute(
1, 2, 0
cache[block_tables[i][block_num - 1], :, : seq_len - token_id, :] = source[i][token_id:seq_len].permute(
1, 0, 2
)
elif type == "decoding":
assert source.size(1) == 1, "seq_len should be equal to 1 when decoding."
source = source.squeeze(1)
slot_idx = (lengths + block_size - 1) % block_size
for i in range(bsz):
cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i]
cache[block_tables[i, needed_blocks[i] - 1], :, slot_idx[i], :] = source[i]

return cache

Expand All @@ -46,12 +46,12 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0):
"""
Func: convert key/value cache for calculation

Args: cache: shape [num_blocks, num_heads, head_size, block_size]
Args: cache: shape [num_blocks, num_heads, block_size, head_size]
lengths: key/value length
block_tables
pad_id: padded_id
"""
num_blocks, num_heads, head_size, block_size = cache.shape
num_blocks, num_heads, block_size, head_size = cache.shape

needed_blocks = (lengths + block_size - 1) // block_size
num_remaing_tokens = lengths % block_size
Expand All @@ -62,8 +62,8 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0):
for i in range(bsz):
_cache = torch.cat(
(
cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size),
cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1),
cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 2, 1, 3)).reshape(-1, num_heads, head_size),
cache[block_tables[i][needed_blocks[i] - 1], :, : num_remaing_tokens[i], :].permute(1, 0, 2),
),
dim=0,
)
Expand Down Expand Up @@ -127,7 +127,7 @@ def nopad_context_forward(
q: torch.Tensor, # [num_tokens, num_heads, head_size]
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
Expand All @@ -142,7 +142,7 @@ def nopad_context_forward(
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads

block_size = k_cache.shape[-1]
block_size = k_cache.size(-2)
bsz, max_blocks_per_sequence = block_tables.shape
max_seq_len = max_blocks_per_sequence * block_size
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
Expand Down Expand Up @@ -196,7 +196,7 @@ def pad_context_forward(
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
Expand All @@ -207,7 +207,7 @@ def pad_context_forward(
num_kv_heads = k.shape[-2]
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1]
block_size = k_cache.size(-2)
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
block_tables.shape[-1] * block_size

Expand Down Expand Up @@ -254,7 +254,7 @@ def pad_decoding_forward(
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
v_cache: torch.Tensor,
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
Expand Down
2 changes: 1 addition & 1 deletion colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def llama_attn_forward(

rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])

_, _, _, block_size = k_cache.shape
block_size = k_cache.size(-2)

if is_prompts:
attn_output = context_attention_unpadded(
Expand Down
2 changes: 1 addition & 1 deletion colossalai/inference/modeling/models/padding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def llama_attn_forward(

rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])

_, _, _, block_size = k_cache.shape
block_size = k_cache.size(-2)

if is_prompts:
attn_output = context_attention_unpadded(
Expand Down
22 changes: 11 additions & 11 deletions colossalai/kernel/triton/context_attn_unpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def _fwd_context_paged_attention_kernel(
stride_od,
stride_cacheb,
stride_cacheh,
stride_cached,
stride_cachebs,
stride_cached,
stride_bts,
stride_btb,
context_lengths,
Expand Down Expand Up @@ -158,29 +158,29 @@ def _fwd_context_paged_attention_kernel(
# Copy k to corresponding cache block
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kt = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offsets_k = K + offset_kv + offsets_dmodel[:, None] * stride_kd + offsets_kt[None, :] * stride_kt
k = tl.load(offsets_k, mask=offsets_kt[None, :] < cur_seq_len, other=0.0)
offsets_k = K + offset_kv + offsets_dmodel[None, :] * stride_kd + offsets_kt[:, None] * stride_kt
k = tl.load(offsets_k, mask=offsets_kt[:, None] < cur_seq_len, other=0.0)
offsets_kcachebs = tl.arange(0, BLOCK_SIZE)
offsets_kcache = (
KCache
+ offset_kvcache
+ offsets_dmodel[:, None] * stride_cached
+ offsets_kcachebs[None, :] * stride_cachebs
+ offsets_dmodel[None, :] * stride_cached
+ offsets_kcachebs[:, None] * stride_cachebs
)
tl.store(offsets_kcache, k, mask=offsets_kcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
tl.store(offsets_kcache, k, mask=offsets_kcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
# Copy v to corresponding cache block
offsets_vd = offsets_dmodel
offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
offsets_v = V + offset_kv + offsets_vt[:, None] * stride_vt + offsets_vd[None, :] * stride_vd
v = tl.load(offsets_v, mask=offsets_vt[:, None] < cur_seq_len, other=0.0)
offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd
v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0)
offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
offsets_vcache = (
VCache
+ offset_kvcache
+ offsets_vcachebs[:, None] * stride_cachebs
+ offsets_dmodel[None, :] * stride_cached
+ offsets_vcachebs[None, :] * stride_cachebs
+ offsets_dmodel[:, None] * stride_cached
)
tl.store(offsets_vcache, v, mask=offsets_vcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
tl.store(offsets_vcache, v, mask=offsets_vcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)

return

Expand Down
34 changes: 17 additions & 17 deletions colossalai/kernel/triton/flash_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
@triton.jit
def _flash_decoding_fwd_kernel(
Q, # [batch_size, head_num, q_len(1), head_dim]
KCache, # [num_blocks, num_kv_heads, head_dim, block_size]
VCache, # [num_blocks, num_kv_heads, head_dim, block_size]
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
block_tables, # [batch_size, max_blocks_per_sequence]
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
mid_o_lse, # [batch_size, head_num, kv_split_num]
Expand All @@ -22,8 +22,8 @@ def _flash_decoding_fwd_kernel(
stride_qd,
stride_cacheb,
stride_cacheh,
stride_cached,
stride_cachebs,
stride_cached,
stride_bts,
stride_btb,
stride_mid_ot,
Expand Down Expand Up @@ -79,18 +79,18 @@ def _flash_decoding_fwd_kernel(

K_block_ptr = tl.make_block_ptr(
base=KCache + offset_kvcache,
shape=(HEAD_DIM, cur_occupied_size),
strides=(stride_cached, stride_cachebs),
shape=(cur_occupied_size, HEAD_DIM),
strides=(stride_cachebs, stride_cached),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_SIZE),
block_shape=(BLOCK_SIZE, HEAD_DIM),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=VCache + offset_kvcache,
shape=(HEAD_DIM, cur_occupied_size),
strides=(stride_cached, stride_cachebs),
shape=(cur_occupied_size, HEAD_DIM),
strides=(stride_cachebs, stride_cached),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_SIZE),
block_shape=(BLOCK_SIZE, HEAD_DIM),
order=(0, 1),
)
k_cur_block = tl.load(K_block_ptr)
Expand All @@ -102,7 +102,7 @@ def _flash_decoding_fwd_kernel(
# NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16,
# Multiplying two tensors with shapes [1, d] * [d, block_size] will fail.
# Refer to https://github.com/openai/triton/discussions/895
S_ij += tl.sum(q[:, None] * k_cur_block, 0)
S_ij += tl.sum(q[None, :] * k_cur_block, 1)
S_ij *= sm_scale
S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf"))

Expand All @@ -111,7 +111,7 @@ def _flash_decoding_fwd_kernel(
p_ij_hat = tl.exp(S_ij)
l = tl.sum(p_ij_hat, 0)
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
acc += tl.sum(v_cur_block * p_ij_hat[None, :], 1)
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
acc = acc / l

offsets_mid_o = (
Expand Down Expand Up @@ -206,8 +206,8 @@ def flash_decoding_attention(

Args:
q (torch.Tensor): [bsz, num_heads, head_dim]
k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
kv_seq_len (torch.Tensor): [batch_size]
records the (kv) sequence lengths incorporating past kv sequence lengths.
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
Expand All @@ -230,13 +230,13 @@ def flash_decoding_attention(
assert head_dim in {32, 64, 128, 256}
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
f"Got incompatible batch size (number of seqs):\n"
f" KV seq lengths bsz {kv_seq_len.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
f" KV seq lengths bsz {kv_seq_len.size(0)}, Block tables bsz {block_tables.size(0)}, "
f"batch size {bsz}"
)
assert k_cache.size(-1) == v_cache.size(-1) == block_size, (
assert k_cache.size(-2) == v_cache.size(-2) == block_size, (
f"Got incompatible block size on kv caches:\n"
f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-1)}, "
f"v_cache block_size {v_cache.size(-1)}"
f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-2)}, "
f"v_cache block_size {v_cache.size(-2)}"
)

# NOTE BLOCK_KV could be considered as block splitting the sequence on k/v
Expand Down
33 changes: 14 additions & 19 deletions colossalai/kernel/triton/kvcache_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def _copy_to_kvcache_seqlen1_kernel(
stride_kd,
stride_cacheb,
stride_cacheh,
stride_cached,
stride_cachebs,
stride_cached,
stride_bts,
stride_btb,
block_size,
Expand All @@ -29,15 +29,15 @@ def _copy_to_kvcache_seqlen1_kernel(
last_bt_block_idx = past_kv_seq_len // block_size
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
offsets_in_last_block = (past_kv_seq_len % block_size) * stride_cachebs
offsets_in_last_block = past_kv_seq_len % block_size
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
kv = tl.load(KV + offsets_kv)
offsets_kvcache = (
block_id * stride_cacheb
+ cur_kv_head_idx * stride_cacheh
+ offsets_in_last_block * stride_cachebs
+ offsets_dmodel * stride_cached
+ offsets_in_last_block
)
tl.store(KVCache + offsets_kvcache, kv)
return
Expand All @@ -52,23 +52,18 @@ def copy_kv_to_blocked_cache(
"""
Copy keys or values to the blocked key/value cache during decoding stage.

Parameters:
- k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
- k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - Blocked key or value cache.
- kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
- block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
Args:
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
"""
assert k.size(-1) == k_cache.size(-2), "Incompatible head dim"
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
if k.dim() == 4:
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)"
bsz, _, num_kv_heads, head_dim = k.shape
# [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim]
k = k.squeeze(dim=1)
elif k.dim() == 3:
bsz, num_kv_heads, head_dim = k.shape
else:
raise ValueError(f"The key dim should be 3 or 4, but got {k.dim()}.")

k = k.squeeze(1) if k.dim() == 4 else k
assert k.dim() == 3, f"Incompatible k dim {k.dim()}"
bsz, num_kv_heads, head_dim = k.shape

assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
f"Got incompatible batch size (number of seqs):\n"
Expand All @@ -77,7 +72,7 @@ def copy_kv_to_blocked_cache(
)

# Modify if the shape of kv cahce is changed.
block_size = k_cache.size(-1)
block_size = k_cache.size(-2)

num_warps = 8 if head_dim > 128 else 4

Expand Down
2 changes: 1 addition & 1 deletion tests/test_infer/test_kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def check_cache_manager(test_config):
assert len(cache_manager._cache_blocks) == num_blocks
key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers
assert len(key_caches) == num_layers
expected_kv_shape = (num_blocks, num_attention_heads, head_size, block_size)
expected_kv_shape = (num_blocks, num_attention_heads, block_size, head_size)
assert key_caches[0].shape == expected_kv_shape
k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0)
expected_kv_block_shape = expected_kv_shape[1:]
Expand Down
Loading