|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | + |
| 4 | +# This source code is licensed under the license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from typing import Callable |
| 8 | + |
| 9 | +import fire |
| 10 | +import pandas as pd |
| 11 | +import torch |
| 12 | +import triton |
| 13 | +from torch._inductor.utils import do_bench_using_profiling |
| 14 | + |
| 15 | +from torchao.prototype.mx_formats.custom_cast import ( |
| 16 | + to_mxfp8_across_dim0_and_dim1, |
| 17 | + to_mxfp8_across_dim0_and_dim1_reference, |
| 18 | + to_mxfp8_dim0_reference, |
| 19 | + to_mxfp8_dim1, |
| 20 | + to_mxfp8_dim1_reference, |
| 21 | +) |
| 22 | +from torchao.quantization.utils import compute_error |
| 23 | +from torchao.testing.float8.roofline_utils import get_specs |
| 24 | + |
| 25 | +torch.manual_seed(0) |
| 26 | + |
| 27 | +bytes_per_el_bf16 = 2 |
| 28 | +bytes_per_el_fp8 = 1 |
| 29 | + |
| 30 | + |
| 31 | +def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float: |
| 32 | + """Thin wrapper around do_bench_using_profiling""" |
| 33 | + no_args = lambda: func(*args, **kwargs) |
| 34 | + time = do_bench_using_profiling(no_args) |
| 35 | + return time * 1e3 |
| 36 | + |
| 37 | + |
| 38 | +def run( |
| 39 | + M: int = 4096, |
| 40 | + K: int = 2048, |
| 41 | + BLOCK_SIZE: int = 32, |
| 42 | + check_accuracy: bool = True, |
| 43 | +): |
| 44 | + print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}") |
| 45 | + print(f"GPU: {torch.cuda.get_device_name(0)}") |
| 46 | + print(f"torch version: {torch.__version__}") |
| 47 | + print(f"triton version: {triton.__version__}") |
| 48 | + |
| 49 | + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000 |
| 50 | + |
| 51 | + to_mxfp8_across_dim0_and_dim1_reference_c = torch.compile( |
| 52 | + to_mxfp8_across_dim0_and_dim1_reference |
| 53 | + ) |
| 54 | + to_mxfp8_dim0_reference_c = torch.compile(to_mxfp8_dim0_reference) |
| 55 | + to_mxfp8_dim1_reference_c = torch.compile(to_mxfp8_dim1_reference) |
| 56 | + |
| 57 | + # reference implementation (plain PyTorch + torch.compile) |
| 58 | + x_d0, x_d1, scale_e8m0_d0, scale_e8m0_d1 = ( |
| 59 | + to_mxfp8_across_dim0_and_dim1_reference_c(x, BLOCK_SIZE) |
| 60 | + ) |
| 61 | + |
| 62 | + # verify reference dim0_dim1 matches dim0 and dim1 separately |
| 63 | + x_d0_separate, scale_e8m0_d0_separate = to_mxfp8_dim0_reference_c(x, BLOCK_SIZE) |
| 64 | + x_d1_separate, scale_e8m0_d1_separate = to_mxfp8_dim1_reference_c(x, BLOCK_SIZE) |
| 65 | + torch.testing.assert_close(x_d0, x_d0_separate, atol=0, rtol=0) |
| 66 | + torch.testing.assert_close(x_d1, x_d1_separate, atol=0, rtol=0) |
| 67 | + torch.testing.assert_close(scale_e8m0_d0, scale_e8m0_d0_separate, atol=0, rtol=0) |
| 68 | + torch.testing.assert_close(scale_e8m0_d1, scale_e8m0_d1_separate, atol=0, rtol=0) |
| 69 | + |
| 70 | + x_d0, x_d1 = x_d0.bfloat16(), x_d1.bfloat16() |
| 71 | + scale_fp_d0 = scale_e8m0_d0.float() |
| 72 | + scale_fp_d1 = scale_e8m0_d1.float() |
| 73 | + x_d0_and_back = (x_d0.reshape(-1, BLOCK_SIZE) * scale_fp_d0).reshape(x_d0.shape) |
| 74 | + x_d1_and_back = ( |
| 75 | + (x_d1.t().reshape(-1, BLOCK_SIZE) * scale_fp_d1).reshape(x_d1.t().shape).t() |
| 76 | + ) |
| 77 | + |
| 78 | + sqnr_bf16_vs_dim0_ref = compute_error(x, x_d0_and_back) |
| 79 | + sqnr_bf16_vs_dim1_ref = compute_error(x, x_d1_and_back) |
| 80 | + print( |
| 81 | + f"bf16 vs normalized reference sqnrs: dim0 {sqnr_bf16_vs_dim0_ref}, dim1 {sqnr_bf16_vs_dim1_ref}" |
| 82 | + ) |
| 83 | + assert ( |
| 84 | + sqnr_bf16_vs_dim0_ref > 28 and sqnr_bf16_vs_dim1_ref > 28 |
| 85 | + ), "reference mx numerics are incorrect" |
| 86 | + |
| 87 | + # triton kernel for dim1 only |
| 88 | + x_d1_only_t, scale_e8m0_d1_only_t = to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE) |
| 89 | + |
| 90 | + # triton kernel for dim0 and dim1 |
| 91 | + x_d0_t, x_d1_t, scale_e8m0_d0_t, scale_e8m0_d1_t = to_mxfp8_across_dim0_and_dim1( |
| 92 | + x, tile_size=BLOCK_SIZE |
| 93 | + ) |
| 94 | + x_d0_t, x_d1_t, x_d1_only_t = ( |
| 95 | + x_d0_t.bfloat16(), |
| 96 | + x_d1_t.bfloat16(), |
| 97 | + x_d1_only_t.bfloat16(), |
| 98 | + ) |
| 99 | + |
| 100 | + # ensure bitwise equivalency of outputs with reference |
| 101 | + if check_accuracy: |
| 102 | + torch.testing.assert_close(x_d0, x_d0_t, atol=0, rtol=0) |
| 103 | + torch.testing.assert_close(x_d1, x_d1_t, atol=0, rtol=0) |
| 104 | + torch.testing.assert_close(scale_e8m0_d0, scale_e8m0_d0_t, atol=0, rtol=0) |
| 105 | + torch.testing.assert_close(scale_e8m0_d1, scale_e8m0_d1_t, atol=0, rtol=0) |
| 106 | + print("reference vs triton dim0_dim1 are bitwise equivalent") |
| 107 | + torch.testing.assert_close(x_d1, x_d1_only_t, atol=0, rtol=0) |
| 108 | + torch.testing.assert_close(scale_e8m0_d1, scale_e8m0_d1_only_t, atol=0, rtol=0) |
| 109 | + print("reference vs triton dim1 are bitwise equivalent") |
| 110 | + else: |
| 111 | + print("accuracy checking skipped") |
| 112 | + |
| 113 | + # now, measure performance |
| 114 | + |
| 115 | + # define a speed-of-light torch.compile kernel to get a sense of |
| 116 | + # achievable mem bandwidth |
| 117 | + def add_one(x): |
| 118 | + x = x + 1 |
| 119 | + return x |
| 120 | + |
| 121 | + add_one_c = torch.compile(add_one) |
| 122 | + |
| 123 | + for _ in range(2): |
| 124 | + __ = add_one_c(x) |
| 125 | + time_add_one_compile_us = benchmark_cuda_function_in_microseconds( |
| 126 | + lambda x: add_one(x), x |
| 127 | + ) |
| 128 | + |
| 129 | + for _ in range(2): |
| 130 | + __ = to_mxfp8_across_dim0_and_dim1_reference_c(x, BLOCK_SIZE) |
| 131 | + time_ref_dim0_dim1_compile_us = benchmark_cuda_function_in_microseconds( |
| 132 | + lambda x, b: to_mxfp8_across_dim0_and_dim1_reference_c(x, b), x, BLOCK_SIZE |
| 133 | + ) |
| 134 | + |
| 135 | + for _ in range(2): |
| 136 | + __ = to_mxfp8_dim0_reference_c(x, BLOCK_SIZE) |
| 137 | + time_ref_dim0_compile_us = benchmark_cuda_function_in_microseconds( |
| 138 | + lambda x, b: to_mxfp8_dim0_reference_c(x, b), x, BLOCK_SIZE |
| 139 | + ) |
| 140 | + |
| 141 | + for _ in range(2): |
| 142 | + __ = to_mxfp8_dim1_reference_c(x, BLOCK_SIZE) |
| 143 | + time_ref_dim1_compile_us = benchmark_cuda_function_in_microseconds( |
| 144 | + lambda x, b: to_mxfp8_dim1_reference_c(x, b), x, BLOCK_SIZE |
| 145 | + ) |
| 146 | + |
| 147 | + for _ in range(2): |
| 148 | + __ = to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE) |
| 149 | + time_triton_dim1_us = benchmark_cuda_function_in_microseconds( |
| 150 | + lambda x, b: to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE), |
| 151 | + x, |
| 152 | + BLOCK_SIZE, |
| 153 | + ) |
| 154 | + |
| 155 | + # warm up |
| 156 | + for _ in range(2): |
| 157 | + __ = to_mxfp8_across_dim0_and_dim1(x, tile_size=BLOCK_SIZE) |
| 158 | + time_triton_dim0_dim1_us = benchmark_cuda_function_in_microseconds( |
| 159 | + lambda x, b: to_mxfp8_across_dim0_and_dim1(x, tile_size=BLOCK_SIZE), |
| 160 | + x, |
| 161 | + BLOCK_SIZE, |
| 162 | + ) |
| 163 | + |
| 164 | + # calculate memory bandwidth |
| 165 | + peak_mem_bw = get_specs()["peak_mem_bw_bytes_sec"] |
| 166 | + |
| 167 | + # add_one kernel |
| 168 | + add_one_bps = x.numel() * bytes_per_el_bf16 * 2 / (time_add_one_compile_us / 1e6) |
| 169 | + |
| 170 | + # dim0 or dim1 kernel |
| 171 | + dim0_bytes_read = x.numel() * bytes_per_el_bf16 |
| 172 | + dim0_bytes_written = (x_d0_t.numel() + scale_e8m0_d0_t.numel()) * bytes_per_el_fp8 |
| 173 | + dim0_bytes_rw = dim0_bytes_read + dim0_bytes_written |
| 174 | + ref_dim0_bps = dim0_bytes_rw / (time_ref_dim0_compile_us / 1e6) |
| 175 | + ref_dim1_bps = dim0_bytes_rw / (time_ref_dim1_compile_us / 1e6) |
| 176 | + triton_dim1_bps = dim0_bytes_rw / (time_triton_dim1_us / 1e6) |
| 177 | + |
| 178 | + # triton dim0_dim1 kernel |
| 179 | + triton_dim0_dim1_bytes_read = x.numel() * bytes_per_el_bf16 |
| 180 | + triton_dim0_dim1_bytes_written = ( |
| 181 | + sum(x.numel() for x in (x_d0_t, x_d1_t, scale_e8m0_d0_t, scale_e8m0_d1_t)) |
| 182 | + * bytes_per_el_fp8 |
| 183 | + ) |
| 184 | + triton_dim0_dim1_bps = ( |
| 185 | + triton_dim0_dim1_bytes_read + triton_dim0_dim1_bytes_written |
| 186 | + ) / (time_triton_dim0_dim1_us / 1e6) |
| 187 | + |
| 188 | + results = [ |
| 189 | + ["add_one", time_add_one_compile_us, add_one_bps / 1e9], |
| 190 | + ["compile_dim0", time_ref_dim0_compile_us, ref_dim0_bps / 1e9], |
| 191 | + ["compile_dim1", time_ref_dim1_compile_us, ref_dim1_bps / 1e9], |
| 192 | + [ |
| 193 | + "compile_dim0_dim1", |
| 194 | + time_ref_dim0_dim1_compile_us, |
| 195 | + triton_dim0_dim1_bps / 1e9, |
| 196 | + ], |
| 197 | + ["triton_dim1", time_triton_dim1_us, triton_dim1_bps / 1e9], |
| 198 | + ["triton_dim0_dim1", time_triton_dim0_dim1_us, triton_dim0_dim1_bps / 1e9], |
| 199 | + ] |
| 200 | + df = pd.DataFrame(results, columns=["experiment", "time_us", "mem_bw_gbps"]) |
| 201 | + df["mem_bw_pct_peak"] = df["mem_bw_gbps"] * 1e9 / peak_mem_bw |
| 202 | + print("\n", df) |
| 203 | + |
| 204 | + |
| 205 | +if __name__ == "__main__": |
| 206 | + fire.Fire(run) |
0 commit comments