Skip to content

Commit 12f10d5

Browse files
authored
[Fix/Inference]Fix CUDA Rotary Rmbedding GQA (#5623)
* fix rotary embedding GQA * change test_rotary_embdding_unpad.py KH
1 parent 5d4c1fe commit 12f10d5

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ __device__ void apply_k_rotary_emb_compute(
115115
(head_offset % shard_block_size) / VecSize;
116116
const int64_t addr_offset =
117117
token_id * key_stride + (i / half_head_dim) * head_dim + head_offset;
118-
const int64_t target_id = block_id * head_num * head_dim * block_size +
118+
const int64_t target_id = block_id * kv_head_num * head_dim * block_size +
119119
(i / half_head_dim) * block_size * head_dim +
120120
block_offset * head_dim + head_offset;
121121

@@ -137,7 +137,7 @@ __device__ void apply_k_rotary_emb_compute(
137137

138138
// apply value memcopy
139139
apply_kv_memcopy<scalar_t, VecSize>(
140-
value, value_cache, value_stride, token_id, block_id, head_num * head_dim,
140+
value, value_cache, value_stride, token_id, block_id, kv_head_num * head_dim,
141141
block_size, block_offset, head_dim, half_head_dim);
142142
}
143143

tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ def numpy_allclose(x, y, rtol, atol):
2121
@pytest.mark.parametrize("BATCH_SIZE", [4])
2222
@pytest.mark.parametrize("SEQ_LEN", [64])
2323
@pytest.mark.parametrize("H", [32])
24+
@pytest.mark.parametrize("K_H", [16, 32])
2425
@pytest.mark.parametrize("D", [64])
2526
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
26-
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
27+
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
2728
torch.manual_seed(10)
2829
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
2930
# our crafted op equals to Transformers
@@ -43,21 +44,21 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
4344
max_blocks_per_sequence = (TOTAL_TOKENS + block_size - 1) // block_size
4445
q_shape = (TOTAL_TOKENS, H, D)
4546
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
46-
k_shape = (TOTAL_TOKENS, H, D)
47+
k_shape = (TOTAL_TOKENS, K_H, D)
4748
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
4849
cos_shape = (TOTAL_TOKENS, D // 2)
4950
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
5051
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
51-
cache_shape = (BATCH_SIZE * max_blocks_per_sequence, H, block_size, D)
52+
cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, block_size, D)
5253
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
5354
v = torch.randn_like(k)
5455
v_cache = torch.zeros_like(k_cache)
5556
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
5657
block_tables = mock_alloc_block_table_and_kvcache_v2(
5758
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size
5859
)
59-
new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
60-
new_q = torch.randn_like(new_k)
60+
new_k = torch.randn((BATCH_SIZE, K_H, D), dtype=dtype, device="cuda")
61+
new_q = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
6162
new_v = torch.randn_like(new_k)
6263

6364
kv_seq_lengths = past_kv_seq_lengths + 1
@@ -123,4 +124,4 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
123124

124125

125126
if __name__ == "__main__":
126-
test_rotary_emb(16, 64, 4, 128, torch.float16)
127+
test_rotary_emb(16, 64, 32, 16, 128, torch.float16)

0 commit comments

Comments
 (0)