-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Open
Labels
Description
Describe the issue
Hi, I'm here again. This time, I bring a very counter intuitive problem.
That is, why putting if-else condition inside a loop is way more faster (~10x) than putting is outside a loop?
Below are my two versions of code:
if-else outside loop
@triton.jit
def _masked_ffn_infer(
a_ptr,
w1_ptr,
w3_ptr,
u1_ptr,
u3_ptr,
c_ptr,
m_ptr,
M,
stride_at, stride_am,
stride_wt, stride_wn,
stride_cm, stride_cn,
T: tl.constexpr,
N: tl.constexpr,
TILE_M: tl.constexpr,
TILE_N: tl.constexpr,
TILE_K: tl.constexpr,
GROUP_SIZE: tl.constexpr
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, TILE_M)
num_pid_n = tl.cdiv(N, TILE_N)
num_pid_in_group = GROUP_SIZE * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
m_range = (pid_m * TILE_M + tl.arange(0, TILE_M)) % M
n_range = (pid_n * TILE_N + tl.arange(0, TILE_N)) % N
k_range = tl.arange(0, TILE_K)
m_data = tl.load(m_ptr + m_range).to(tl.int1)
a_offs = m_range[:, None] * stride_am + k_range[None, :]
w_offs = n_range[:, None] * stride_wn + k_range[None, :]
c_offs = m_range[:, None] * stride_cm + n_range[None, :]
acc1 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
acc3 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
if tl.max(m_data) == 0:
for t in range(0, T):
a = tl.load(a_ptr + a_offs)
u1 = tl.load(u1_ptr + w_offs)
u3 = tl.load(u3_ptr + w_offs)
acc1 = tl.dot(a, u1.T, acc1)
acc3 = tl.dot(a, u3.T, acc3)
a_ptr += stride_at
u1_ptr += stride_wt
u3_ptr += stride_wt
elif tl.min(m_data) == 1:
for t in range(0, T):
a = tl.load(a_ptr + a_offs)
w1 = tl.load(w1_ptr + w_offs)
w3 = tl.load(w3_ptr + w_offs)
acc1 = tl.dot(a, w1.T, acc1)
acc3 = tl.dot(a, w3.T, acc3)
a_ptr += stride_at
w1_ptr += stride_wt
w3_ptr += stride_wt
else:
for t in range(0, T):
a = tl.load(a_ptr + a_offs)
w1 = tl.load(w1_ptr + w_offs)
w3 = tl.load(w3_ptr + w_offs)
u1 = tl.load(u1_ptr + w_offs)
u3 = tl.load(u3_ptr + w_offs)
acc1 += tl.where(
m_data[:, None],
tl.dot(a, w1.T),
tl.dot(a, u1.T))
acc3 += tl.where(
m_data[:, None],
tl.dot(a, w3.T),
tl.dot(a, u3.T))
a_ptr += stride_at
w1_ptr += stride_wt
w3_ptr += stride_wt
u1_ptr += stride_wt
u3_ptr += stride_wt
acc1 *= tl.sigmoid(acc1)
acc1 *= acc3
tl.store(
c_ptr + c_offs,
acc1.to(tl.bfloat16),
m_range[:, None] < M)
if-else inside loop
@triton.jit
def _masked_ffn_infer(
a_ptr,
w1_ptr,
w3_ptr,
u1_ptr,
u3_ptr,
c_ptr,
m_ptr,
M,
stride_at, stride_am,
stride_wt, stride_wn,
stride_cm, stride_cn,
T: tl.constexpr,
N: tl.constexpr,
TILE_M: tl.constexpr,
TILE_N: tl.constexpr,
TILE_K: tl.constexpr,
GROUP_SIZE: tl.constexpr
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, TILE_M)
num_pid_n = tl.cdiv(N, TILE_N)
num_pid_in_group = GROUP_SIZE * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
m_range = (pid_m * TILE_M + tl.arange(0, TILE_M)) % M
n_range = (pid_n * TILE_N + tl.arange(0, TILE_N)) % N
k_range = tl.arange(0, TILE_K)
m_data = tl.load(m_ptr + m_range).to(tl.int1)
a_offs = m_range[:, None] * stride_am + k_range[None, :]
w_offs = n_range[:, None] * stride_wn + k_range[None, :]
c_offs = m_range[:, None] * stride_cm + n_range[None, :]
acc1 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
acc3 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for t in range(0, T):
a = tl.load(a_ptr + a_offs)
if tl.max(m_data) == 0:
u1 = tl.load(u1_ptr + w_offs)
u3 = tl.load(u3_ptr + w_offs)
acc1 = tl.dot(a, u1.T, acc1)
acc3 = tl.dot(a, u3.T, acc3)
elif tl.min(m_data) == 1:
a = tl.load(a_ptr + a_offs)
w1 = tl.load(w1_ptr + w_offs)
w3 = tl.load(w3_ptr + w_offs)
acc1 = tl.dot(a, w1.T, acc1)
acc3 = tl.dot(a, w3.T, acc3)
else:
w1 = tl.load(w1_ptr + w_offs)
w3 = tl.load(w3_ptr + w_offs)
u1 = tl.load(u1_ptr + w_offs)
u3 = tl.load(u3_ptr + w_offs)
acc1 += tl.where(
m_data[:, None],
tl.dot(a, w1.T),
tl.dot(a, u1.T))
acc3 += tl.where(
m_data[:, None],
tl.dot(a, w3.T),
tl.dot(a, u3.T))
a_ptr += stride_at
w1_ptr += stride_wt
w3_ptr += stride_wt
u1_ptr += stride_wt
u3_ptr += stride_wt
acc1 *= tl.sigmoid(acc1)
acc1 *= acc3
tl.store(
c_ptr + c_offs,
acc1.to(tl.bfloat16),
m_range[:, None] < M)
Considering the redundant computation in the inside version, it should be slower if we use any other programming language, but this is not how Triton works.
Do you have some ideas? Appreciate it!
Environment details
Triton: 3.1.0
GPU: A100