Open
Description
Describe the bug
The following kernel has different output with num_stages=1
vs num_stages=2
.
import torch
import triton
import triton.language as tl
import pytest
import time
import math
import logging
@triton.jit
def my_kernel(a_ptr, b_ptr,
out_a_ptr,
out_ab_ptr,
N,
D: tl.constexpr,
BLOCK_SIZE_N_INNER: tl.constexpr = 64,
):
out_a = tl.zeros((D,), dtype=tl.float32)
out_ab = tl.zeros((D, D), dtype=tl.float32)
n_iters = tl.cdiv(N, BLOCK_SIZE_N_INNER)
for k in range(0, n_iters):
rows = k * BLOCK_SIZE_N_INNER + tl.arange(0, BLOCK_SIZE_N_INNER)
mask_rows = rows < N
cols = tl.arange(0, D)
offs_a = rows[:, None] * D + cols[None, :]
offs_b = rows[:, None] * D + cols[None, :]
ptrs_a = a_ptr + offs_a
ptrs_b = b_ptr + offs_b
mask = mask_rows[:, None]
block_a = tl.load(ptrs_a, mask=mask)
block_b = tl.load(ptrs_b, mask=mask)
block_at = tl.trans(block_a)
out_a += tl.sum(block_a.to(tl.float32), axis=0)
out_ab += tl.dot(block_at, block_b)
out_a_ptrs = out_a_ptr + tl.arange(0, D)
tl.store(out_a_ptrs, out_a)
out_ab_ptrs = out_ab_ptr + tl.arange(0, D)[:, None] * D + tl.arange(0, D)[None, :]
tl.store(out_ab_ptrs, out_ab)
@pytest.mark.parametrize("N", [5000])
@pytest.mark.parametrize("D", [128])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("num_stages", [1, 2])
def test(N, D, dtype, num_stages):
torch.manual_seed(0)
elements = list(range(N))
a = torch.rand(N, D, dtype=dtype, device='cuda')
b = torch.rand(N, D, dtype=dtype, device='cuda')
out_a = torch.zeros(D, dtype=torch.float32, device='cuda')
out_ab = torch.zeros(D, D, dtype=torch.float32, device='cuda')
compiled = my_kernel[(1,)](
a, b,
out_a, out_ab,
N,
D=D, num_stages=num_stages
)
a_tensor = a.to(dtype=torch.float32)
b_tensor = b.to(dtype=torch.float32)
ref = torch.sum(a_tensor, dim=0)
max_err = (out_a - ref).abs().max().item()
allowed_err = 1e-5 * N
assert max_err < allowed_err, f"Max A error too high: {max_err}"
ref = a_tensor.T @ b_tensor
max_err = (out_ab - ref).abs().max().item()
allowed_err = 1e-5 * N
assert max_err < allowed_err, f"Max AB error too high: {max_err}"
This kernel passes the test with num_stages=1
, but reliably fails the test with num_stages=2
. Interestingly, the kernel passes the test in both configurations if the following lines are swapped.
out_a += tl.sum(block_a.to(tl.float32), axis=0)
out_ab += tl.dot(block_at, block_b)
Environment details
Triton: 3.3.0
GPU: GH200