Skip to content

triton kernel to cast to mx across dim0 and dim1 #1869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion benchmarks/mx_formats/cast_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import triton
from torch._inductor.utils import do_bench_using_profiling

from torchao.prototype.mx_formats.custom_cast import (
to_mxfp8_dim1,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx

torch.manual_seed(0)
Expand Down Expand Up @@ -49,6 +52,12 @@ def to_mx_dim0_reference(x_hp, block_size):
return data_d0, scale_d0


def to_mx_dim1_reference(x_hp, block_size):
x_hp = x_hp.t().contiguous()
scale_d1, data_d1 = to_mx(x_hp, torch.float8_e4m3fn, block_size)
return data_d1.t(), scale_d1


def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
"""Thin wrapper around do_bench_using_profiling"""
no_args = lambda: func(*args, **kwargs)
Expand All @@ -67,7 +76,7 @@ def run(
print(f"torch version: {torch.__version__}")
print(f"triton version: {triton.__version__}")
print(f"mode: {mode}")
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx")
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton")

x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000

Expand Down Expand Up @@ -144,6 +153,41 @@ def run(
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim1_mx":
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)

for _ in range(2):
__ = to_mx_dim1_reference_c(x, BLOCK_SIZE)
time_us = benchmark_cuda_function_in_microseconds(
lambda x, b: to_mx_dim1_reference_c(x, BLOCK_SIZE),
x,
BLOCK_SIZE,
)

assert y_d1.dtype == torch.float8_e4m3fn
assert s_d1.dtype == torch.uint8
bytes_r = x.numel() * bytes_per_el_bf16
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim1_mx_triton":
y_d1, s_d1 = to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE)

for _ in range(2):
__ = to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE)
time_us = benchmark_cuda_function_in_microseconds(
lambda x, b: to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE),
x,
BLOCK_SIZE,
)

assert y_d1.dtype == torch.float8_e4m3fn
assert s_d1.dtype == torch.float8_e8m0fnu
bytes_r = x.numel() * bytes_per_el_bf16
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

else:
raise AssertionError(f"unknown mode {mode}")

Expand Down
206 changes: 206 additions & 0 deletions benchmarks/mx_formats/mx_dim0_dim1_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable

import fire
import pandas as pd
import torch
import triton
from torch._inductor.utils import do_bench_using_profiling

from torchao.prototype.mx_formats.custom_cast import (
to_mxfp8_across_dim0_and_dim1,
to_mxfp8_across_dim0_and_dim1_reference,
to_mxfp8_dim0_reference,
to_mxfp8_dim1,
to_mxfp8_dim1_reference,
)
from torchao.quantization.utils import compute_error
from torchao.testing.float8.roofline_utils import get_specs

torch.manual_seed(0)

bytes_per_el_bf16 = 2
bytes_per_el_fp8 = 1


def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
"""Thin wrapper around do_bench_using_profiling"""
no_args = lambda: func(*args, **kwargs)
time = do_bench_using_profiling(no_args)
return time * 1e3


def run(
M: int = 4096,
K: int = 2048,
BLOCK_SIZE: int = 32,
check_accuracy: bool = True,
):
print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"torch version: {torch.__version__}")
print(f"triton version: {triton.__version__}")

x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000

to_mxfp8_across_dim0_and_dim1_reference_c = torch.compile(
to_mxfp8_across_dim0_and_dim1_reference
)
to_mxfp8_dim0_reference_c = torch.compile(to_mxfp8_dim0_reference)
to_mxfp8_dim1_reference_c = torch.compile(to_mxfp8_dim1_reference)

# reference implementation (plain PyTorch + torch.compile)
x_d0, x_d1, scale_e8m0_d0, scale_e8m0_d1 = (
to_mxfp8_across_dim0_and_dim1_reference_c(x, BLOCK_SIZE)
)

# verify reference dim0_dim1 matches dim0 and dim1 separately
x_d0_separate, scale_e8m0_d0_separate = to_mxfp8_dim0_reference_c(x, BLOCK_SIZE)
x_d1_separate, scale_e8m0_d1_separate = to_mxfp8_dim1_reference_c(x, BLOCK_SIZE)
torch.testing.assert_close(x_d0, x_d0_separate, atol=0, rtol=0)
torch.testing.assert_close(x_d1, x_d1_separate, atol=0, rtol=0)
torch.testing.assert_close(scale_e8m0_d0, scale_e8m0_d0_separate, atol=0, rtol=0)
torch.testing.assert_close(scale_e8m0_d1, scale_e8m0_d1_separate, atol=0, rtol=0)

x_d0, x_d1 = x_d0.bfloat16(), x_d1.bfloat16()
scale_fp_d0 = scale_e8m0_d0.float()
scale_fp_d1 = scale_e8m0_d1.float()
x_d0_and_back = (x_d0.reshape(-1, BLOCK_SIZE) * scale_fp_d0).reshape(x_d0.shape)
x_d1_and_back = (
(x_d1.t().reshape(-1, BLOCK_SIZE) * scale_fp_d1).reshape(x_d1.t().shape).t()
)

sqnr_bf16_vs_dim0_ref = compute_error(x, x_d0_and_back)
sqnr_bf16_vs_dim1_ref = compute_error(x, x_d1_and_back)
print(
f"bf16 vs normalized reference sqnrs: dim0 {sqnr_bf16_vs_dim0_ref}, dim1 {sqnr_bf16_vs_dim1_ref}"
)
assert (
sqnr_bf16_vs_dim0_ref > 28 and sqnr_bf16_vs_dim1_ref > 28
), "reference mx numerics are incorrect"

# triton kernel for dim1 only
x_d1_only_t, scale_e8m0_d1_only_t = to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)

# triton kernel for dim0 and dim1
x_d0_t, x_d1_t, scale_e8m0_d0_t, scale_e8m0_d1_t = to_mxfp8_across_dim0_and_dim1(
x, tile_size=BLOCK_SIZE
)
x_d0_t, x_d1_t, x_d1_only_t = (
x_d0_t.bfloat16(),
x_d1_t.bfloat16(),
x_d1_only_t.bfloat16(),
)

# ensure bitwise equivalency of outputs with reference
if check_accuracy:
torch.testing.assert_close(x_d0, x_d0_t, atol=0, rtol=0)
torch.testing.assert_close(x_d1, x_d1_t, atol=0, rtol=0)
torch.testing.assert_close(scale_e8m0_d0, scale_e8m0_d0_t, atol=0, rtol=0)
torch.testing.assert_close(scale_e8m0_d1, scale_e8m0_d1_t, atol=0, rtol=0)
print("reference vs triton dim0_dim1 are bitwise equivalent")
torch.testing.assert_close(x_d1, x_d1_only_t, atol=0, rtol=0)
torch.testing.assert_close(scale_e8m0_d1, scale_e8m0_d1_only_t, atol=0, rtol=0)
print("reference vs triton dim1 are bitwise equivalent")
else:
print("accuracy checking skipped")

# now, measure performance

# define a speed-of-light torch.compile kernel to get a sense of
# achievable mem bandwidth
def add_one(x):
x = x + 1
return x

add_one_c = torch.compile(add_one)

for _ in range(2):
__ = add_one_c(x)
time_add_one_compile_us = benchmark_cuda_function_in_microseconds(
lambda x: add_one(x), x
)

for _ in range(2):
__ = to_mxfp8_across_dim0_and_dim1_reference_c(x, BLOCK_SIZE)
time_ref_dim0_dim1_compile_us = benchmark_cuda_function_in_microseconds(
lambda x, b: to_mxfp8_across_dim0_and_dim1_reference_c(x, b), x, BLOCK_SIZE
)

for _ in range(2):
__ = to_mxfp8_dim0_reference_c(x, BLOCK_SIZE)
time_ref_dim0_compile_us = benchmark_cuda_function_in_microseconds(
lambda x, b: to_mxfp8_dim0_reference_c(x, b), x, BLOCK_SIZE
)

for _ in range(2):
__ = to_mxfp8_dim1_reference_c(x, BLOCK_SIZE)
time_ref_dim1_compile_us = benchmark_cuda_function_in_microseconds(
lambda x, b: to_mxfp8_dim1_reference_c(x, b), x, BLOCK_SIZE
)

for _ in range(2):
__ = to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
time_triton_dim1_us = benchmark_cuda_function_in_microseconds(
lambda x, b: to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE),
x,
BLOCK_SIZE,
)

# warm up
for _ in range(2):
__ = to_mxfp8_across_dim0_and_dim1(x, tile_size=BLOCK_SIZE)
time_triton_dim0_dim1_us = benchmark_cuda_function_in_microseconds(
lambda x, b: to_mxfp8_across_dim0_and_dim1(x, tile_size=BLOCK_SIZE),
x,
BLOCK_SIZE,
)

# calculate memory bandwidth
peak_mem_bw = get_specs()["peak_mem_bw_bytes_sec"]

# add_one kernel
add_one_bps = x.numel() * bytes_per_el_bf16 * 2 / (time_add_one_compile_us / 1e6)

# dim0 or dim1 kernel
dim0_bytes_read = x.numel() * bytes_per_el_bf16
dim0_bytes_written = (x_d0_t.numel() + scale_e8m0_d0_t.numel()) * bytes_per_el_fp8
dim0_bytes_rw = dim0_bytes_read + dim0_bytes_written
ref_dim0_bps = dim0_bytes_rw / (time_ref_dim0_compile_us / 1e6)
ref_dim1_bps = dim0_bytes_rw / (time_ref_dim1_compile_us / 1e6)
triton_dim1_bps = dim0_bytes_rw / (time_triton_dim1_us / 1e6)

# triton dim0_dim1 kernel
triton_dim0_dim1_bytes_read = x.numel() * bytes_per_el_bf16
triton_dim0_dim1_bytes_written = (
sum(x.numel() for x in (x_d0_t, x_d1_t, scale_e8m0_d0_t, scale_e8m0_d1_t))
* bytes_per_el_fp8
)
triton_dim0_dim1_bps = (
triton_dim0_dim1_bytes_read + triton_dim0_dim1_bytes_written
) / (time_triton_dim0_dim1_us / 1e6)

results = [
["add_one", time_add_one_compile_us, add_one_bps / 1e9],
["compile_dim0", time_ref_dim0_compile_us, ref_dim0_bps / 1e9],
["compile_dim1", time_ref_dim1_compile_us, ref_dim1_bps / 1e9],
[
"compile_dim0_dim1",
time_ref_dim0_dim1_compile_us,
triton_dim0_dim1_bps / 1e9,
],
["triton_dim1", time_triton_dim1_us, triton_dim1_bps / 1e9],
["triton_dim0_dim1", time_triton_dim0_dim1_us, triton_dim0_dim1_bps / 1e9],
]
df = pd.DataFrame(results, columns=["experiment", "time_us", "mem_bw_gbps"])
df["mem_bw_pct_peak"] = df["mem_bw_gbps"] * 1e9 / peak_mem_bw
print("\n", df)


if __name__ == "__main__":
fire.Fire(run)
Loading
Loading