3
3
from packaging import version
4
4
from transformers .models .llama .modeling_llama import LlamaRotaryEmbedding , apply_rotary_pos_emb
5
5
6
- from colossalai .kernel .triton import rotary_embedding
6
+ from colossalai .kernel .triton import copy_kv_to_blocked_cache , rotary_embedding
7
7
from tests .test_infer .test_ops .triton .kernel_utils import mock_alloc_block_table_and_kvcache_v2
8
8
9
9
try :
@@ -94,8 +94,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
94
94
x_names = ["num_tokens" ],
95
95
x_vals = [2 ** i for i in range (4 , 11 )],
96
96
line_arg = "provider" ,
97
- line_vals = ["torch_rotary_emb_func " , "triton_rotary_emb_func " ],
98
- line_names = ["torch_rotary_emb_func " , "triton_rotary_emb_func " ],
97
+ line_vals = ["no_fused_rotary_emb_func " , "fused_triton_rotary_emb_func " ],
98
+ line_names = ["no_fused_rotary_emb_func " , "fused_triton_rotary_emb_func " ],
99
99
styles = [("red" , "-" ), ("blue" , "-" )],
100
100
ylabel = "ms" ,
101
101
plot_name = f"rotary_emb-batch-{ BATCH } " ,
@@ -110,23 +110,43 @@ def benchmark_rotary_emb(
110
110
num_tokens : int ,
111
111
num_kv_heads : int ,
112
112
):
113
+ BATCH_SIZE = 4
114
+ SEQ_LEN = num_tokens // BATCH_SIZE
115
+ max_num_blocks_per_seq = 8
116
+ block_size = 64
113
117
warmup = 10
114
118
rep = 100
115
119
116
- head_dim = 128
120
+ head_dim = 256
117
121
dtype = torch .float16
122
+
118
123
q_shape = (num_tokens , num_kv_heads , head_dim )
119
124
q = - 2.3 + 0.5 * torch .randn (q_shape , dtype = dtype , device = "cuda" )
120
125
k_shape = (num_tokens , num_kv_heads , head_dim )
121
126
k = - 2.3 + 0.5 * torch .randn (k_shape , dtype = dtype , device = "cuda" )
122
127
cos_shape = (num_tokens , head_dim // 2 )
123
128
cos = - 1.2 + 0.5 * torch .randn (cos_shape , dtype = dtype , device = "cuda" )
124
129
sin = - 2.0 + 0.5 * torch .randn (cos_shape , dtype = dtype , device = "cuda" )
130
+ cache_shape = (BATCH_SIZE * max_num_blocks_per_seq , num_kv_heads , block_size , head_dim )
131
+ k_cache = torch .zeros (size = cache_shape , dtype = dtype , device = "cuda" )
132
+ v = torch .randn_like (k )
133
+ v_cache = torch .zeros_like (k_cache )
134
+ past_kv_seq_lengths = torch .tensor ([SEQ_LEN - 1 for _ in range (BATCH_SIZE )], dtype = torch .int32 , device = "cuda" )
135
+ block_tables = mock_alloc_block_table_and_kvcache_v2 (
136
+ k , v , k_cache , v_cache , past_kv_seq_lengths , BATCH_SIZE , max_num_blocks_per_seq , block_size
137
+ )
138
+ new_k = torch .randn ((BATCH_SIZE , num_kv_heads , head_dim ), dtype = dtype , device = "cuda" )
139
+ new_q = torch .randn_like (new_k )
140
+ kv_seq_lengths = past_kv_seq_lengths + 1
141
+ block_tables = block_tables .to (device = "cuda" )
125
142
126
- if provider == "torch_rotary_emb_func" :
127
- fn = lambda : torch_rotary_emb (q , cos , sin )
128
- elif provider == "triton_rotary_emb_func" :
129
- fn = lambda : rotary_embedding (q , k , cos , sin )
143
+ if provider == "no_fused_rotary_emb_func" :
144
+ fn = lambda : [
145
+ rotary_embedding (new_q , new_k , cos , sin ),
146
+ copy_kv_to_blocked_cache (new_k , k_cache , kv_lengths = kv_seq_lengths , block_tables = block_tables ),
147
+ ]
148
+ elif provider == "fused_triton_rotary_emb_func" :
149
+ fn = lambda : rotary_embedding (new_q , new_k , cos , sin , k_cache , block_tables , kv_seq_lengths )
130
150
else :
131
151
raise ValueError ("Undefined provider" )
132
152
@@ -135,5 +155,5 @@ def benchmark_rotary_emb(
135
155
136
156
137
157
if __name__ == "__main__" :
138
- test_rotary_emb (4 , 64 , 32 , 64 , torch .float32 )
139
- # benchmark_rotary_emb.run(save_path=".",print_data=True)
158
+ # test_rotary_emb(4, 64, 32, 64, torch.float32)
159
+ benchmark_rotary_emb .run (save_path = "." , print_data = True )
0 commit comments