Skip to content

The position of if-else significantly affect performance, which is unexpected. #6491

@wenhaoli-xmu

Description

@wenhaoli-xmu

Describe the issue

Hi, I'm here again. This time, I bring a very counter intuitive problem.

That is, why putting if-else condition inside a loop is way more faster (~10x) than putting is outside a loop?

Below are my two versions of code:

if-else outside loop
@triton.jit
def _masked_ffn_infer(
        a_ptr, 
        w1_ptr,
        w3_ptr, 
        u1_ptr, 
        u3_ptr,
        c_ptr, 
        m_ptr,
        M,
        stride_at, stride_am,
        stride_wt, stride_wn,
        stride_cm, stride_cn,
        T: tl.constexpr,
        N: tl.constexpr,
        TILE_M: tl.constexpr, 
        TILE_N: tl.constexpr, 
        TILE_K: tl.constexpr,
        GROUP_SIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, TILE_M)
    num_pid_n = tl.cdiv(N, TILE_N)
    num_pid_in_group = GROUP_SIZE * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    m_range = (pid_m * TILE_M + tl.arange(0, TILE_M)) % M
    n_range = (pid_n * TILE_N + tl.arange(0, TILE_N)) % N
    k_range = tl.arange(0, TILE_K)

    m_data = tl.load(m_ptr + m_range).to(tl.int1)

    a_offs = m_range[:, None] * stride_am + k_range[None, :]
    w_offs = n_range[:, None] * stride_wn + k_range[None, :]
    c_offs = m_range[:, None] * stride_cm + n_range[None, :]

    acc1 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
    acc3 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)

    if tl.max(m_data) == 0:
        for t in range(0, T):
            a = tl.load(a_ptr + a_offs)
            u1 = tl.load(u1_ptr + w_offs)
            u3 = tl.load(u3_ptr + w_offs)
            acc1 = tl.dot(a, u1.T, acc1)
            acc3 = tl.dot(a, u3.T, acc3)
            a_ptr += stride_at
            u1_ptr += stride_wt
            u3_ptr += stride_wt
    
    elif tl.min(m_data) == 1:
        for t in range(0, T):
            a = tl.load(a_ptr + a_offs)
            w1 = tl.load(w1_ptr + w_offs)
            w3 = tl.load(w3_ptr + w_offs)
            acc1 = tl.dot(a, w1.T, acc1)
            acc3 = tl.dot(a, w3.T, acc3)
            a_ptr += stride_at
            w1_ptr += stride_wt
            w3_ptr += stride_wt
    
    else:
        for t in range(0, T):
            a = tl.load(a_ptr + a_offs)
            w1 = tl.load(w1_ptr + w_offs)
            w3 = tl.load(w3_ptr + w_offs)
            u1 = tl.load(u1_ptr + w_offs)
            u3 = tl.load(u3_ptr + w_offs)
            acc1 += tl.where(
                m_data[:, None],
                tl.dot(a, w1.T),
                tl.dot(a, u1.T))
            acc3 += tl.where(
                m_data[:, None],
                tl.dot(a, w3.T),
                tl.dot(a, u3.T))
            a_ptr += stride_at
            w1_ptr += stride_wt
            w3_ptr += stride_wt
            u1_ptr += stride_wt
            u3_ptr += stride_wt

    acc1 *= tl.sigmoid(acc1)
    acc1 *= acc3
    tl.store(
        c_ptr + c_offs, 
        acc1.to(tl.bfloat16), 
        m_range[:, None] < M)
if-else inside loop
@triton.jit
def _masked_ffn_infer(
        a_ptr, 
        w1_ptr,
        w3_ptr, 
        u1_ptr, 
        u3_ptr,
        c_ptr, 
        m_ptr,
        M,
        stride_at, stride_am,
        stride_wt, stride_wn,
        stride_cm, stride_cn,
        T: tl.constexpr,
        N: tl.constexpr,
        TILE_M: tl.constexpr, 
        TILE_N: tl.constexpr, 
        TILE_K: tl.constexpr,
        GROUP_SIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, TILE_M)
    num_pid_n = tl.cdiv(N, TILE_N)
    num_pid_in_group = GROUP_SIZE * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    m_range = (pid_m * TILE_M + tl.arange(0, TILE_M)) % M
    n_range = (pid_n * TILE_N + tl.arange(0, TILE_N)) % N
    k_range = tl.arange(0, TILE_K)

    m_data = tl.load(m_ptr + m_range).to(tl.int1)

    a_offs = m_range[:, None] * stride_am + k_range[None, :]
    w_offs = n_range[:, None] * stride_wn + k_range[None, :]
    c_offs = m_range[:, None] * stride_cm + n_range[None, :]

    acc1 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
    acc3 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)

    
    for t in range(0, T):

        a = tl.load(a_ptr + a_offs)

        if tl.max(m_data) == 0:
            
            u1 = tl.load(u1_ptr + w_offs)
            u3 = tl.load(u3_ptr + w_offs)
            acc1 = tl.dot(a, u1.T, acc1)
            acc3 = tl.dot(a, u3.T, acc3)
    
        elif tl.min(m_data) == 1:
            a = tl.load(a_ptr + a_offs)
            w1 = tl.load(w1_ptr + w_offs)
            w3 = tl.load(w3_ptr + w_offs)
            acc1 = tl.dot(a, w1.T, acc1)
            acc3 = tl.dot(a, w3.T, acc3)
    
        else:
            w1 = tl.load(w1_ptr + w_offs)
            w3 = tl.load(w3_ptr + w_offs)
            u1 = tl.load(u1_ptr + w_offs)
            u3 = tl.load(u3_ptr + w_offs)
            acc1 += tl.where(
                m_data[:, None],
                tl.dot(a, w1.T),
                tl.dot(a, u1.T))
            acc3 += tl.where(
                m_data[:, None],
                tl.dot(a, w3.T),
                tl.dot(a, u3.T))

        a_ptr += stride_at
        w1_ptr += stride_wt
        w3_ptr += stride_wt
        u1_ptr += stride_wt
        u3_ptr += stride_wt

    acc1 *= tl.sigmoid(acc1)
    acc1 *= acc3
    tl.store(
        c_ptr + c_offs, 
        acc1.to(tl.bfloat16), 
        m_range[:, None] < M)

Considering the redundant computation in the inside version, it should be slower if we use any other programming language, but this is not how Triton works.

Do you have some ideas? Appreciate it! ☺️☺️

Environment details

Triton: 3.1.0
GPU: A100

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions