@@ -21,9 +21,10 @@ def numpy_allclose(x, y, rtol, atol):
2121@pytest .mark .parametrize ("BATCH_SIZE" , [4 ])
2222@pytest .mark .parametrize ("SEQ_LEN" , [64 ])
2323@pytest .mark .parametrize ("H" , [32 ])
24+ @pytest .mark .parametrize ("K_H" , [16 , 32 ])
2425@pytest .mark .parametrize ("D" , [64 ])
2526@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .float32 ])
26- def test_rotary_emb (BATCH_SIZE , SEQ_LEN , H , D , dtype ):
27+ def test_rotary_emb (BATCH_SIZE , SEQ_LEN , H , K_H , D , dtype ):
2728 torch .manual_seed (10 )
2829 TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
2930 # our crafted op equals to Transformers
@@ -43,21 +44,21 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
4344 max_blocks_per_sequence = (TOTAL_TOKENS + block_size - 1 ) // block_size
4445 q_shape = (TOTAL_TOKENS , H , D )
4546 q = - 2.3 + 0.5 * torch .randn (q_shape , dtype = dtype , device = "cuda" )
46- k_shape = (TOTAL_TOKENS , H , D )
47+ k_shape = (TOTAL_TOKENS , K_H , D )
4748 k = - 2.3 + 0.5 * torch .randn (k_shape , dtype = dtype , device = "cuda" )
4849 cos_shape = (TOTAL_TOKENS , D // 2 )
4950 cos = - 1.2 + 0.5 * torch .randn (cos_shape , dtype = dtype , device = "cuda" )
5051 sin = - 2.0 + 0.5 * torch .randn (cos_shape , dtype = dtype , device = "cuda" )
51- cache_shape = (BATCH_SIZE * max_blocks_per_sequence , H , block_size , D )
52+ cache_shape = (BATCH_SIZE * max_blocks_per_sequence , K_H , block_size , D )
5253 k_cache = torch .zeros (size = cache_shape , dtype = dtype , device = "cuda" )
5354 v = torch .randn_like (k )
5455 v_cache = torch .zeros_like (k_cache )
5556 past_kv_seq_lengths = torch .tensor ([SEQ_LEN - 1 for _ in range (BATCH_SIZE )], dtype = torch .int32 , device = "cuda" )
5657 block_tables = mock_alloc_block_table_and_kvcache_v2 (
5758 k , v , k_cache , v_cache , past_kv_seq_lengths , BATCH_SIZE , max_blocks_per_sequence , block_size
5859 )
59- new_k = torch .randn ((BATCH_SIZE , H , D ), dtype = dtype , device = "cuda" )
60- new_q = torch .randn_like ( new_k )
60+ new_k = torch .randn ((BATCH_SIZE , K_H , D ), dtype = dtype , device = "cuda" )
61+ new_q = torch .randn (( BATCH_SIZE , H , D ), dtype = dtype , device = "cuda" )
6162 new_v = torch .randn_like (new_k )
6263
6364 kv_seq_lengths = past_kv_seq_lengths + 1
@@ -123,4 +124,4 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
123124
124125
125126if __name__ == "__main__" :
126- test_rotary_emb (16 , 64 , 4 , 128 , torch .float16 )
127+ test_rotary_emb (16 , 64 , 32 , 16 , 128 , torch .float16 )
0 commit comments