From d7f87929fa112c9efaecdd36066eb88a63688773 Mon Sep 17 00:00:00 2001 From: Nora Belrose <39116809+norabelrose@users.noreply.github.com> Date: Mon, 26 Apr 2021 12:25:16 -0700 Subject: [PATCH] [PYTHON] Added shape & device checks for inputs to sparse matmul op (#93) --- python/triton/ops/blocksparse/matmul.py | 146 ++++++++++++++++++------ 1 file changed, 110 insertions(+), 36 deletions(-) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 926239b40e62..c1787eaa867b 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -197,7 +197,7 @@ class _matmul(torch.autograd.Function): # performs load-balancing to achieve more smaller reductions # between `seg_size` elements @staticmethod - def load_balance(sizes, block): + def load_balance(sizes): # segment size # heuristics taken from OpenAI blocksparse code # https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95 @@ -273,33 +273,41 @@ def make_sdd_lut(layout, block, device): @staticmethod def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs): - + # (A * B)^T = (B^T * A^T) if trans_c: a, b = b, a trans_a, trans_b = not trans_b, not trans_a - AS0 = a.size(0) - AS2 = a.size(3 if trans_a else 2) - AS3 = a.size(2 if trans_a else 3) + + # Shape check + a_dim = -2 if trans_a else -1 + b_dim = -1 if trans_b else -2 + a_inner, b_inner = a.shape[a_dim], b.shape[b_dim] + if a_inner != b_inner: + raise ValueError(f"Size of tensor A along the {_dim_to_name(a_dim)} dim ({a_inner}) must match size " + f"of tensor B along the {_dim_to_name(b_dim)} dim ({b_inner})") + if a_inner % 16 != 0: + raise ValueError('Reduction size for SDD must be a multiple of 16') + + batch_size = a.size(0) + a_outer = a.size(3 if trans_a else 2) dtype = a.dtype device = a.device - is_16_multiple = AS3 % 16 == 0 - if not is_16_multiple: - raise ValueError('Reduction size for SDD must be a multiple of 16') + # create kernel total_width = sum([width * pack * pack for width, pack in zip(widths, packs)]) - c = torch.zeros((AS0, total_width, block, block), dtype=dtype, device=device) + c = torch.zeros((batch_size, total_width, block, block), dtype=dtype, device=device) for lut, width, pack in zip(luts, widths, packs): num_lock = 1 - meta = {'TM': block * pack, 'TN': block * pack, 'BLOCK': block, 'TK': 32, 'TZ': 1, \ + meta = {'TM': block * pack, 'TN': block * pack, 'BLOCK': block, 'TK': 32, 'TZ': 1, 'SDD': True, 'DSD': False, 'DDS': False} # create output - locks = _matmul.get_locks(2 * width * AS0 * num_lock, a.device) + locks = _matmul.get_locks(2 * width * batch_size * num_lock, a.device) # maximum grid size is 65535 # so operation might be decomposed into multiple # kernel calls max_width = 49152 for off_width in range(0, width, max_width): - grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), AS0] + grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), batch_size] _kernel[grid]( a, b, @@ -316,9 +324,9 @@ def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, c.stride(0), c.stride(2), c.stride(3), - AS2, - AS2, - AS3, + a_outer, + a_outer, + a_inner, off_width, lut, locks, @@ -353,7 +361,7 @@ def make_dxx_lut(layout, block, step, trans, device, transform=lambda idx: idx): sizes = torch.sum(layout[z, :, :], 1) else: sizes = torch.sum(layout[z, :, :], 0) - z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block) + z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes) z_depth = z * torch.ones_like(z_segments) z_lockid[z_lockid > 0] += current_maxid current_maxid = z_lockid.max() @@ -433,7 +441,7 @@ def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, BS2 = block * spdims[1 if trans_b else 2] dtype = a.dtype # kernel - meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\ + meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1, 'SDD': False, 'DSD': False, 'DDS': True} # output CS0 = AS0 @@ -480,7 +488,7 @@ def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, BS3 = b.size(2 if trans_b else 3) dtype = a.dtype # kernel - meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\ + meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1, 'SDD': False, 'DSD': True, 'DDS': False} # output CS0 = BS0 @@ -599,8 +607,8 @@ def make_lut(self, dtype, device): db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device) elif self.mode == 'dds': db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, device) - self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\ - da_lut, da_num_locks, da_width, da_packs,\ + self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs, + da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs) return self.lut_cache[key] @@ -610,32 +618,98 @@ def __init__(self, layout, block, mode, trans_a=False, trans_b=False): # look-up table cache self.lut_cache = dict() # attributes + self.block = block + self.mode = mode self.trans_a = trans_a self.trans_b = trans_b - self.mode = mode - self.spdims = layout.shape - self.block = block - self.layout = layout - # Kernel assumes that all tensors are 4 dimensional - @staticmethod - def _pad_shape(x): - # Add extra batch dimensions if needed - for i in range(4 - x.ndim): - x = x.unsqueeze(0) + layout_dim = layout.ndim + assert layout_dim in (2, 3), "Layout should be a 2 or 3 dimensional tensor of 0s and 1s" + + if not mode == 'sdd': + # Dims to be reduced on the 'inside' of the matmul, either -1 or -2 + trans_dense, trans_sparse, sparse_inner = (trans_b, trans_a, -1) if mode == 'dsd' else (trans_a, trans_b, -2) + self.dense_inner_dim = -((sparse_inner % 2) + 1) if not trans_dense else sparse_inner + sparse_inner = sparse_inner if not trans_sparse else -((sparse_inner % 2) + 1) - return x + # Inner dim of the dense input should be equal to the inner dim of the sparse input + self.dense_inner_size = layout.shape[sparse_inner] * block + # Expected shape for sparse inputs + self.sparse_shape = (layout.sum().item(), block, block) + + # Support using the same layout across attention heads etc. + if layout_dim == 2: + layout = layout.unsqueeze(0) + + layout = layout.long() # Above code assumes the layout tensor is an integral type + self.layout = layout + self.spdims = layout.shape def __call__(self, a, b): c_lut, c_num_locks, c_width, c_packs,\ da_lut, da_num_locks, da_width, da_packs,\ db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device) - # pad shapes with ones - a = matmul._pad_shape(a) - b = matmul._pad_shape(b) + + # If we don't check for invalid shapes, devices, & dtypes here, they will lead to undefined behavior + # and potential illegal memory accesses + original_dims = max(a.ndim, b.ndim) + a, b = self._validate_inputs(a, b) + # execute c = _matmul.apply( - a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width, c_packs, - da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs + a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width, + c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs ) + # This removes any leading singleton dimensions we may have added to the tensor that weren't in the input + dims_to_trim = c.ndim - original_dims + for _ in range(dims_to_trim): + c = c.squeeze(0) + return c + + def _validate_inputs(self, a, b): + if a.device != b.device: + raise ValueError(f"Inputs must be on the same device; got {a.device} for tensor A " + f"and {b.device} for tensor B") + if not a.is_cuda: + raise ValueError("Only GPU devices are supported for now") + + # When autocast is enabled, torch.matmul autocasts to float16, so we do the same here + if torch.is_autocast_enabled(): + a, b = a.half(), b.half() + elif a.dtype != b.dtype: + raise ValueError(f"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B") + + mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b + if mode != 'sdd': + # One input is sparse + dense, dense_name, sparse, sparse_name = (a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A') + dense_inner = dense.shape[self.dense_inner_dim] + if dense_inner != self.dense_inner_size: + raise ValueError(f"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim " + f"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.") + + if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape: + raise ValueError(f"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument " + f"{sparse_name}, got {sparse.shape}") + + def add_extra_dims(x): + # Add extra leading singleton dimensions if needed + dims_needed = 4 - x.ndim + if dims_needed > 0: + singletons = [1] * dims_needed + x = x.view(*singletons, *x.shape) + elif dims_needed < 0: + raise ValueError("Tensors with more than 4 dimensions are not currently supported") + + return x + + # Pad shapes with leading singleton dimensions + a = add_extra_dims(a) + b = add_extra_dims(b) + + return a, b + +def _dim_to_name(x): + # assert x in (-1, -2) + return "last" if x == -1 else "second to last"