Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 splitgemm user defined triton kernel #263

Merged
merged 8 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions test/dtypes/test_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import unittest
import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
try:
from torchao.prototype.fp8 import gemm_split_k
triton_available = True
except ImportError:
triton_available = False

@unittest.skipIf(not triton_available, "Triton is required but not available")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
class TestFP8Gemm(TestCase):
# @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_gemm_split_k(self):
m, n, k = 256, 256, 512

a = torch.randn((m, k), dtype=torch.float16, device="cuda")
b = torch.randn((k, n), dtype=torch.float16, device="cuda")
c = gemm_split_k(a, b)
c_expected = torch.matmul(a, b)
assert torch.allclose(c, c_expected, atol=0.07) # less than this and the accuracy check fails

# https://pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "User defined triton functions are only supported in PyTorch 2.3 and above")
def test_user_defined_triton_function(self):
m, n, k = 256, 256, 512

a = torch.randn((m, k), dtype=torch.float16, device="cuda")
b = torch.randn((k, n), dtype=torch.float16, device="cuda")
compiled_function = torch.compile(gemm_split_k, fullgraph=True)(a,b)



if __name__ == "__main__":
run_tests()
1 change: 1 addition & 0 deletions torchao/prototype/fp8/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .splitk_gemm import gemm_split_k
119 changes: 119 additions & 0 deletions torchao/prototype/fp8/splitk_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Code from https://github.com/pytorch-labs/applied-ai/blob/main/kernels/triton/inference/fp8/splitk_gemm_fp8.py
import torch
import triton
import triton.language as tl

@triton.jit
def grouped_launch(pid,
m, n,
block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):

grid_m = tl.cdiv(m, block_m)
grid_n = tl.cdiv(n, block_n)

width = group_m * grid_n
group_id = pid // width
group_size = tl.minimum(grid_m - group_id * group_m, group_m)

pid_m = group_id * group_m + (pid % group_size)
pid_n = (pid % width) // group_size

return pid_m, pid_n


@triton.jit()
def col_major(pid,
m, n,
block_m: tl.constexpr, block_n: tl.constexpr):

grid_m = tl.cdiv(m, block_m)

pid_m = pid % grid_m
pid_n = pid // grid_m

return pid_m, pid_n


@triton.jit
def gemm_split_k_kernel(a_ptr, b_ptr, c_ptr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
m, n, k,
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,
split_k: tl.constexpr, group_m: tl.constexpr):

pid = tl.program_id(0)
pid_k = tl.program_id(1)
grid_k = tl.cdiv(k, block_k*split_k)

pid_m, pid_n = grouped_launch(pid,
m, n,
block_m, block_n, group_m)

offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)
offs_k = pid_k*block_k + tl.arange(0, block_k)

offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n)

a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

acc = tl.zeros((block_m, block_n), dtype=tl.float32)
for k_ in range(0, grid_k):

k_remaining = k - k_ * (block_k * split_k)

a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)

acc = tl.dot(a, b, acc, out_dtype=tl.float32)

a_ptrs += block_k * split_k * stride_ak
b_ptrs += block_k * split_k * stride_bk

acc = acc.to(tl.float16)

offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)

c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
mask = (offs_m < m)[:, None] & (offs_n < n)[None, :]

tl.atomic_add(c_ptrs, acc, mask=mask)

def gemm_split_k(a, b):

m, k = a.shape
_, n = b.shape

# Need to change these otherwise was getting
# triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 393216, Hardware limit: 232448. Reducing block sizes or `num_stages` may help.
# TODO: Should we tune this differently for different hardware?
block_m = 32
block_n = 32
block_k = 256
num_stages = 2
num_warps = 4
split_k = 4
group_m = 8

total_blocks_m = triton.cdiv(m, block_m)
total_blocks_n = triton.cdiv(n, block_n)
total_programs_mn = total_blocks_m * total_blocks_n
total_programs_k = split_k

grid = (total_programs_mn, total_programs_k)

c = torch.zeros((m, n), device=a.device, dtype=torch.float16)
gemm_split_k_kernel[grid](a, b, c,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
m, n, k,
block_m, block_n, block_k,
split_k, group_m, num_stages=num_stages, num_warps=num_warps)

return c
1 change: 1 addition & 0 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def get_model_size_in_bytes(model):
s += b.nelement() * b.element_size()
return s

# TODO: quantization namespace is not the right place ot have this
if version.parse(torch.__version__) >= version.parse("2.4.0.dev"):
TORCH_VERSION_AFTER_2_4 = True
else:
Expand Down
Loading