Skip to content

Ordering of lines matters with num_stages > 1. #6691

Open
@nhukc

Description

@nhukc

Describe the bug

The following kernel has different output with num_stages=1 vs num_stages=2.

import torch
import triton
import triton.language as tl
import pytest
import time
import math
import logging

@triton.jit
def my_kernel(a_ptr, b_ptr,
        out_a_ptr,
        out_ab_ptr,
        N,
        D: tl.constexpr,
        BLOCK_SIZE_N_INNER: tl.constexpr = 64,
):
    out_a = tl.zeros((D,), dtype=tl.float32)
    out_ab = tl.zeros((D, D), dtype=tl.float32)

    n_iters = tl.cdiv(N, BLOCK_SIZE_N_INNER)
    for k in range(0, n_iters):
        rows = k * BLOCK_SIZE_N_INNER + tl.arange(0, BLOCK_SIZE_N_INNER)
        mask_rows = rows < N
        cols = tl.arange(0, D)

        offs_a = rows[:, None] * D + cols[None, :]
        offs_b = rows[:, None] * D + cols[None, :]
        ptrs_a = a_ptr + offs_a
        ptrs_b = b_ptr + offs_b
        mask = mask_rows[:, None]

        block_a = tl.load(ptrs_a, mask=mask)
        block_b = tl.load(ptrs_b, mask=mask)
        block_at = tl.trans(block_a)

        out_a += tl.sum(block_a.to(tl.float32), axis=0)
        out_ab += tl.dot(block_at, block_b)

    out_a_ptrs = out_a_ptr + tl.arange(0, D)
    tl.store(out_a_ptrs, out_a)
    
    out_ab_ptrs = out_ab_ptr + tl.arange(0, D)[:, None] * D + tl.arange(0, D)[None, :]
    tl.store(out_ab_ptrs, out_ab)

@pytest.mark.parametrize("N", [5000])
@pytest.mark.parametrize("D", [128])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("num_stages", [1, 2])
def test(N, D, dtype, num_stages):
    torch.manual_seed(0)
    elements = list(range(N))

    a = torch.rand(N, D, dtype=dtype, device='cuda')
    b = torch.rand(N, D, dtype=dtype, device='cuda')

    out_a = torch.zeros(D, dtype=torch.float32, device='cuda')
    out_ab = torch.zeros(D, D, dtype=torch.float32, device='cuda')

    compiled = my_kernel[(1,)](
        a, b,
        out_a, out_ab,
        N,
        D=D, num_stages=num_stages
    )

    a_tensor = a.to(dtype=torch.float32)
    b_tensor = b.to(dtype=torch.float32)

    ref = torch.sum(a_tensor, dim=0)
    max_err = (out_a - ref).abs().max().item()
    allowed_err = 1e-5 * N
    assert max_err < allowed_err, f"Max A error too high: {max_err}"
    ref = a_tensor.T @ b_tensor
    max_err = (out_ab - ref).abs().max().item()
    allowed_err = 1e-5 * N
    assert max_err < allowed_err, f"Max AB error too high: {max_err}"

This kernel passes the test with num_stages=1, but reliably fails the test with num_stages=2. Interestingly, the kernel passes the test in both configurations if the following lines are swapped.

        out_a += tl.sum(block_a.to(tl.float32), axis=0)
        out_ab += tl.dot(block_at, block_b)

Environment details

Triton: 3.3.0
GPU: GH200

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions