@@ -10,6 +10,7 @@ class BlockSparseMatrix(torch.nn.Module):
1010 # Data is (len(cols), block_shape, block_shape)
1111 def __init__ (self , shape , block_mask , data , block_shape = (16 , 16 )):
1212 super (BlockSparseMatrix , self ).__init__ ()
13+ self .int_type = torch .int32
1314
1415 if block_mask .device != data .device :
1516 raise Exception ("block_mask and data should have same device, got %s and %s" % (block_mask .device , data .device ))
@@ -34,7 +35,7 @@ def __init__(self, shape, block_mask, data, block_shape=(16, 16)):
3435
3536
3637 self .data = torch .nn .Parameter (data )
37- for name in ("block_mask" , " cols_a" , "row_start_ends_a" , "rows_b" , "col_start_ends_b" , "blocks" ):
38+ for name in ("cols_a" , "row_start_ends_a" , "rows_b" , "col_start_ends_b" , "blocks" ):
3839 self .register_buffer (name , locals ()[name ])
3940
4041 self .sanity_check (self .cols_a , self .row_start_ends_a , self .shape , self .block_shape )
@@ -79,15 +80,15 @@ def build_indices_(self, block_mask, nnzt, transpose_indices):
7980 row_start_ends = torch .zeros ((X + 1 ,), dtype = torch .long , device = device )
8081
8182 row_start_ends .index_add_ (0 , rows + 1 , torch .ones (size = (cols .shape [0 ],), dtype = torch .long , device = device ))
82- row_start_ends = row_start_ends .cumsum (0 ).int ( )
83+ row_start_ends = row_start_ends .cumsum (0 ).to ( dtype = self . int_type )
8384
84- cols = torch .stack ([cols , block_shuffle ], 1 ).int ( )
85+ cols = torch .stack ([cols , block_shuffle ], 1 ).to ( dtype = self . int_type )
8586
8687 return cols , row_start_ends
8788
8889 def build_indices (self , block_mask ):
8990 nnz = block_mask .nonzero ()
90- blocks = nnz .flip (- 1 ).flatten ().to (dtype = torch . int32 )
91+ blocks = nnz .flip (- 1 ).flatten ().to (dtype = self . int_type )
9192
9293 nnzt = nnz .transpose (0 , 1 )
9394 cols_a , row_start_ends_a = self .build_indices_ (block_mask , nnzt , False )
@@ -129,7 +130,7 @@ def zeros(cls, shape, n_blocks = None, blocks = None, block_shape=(32, 32), devi
129130 positions = torch .tensor (positions , dtype = torch .int64 , device = device ).sort ()[0 ]
130131
131132 block_mask = torch .zeros (X * Y , dtype = torch .bool , device = device )
132- block_mask [positions ] = True
133+ block_mask [positions ] = 1
133134 block_mask = block_mask .view (X , Y )
134135 data = torch .zeros ((n_blocks * block_shape [0 ], block_shape [1 ]), dtype = torch .float , device = device )
135136
@@ -177,16 +178,16 @@ def build_coo_block_index(self):
177178 device = self .cols_a .device
178179 # Build a tensor to store the row indices.
179180 # It's one element too long for the moment, we'll trim it later
180- rows = torch .zeros ((self .cols_a .shape [0 ] + 1 ), dtype = torch . int32 , device = device )
181+ rows = torch .zeros ((self .cols_a .shape [0 ] + 1 ), dtype = self . int_type , device = device )
181182
182183 # Change self.row_start_ends_a to the right type
183184 row_end_prepare = self .row_start_ends_a [1 :].long ()
184185
185186 # Add ones to the start position of each new row
186- rows .index_add_ (0 , row_end_prepare , torch .ones (size = row_end_prepare .shape , dtype = torch . int32 , device = device ))
187+ rows .index_add_ (0 , row_end_prepare , torch .ones (size = row_end_prepare .shape , dtype = self . int_type , device = device ))
187188
188189 # Accumulate those start positions to fill the remaining positions
189- rows = rows .cumsum (0 ).int ( )
190+ rows = rows .cumsum (0 ).to ( dtype = self . int_type )
190191
191192 # Trim the last element: it's just a left over
192193 rows = rows [:- 1 ]
@@ -219,9 +220,8 @@ def sanity_check(self, cols, row_end, shape, block_shape):
219220 row_end = row_end [1 :]
220221 if len (cols .shape ) != 2 :
221222 raise Exception ("cols should be bidimensional, not of shape %s" % cols .shape )
222- if cols .dtype != torch .int32 :
223-
224- raise Exception ("cols should be int32, not of type %s" % cols .dtype )
223+ if cols .dtype != self .int_type :
224+ raise Exception ("cols should be %s, not of type %s" % (self .int_type , cols .dtype ))
225225 max_col = cols [:,0 ].max ()
226226 if max_col > shape [1 ] / block_shape [1 ]:
227227 raise Exception ("cols max element (%d) cannot be larger than shape[1]/block_shape[1] (%d)" % (max_col , shape [1 ] / block_shape [1 ]))
@@ -230,8 +230,8 @@ def sanity_check(self, cols, row_end, shape, block_shape):
230230 raise Exception ("row_end should be unidimensional, not of shape %s" % row_end .shape )
231231 if row_end .shape [0 ] != shape [0 ] / block_shape [0 ]:
232232 raise Exception ("row_end.shape[0] (%d) should be equal to shape[0]/block_shape[0] (%d)" % (row_end .shape [0 ], shape [0 ] / block_shape [0 ]))
233- if row_end .dtype != torch . int32 :
234- raise Exception ("row_end should be int32 , not of type %s" % row_end .dtype )
233+ if row_end .dtype != self . int_type :
234+ raise Exception ("row_end should be %s , not of type %s" % ( self . int_type , row_end .dtype ) )
235235
236236 max_row_end = row_end .max ()
237237 if max_row_end > cols .shape [0 ]:
@@ -305,9 +305,9 @@ def reverse_matmul_(self, dense_a, transpose = True):
305305 assert (out .is_contiguous ())
306306
307307 assert (ptr_b .is_contiguous ())
308- assert (ptr_b .dtype == torch . int32 )
308+ assert (ptr_b .dtype == self . int_type )
309309 assert (indices_b .is_contiguous ())
310- assert (indices_b .dtype == torch . int32 )
310+ assert (indices_b .dtype == self . int_type )
311311
312312 assert (ptr_b .shape [0 ] == self .blocks_count ()[dim ] + 1 )
313313
0 commit comments