44import warnings
55import math
66
7- class BlockSparseMatrix (torch .nn .Module ):
7+ class BlockSparseMatrixBase (torch .nn .Module ):
88 # cols is a list of nonzero block column indexes (int32)
99 # row_start is a index into cols (int32)
1010 # Data is (len(cols), block_shape, block_shape)
11- def __init__ (self , shape , block_mask , data , block_shape = (16 , 16 )):
12- super (BlockSparseMatrix , self ).__init__ ()
11+ def __init__ (self , shape , block_mask , data , block_shape = (32 , 32 )):
12+ super (BlockSparseMatrixBase , self ).__init__ ()
1313 self .int_type = torch .int32
1414
15- if len (shape ) != 2 or shape [0 ] % 16 != 0 or shape [1 ] % 16 != 0 :
16- raise Exception ("shape should be a tuple of 2 multiples of 16" )
15+ if len (shape ) != 2 :
16+ raise Exception ("shape should be a tuple of 2 ints" )
17+
1718 self .shape = torch .Size (shape )
18- if len (block_shape ) != 2 or block_shape [0 ] % 16 != 0 or block_shape [1 ] % 16 != 0 :
19- raise Exception ("block_shape should be a tuple of 2 multiples of 16" )
19+ if len (block_shape ) != 2 :
20+ raise Exception ("block_shape should be a tuple of 2 ints" )
21+
2022 self .block_shape = tuple (block_shape )
2123
2224 self .data = torch .nn .Parameter (data )
@@ -74,15 +76,15 @@ def build_indices_(self, block_mask, block_ptr, nnzt, transpose_indices):
7476 # Reorganize the indexes with transposed ordering
7577 block_indices = block_indices .reshape (X , Y ).t ().reshape (X * Y )
7678 # Only keeps the non zero, and substract 1 to find back the right block index
77- block_ptr = block_indices [block_indices .nonzero ()] - 1
79+ block_ptr = block_indices [torch .nonzero (block_indices , as_tuple = False )] - 1
7880 # Remove spurious dimension
7981 block_ptr = block_ptr .squeeze (- 1 )
8082
8183 X , Y = Y , X
8284
8385 rows = cols
8486
85- nnztt = block_mask .t (). nonzero ( )
87+ nnztt = torch . nonzero ( block_mask .t (), as_tuple = False )
8688 cols = nnztt [:,1 ]
8789
8890 row_start_ends = torch .zeros ((X + 1 ,), dtype = torch .long , device = device )
@@ -100,7 +102,7 @@ def build_indices(self, block_mask, block_ptr = None):
100102 # assume that the content of block_ptr is just from 0..n_blocks
101103 # Used to recycle blocks
102104
103- nnz = block_mask .nonzero ()
105+ nnz = torch .nonzero (block_mask , as_tuple = False )
104106
105107 if block_ptr == None :
106108 block_ptr = torch .arange (0 , nnz .shape [0 ], device = block_mask .device )
@@ -510,3 +512,17 @@ def matmul_with_output_sparse_support(self, dense_a, dense_b, overwrite_data = F
510512 ret = self .matmul_with_output_sparse_support_ (rewritten_a , rewritten_b , overwrite_data )
511513
512514 return ret
515+
516+
517+ class BlockSparseMatrix (BlockSparseMatrixBase ):
518+ # cols is a list of nonzero block column indexes (int32)
519+ # row_start is a index into cols (int32)
520+ # Data is (len(cols), block_shape, block_shape)
521+ def __init__ (self , shape , block_mask , data , block_shape = (32 , 32 )):
522+ if len (shape ) != 2 or shape [0 ] % 32 != 0 or shape [1 ] % 32 != 0 :
523+ raise Exception ("shape should be a tuple of 2 multiples of 32" )
524+
525+ if len (block_shape ) != 2 or block_shape [0 ] % 32 != 0 or block_shape [1 ] % 32 != 0 :
526+ raise Exception ("block_shape should be a tuple of 2 multiples of 32" )
527+
528+ super (BlockSparseMatrix , self ).__init__ (shape , block_mask , data , block_shape )
0 commit comments