|
1 | 1 | import pytest |
2 | 2 | import torch |
| 3 | +from packaging import version |
3 | 4 | from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb |
4 | 5 |
|
5 | 6 | from colossalai.kernel.triton import rotary_embedding |
6 | 7 |
|
| 8 | +try: |
| 9 | + import triton # noqa |
| 10 | + |
| 11 | + HAS_TRITON = True |
| 12 | +except ImportError: |
| 13 | + HAS_TRITON = False |
| 14 | + print("please install triton from https://github.com/openai/triton") |
| 15 | + |
| 16 | +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") |
| 17 | + |
7 | 18 |
|
8 | 19 | def torch_rotary_emb(x, cos, sin): |
9 | 20 | seq_len, h, dim = x.shape |
@@ -52,5 +63,52 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): |
52 | 63 | assert torch.allclose(k, k_ref, atol=1e-4, rtol=1e-4) |
53 | 64 |
|
54 | 65 |
|
| 66 | +BATCH = 16 |
| 67 | +configs = [ |
| 68 | + triton.testing.Benchmark( |
| 69 | + x_names=["num_tokens"], |
| 70 | + x_vals=[2**i for i in range(4, 11)], |
| 71 | + line_arg="provider", |
| 72 | + line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"], |
| 73 | + line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"], |
| 74 | + styles=[("red", "-"), ("blue", "-")], |
| 75 | + ylabel="ms", |
| 76 | + plot_name=f"rotary_emb-batch-{BATCH}", |
| 77 | + args={"num_kv_heads": 16}, |
| 78 | + ) |
| 79 | +] |
| 80 | + |
| 81 | + |
| 82 | +@triton.testing.perf_report(configs) |
| 83 | +def benchmark_rotary_emb( |
| 84 | + provider: str, |
| 85 | + num_tokens: int, |
| 86 | + num_kv_heads: int, |
| 87 | +): |
| 88 | + warmup = 10 |
| 89 | + rep = 100 |
| 90 | + |
| 91 | + head_dim = 128 |
| 92 | + dtype = torch.float16 |
| 93 | + q_shape = (num_tokens, num_kv_heads, head_dim) |
| 94 | + q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") |
| 95 | + k_shape = (num_tokens, num_kv_heads, head_dim) |
| 96 | + k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") |
| 97 | + cos_shape = (num_tokens, head_dim // 2) |
| 98 | + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") |
| 99 | + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") |
| 100 | + |
| 101 | + if provider == "torch_rotary_emb_func": |
| 102 | + fn = lambda: torch_rotary_emb(q, cos, sin) |
| 103 | + elif provider == "triton_rotary_emb_func": |
| 104 | + fn = lambda: rotary_embedding(q, k, cos, sin) |
| 105 | + else: |
| 106 | + raise ValueError("Undefined provider") |
| 107 | + |
| 108 | + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
| 109 | + return ms |
| 110 | + |
| 111 | + |
55 | 112 | if __name__ == "__main__": |
56 | 113 | test_rotary_emb(4, 64, 32, 64, torch.float32) |
| 114 | + # benchmark_rotary_emb.run(save_path=".",print_data=True) |
0 commit comments