forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FP8 splitgemm user defined triton kernel (pytorch#263)
* FP8 splitgemm user defined triton kernel * yolo * Trigger CI * yolo * yolo * yolo * yolo * Update test_fp8.py
- Loading branch information
Showing
4 changed files
with
166 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import os | ||
import unittest | ||
import torch | ||
from torch.testing._internal.common_utils import ( | ||
TestCase, | ||
instantiate_parametrized_tests, | ||
parametrize, | ||
run_tests, | ||
) | ||
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 | ||
try: | ||
from torchao.prototype.fp8 import gemm_split_k | ||
triton_available = True | ||
except ImportError: | ||
triton_available = False | ||
|
||
@unittest.skipIf(not triton_available, "Triton is required but not available") | ||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") | ||
class TestFP8Gemm(TestCase): | ||
# @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) | ||
def test_gemm_split_k(self): | ||
m, n, k = 256, 256, 512 | ||
|
||
a = torch.randn((m, k), dtype=torch.float16, device="cuda") | ||
b = torch.randn((k, n), dtype=torch.float16, device="cuda") | ||
c = gemm_split_k(a, b) | ||
c_expected = torch.matmul(a, b) | ||
assert torch.allclose(c, c_expected, atol=0.07) # less than this and the accuracy check fails | ||
|
||
# https://pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html | ||
@unittest.skip("fp8 kernel compilation does not work on a10g") | ||
def test_user_defined_triton_function(self): | ||
import torch._inductor.config | ||
torch._inductor.config.force_disable_caches = True | ||
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1' | ||
m, n, k = 256, 256, 512 | ||
|
||
a = torch.randn((m, k), dtype=torch.float16, device="cuda") | ||
b = torch.randn((k, n), dtype=torch.float16, device="cuda") | ||
compiled_function = torch.compile(gemm_split_k, fullgraph=True)(a,b) | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .splitk_gemm import gemm_split_k |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Code from https://github.com/pytorch-labs/applied-ai/blob/main/kernels/triton/inference/fp8/splitk_gemm_fp8.py | ||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
@triton.jit | ||
def grouped_launch(pid, | ||
m, n, | ||
block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr): | ||
|
||
grid_m = tl.cdiv(m, block_m) | ||
grid_n = tl.cdiv(n, block_n) | ||
|
||
width = group_m * grid_n | ||
group_id = pid // width | ||
group_size = tl.minimum(grid_m - group_id * group_m, group_m) | ||
|
||
pid_m = group_id * group_m + (pid % group_size) | ||
pid_n = (pid % width) // group_size | ||
|
||
return pid_m, pid_n | ||
|
||
|
||
@triton.jit() | ||
def col_major(pid, | ||
m, n, | ||
block_m: tl.constexpr, block_n: tl.constexpr): | ||
|
||
grid_m = tl.cdiv(m, block_m) | ||
|
||
pid_m = pid % grid_m | ||
pid_n = pid // grid_m | ||
|
||
return pid_m, pid_n | ||
|
||
|
||
@triton.jit | ||
def gemm_split_k_kernel(a_ptr, b_ptr, c_ptr, | ||
stride_am, stride_ak, | ||
stride_bk, stride_bn, | ||
stride_cm, stride_cn, | ||
m, n, k, | ||
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr, | ||
split_k: tl.constexpr, group_m: tl.constexpr): | ||
|
||
pid = tl.program_id(0) | ||
pid_k = tl.program_id(1) | ||
grid_k = tl.cdiv(k, block_k*split_k) | ||
|
||
pid_m, pid_n = grouped_launch(pid, | ||
m, n, | ||
block_m, block_n, group_m) | ||
|
||
offs_m = pid_m*block_m + tl.arange(0, block_m) | ||
offs_n = pid_n*block_n + tl.arange(0, block_n) | ||
offs_k = pid_k*block_k + tl.arange(0, block_k) | ||
|
||
offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m) | ||
offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n) | ||
|
||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) | ||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) | ||
|
||
acc = tl.zeros((block_m, block_n), dtype=tl.float32) | ||
for k_ in range(0, grid_k): | ||
|
||
k_remaining = k - k_ * (block_k * split_k) | ||
|
||
a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0) | ||
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) | ||
|
||
acc = tl.dot(a, b, acc, out_dtype=tl.float32) | ||
|
||
a_ptrs += block_k * split_k * stride_ak | ||
b_ptrs += block_k * split_k * stride_bk | ||
|
||
acc = acc.to(tl.float16) | ||
|
||
offs_m = pid_m*block_m + tl.arange(0, block_m) | ||
offs_n = pid_n*block_n + tl.arange(0, block_n) | ||
|
||
c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn) | ||
mask = (offs_m < m)[:, None] & (offs_n < n)[None, :] | ||
|
||
tl.atomic_add(c_ptrs, acc, mask=mask) | ||
|
||
def gemm_split_k(a, b): | ||
|
||
m, k = a.shape | ||
_, n = b.shape | ||
|
||
# Need to change these otherwise was getting | ||
# triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 393216, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. | ||
# TODO: Should we tune this differently for different hardware? | ||
block_m = 32 | ||
block_n = 32 | ||
block_k = 256 | ||
num_stages = 2 | ||
num_warps = 4 | ||
split_k = 4 | ||
group_m = 8 | ||
|
||
total_blocks_m = triton.cdiv(m, block_m) | ||
total_blocks_n = triton.cdiv(n, block_n) | ||
total_programs_mn = total_blocks_m * total_blocks_n | ||
total_programs_k = split_k | ||
|
||
grid = (total_programs_mn, total_programs_k) | ||
|
||
c = torch.zeros((m, n), device=a.device, dtype=torch.float16) | ||
gemm_split_k_kernel[grid](a, b, c, | ||
a.stride(0), a.stride(1), | ||
b.stride(0), b.stride(1), | ||
c.stride(0), c.stride(1), | ||
m, n, k, | ||
block_m, block_n, block_k, | ||
split_k, group_m, num_stages=num_stages, num_warps=num_warps) | ||
|
||
return c |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters