|  | 
|  | 1 | +# SPDX-License-Identifier: Apache-2.0 | 
|  | 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | 
|  | 3 | + | 
|  | 4 | +import csv | 
|  | 5 | +import os | 
|  | 6 | +import random | 
|  | 7 | +from datetime import datetime | 
|  | 8 | + | 
|  | 9 | +import flashinfer | 
|  | 10 | +import torch | 
|  | 11 | + | 
|  | 12 | +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 | 
|  | 13 | + | 
|  | 14 | +# KV Cache Layout for TRT-LLM | 
|  | 15 | +# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) | 
|  | 16 | + | 
|  | 17 | + | 
|  | 18 | +def to_float8(x, dtype=torch.float8_e4m3fn): | 
|  | 19 | +    finfo = torch.finfo(dtype) | 
|  | 20 | +    min_val, max_val = x.aminmax() | 
|  | 21 | +    amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) | 
|  | 22 | +    scale = finfo.max / amax * 0.1 | 
|  | 23 | +    x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) | 
|  | 24 | +    return x_scl_sat.to(dtype), scale.float().reciprocal() | 
|  | 25 | + | 
|  | 26 | + | 
|  | 27 | +@torch.no_grad() | 
|  | 28 | +def benchmark_decode( | 
|  | 29 | +    num_seqs, | 
|  | 30 | +    max_seq_len, | 
|  | 31 | +    page_size=16, | 
|  | 32 | +    dtype=torch.bfloat16, | 
|  | 33 | +    kv_layout="HND", | 
|  | 34 | +    num_kv_heads=8, | 
|  | 35 | +    kv_cache_dtype="auto", | 
|  | 36 | +    head_dim=128, | 
|  | 37 | +    warmup=10, | 
|  | 38 | +    trials=20, | 
|  | 39 | +): | 
|  | 40 | +    torch.set_default_device("cuda") | 
|  | 41 | +    device = "cuda" | 
|  | 42 | +    torch.manual_seed(0) | 
|  | 43 | + | 
|  | 44 | +    # Currently only HEAD_GRP_SIZE == 8 is supported | 
|  | 45 | +    HEAD_GRP_SIZE = 8 | 
|  | 46 | +    MAX_SEQ_LEN = max_seq_len | 
|  | 47 | + | 
|  | 48 | +    # large number to reduce kv_cache reuse | 
|  | 49 | +    NUM_BLOCKS = int(256000 / page_size) | 
|  | 50 | + | 
|  | 51 | +    workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device) | 
|  | 52 | + | 
|  | 53 | +    # For decode, batch_size is num_decode_token | 
|  | 54 | +    num_qo_heads = num_kv_heads * HEAD_GRP_SIZE | 
|  | 55 | +    sm_scale = float(1.0 / (head_dim**0.5)) | 
|  | 56 | +    q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype) | 
|  | 57 | +    kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] | 
|  | 58 | + | 
|  | 59 | +    max_kv_len = max(kv_lens) | 
|  | 60 | +    kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device) | 
|  | 61 | +    max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size | 
|  | 62 | + | 
|  | 63 | +    block_tables = torch.randint( | 
|  | 64 | +        0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 | 
|  | 65 | +    ) | 
|  | 66 | + | 
|  | 67 | +    kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) | 
|  | 68 | +    kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype) | 
|  | 69 | +    k_scale = v_scale = 1.0 | 
|  | 70 | + | 
|  | 71 | +    if kv_cache_dtype.startswith("fp8"): | 
|  | 72 | +        kv_cache, _ = to_float8(kv_cache) | 
|  | 73 | + | 
|  | 74 | +    # Benchmark TRT decode | 
|  | 75 | +    def trt_decode(): | 
|  | 76 | +        return flashinfer.decode.trtllm_batch_decode_with_kv_cache( | 
|  | 77 | +            q, | 
|  | 78 | +            kv_cache, | 
|  | 79 | +            workspace_buffer, | 
|  | 80 | +            num_qo_heads, | 
|  | 81 | +            num_kv_heads, | 
|  | 82 | +            sm_scale, | 
|  | 83 | +            block_tables, | 
|  | 84 | +            kv_lens_tensor, | 
|  | 85 | +            page_size, | 
|  | 86 | +            max_kv_len, | 
|  | 87 | +            kv_cache_dtype, | 
|  | 88 | +            k_scale, | 
|  | 89 | +            v_scale, | 
|  | 90 | +        ) | 
|  | 91 | + | 
|  | 92 | +    def time_fn(fn, warmup=10, trials=20): | 
|  | 93 | +        torch.cuda.synchronize() | 
|  | 94 | +        start = torch.cuda.Event(enable_timing=True) | 
|  | 95 | +        end = torch.cuda.Event(enable_timing=True) | 
|  | 96 | +        times = [] | 
|  | 97 | +        for i in range(warmup): | 
|  | 98 | +            fn() | 
|  | 99 | +        for i in range(trials): | 
|  | 100 | +            start.record() | 
|  | 101 | +            fn() | 
|  | 102 | +            end.record() | 
|  | 103 | +            torch.cuda.synchronize() | 
|  | 104 | +            times.append(start.elapsed_time(end))  # ms | 
|  | 105 | +        return sum(times) / len(times), torch.std(torch.tensor(times)) | 
|  | 106 | + | 
|  | 107 | +    # TRT Decode | 
|  | 108 | +    trt_mean, trt_std = time_fn(trt_decode) | 
|  | 109 | + | 
|  | 110 | +    kv_indptr = [0] | 
|  | 111 | +    kv_indices = [] | 
|  | 112 | +    kv_last_page_lens = [] | 
|  | 113 | +    for i in range(num_seqs): | 
|  | 114 | +        seq_len = kv_lens[i] | 
|  | 115 | +        assert seq_len > 0 | 
|  | 116 | +        num_blocks = (seq_len + page_size - 1) // page_size | 
|  | 117 | +        kv_indices.extend(block_tables[i, :num_blocks]) | 
|  | 118 | +        kv_indptr.append(kv_indptr[-1] + num_blocks) | 
|  | 119 | +        kv_last_page_len = seq_len % page_size | 
|  | 120 | +        if kv_last_page_len == 0: | 
|  | 121 | +            kv_last_page_len = page_size | 
|  | 122 | +        kv_last_page_lens.append(kv_last_page_len) | 
|  | 123 | + | 
|  | 124 | +    kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) | 
|  | 125 | +    kv_indices = torch.tensor(kv_indices, dtype=torch.int32) | 
|  | 126 | +    kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) | 
|  | 127 | + | 
|  | 128 | +    wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( | 
|  | 129 | +        workspace_buffer, | 
|  | 130 | +        kv_layout, | 
|  | 131 | +        use_tensor_cores=((num_qo_heads // num_kv_heads) > 4), | 
|  | 132 | +    ) | 
|  | 133 | + | 
|  | 134 | +    wrapper.plan( | 
|  | 135 | +        kv_indptr, | 
|  | 136 | +        kv_indices, | 
|  | 137 | +        kv_last_page_lens, | 
|  | 138 | +        num_qo_heads, | 
|  | 139 | +        num_kv_heads, | 
|  | 140 | +        head_dim, | 
|  | 141 | +        page_size, | 
|  | 142 | +        "NONE", | 
|  | 143 | +        q_data_type=dtype, | 
|  | 144 | +        kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype, | 
|  | 145 | +    ) | 
|  | 146 | + | 
|  | 147 | +    def baseline_decode(): | 
|  | 148 | +        return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale) | 
|  | 149 | + | 
|  | 150 | +    baseline_mean, baseline_std = time_fn(baseline_decode) | 
|  | 151 | + | 
|  | 152 | +    # Calculate percentage speedup (positive means TRT is faster) | 
|  | 153 | +    speedup_percent = (baseline_mean - trt_mean) / baseline_mean | 
|  | 154 | + | 
|  | 155 | +    print( | 
|  | 156 | +        f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}" | 
|  | 157 | +        f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}" | 
|  | 158 | +    ) | 
|  | 159 | + | 
|  | 160 | +    # Return results for CSV writing | 
|  | 161 | +    return { | 
|  | 162 | +        "num_seqs": num_seqs, | 
|  | 163 | +        "trt_mean": trt_mean, | 
|  | 164 | +        "trt_std": trt_std.item(), | 
|  | 165 | +        "baseline_mean": baseline_mean, | 
|  | 166 | +        "baseline_std": baseline_std.item(), | 
|  | 167 | +        "speedup_percent": speedup_percent, | 
|  | 168 | +        "q_dtype": str(dtype), | 
|  | 169 | +        "kv_cache_dtype": kv_cache_dtype, | 
|  | 170 | +        "page_size": page_size, | 
|  | 171 | +        "num_kv_heads": num_kv_heads, | 
|  | 172 | +        "head_dim": head_dim, | 
|  | 173 | +        "max_seq_len": max_seq_len, | 
|  | 174 | +    } | 
|  | 175 | + | 
|  | 176 | + | 
|  | 177 | +def write_results_to_csv(results, filename=None): | 
|  | 178 | +    """Write benchmark results to CSV file.""" | 
|  | 179 | +    if filename is None: | 
|  | 180 | +        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | 
|  | 181 | +        filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" | 
|  | 182 | + | 
|  | 183 | +    fieldnames = [ | 
|  | 184 | +        "num_seqs", | 
|  | 185 | +        "trt_mean", | 
|  | 186 | +        "trt_std", | 
|  | 187 | +        "baseline_mean", | 
|  | 188 | +        "baseline_std", | 
|  | 189 | +        "speedup_percent", | 
|  | 190 | +        "q_dtype", | 
|  | 191 | +        "kv_cache_dtype", | 
|  | 192 | +        "page_size", | 
|  | 193 | +        "num_kv_heads", | 
|  | 194 | +        "head_dim", | 
|  | 195 | +        "max_seq_len", | 
|  | 196 | +    ] | 
|  | 197 | + | 
|  | 198 | +    file_exists = os.path.exists(filename) | 
|  | 199 | + | 
|  | 200 | +    with open(filename, "a", newline="") as csvfile: | 
|  | 201 | +        writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | 
|  | 202 | + | 
|  | 203 | +        if not file_exists: | 
|  | 204 | +            writer.writeheader() | 
|  | 205 | + | 
|  | 206 | +        for result in results: | 
|  | 207 | +            writer.writerow(result) | 
|  | 208 | + | 
|  | 209 | +    print(f"Results written to {filename}") | 
|  | 210 | + | 
|  | 211 | + | 
|  | 212 | +if __name__ == "__main__": | 
|  | 213 | +    num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] | 
|  | 214 | +    max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] | 
|  | 215 | +    all_results = [] | 
|  | 216 | + | 
|  | 217 | +    print("Running benchmark for kv_cache_dtype: bfloat16") | 
|  | 218 | +    print( | 
|  | 219 | +        "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent" | 
|  | 220 | +    ) | 
|  | 221 | +    for max_seq_len in max_seq_lens: | 
|  | 222 | +        for bs in num_seqs: | 
|  | 223 | +            result = benchmark_decode( | 
|  | 224 | +                bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto" | 
|  | 225 | +            ) | 
|  | 226 | +            all_results.append(result) | 
|  | 227 | + | 
|  | 228 | +    print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8") | 
|  | 229 | +    print( | 
|  | 230 | +        "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent" | 
|  | 231 | +    ) | 
|  | 232 | +    for max_seq_len in max_seq_lens: | 
|  | 233 | +        for bs in num_seqs: | 
|  | 234 | +            result = benchmark_decode( | 
|  | 235 | +                bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8" | 
|  | 236 | +            ) | 
|  | 237 | +            all_results.append(result) | 
|  | 238 | + | 
|  | 239 | +    # Write all results to CSV | 
|  | 240 | +    write_results_to_csv(all_results) | 
0 commit comments