Skip to content

Commit 8e606ec

Browse files
authored
[Inference] Benchmarking rotary embedding and add a fetch function (#5277)
* fix bugs and add a cos/sin cache fetch func * add docstring * fix bug * fix
1 parent b785319 commit 8e606ec

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

tests/test_infer_ops/triton/test_rotary_embdding_unpad.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
import pytest
22
import torch
3+
from packaging import version
34
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
45

56
from colossalai.kernel.triton import rotary_embedding
67

8+
try:
9+
import triton # noqa
10+
11+
HAS_TRITON = True
12+
except ImportError:
13+
HAS_TRITON = False
14+
print("please install triton from https://github.com/openai/triton")
15+
16+
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
17+
718

819
def torch_rotary_emb(x, cos, sin):
920
seq_len, h, dim = x.shape
@@ -52,5 +63,52 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
5263
assert torch.allclose(k, k_ref, atol=1e-4, rtol=1e-4)
5364

5465

66+
BATCH = 16
67+
configs = [
68+
triton.testing.Benchmark(
69+
x_names=["num_tokens"],
70+
x_vals=[2**i for i in range(4, 11)],
71+
line_arg="provider",
72+
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"],
73+
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"],
74+
styles=[("red", "-"), ("blue", "-")],
75+
ylabel="ms",
76+
plot_name=f"rotary_emb-batch-{BATCH}",
77+
args={"num_kv_heads": 16},
78+
)
79+
]
80+
81+
82+
@triton.testing.perf_report(configs)
83+
def benchmark_rotary_emb(
84+
provider: str,
85+
num_tokens: int,
86+
num_kv_heads: int,
87+
):
88+
warmup = 10
89+
rep = 100
90+
91+
head_dim = 128
92+
dtype = torch.float16
93+
q_shape = (num_tokens, num_kv_heads, head_dim)
94+
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
95+
k_shape = (num_tokens, num_kv_heads, head_dim)
96+
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
97+
cos_shape = (num_tokens, head_dim // 2)
98+
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
99+
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
100+
101+
if provider == "torch_rotary_emb_func":
102+
fn = lambda: torch_rotary_emb(q, cos, sin)
103+
elif provider == "triton_rotary_emb_func":
104+
fn = lambda: rotary_embedding(q, k, cos, sin)
105+
else:
106+
raise ValueError("Undefined provider")
107+
108+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
109+
return ms
110+
111+
55112
if __name__ == "__main__":
56113
test_rotary_emb(4, 64, 32, 64, torch.float32)
114+
# benchmark_rotary_emb.run(save_path=".",print_data=True)

0 commit comments

Comments
 (0)