Skip to content

Commit a19ee39

Browse files
committed
Preliminary work for block sparse emulation code.
1 parent 0985083 commit a19ee39

File tree

5 files changed

+93
-21
lines changed

5 files changed

+93
-21
lines changed

MANIFEST.in

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
include README.md
2-
graft pytorch_block_sparse/cutlass/*.h
3-
graft pytorch_block_sparse/native/*.h
4-
graft pytorch_block_sparse/tests/.py
2+
graft pytorch_block_sparse/cutlass/
3+
graft pytorch_block_sparse/native/
4+
graft pytorch_block_sparse/tests/
55
global-exclude *.py[cod]
66
global-exclude *~

pytorch_block_sparse/block_sparse.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,21 @@
44
import warnings
55
import math
66

7-
class BlockSparseMatrix(torch.nn.Module):
7+
class BlockSparseMatrixBase(torch.nn.Module):
88
# cols is a list of nonzero block column indexes (int32)
99
# row_start is a index into cols (int32)
1010
# Data is (len(cols), block_shape, block_shape)
11-
def __init__(self, shape, block_mask, data, block_shape=(16, 16)):
12-
super(BlockSparseMatrix, self).__init__()
11+
def __init__(self, shape, block_mask, data, block_shape=(32, 32)):
12+
super(BlockSparseMatrixBase, self).__init__()
1313
self.int_type = torch.int32
1414

15-
if len(shape) != 2 or shape[0] % 16 != 0 or shape[1] % 16 != 0:
16-
raise Exception("shape should be a tuple of 2 multiples of 16")
15+
if len(shape) != 2:
16+
raise Exception("shape should be a tuple of 2 ints")
17+
1718
self.shape = torch.Size(shape)
18-
if len(block_shape) != 2 or block_shape[0] % 16 != 0 or block_shape[1] % 16 != 0:
19-
raise Exception("block_shape should be a tuple of 2 multiples of 16")
19+
if len(block_shape) != 2:
20+
raise Exception("block_shape should be a tuple of 2 ints")
21+
2022
self.block_shape = tuple(block_shape)
2123

2224
self.data = torch.nn.Parameter(data)
@@ -74,15 +76,15 @@ def build_indices_(self, block_mask, block_ptr, nnzt, transpose_indices):
7476
# Reorganize the indexes with transposed ordering
7577
block_indices = block_indices.reshape(X, Y).t().reshape(X * Y)
7678
# Only keeps the non zero, and substract 1 to find back the right block index
77-
block_ptr = block_indices[block_indices.nonzero()] - 1
79+
block_ptr = block_indices[torch.nonzero(block_indices, as_tuple=False)] - 1
7880
# Remove spurious dimension
7981
block_ptr = block_ptr.squeeze(-1)
8082

8183
X, Y = Y, X
8284

8385
rows = cols
8486

85-
nnztt = block_mask.t().nonzero()
87+
nnztt = torch.nonzero(block_mask.t(), as_tuple=False)
8688
cols = nnztt[:,1]
8789

8890
row_start_ends = torch.zeros((X + 1,), dtype=torch.long, device = device)
@@ -100,7 +102,7 @@ def build_indices(self, block_mask, block_ptr = None):
100102
# assume that the content of block_ptr is just from 0..n_blocks
101103
# Used to recycle blocks
102104

103-
nnz = block_mask.nonzero()
105+
nnz = torch.nonzero(block_mask, as_tuple=False)
104106

105107
if block_ptr == None:
106108
block_ptr = torch.arange(0, nnz.shape[0], device=block_mask.device)
@@ -510,3 +512,17 @@ def matmul_with_output_sparse_support(self, dense_a, dense_b, overwrite_data = F
510512
ret = self.matmul_with_output_sparse_support_(rewritten_a, rewritten_b, overwrite_data)
511513

512514
return ret
515+
516+
517+
class BlockSparseMatrix(BlockSparseMatrixBase):
518+
# cols is a list of nonzero block column indexes (int32)
519+
# row_start is a index into cols (int32)
520+
# Data is (len(cols), block_shape, block_shape)
521+
def __init__(self, shape, block_mask, data, block_shape=(32, 32)):
522+
if len(shape) != 2 or shape[0] % 32 != 0 or shape[1] % 32 != 0:
523+
raise Exception("shape should be a tuple of 2 multiples of 32")
524+
525+
if len(block_shape) != 2 or block_shape[0] % 32 != 0 or block_shape[1] % 32 != 0:
526+
raise Exception("block_shape should be a tuple of 2 multiples of 32")
527+
528+
super(BlockSparseMatrix, self).__init__(shape, block_mask, data, block_shape)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
import torch.nn
3+
from . import block_sparse
4+
5+
class BlockSparseMatrixEmulator(block_sparse.BlockSparseMatrixBase):
6+
# cols is a list of nonzero block column indexes (int32)
7+
# row_start is a index into cols (int32)
8+
# Data is (len(cols), block_shape, block_shape)
9+
def __init__(self, shape, block_mask, data, block_shape):
10+
super(BlockSparseMatrixEmulator, self).__init__(shape, block_mask, data, block_shape)
11+
12+
def rebuild(self, block_mask, block_ptr=None):
13+
super().rebuild(block_mask, block_ptr)
14+
self._dense = self.to_dense()
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from unittest import TestCase
2+
import torch
3+
import unittest
4+
import torch.optim as optim
5+
from pytorch_block_sparse.block_sparse import BlockSparseMatrix
6+
from pytorch_block_sparse.block_sparse_emulate import BlockSparseMatrixEmulator
7+
from pytorch_block_sparse.block_sparse_linear import PseudoBlockSparseLinear
8+
9+
class TestFun(TestCase):
10+
def help_contruct(self, shape, block_mask, data, block_shape=(16, 16)):
11+
try:
12+
real = BlockSparseMatrix(shape, block_mask, data, block_shape)
13+
except:
14+
real = None
15+
emul = BlockSparseMatrixEmulator(shape, block_mask, data, block_shape)
16+
17+
return real, emul
18+
19+
def help_randn(cls, shape, n_blocks, blocks=None, block_shape=(32, 32), device="cuda", positive=False):
20+
try:
21+
real = BlockSparseMatrix.randn(shape, n_blocks, blocks, block_shape, device=device, positive=positive)
22+
except:
23+
real = None
24+
emul = BlockSparseMatrixEmulator.randn(shape, n_blocks, blocks, block_shape, device=device, positive=positive)
25+
26+
return real, emul
27+
28+
def test0(self):
29+
d = dict
30+
test_sizes = [d(nb=2, s=(3,5), bs=(1,1))]
31+
map = d(nb= "n_blocks", s="shape", bs="block_shape")
32+
33+
for ts in test_sizes:
34+
ts = {map[k]:v for k,v in ts.items()}
35+
self.help_randn(**ts, device="cpu")

setup.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,20 @@
66

77
version = "0.1.2"
88

9+
ext_modules = []
10+
11+
import torch
12+
if torch.cuda.is_available():
13+
ext = CUDAExtension('block_sparse_native',
14+
['pytorch_block_sparse/native/block_sparse_native.cpp',
15+
'pytorch_block_sparse/native/block_sparse_cutlass_kernel_back.cu',
16+
'pytorch_block_sparse/native/block_sparse_cutlass_kernel.cu'],
17+
extra_compile_args=['-I', '%s/pytorch_block_sparse' % rootdir]
18+
)
19+
ext_modules = [ext]
20+
else:
21+
print("WARNING: torch cuda seems unavailable, emulated features only will be available.")
22+
923
setup(name='pytorch_block_sparse',
1024
version=version,
1125
description='PyTorch extension for fast block sparse matrices computation, drop in replacement for torch.nn.Linear.',
@@ -25,14 +39,7 @@
2539
install_requires=[],
2640
include_package_data=True,
2741
zip_safe=False,
28-
ext_modules=[
29-
CUDAExtension('block_sparse_native',
30-
['pytorch_block_sparse/native/block_sparse_native.cpp',
31-
'pytorch_block_sparse/native/block_sparse_cutlass_kernel_back.cu',
32-
'pytorch_block_sparse/native/block_sparse_cutlass_kernel.cu'],
33-
extra_compile_args=['-I', '%s/pytorch_block_sparse' % rootdir]
34-
),
35-
],
42+
ext_modules=ext_modules,
3643
cmdclass={
3744
'build_ext': BuildExtension
3845
}

0 commit comments

Comments
 (0)