Skip to content

Commit 31b905e

Browse files
committed
revise block size retrieval
1 parent 6bb1809 commit 31b905e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

colossalai/inference/modeling/layers/attention.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def pad_context_forward(
207207
num_kv_heads = k.shape[-2]
208208
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
209209
num_kv_groups = num_heads // num_kv_heads
210-
block_size = k_cache.shape[-1]
210+
block_size = k_cache.size(-2)
211211
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
212212
block_tables.shape[-1] * block_size
213213

0 commit comments

Comments
 (0)