|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | + |
| 4 | +# This source code is licensed under the license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +""" |
| 8 | +Starting with https://github.com/vkuzo/pytorch_scripts/blob/main/mx_cast_poc/20250305_mx_dim0_dim1_cast.py |
| 9 | +and making it nice. |
| 10 | +""" |
| 11 | + |
| 12 | +from typing import Callable |
| 13 | + |
| 14 | +import fire |
| 15 | +import torch |
| 16 | +import triton |
| 17 | +from torch._inductor.utils import do_bench_using_profiling |
| 18 | + |
| 19 | +from torchao.prototype.mx_formats.custom_cast import ( |
| 20 | + to_mxfp8_across_dim0_and_dim1, |
| 21 | + to_mxfp8_across_dim0_and_dim1_reference, |
| 22 | + to_mxfp8_dim0_reference, |
| 23 | + to_mxfp8_dim1, |
| 24 | + to_mxfp8_dim1_reference, |
| 25 | +) |
| 26 | +from torchao.testing.float8.roofline_utils import get_specs |
| 27 | + |
| 28 | +torch.manual_seed(0) |
| 29 | + |
| 30 | +bytes_per_el_bf16 = 2 |
| 31 | +bytes_per_el_fp8 = 1 |
| 32 | + |
| 33 | + |
| 34 | +def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float: |
| 35 | + """Thin wrapper around do_bench_using_profiling""" |
| 36 | + no_args = lambda: func(*args, **kwargs) |
| 37 | + time = do_bench_using_profiling(no_args) |
| 38 | + return time * 1e3 |
| 39 | + |
| 40 | + |
| 41 | +def compute_error(x, y): |
| 42 | + Ps = torch.linalg.norm(x) |
| 43 | + Pn = torch.linalg.norm(x - y) |
| 44 | + return 20 * torch.log10(Ps / Pn) |
| 45 | + |
| 46 | + |
| 47 | +def run( |
| 48 | + M: int = 4096, |
| 49 | + K: int = 2048, |
| 50 | + BLOCK_SIZE: int = 32, |
| 51 | + check_accuracy: bool = True, |
| 52 | + mode: str = "dim0_and_dim1", |
| 53 | +): |
| 54 | + print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}") |
| 55 | + print(f"GPU: {torch.cuda.get_device_name(0)}") |
| 56 | + print(f"torch version: {torch.__version__}") |
| 57 | + print(f"triton version: {triton.__version__}") |
| 58 | + print(f"mode: {mode}") |
| 59 | + assert mode in "dim0_and_dim1", "dim1" |
| 60 | + |
| 61 | + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000 |
| 62 | + |
| 63 | + to_mxfp8_across_dim0_and_dim1_reference_c = torch.compile( |
| 64 | + to_mxfp8_across_dim0_and_dim1_reference |
| 65 | + ) |
| 66 | + to_mxfp8_dim0_reference_c = torch.compile(to_mxfp8_dim0_reference) |
| 67 | + to_mxfp8_dim1_reference_c = torch.compile(to_mxfp8_dim1_reference) |
| 68 | + |
| 69 | + # reference implementation (plain PyTorch + torch.compile) |
| 70 | + if mode == "dim0_and_dim1": |
| 71 | + # TODO remove the mode here? |
| 72 | + x_d0, x_d1, scale_e8m0_d0, scale_e8m0_d1 = ( |
| 73 | + to_mxfp8_across_dim0_and_dim1_reference_c(x, BLOCK_SIZE) |
| 74 | + ) |
| 75 | + else: # dim1 |
| 76 | + x_d0, scale_e8m0_d0 = to_mxfp8_dim0_reference_c(x, BLOCK_SIZE) |
| 77 | + x_d1, scale_e8m0_d1 = to_mxfp8_dim1_reference_c(x, BLOCK_SIZE) |
| 78 | + |
| 79 | + x_d0, x_d1 = x_d0.bfloat16(), x_d1.bfloat16() |
| 80 | + scale_fp_d0 = scale_e8m0_d0.float() |
| 81 | + scale_fp_d1 = scale_e8m0_d1.float() |
| 82 | + x_d0_and_back = (x_d0.reshape(-1, BLOCK_SIZE) * scale_fp_d0).reshape(x_d0.shape) |
| 83 | + x_d1_and_back = ( |
| 84 | + (x_d1.t().reshape(-1, BLOCK_SIZE) * scale_fp_d1).reshape(x_d1.t().shape).t() |
| 85 | + ) |
| 86 | + |
| 87 | + sqnr_bf16_vs_dim0_ref = compute_error(x, x_d0_and_back) |
| 88 | + sqnr_bf16_vs_dim1_ref = compute_error(x, x_d1_and_back) |
| 89 | + print( |
| 90 | + f"bf16 vs normalized reference sqnrs: dim0 {sqnr_bf16_vs_dim0_ref}, dim1 {sqnr_bf16_vs_dim1_ref}" |
| 91 | + ) |
| 92 | + assert ( |
| 93 | + sqnr_bf16_vs_dim0_ref > 28 and sqnr_bf16_vs_dim1_ref > 28 |
| 94 | + ), "reference mx numerics are incorrect" |
| 95 | + |
| 96 | + # triton kernel for dim1 only |
| 97 | + x_d1_only_t, scale_e8m0_d1_only_t = to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE) |
| 98 | + |
| 99 | + # triton kernel for dim0 and dim1 |
| 100 | + x_d0_t, x_d1_t, scale_e8m0_d0_t, scale_e8m0_d1_t = to_mxfp8_across_dim0_and_dim1( |
| 101 | + x, tile_size=BLOCK_SIZE |
| 102 | + ) |
| 103 | + x_d0_t, x_d1_t, x_d1_only_t = ( |
| 104 | + x_d0_t.bfloat16(), |
| 105 | + x_d1_t.bfloat16(), |
| 106 | + x_d1_only_t.bfloat16(), |
| 107 | + ) |
| 108 | + |
| 109 | + # ensure bitwise equivalency of outputs with reference |
| 110 | + if check_accuracy: |
| 111 | + torch.testing.assert_close(x_d0, x_d0_t, atol=0, rtol=0) |
| 112 | + torch.testing.assert_close(x_d1, x_d1_t, atol=0, rtol=0) |
| 113 | + torch.testing.assert_close(scale_e8m0_d0, scale_e8m0_d0_t, atol=0, rtol=0) |
| 114 | + torch.testing.assert_close(scale_e8m0_d1, scale_e8m0_d1_t, atol=0, rtol=0) |
| 115 | + torch.testing.assert_close(x_d1, x_d1_only_t, atol=0, rtol=0) |
| 116 | + # print('reference', scale_e8m0_d1) |
| 117 | + # print('triton', scale_e8m0_d1_only_t) |
| 118 | + torch.testing.assert_close(scale_e8m0_d1, scale_e8m0_d1_only_t, atol=0, rtol=0) |
| 119 | + print("normalized reference vs normalized triton are bitwise equivalent") |
| 120 | + # return |
| 121 | + else: |
| 122 | + print("accuracy checking skipped") |
| 123 | + |
| 124 | + # now, measure performance |
| 125 | + |
| 126 | + for _ in range(2): |
| 127 | + __ = to_mxfp8_across_dim0_and_dim1_reference_c(x, BLOCK_SIZE) |
| 128 | + time_ref_dim0_dim1_compile_us = benchmark_cuda_function_in_microseconds( |
| 129 | + lambda x, b: to_mxfp8_across_dim0_and_dim1_reference_c(x, b), x, BLOCK_SIZE |
| 130 | + ) |
| 131 | + |
| 132 | + for _ in range(2): |
| 133 | + __ = to_mxfp8_dim0_reference_c(x, BLOCK_SIZE) |
| 134 | + time_ref_dim0_compile_us = benchmark_cuda_function_in_microseconds( |
| 135 | + lambda x, b: to_mxfp8_dim0_reference_c(x, b), x, BLOCK_SIZE |
| 136 | + ) |
| 137 | + |
| 138 | + for _ in range(2): |
| 139 | + __ = to_mxfp8_dim1_reference_c(x, BLOCK_SIZE) |
| 140 | + time_ref_dim1_compile_us = benchmark_cuda_function_in_microseconds( |
| 141 | + lambda x, b: to_mxfp8_dim1_reference_c(x, b), x, BLOCK_SIZE |
| 142 | + ) |
| 143 | + |
| 144 | + for _ in range(2): |
| 145 | + __ = to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE) |
| 146 | + time_triton_dim1_us = benchmark_cuda_function_in_microseconds( |
| 147 | + lambda x, b: to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE), |
| 148 | + x, |
| 149 | + BLOCK_SIZE, |
| 150 | + ) |
| 151 | + |
| 152 | + # warm up |
| 153 | + for _ in range(2): |
| 154 | + __ = to_mxfp8_across_dim0_and_dim1(x, tile_size=BLOCK_SIZE) |
| 155 | + time_triton_dim0_dim1_us = benchmark_cuda_function_in_microseconds( |
| 156 | + lambda x, b: to_mxfp8_across_dim0_and_dim1(x, tile_size=BLOCK_SIZE), |
| 157 | + x, |
| 158 | + BLOCK_SIZE, |
| 159 | + ) |
| 160 | + |
| 161 | + # calculate memory bandwidth |
| 162 | + peak_mem_bw = get_specs()["peak_mem_bw_bytes_sec"] |
| 163 | + |
| 164 | + # dim0 or dim1 kernel |
| 165 | + dim0_bytes_read = x.numel() * bytes_per_el_bf16 |
| 166 | + dim0_bytes_written = (x_d0_t.numel() + scale_e8m0_d0_t.numel()) * bytes_per_el_fp8 |
| 167 | + dim0_bytes_rw = dim0_bytes_read + dim0_bytes_written |
| 168 | + ref_dim0_bps = dim0_bytes_rw / (time_ref_dim0_compile_us / 1e6) |
| 169 | + ref_dim1_bps = dim0_bytes_rw / (time_ref_dim1_compile_us / 1e6) |
| 170 | + triton_dim1_bps = dim0_bytes_rw / (time_triton_dim1_us / 1e6) |
| 171 | + |
| 172 | + # triton dim0_dim1 kernel |
| 173 | + triton_dim0_dim1_bytes_read = x.numel() * bytes_per_el_bf16 |
| 174 | + triton_dim0_dim1_bytes_written = ( |
| 175 | + sum(x.numel() for x in (x_d0_t, x_d1_t, scale_e8m0_d0_t, scale_e8m0_d1_t)) |
| 176 | + * bytes_per_el_fp8 |
| 177 | + ) |
| 178 | + triton_dim0_dim1_bps = ( |
| 179 | + triton_dim0_dim1_bytes_read + triton_dim0_dim1_bytes_written |
| 180 | + ) / (time_triton_dim0_dim1_us / 1e6) |
| 181 | + triton_dim0_dim1_pct_peak_mem = triton_dim0_dim1_bps / peak_mem_bw |
| 182 | + |
| 183 | + print("time_ref_dim0_dim1_compile_us", time_ref_dim0_dim1_compile_us) |
| 184 | + print("time_ref_dim0_compile_us", time_ref_dim0_compile_us) |
| 185 | + print("time_ref_dim1_compile_us", time_ref_dim1_compile_us) |
| 186 | + print("time_triton_dim1_us", time_triton_dim1_us) |
| 187 | + print("time_triton_dim0_dim1_us", time_triton_dim0_dim1_us) |
| 188 | + print("ref_dim0_mem_bw_gbps", ref_dim0_bps / 1e9) |
| 189 | + print("ref_dim1_mem_bw_gbps", ref_dim1_bps / 1e9) |
| 190 | + print("triton_dim1_mem_bw_gbps", triton_dim1_bps / 1e9) |
| 191 | + print("triton_dim0_dim1_mem_bw_gbps", triton_dim0_dim1_bps / 1e9) |
| 192 | + # Note: as of 2025-03-11, inductor code for adding 1.0 to a large bf16 tensor |
| 193 | + # can achieve around 50-70% of B200 peak mem bw |
| 194 | + print("triton_dim0_dim1_pct_peak_mem", triton_dim0_dim1_pct_peak_mem) |
| 195 | + print("dim0_dim1 speedup", time_ref_dim0_dim1_compile_us / time_triton_dim0_dim1_us) |
| 196 | + |
| 197 | + |
| 198 | +if __name__ == "__main__": |
| 199 | + fire.Fire(run) |
0 commit comments