Skip to content

Commit

Permalink
[PYTHON] Added shape & device checks for inputs to sparse matmul op (t…
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose authored and ptillet committed Jul 27, 2021
1 parent bfc0a75 commit d7f8792
Showing 1 changed file with 110 additions and 36 deletions.
146 changes: 110 additions & 36 deletions python/triton/ops/blocksparse/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class _matmul(torch.autograd.Function):
# performs load-balancing to achieve more smaller reductions
# between `seg_size` elements
@staticmethod
def load_balance(sizes, block):
def load_balance(sizes):
# segment size
# heuristics taken from OpenAI blocksparse code
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
Expand Down Expand Up @@ -273,33 +273,41 @@ def make_sdd_lut(layout, block, device):

@staticmethod
def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):

# (A * B)^T = (B^T * A^T)
if trans_c:
a, b = b, a
trans_a, trans_b = not trans_b, not trans_a
AS0 = a.size(0)
AS2 = a.size(3 if trans_a else 2)
AS3 = a.size(2 if trans_a else 3)

# Shape check
a_dim = -2 if trans_a else -1
b_dim = -1 if trans_b else -2
a_inner, b_inner = a.shape[a_dim], b.shape[b_dim]
if a_inner != b_inner:
raise ValueError(f"Size of tensor A along the {_dim_to_name(a_dim)} dim ({a_inner}) must match size "
f"of tensor B along the {_dim_to_name(b_dim)} dim ({b_inner})")
if a_inner % 16 != 0:
raise ValueError('Reduction size for SDD must be a multiple of 16')

batch_size = a.size(0)
a_outer = a.size(3 if trans_a else 2)
dtype = a.dtype
device = a.device
is_16_multiple = AS3 % 16 == 0
if not is_16_multiple:
raise ValueError('Reduction size for SDD must be a multiple of 16')

# create kernel
total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
c = torch.zeros((AS0, total_width, block, block), dtype=dtype, device=device)
c = torch.zeros((batch_size, total_width, block, block), dtype=dtype, device=device)
for lut, width, pack in zip(luts, widths, packs):
num_lock = 1
meta = {'TM': block * pack, 'TN': block * pack, 'BLOCK': block, 'TK': 32, 'TZ': 1, \
meta = {'TM': block * pack, 'TN': block * pack, 'BLOCK': block, 'TK': 32, 'TZ': 1,
'SDD': True, 'DSD': False, 'DDS': False}
# create output
locks = _matmul.get_locks(2 * width * AS0 * num_lock, a.device)
locks = _matmul.get_locks(2 * width * batch_size * num_lock, a.device)
# maximum grid size is 65535
# so operation might be decomposed into multiple
# kernel calls
max_width = 49152
for off_width in range(0, width, max_width):
grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), AS0]
grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), batch_size]
_kernel[grid](
a,
b,
Expand All @@ -316,9 +324,9 @@ def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks,
c.stride(0),
c.stride(2),
c.stride(3),
AS2,
AS2,
AS3,
a_outer,
a_outer,
a_inner,
off_width,
lut,
locks,
Expand Down Expand Up @@ -353,7 +361,7 @@ def make_dxx_lut(layout, block, step, trans, device, transform=lambda idx: idx):
sizes = torch.sum(layout[z, :, :], 1)
else:
sizes = torch.sum(layout[z, :, :], 0)
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block)
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes)
z_depth = z * torch.ones_like(z_segments)
z_lockid[z_lockid > 0] += current_maxid
current_maxid = z_lockid.max()
Expand Down Expand Up @@ -433,7 +441,7 @@ def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks,
BS2 = block * spdims[1 if trans_b else 2]
dtype = a.dtype
# kernel
meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\
meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,
'SDD': False, 'DSD': False, 'DDS': True}
# output
CS0 = AS0
Expand Down Expand Up @@ -480,7 +488,7 @@ def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks,
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# kernel
meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\
meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,
'SDD': False, 'DSD': True, 'DDS': False}
# output
CS0 = BS0
Expand Down Expand Up @@ -599,8 +607,8 @@ def make_lut(self, dtype, device):
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
elif self.mode == 'dds':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, device)
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs,\
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs)
return self.lut_cache[key]

Expand All @@ -610,32 +618,98 @@ def __init__(self, layout, block, mode, trans_a=False, trans_b=False):
# look-up table cache
self.lut_cache = dict()
# attributes
self.block = block
self.mode = mode
self.trans_a = trans_a
self.trans_b = trans_b
self.mode = mode
self.spdims = layout.shape
self.block = block
self.layout = layout

# Kernel assumes that all tensors are 4 dimensional
@staticmethod
def _pad_shape(x):
# Add extra batch dimensions if needed
for i in range(4 - x.ndim):
x = x.unsqueeze(0)
layout_dim = layout.ndim
assert layout_dim in (2, 3), "Layout should be a 2 or 3 dimensional tensor of 0s and 1s"

if not mode == 'sdd':
# Dims to be reduced on the 'inside' of the matmul, either -1 or -2
trans_dense, trans_sparse, sparse_inner = (trans_b, trans_a, -1) if mode == 'dsd' else (trans_a, trans_b, -2)
self.dense_inner_dim = -((sparse_inner % 2) + 1) if not trans_dense else sparse_inner
sparse_inner = sparse_inner if not trans_sparse else -((sparse_inner % 2) + 1)

return x
# Inner dim of the dense input should be equal to the inner dim of the sparse input
self.dense_inner_size = layout.shape[sparse_inner] * block
# Expected shape for sparse inputs
self.sparse_shape = (layout.sum().item(), block, block)

# Support using the same layout across attention heads etc.
if layout_dim == 2:
layout = layout.unsqueeze(0)

layout = layout.long() # Above code assumes the layout tensor is an integral type
self.layout = layout
self.spdims = layout.shape

def __call__(self, a, b):
c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
# pad shapes with ones
a = matmul._pad_shape(a)
b = matmul._pad_shape(b)

# If we don't check for invalid shapes, devices, & dtypes here, they will lead to undefined behavior
# and potential illegal memory accesses
original_dims = max(a.ndim, b.ndim)
a, b = self._validate_inputs(a, b)

# execute
c = _matmul.apply(
a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs
a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width,
c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs
)
# This removes any leading singleton dimensions we may have added to the tensor that weren't in the input
dims_to_trim = c.ndim - original_dims
for _ in range(dims_to_trim):
c = c.squeeze(0)

return c

def _validate_inputs(self, a, b):
if a.device != b.device:
raise ValueError(f"Inputs must be on the same device; got {a.device} for tensor A "
f"and {b.device} for tensor B")
if not a.is_cuda:
raise ValueError("Only GPU devices are supported for now")

# When autocast is enabled, torch.matmul autocasts to float16, so we do the same here
if torch.is_autocast_enabled():
a, b = a.half(), b.half()
elif a.dtype != b.dtype:
raise ValueError(f"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B")

mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b
if mode != 'sdd':
# One input is sparse
dense, dense_name, sparse, sparse_name = (a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A')
dense_inner = dense.shape[self.dense_inner_dim]
if dense_inner != self.dense_inner_size:
raise ValueError(f"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim "
f"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.")

if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape:
raise ValueError(f"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument "
f"{sparse_name}, got {sparse.shape}")

def add_extra_dims(x):
# Add extra leading singleton dimensions if needed
dims_needed = 4 - x.ndim
if dims_needed > 0:
singletons = [1] * dims_needed
x = x.view(*singletons, *x.shape)
elif dims_needed < 0:
raise ValueError("Tensors with more than 4 dimensions are not currently supported")

return x

# Pad shapes with leading singleton dimensions
a = add_extra_dims(a)
b = add_extra_dims(b)

return a, b

def _dim_to_name(x):
# assert x in (-1, -2)
return "last" if x == -1 else "second to last"

0 comments on commit d7f8792

Please sign in to comment.