Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: bug where some specific shapes make inference on cuda graphs crashing #177

Merged
merged 7 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 59 additions & 55 deletions src/kernl/implementations/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _fwd_kernel(
┌────────────┐
│ │ │
M Dimension│ ├────────────┤ ┌───┐
size_m │ │ │ │ │ BLOCK_M
size_m │ │ │ │ │ BLOCK_M
│ ├────────────┤ └───┘
│ │ │ BLOCK_N
│ │ │
Expand Down Expand Up @@ -180,41 +180,43 @@ def _fwd_kernel(
head_idx = tl.program_id(1)

# offsets
offs_m = m_block_idx * BLOCK_M + tl.arange(0, BLOCK_M) # rows offsets on M axis
offs_n = tl.arange(0, BLOCK_N) # First block on N dimension
offs_d = tl.arange(0, BLOCK_DHEAD) # Full head
range_offs_m = tl.arange(0, BLOCK_M) # first block on M dimension
range_offs_n = tl.arange(0, BLOCK_N) # first block on N dimension
range_offs_d = tl.arange(0, BLOCK_DHEAD) # full head

offs_m = m_block_idx * BLOCK_M + range_offs_m # rows offsets on M axis

current_batch_idx = head_idx // heads
current_head_idx = head_idx % heads
# memory offsets matrices on whole Q, K, V matrices
# Offsets for the current block on matrix Q
off_q = (
offs_q = (
current_batch_idx * q_batch_stride
+ current_head_idx * q_head_stride
+ (offs_m[:, None] * q_m_stride + offs_d[None, :] * q_k_stride)
+ (offs_m[:, None] * q_m_stride + range_offs_d[None, :] * q_k_stride)
)

# Offsets for the first block on matrix K
off_k = (
offs_k = (
current_batch_idx * k_batch_stride
+ current_head_idx * k_head_stride
+ (offs_n[:, None] * k_n_stride + offs_d[None, :] * k_k_stride)
+ (range_offs_n[:, None] * k_n_stride + range_offs_d[None, :] * k_k_stride)
)

# Offsets for the first block on matrix V
off_v = (
offs_v = (
current_batch_idx * v_batch_stride
+ current_head_idx * v_head_stride
+ (offs_n[:, None] * v_k_stride + offs_d[None, :] * v_n_stride)
+ (range_offs_n[:, None] * v_k_stride + range_offs_d[None, :] * v_n_stride)
)

# pointers to blocks in Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
ptrs_q = Q + offs_q
ptrs_k = K + offs_k
ptrs_v = V + offs_v

# Temporary pointer to memory to fix bug in triton compiler
t_ptrs = TMP + head_idx * size_m_rounded + offs_m
ptrs_t = TMP + head_idx * size_m_rounded + offs_m

# initialize pointer to m and d used to compute normalizer for softmax
l_i = tl.zeros((BLOCK_M,), dtype=tl.float32) - float("inf")
Expand All @@ -227,9 +229,9 @@ def _fwd_kernel(
# load q, a block of full rows of matrix q
# it will stay in SRAM throughout
if NEED_LOAD_MASK_SIZE_M:
q = tl.load(q_ptrs, mask=offs_m[:, None] < size_m, other=0.0)
q = tl.load(ptrs_q, mask=offs_m[:, None] < size_m, other=0.0)
else:
q = tl.load(q_ptrs)
q = tl.load(ptrs_q)

n_end = size_n
if IS_CAUSAL:
Expand All @@ -247,52 +249,57 @@ def _fwd_kernel(
offs_base_mask = mask_batch_idx * attention_mask_batch_stride + mask_head_idx * attention_mask_head_stride

# loop over k, v and update accumulator
# n_row_offset is the row offset on dimension N of the current block
# block_start_index_n is the row offset on dimension N of the current block
# It's used for both the N dimension of K and V because they are handled at the same time
for n_row_offset in range(0, n_end, BLOCK_N):
n_row_offset = tl.multiple_of(n_row_offset, BLOCK_N)
for block_start_index_n in range(0, n_end, BLOCK_N):
block_start_index_n = tl.multiple_of(block_start_index_n, BLOCK_N)
offs_n = block_start_index_n + range_offs_n
# We load the current block in K in SRAM
# We do the first multiplication between the block in Q and the current block in K
# We finish with the scaling (sqrt(BLOCK_DHEAD) in Vaswani et al. but sm_scale here)
if NEED_LOAD_MASK_SIZE_N:
load_mask = (n_row_offset + offs_n)[:, None] < size_n
k = tl.load(k_ptrs + n_row_offset * k_n_stride, mask=load_mask, other=0.0)
load_mask = offs_n[:, None] < size_n
k = tl.load(ptrs_k + block_start_index_n * k_n_stride, mask=load_mask, other=0.0)
else:
k = tl.load(k_ptrs + n_row_offset * k_n_stride)
k = tl.load(ptrs_k + block_start_index_n * k_n_stride)
qk = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

# required to fix a Triton compiler bug, if not done, there is a precision issue
if NEED_LOAD_MASK_SIZE_N:
qk = tl.where(offs_n[None, :] < size_n, qk, float("-inf"))
qk = tl.where(range_offs_n[None, :] < size_n, qk, float("-inf"))
qk += tl.dot(q, k, trans_b=True)
qk *= sm_scale
if IS_CAUSAL:
qk += tl.where(offs_m[:, None] >= (n_row_offset + offs_n[None, :]), 0, float("-inf"))
qk += tl.where(offs_m[:, None] >= offs_n[None, :], 0, float("-inf"))

if HAS_MASK:
offs_mask = offs_base_mask + (offs_n[None, :] + n_row_offset) * attention_mask_k_stride
if NEED_LOAD_MASK_SIZE_N:
attention_load_mask = (n_row_offset + offs_n)[None, :] < size_n
# If it's a broadcast we only load vector size BLOCK_N else a matrix size (BLOCK_M, BLOCK_N)
if MASK_M_SIZE == 1:
if NEED_LOAD_MASK_SIZE_N:
m = tl.load(attention_mask + offs_mask, mask=attention_load_mask, other=float("-inf"))
else:
m = tl.load(attention_mask + offs_mask)
else:
# we assume mask has a vector shape
offs_mask = offs_base_mask + offs_n[None, :] * attention_mask_k_stride
if MASK_M_SIZE != 1: # mask has a square shape, we load (BLOCK_M, BLOCK_N) elements
offs_mask += offs_m[:, None] * attention_mask_m_stride
if NEED_LOAD_MASK_SIZE_N:
m = tl.load(
attention_mask + offs_mask,
eviction_policy="evict_first", # The mask matrix is never reused
mask=attention_load_mask,
other=float("-inf"),
)
else:
m = tl.load(
attention_mask + offs_mask,
eviction_policy="evict_first",
)

if NEED_LOAD_MASK_SIZE_N & MASK_M_SIZE == 1: # mask has a vector shape need a load mask
attention_load_mask = offs_n[None, :] < size_n
if MASK_M_SIZE != 1: # mask has a matrix shape
if NEED_LOAD_MASK_SIZE_M & (not NEED_LOAD_MASK_SIZE_N): # load mask on M axis
attention_load_mask = offs_m[:, None] < size_m
elif (not NEED_LOAD_MASK_SIZE_M) & NEED_LOAD_MASK_SIZE_N: # load mask on N axis
attention_load_mask = offs_n[None, :] < size_n
elif NEED_LOAD_MASK_SIZE_M & NEED_LOAD_MASK_SIZE_N: # load mask on both axis
attention_load_mask = (offs_n[None, :] < size_n) & (offs_m[:, None] < size_m)

if NEED_LOAD_MASK_SIZE_M | NEED_LOAD_MASK_SIZE_N:
m = tl.load(
attention_mask + offs_mask,
eviction_policy="evict_first",
mask=attention_load_mask,
other=float("-inf"),
)
else:
m = tl.load(
attention_mask + offs_mask,
eviction_policy="evict_first",
)
# Avoids NaN
m = tl.where(m == float("-inf"), min_clamp_value, m)
qk += m
Expand Down Expand Up @@ -327,32 +334,29 @@ def _fwd_kernel(

# This isn't useful in the algorithm, simply to fix a compiler bug
# BUG: have to store and immediately load
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs)
tl.store(ptrs_t, acc_scale)
acc_scale = tl.load(ptrs_t)

# acc scaling
acc = acc * acc_scale[:, None]

# We now apply the last operation, the multiplication by a block of matrix V
if NEED_LOAD_MASK_SIZE_N:
load_mask = (n_row_offset + offs_n)[:, None] < size_n # repeated otherwise triton segfault
v = tl.load(v_ptrs + n_row_offset * v_k_stride, mask=load_mask, other=0.0)
load_mask = offs_n[:, None] < size_n # repeated otherwise triton segfault
v = tl.load(ptrs_v + block_start_index_n * v_k_stride, mask=load_mask, other=0.0)
else:
v = tl.load(v_ptrs + n_row_offset * v_k_stride)
v = tl.load(ptrs_v + block_start_index_n * v_k_stride)
qk_softmax = qk_softmax.to(Q.dtype.element_ty)
acc += tl.dot(qk_softmax, v)

# We update the normalizer for the next iteration
d_i = d_new
l_i = l_new

# For some reason we need to re-init this variable
# The other variables in the original implementations seem not needed
offs_n = tl.arange(0, BLOCK_DHEAD)
off_o = (
current_batch_idx * o_batch_stride
+ current_head_idx * o_head_stride
+ (offs_m[:, None] * o_m_stride + offs_n[None, :] * o_n_stride)
+ (offs_m[:, None] * o_m_stride + range_offs_d[None, :] * o_n_stride)
)

out_ptrs = output + off_o
Expand Down
3 changes: 1 addition & 2 deletions test/test_torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def reference_fp32(request):
)
@pytest.mark.parametrize(
"shape",
# TODO add shape 32x32 which may be unstable with T5
[(bs, seq_l) for bs in [1, 8, 32] for seq_l in [16, 33, 128, 256, 384, 512] if bs * seq_l < 10000],
[(bs, seq_l) for bs in [1, 8, 16, 32] for seq_l in [16, 32, 33, 128, 256, 384, 512] if bs * seq_l < 10000],
ids=lambda x: f"{x[0]}x{x[1]}",
)
@pytest.mark.parametrize("implementation", implementations, ids=lambda v: v.name)
Expand Down