From d809020fc876da70f2e2546abcef4b2989ad17b5 Mon Sep 17 00:00:00 2001 From: Nikita Vedeneev Date: Fri, 3 Mar 2023 10:08:20 +0000 Subject: [PATCH] Triton kernel for bsr @ dense (#94823) Pull Request resolved: https://github.com/pytorch/pytorch/pull/94823 Approved by: https://github.com/cpuhrsch, https://github.com/malfet --- mypy.ini | 3 + test/test_sparse_csr.py | 56 +++- torch/sparse/_triton_ops.py | 608 ++++++++++++++++++++++++++++++++++++ 3 files changed, 666 insertions(+), 1 deletion(-) create mode 100644 torch/sparse/_triton_ops.py diff --git a/mypy.ini b/mypy.ini index 380f432c4805e4..0636600a98a52c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -185,6 +185,9 @@ ignore_errors = True # Third party dependencies that don't have types. # +[mypy-triton.*] +ignore_missing_imports = True + [mypy-tensorflow.*] ignore_missing_imports = True diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 3e38ce6f7bd0d1..41a12c0932e105 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -9,7 +9,7 @@ from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC from torch.testing._internal.common_utils import \ (TEST_WITH_ROCM, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff, parametrize, - subtest, skipIfTorchDynamo) + subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU) from torch.testing._internal.common_device_type import \ (ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric, precisionOverride, skipMeta, skipCUDAIf, skipCUDAIfRocm, skipCPUIfNoMklSparse, skipCUDAIfRocmVersionLessThan) @@ -1462,6 +1462,60 @@ def run_test_block_addmm_addmv(self, self.assertEqual(actual, out) self.assertEqual(actual, expected) + @parametrize("block_size", [16, 32, 64]) + @parametrize("index_dtype", [torch.int32, torch.int64]) + @onlyCUDA + @skipIfRocm + @dtypes(torch.half, torch.bfloat16) + @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else []) + @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") + def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size): + from functools import partial + + from torch._inductor.utils import has_triton + from torch.sparse._triton_ops import bsr_dense_mm + + if not has_triton(): + self.skipTest("Triton is not available.") + + # Note that each value in a non-zero block is in range block_size * [low^2, high^2). + tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5) + + # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`. + batches = [(), (2,)] + size = [128, 256, 0] + + # Whether to make inputs orthogonal so that the product is zero + make_orthogonal = [True, False] + + for bd, bs, m, n, k, is_ortho in itertools.product(batches, batches, size, size, size, make_orthogonal): + bsr = tensor(bs + (m, k)) + # NOTE: do not get confused, it will be transposed + dense = tensor(bd + (n, k)) + + if is_ortho: + bsr = torch.cat((bsr, torch.zeros_like(bsr)), dim=-1) + dense = torch.cat((torch.zeros_like(dense), dense), dim=-1) + + bsr = bsr.to_sparse_bsr(block_size) + + res_tri = bsr_dense_mm(bsr, dense.transpose(-2, -1)) + res_dense = bsr.to_dense() @ dense.transpose(-2, -1) + self.assertEqual(res_tri, res_dense) + + # check whether bsr_dense_mm handles different grid sizes + # None means max possible grid size which is CUDA-dependent. + grid_size = (None, 2, 4) + grid_gen = itertools.product(grid_size, repeat=3) + for is_sparse_rowspace, grid in itertools.product((True, False), grid_gen): + res_tri = bsr_dense_mm( + bsr, + dense.transpose(-2, -1), + max_grid=grid, + is_sparse_rowspace_mode=is_sparse_rowspace + ) + self.assertEqual(res_tri, res_dense) + # TODO: block_size 1 is broken @parametrize("block_size", [2, 3]) @parametrize("index_dtype", [torch.int32, torch.int64]) diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py new file mode 100644 index 00000000000000..d7b34f34905d74 --- /dev/null +++ b/torch/sparse/_triton_ops.py @@ -0,0 +1,608 @@ +import torch +from torch._inductor.cuda_properties import get_device_capability + +def _has_triton(): + if not torch.cuda.is_available(): + return False + try: + import triton + + return triton is not None and get_device_capability() >= (7, 0) + except ImportError: + return False + +def compressed_indices_to_plain_indices(cidx, pidx): + nnz = pidx.shape[-1] + cdim = cidx.shape[-1] - 1 + batch_numel = cidx.shape[0] + batch_offset = torch.arange(batch_numel, dtype=cidx.dtype, device=cidx.device)[ + :, None + ] + + cidx_batch_offsetted = cidx[:, :-1] + nnz * batch_offset + cidx_linear = torch.empty( + (batch_numel * cdim + 1,), dtype=cidx.dtype, device=cidx.device + ) + cidx_linear[:-1] = cidx_batch_offsetted.reshape(-1) + cidx_linear[-1] = nnz * batch_numel + + idx_linear = torch._convert_indices_from_csr_to_coo( + cidx_linear, pidx.reshape(-1), out_int32=(cidx.dtype == torch.int32) + ).select(0, 0) + + return idx_linear.reshape(batch_numel, -1).sub_(cdim * batch_offset) + + +def slicer(dim, slice_range, *tensors): + for t in tensors: + slices = [slice(None)] * t.dim() + slices[dim] = slice_range + yield t[slices] + +if _has_triton(): + import triton + import triton.language as tl + from typing import Optional, Tuple + + @triton.jit + def _bsr_strided_dense_rowspace_kernel( + BLOCKSIZE_ROW: tl.constexpr, + BLOCKSIZE_COL: tl.constexpr, + # values prologue + values_ptr, + values_batch_stride, + values_nnz_stride, + values_row_block_stride, + values_col_block_stride, + # values epilogue + # crow_indices prologue + crow_indices_ptr, + crow_indices_batch_stride, + crow_indices_stride, + # crow_indices epilogue + # col_indices prologue + col_indices_ptr, + col_indices_batch_stride, + col_indices_stride, + # col_indices epilogue + # dense prologue + dense_ptr, + dense_batch_stride, + dense_tiled_row_stride, + dense_tiled_col_stride, + dense_row_block_stride, + dense_col_block_stride, + # dense epilogue + # output prologue + output_ptr, + output_batch_stride, + output_tiled_row_stride, + output_tiled_col_stride, + output_row_block_stride, + output_col_block_stride, + # output epilogue + GROUP_SIZE_ROW: tl.constexpr, + ): + batch_pid = tl.program_id(axis=2) + row_block_pid = tl.program_id(axis=0) + col_block_pid = tl.program_id(axis=1) + n_block_rows = tl.num_programs(axis=0) + n_block_cols = tl.num_programs(axis=1) + + row_block_pid, col_block_pid = tl.swizzle2d( + row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW + ) + + crow_indices_offset_ptr = ( + crow_indices_ptr + + crow_indices_batch_stride * batch_pid + + crow_indices_stride * row_block_pid + ) + nnz_offset = tl.load(crow_indices_offset_ptr) + nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) + + # Compute nnz for the row with number row_block_pid. + # If it is zero, skip the row. + row_nnz = nnz_offset_next - nnz_offset + if row_nnz == 0: + return + + row_block_arange = tl.arange(0, BLOCKSIZE_ROW) + col_block_arange = tl.arange(0, BLOCKSIZE_COL) + + # Pointers are set to the first block of the current row. + values_block_ptrs = ( + values_ptr + + values_batch_stride * batch_pid + + values_nnz_stride * nnz_offset + + values_row_block_stride * row_block_arange[:, None] + + values_col_block_stride * col_block_arange[None, :] + ) + + # NOTE: dense is advanced into all dimensions but the tiled row one. + # That will be advanced in the loop according to values in col_indices. + dense_block_ptrs = ( + dense_ptr + + dense_batch_stride * batch_pid + + dense_tiled_col_stride * col_block_pid + + dense_row_block_stride * col_block_arange[:, None] + + dense_col_block_stride * row_block_arange[None, :] + ) + + # Pointers are set to exact write-to locations + output_ptrs = ( + output_ptr + + output_batch_stride * batch_pid + + output_tiled_row_stride * row_block_pid + + output_tiled_col_stride * col_block_pid + + output_row_block_stride * row_block_arange[:, None] + + output_col_block_stride * row_block_arange[None, :] + ) + + # Set pointer to the first nonzero element in the current row + col_index_nnz_ptr = ( + col_indices_ptr + + col_indices_batch_stride * batch_pid + + col_indices_stride * nnz_offset + ) + + output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), tl.float32) + for _ in range(row_nnz): + values_block = tl.load(values_block_ptrs) + + # find which row of dense needs to get loaded + # for multiplication with values_block. + dense_row_idx = tl.load(col_index_nnz_ptr) + dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx) + + # do block mm + output_acc_block += tl.dot(values_block, dense_block) + + # move val/col_index ptrs to the next block in the row + values_block_ptrs += values_nnz_stride + col_index_nnz_ptr += col_indices_stride + + # write back the result + tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty)) + + + @triton.jit + def _bsr_strided_sparse_rowspace_kernel( + BLOCKSIZE_ROW: tl.constexpr, + BLOCKSIZE_COL: tl.constexpr, + batch_idx_ptr, + row_idx_ptr, + nnz_per_row_ptr, + nnz_per_row_cumsum_ptr, + col_indices_ptr, + col_indices_stride, + # values prologue + values_ptr, + values_nnz_stride, + values_row_block_stride, + values_col_block_stride, + # values epilogue + # dense prologue + dense_ptr, + dense_batch_stride, + dense_tiled_row_stride, + dense_tiled_col_stride, + dense_row_block_stride, + dense_col_block_stride, + # dense epilogue + # output prologue + output_ptr, + output_batch_stride, + output_tiled_row_stride, + output_tiled_col_stride, + output_row_block_stride, + output_col_block_stride, + # output epilogue + GROUP_SIZE_ROW: tl.constexpr, + ): + row_block_pid = tl.program_id(axis=0) + col_block_pid = tl.program_id(axis=1) + n_block_rows = tl.num_programs(axis=0) + n_block_cols = tl.num_programs(axis=1) + + row_block_pid, col_block_pid = tl.swizzle2d( + row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW + ) + + batch_idx = tl.load(batch_idx_ptr + row_block_pid) + row_idx = tl.load(row_idx_ptr + row_block_pid) + row_idx_nnz = tl.load(nnz_per_row_ptr + row_block_pid) + row_idx_nnz_cumsum = tl.load(nnz_per_row_cumsum_ptr + row_block_pid) + row_idx_nnz_offset = row_idx_nnz_cumsum - row_idx_nnz + + row_block_arange = tl.arange(0, BLOCKSIZE_ROW) + col_block_arange = tl.arange(0, BLOCKSIZE_COL) + + # Pointers are set to the first block of the current row. + values_block_ptrs = ( + values_ptr + + values_nnz_stride * row_idx_nnz_offset + + values_row_block_stride * row_block_arange[:, None] + + values_col_block_stride * col_block_arange[None, :] + ) + + # NOTE: dense is advanced into all dimensions but the tiled row one. + # That will be advanced in the loop according to values in col_indices. + dense_block_ptrs = ( + dense_ptr + + dense_batch_stride * batch_idx + + dense_tiled_col_stride * col_block_pid + + dense_row_block_stride * col_block_arange[:, None] + + dense_col_block_stride * row_block_arange[None, :] + ) + + # Pointers are set to exact write-to locations + output_ptrs = ( + output_ptr + + output_batch_stride * batch_idx + + output_tiled_row_stride * row_idx + + output_tiled_col_stride * col_block_pid + + output_row_block_stride * row_block_arange[:, None] + + output_col_block_stride * row_block_arange[None, :] + ) + + output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), tl.float32) + col_index_nnz_ptr = col_indices_ptr + row_idx_nnz_offset * col_indices_stride + for _ in range(row_idx_nnz): + values_block = tl.load(values_block_ptrs) + + # find which row of dense needs to get loaded + # for multiplication with values_block. + dense_row_idx = tl.load(col_index_nnz_ptr) + dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx) + + # do block mm + output_acc_block += tl.dot(values_block, dense_block) + + # move val/col_index ptrs to the next block in the row + values_block_ptrs += values_nnz_stride + col_index_nnz_ptr += col_indices_stride + + # write back the result + tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty)) + + + def _run_sparse_rowspace_kernel( + blocksize, values, crow_indices, col_indices, dense, output, max_grid + ): + # Compute a vector of non-zero elements numbers per each row. + # We want to ultimately iterate over non-zero rows. + nnz_per_row = crow_indices[:, 1:] - crow_indices[:, :-1] + + # Compute indices of non-zero counts. + # batch_idx maps to a broadcasted batch index, while + # row_idx tracks non-zero rows of the sparse argument + # and rows of the output that get modified. + batch_idx, row_idx = nnz_per_row.nonzero(as_tuple=True) + + # Compress the vector of counts to hold only non-zero values. + nnz_per_row = nnz_per_row[batch_idx, row_idx] + # Compute cumulative counts which along with nnz_per_row + # are used to compute offsets into nnz values. + nnz_per_row_cumsum = nnz_per_row.cumsum(-1) + + n_nnz_block_rows = row_idx.size(-1) + n_block_cols = dense.size(-3) + max_n_nnz_block_rows, max_n_block_cols = max_grid[:2] + + for c_start in range(0, n_block_cols, max_n_block_cols): + c_dense, c_output = slicer( + -3, slice(c_start, c_start + max_n_block_cols), dense, output + ) + c_grid = min(n_block_cols - c_start, max_n_block_cols) + + for r_start in range(0, n_nnz_block_rows, max_n_nnz_block_rows): + r_batch_idx, r_row_idx, r_nnz_per_row, r_nnz_per_row_cumsum = slicer( + 0, + slice(r_start, r_start + max_n_nnz_block_rows), + batch_idx, + row_idx, + nnz_per_row, + nnz_per_row_cumsum, + ) + r_grid = min(n_nnz_block_rows - r_start, max_n_nnz_block_rows) + + _bsr_strided_sparse_rowspace_kernel[(r_grid, c_grid)]( + *blocksize, + r_batch_idx, + r_row_idx, + r_nnz_per_row, + r_nnz_per_row_cumsum, + col_indices, + *col_indices.stride(), + values, + *values.stride(), + c_dense, + *c_dense.stride(), + c_output, + *c_output.stride(), + GROUP_SIZE_ROW=4, + num_stages=4, + num_warps=4, + ) + + + def _run_dense_rowspace_kernel( + blocksize, values, crow_indices, col_indices, dense, output, max_grid + ): + # Launch kernel + n_batches = dense.size(0) + n_block_rows = crow_indices.size(-1) - 1 + n_block_cols = dense.size(-3) + max_n_block_rows, max_n_block_cols, max_n_batches = max_grid + + for b_start in range(0, n_batches, max_n_batches): + b_v, b_crow, b_col, b_d, b_o = slicer( + 0, + slice(b_start, b_start + max_n_batches), + values, + crow_indices, + col_indices, + dense, + output, + ) + b_grid = min(n_batches - b_start, max_n_batches) + + for c_start in range(0, n_block_cols, max_n_block_cols): + bc_d, bc_o = slicer( + -3, slice(c_start, c_start + max_n_block_cols), b_d, b_o + ) + c_grid = min(n_block_cols - c_start, max_n_block_cols) + + for r_start in range(0, n_block_rows, max_n_block_rows): + r_slice = slice(r_start, r_start + max_n_block_rows) + br_crow = next(slicer(-1, r_slice, b_crow)) + brc_o = next(slicer(-4, r_slice, bc_o)) + r_grid = min(n_block_rows - r_start, max_n_block_rows) + + _bsr_strided_dense_rowspace_kernel[(r_grid, c_grid, b_grid)]( + *blocksize, + b_v, + *b_v.stride(), + br_crow, + *br_crow.stride(), + b_col, + *b_col.stride(), + bc_d, + *bc_d.stride(), + brc_o, + *brc_o.stride(), + GROUP_SIZE_ROW=4, + num_stages=4, + num_warps=4, + ) + + + def bsr_dense_mm( + bsr: torch.Tensor, + dense: torch.Tensor, + *, + skip_checks: bool = False, + is_sparse_rowspace_mode: Optional[bool] = None, + max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, + out: Optional[torch.Tensor] = None, + ): + m, kl = bsr.shape[-2:] + kr, n = dense.shape[-2:] + + def check(cond, msg): + if not cond: + raise ValueError(msg) + + if not skip_checks: + check( + bsr.layout == torch.sparse_bsr, + "bsr_dense_mm(): only BSR sparse format is supported for the sparse argument.", + ) + + check( + bsr.device == dense.device and bsr.device.type == "cuda", + "bsr_dense_mm(): all inputs are expected to be on the same GPU device.", + ) + + check( + bsr.dtype == dense.dtype + and bsr.dtype in (torch.half, torch.bfloat16, torch.float), + "bsr_dense_mm(): all inputs are expected to be of the same dtype " + "and one of (half, bfloat16, float32), " + f"but got bsr.dtype == {bsr.dtype} and dense.dtype == {dense.dtype}.", + ) + + check( + bsr.dim() >= 2 and dense.dim() >= 2, + "bsr_dense_mm(): all inputs are expected to be at least 2D, " + f"but got bsr.dim() == {bsr.dim()} and dense.dim() == {dense.dim()}.", + ) + + check( + kl == kr, + "bsr_dense_mm(): argument sizes are not compatible for matrix multiplication, " + f"got bsr.shape[-1] == {kl} which is not equal to dense.shape[-2] == {kr}.", + ) + + row_block = bsr.values().shape[-2] + check( + not n % row_block, + f"bsr_dense_mm(): dense.size(-1) == {n} should be divisible by " + f"blocksize[0] == {row_block}.", + ) + + # Required to undo the fake batch dimension insertion. + original_batch_dims_broadcasted = torch.broadcast_shapes( + bsr.shape[:-2], dense.shape[:-2] + ) + + if out is not None and not skip_checks: + expected_out_shape = original_batch_dims_broadcasted + (m, n) + check( + out.shape == expected_out_shape, + "bsr_dense_mm(): `out` argument has wrong shape, " + f"expected {expected_out_shape}, but got {out.shape}.", + ) + check( + out.is_contiguous() or out.transpose(-2, -1).is_contiguous(), + "bsr_dense_mm(): only row-major/col-major `out` arguments are supported, " + "i.e. (out.is_contiguous() or out.transpose(-2, -1).is_contiguous()) " + "should be True.", + ) + + # Short circuit if lhs is zero + if bsr._nnz() == 0: + return dense.new_zeros(original_batch_dims_broadcasted + (m, n)) + + # TODO: insert switch + if is_sparse_rowspace_mode is None: + is_sparse_rowspace_mode = False + + # Introduce fake batch dimension if not present for convenience. + def unsqueeze_batch_dim(t, n_non_batch_dims): + if t.dim() > n_non_batch_dims: + return t + else: + return t.unsqueeze(0) + + def make_triton_contiguous(t): + # Triton does not distinguish between row- and col-majorness + # and will be fast as long as there is a contiguous dimension. + if not (t.is_contiguous() or t.transpose(-2, -1).is_contiguous()): + return t.contiguous() + else: + return t + + crow_indices = unsqueeze_batch_dim(bsr.crow_indices(), 1) + col_indices = unsqueeze_batch_dim(bsr.col_indices(), 1) + values = make_triton_contiguous(unsqueeze_batch_dim(bsr.values(), 3)) + dense = make_triton_contiguous(unsqueeze_batch_dim(dense, 2)) + nnz = values.shape[-3] + blocksize = values.shape[-2:] + + # Compute broadcasted batch dimension + bsr_batch_dims = values.shape[:-3] + dense_batch_dims = dense.shape[:-2] + batch_dims_broadcasted = torch.broadcast_shapes(bsr_batch_dims, dense_batch_dims) + + # Allocate out + if out is None: + out = dense.new_zeros(batch_dims_broadcasted + (m, n)) + + # Broadcast batch dimensions and squash + def batch_broadcast_and_squash(t, batch_dims, invariant_dims): + return t.broadcast_to(batch_dims + invariant_dims).flatten( + 0, len(batch_dims) - 1 + ) + + crow_indices = batch_broadcast_and_squash( + crow_indices, batch_dims_broadcasted, (-1,) + ) + + if is_sparse_rowspace_mode: + # Flatten batch dimension with nnz dimension + # as required by the sparse rowspace kernel. + col_indices = batch_broadcast_and_squash( + col_indices, batch_dims_broadcasted + (-1,), () + ) + values = batch_broadcast_and_squash( + values, batch_dims_broadcasted + (values.shape[-3],), values.shape[-2:] + ) + else: + col_indices = batch_broadcast_and_squash( + col_indices, batch_dims_broadcasted, (-1,) + ) + values = batch_broadcast_and_squash( + values, batch_dims_broadcasted, values.shape[-3:] + ) + + dense = batch_broadcast_and_squash(dense, batch_dims_broadcasted, dense.shape[-2:]) + + # NOTE: out is contiguous, so batch_broadcast_and_squash will create a view + out = batch_broadcast_and_squash(out, batch_dims_broadcasted, out.shape[-2:]) + + # NOTE: this function will ALWAYS create a view + def tile_to_blocksize(t, blocksize): + *rest, m, n = t.shape + new_shape = rest + [ + m // blocksize[0], + blocksize[0], + n // blocksize[1], + blocksize[1], + ] + return t.reshape(new_shape).transpose(-3, -2) + + # "Blockify" the row dimension of dense with blocksize[1] + # since dense is on the rhs of matmul + dense = tile_to_blocksize(dense, blocksize[::-1]) + # "Blockify" the row dimension of out with blocksize[0] + # which is inherited from the bsr input. + # NOTE: tile_to_blocksize will create a view. + # NOTE: out.blocksize[-1] == dense.blocksize[-1], + # so it could be any value in [1, dense.shape[-1]). + # We need to probably use the largest possible blocksize + # so that it fits into SRAM. + out = tile_to_blocksize(out, (blocksize[0], blocksize[0])) + + # Launch kernel + if is_sparse_rowspace_mode: + kernel = _run_sparse_rowspace_kernel + else: + kernel = _run_dense_rowspace_kernel + + # cuda_max_grid = (2 ** 31 - 1, 2 ** 16 - 1, 2 ** 16 - 1) + cuda_max_grid = (2147483647, 65535, 65535) + if max_grid is None: + max_grid = cuda_max_grid + else: + + def valid_grid_dim(g, mg): + if g is None: + return mg + else: + # grid must be at least 1 and no greater than mg + return max(1, min(g, mg)) + + max_grid = tuple( + valid_grid_dim(g, mg) for g, mg in zip(max_grid, cuda_max_grid) + ) # type: ignore[assignment] + + kernel(blocksize, values, crow_indices, col_indices, dense, out, max_grid) + + # Block dims need to rejoin with the corresponding block dimensions + # prior to reshape so that blocks do not end up being transposed. + # NB: type checker is not able to narrow Optional[Tensor] to tensor by this point + return out.transpose(-3, -2).reshape(original_batch_dims_broadcasted + (m, n)) # type: ignore[union-attr] +else: + bsr_dense_mm = None # type: ignore[assignment] + + +if __name__ == "__main__": + from torch._inductor.utils import has_triton + + if has_triton(): + torch.manual_seed(13) + dtype = torch.float32 + p = 0.5 + mask_size = (8, 8) + block_size = (64, 64) + size = (mask_size[0] * block_size[0], mask_size[1] * block_size[1]) + + n_exp = 512 + diff = torch.ones(n_exp, device="cuda", dtype=torch.float32) + for i in range(n_exp): + mask = torch.rand(*mask_size, device="cuda") < p + x = torch.rand(*mask_size, *block_size, dtype=dtype, device="cuda") / 10 + x = ( + (mask[:, :, None, None] * x) + .transpose(-3, -2) + .reshape(*size) + .to_sparse_bsr(*block_size) + ) + y = torch.rand(5, *size, dtype=dtype, device="cuda") / 10 + res_dense = x.to_dense() @ y + res = bsr_dense_mm(x, y) + diff[i] = (res - res_dense).abs().max() + print(f"mean: {diff.mean()}, std: {diff.std()}") + print(f"max diff: {diff.max()}")