1- import math
2-
31import numpy
42import torch
53import torch .nn
@@ -24,12 +22,15 @@ def __init__(self, shape, block_mask, data, block_shape=(32, 32)):
2422
2523 self .data = torch .nn .Parameter (data )
2624
27- self .rebuild (block_mask )
25+ self .rebuild (block_mask , callback = False )
26+
27+ def updated_data (self ):
28+ pass
2829
2930 def get_differentiable_data (self ):
3031 return self .data
3132
32- def rebuild (self , block_mask , block_ptr = None ):
33+ def rebuild (self , block_mask , block_ptr = None , callback = True ):
3334 data = self .data
3435 block_shape = self .block_shape
3536
@@ -71,6 +72,8 @@ def rebuild(self, block_mask, block_ptr=None):
7172 (self .block_shape [1 ], self .block_shape [0 ]),
7273 )
7374 self .check_ = False
75+ if callback :
76+ self .updated_data ()
7477
7578 @staticmethod
7679 def blocks_count_ (shape , block_shape ):
@@ -219,25 +222,26 @@ def block_replace(self, block_replacements):
219222 def zeros (cls , shape , n_blocks = None , blocks = None , block_shape = (32 , 32 ), device = "cuda" ):
220223 for i in range (2 ):
221224 if shape [i ] % block_shape [i ] != 0 :
222- raise Exception (
223- f"Invalid shape: shape[{ i } ]({ shape [i ]} ) %% block_shape[{ i } ]({ block_shape [i ]} ) is not 0."
224- )
225+ raise Exception (f"Invalid shape: shape[{ i } ]={ shape [i ]} %% block_shape[{ i } ]={ block_shape [i ]} is not 0." )
226+
227+ X , Y = cls .blocks_count_ (shape , block_shape )
228+
225229 if n_blocks is None :
226- assert blocks is not None
227- for b in blocks :
228- for i in range (2 ):
229- if b [i ] * block_shape [i ] >= shape [i ]:
230- raise Exception (
231- f"Invalid block definition: block[{ i } ] = { b [i ]} : should be < { shape [i ] // block_shape [i ]} "
232- )
233- n_blocks = len (blocks )
230+ if blocks is not None :
231+ for b in blocks :
232+ for i in range (2 ):
233+ if b [i ] * block_shape [i ] >= shape [i ]:
234+ raise Exception (
235+ f"Invalid block definition: block[{ i } ] = { b [i ]} : should be < { shape [i ] // block_shape [i ]} "
236+ )
237+ n_blocks = len (blocks )
238+ else :
239+ n_blocks = X * Y
234240 else :
235241 assert blocks is None
236242 if len (shape ) != 2 or shape [0 ] % block_shape [0 ] != 0 or shape [1 ] % block_shape [1 ] != 0 :
237243 raise Exception ("shape should be a tuple of 2 multiples of block_shape" )
238244
239- X , Y = cls .blocks_count_ (shape , block_shape )
240-
241245 if n_blocks > X * Y :
242246 raise Exception ("Too many blocks : %d > %d * %d = %d" % (n_blocks , X , Y , X * Y ))
243247 if blocks is not None :
@@ -257,11 +261,27 @@ def zeros(cls, shape, n_blocks=None, blocks=None, block_shape=(32, 32), device="
257261
258262 return cls (shape , block_mask , data , block_shape )
259263
264+ @classmethod
265+ def ones (
266+ cls ,
267+ shape ,
268+ n_blocks = None ,
269+ blocks = None ,
270+ block_shape = (32 , 32 ),
271+ device = "cuda" ,
272+ positive = False ,
273+ ):
274+ ret = cls .zeros (shape , n_blocks , blocks , block_shape , device )
275+ with torch .no_grad ():
276+ ret .data += 1
277+ ret .updated_data ()
278+ return ret
279+
260280 @classmethod
261281 def randn (
262282 cls ,
263283 shape ,
264- n_blocks ,
284+ n_blocks = None ,
265285 blocks = None ,
266286 block_shape = (32 , 32 ),
267287 device = "cuda" ,
@@ -273,35 +293,63 @@ def randn(
273293 ret .data .normal_ ().abs_ ()
274294 else :
275295 ret .data .normal_ ()
296+ ret .updated_data ()
276297 return ret
277298
278299 @classmethod
279- def from_dense (cls , dense , block_shape = (32 , 32 ), block_count = None ):
280- dense_block_count = (dense .shape [0 ] * dense .shape [1 ]) // (block_shape [0 ] * block_shape [1 ])
281- if block_count is None :
282- block_count = dense_block_count
283-
284- ret = cls .zeros (
285- dense .shape ,
286- n_blocks = block_count ,
287- block_shape = block_shape ,
288- device = dense .device ,
289- )
300+ def from_dense (cls , dense , block_shape = (32 , 32 ), block_count = None , blocks = None , slow = False , out = None ):
301+ if out is None :
302+ if blocks is None :
303+ dense_block_count = (dense .shape [0 ] * dense .shape [1 ]) // (block_shape [0 ] * block_shape [1 ])
304+ if block_count is None :
305+ block_count = dense_block_count
306+ else :
307+ block_count = None
308+
309+ ret = cls .zeros (
310+ dense .shape ,
311+ n_blocks = block_count ,
312+ block_shape = block_shape ,
313+ blocks = blocks ,
314+ device = dense .device ,
315+ )
316+ else :
317+ ret = out
290318
291- if block_count == dense_block_count :
292- # TODO : use some pytorch dimensions transposition to speed up this block by block copy
319+ if out is not None or blocks is not None or block_count == dense_block_count :
320+ # In case we keep the full matrix (block_count == dense_block_count), we make sure the
321+ # order is the right one, mostly for testing purposes.
293322 coo = ret .build_coo_block_index ().long ()
294-
295- for i in range (coo .shape [1 ]):
296- r , c = coo [0 ][i ], coo [1 ][i ]
297- bs = ret .block_shape
298- ret .data [i * bs [0 ] : (i + 1 ) * bs [0 ], :] = dense [
299- r * bs [0 ] : (r + 1 ) * bs [0 ], c * bs [1 ] : (c + 1 ) * bs [1 ]
300- ].t ()
323+ if slow :
324+ # Legacy version, used for testing only
325+ for i in range (coo .shape [1 ]):
326+ r , c = coo [0 ][i ], coo [1 ][i ]
327+ bs = ret .block_shape
328+ part = dense [r * bs [0 ] : (r + 1 ) * bs [0 ], c * bs [1 ] : (c + 1 ) * bs [1 ]]
329+ part = part .t ().reshape (block_shape [0 ], block_shape [1 ])
330+ with torch .no_grad ():
331+ ret .data [i * bs [0 ] : (i + 1 ) * bs [0 ]] = part
332+ else :
333+ dense2 = dense .reshape (
334+ dense .shape [0 ] // block_shape [0 ], block_shape [0 ], dense .shape [1 ] // block_shape [1 ], block_shape [1 ]
335+ )
336+ dense2 = dense2 .transpose (1 , 2 )
337+ dense2 = dense2 .transpose (2 , 3 )
338+ dense2 = dense2 .reshape (- 1 , block_shape [0 ], block_shape [1 ])
339+ indices = coo [0 ] * (dense .shape [1 ] // block_shape [1 ]) + coo [1 ]
340+ indices = indices .unsqueeze (- 1 ).unsqueeze (- 1 ).expand (- 1 , block_shape [0 ], block_shape [1 ])
341+ new_data = torch .gather (dense2 , 0 , indices )
342+ new_data = new_data .reshape (- 1 , block_shape [1 ])
343+ with torch .no_grad ():
344+ ret .data .copy_ (new_data )
301345 else :
346+ # We just keep the first elements in the dense matrix
347+ # Of course this only captures the statistical distribution in the dense matrix
302348 param_count = ret .data .numel ()
303- density = block_count / dense_block_count
304- ret .data .copy_ (dense .flatten ()[:param_count ].reshape (ret .data .shape ) / math .sqrt (density ))
349+ with torch .no_grad ():
350+ ret .data .copy_ (dense .flatten ()[:param_count ].reshape (ret .data .shape ))
351+
352+ ret .updated_data ()
305353
306354 return ret
307355
@@ -315,6 +363,11 @@ def __repr__(self):
315363 self .block_shape ,
316364 )
317365
366+ def multiply_ (self , factor ):
367+ with torch .no_grad ():
368+ self .data .multiply_ (factor )
369+ self .updated_data ()
370+
318371 def build_coo_block_index (self ):
319372 device = self .cols_a .device
320373 # Build a tensor to store the row indices.
@@ -356,11 +409,13 @@ def to_sparse(self, data_replace=None):
356409 data = data_replace
357410 else :
358411 data = self .data
359- data = data .reshape (- 1 , * self .block_shape ).transpose (1 , 2 )
412+ data = data .reshape (- 1 , self .block_shape [1 ], self .block_shape [0 ])
413+ data = data .transpose (1 , 2 )
360414 out = torch .sparse .FloatTensor (
361415 coo ,
362416 data ,
363- (self .shape [0 ] // self .block_shape [0 ], self .shape [1 ] // self .block_shape [1 ]) + self .block_shape ,
417+ (self .shape [0 ] // self .block_shape [0 ], self .shape [1 ] // self .block_shape [1 ])
418+ + (self .block_shape [0 ], self .block_shape [1 ]),
364419 )
365420
366421 return out
@@ -489,7 +544,8 @@ def reverse_matmul_(self, dense_a, transpose=True):
489544 data_b = data .reshape (- 1 , block_shape [1 ]).contiguous ()
490545
491546 if not dense_a .is_contiguous ():
492- # warnings.warn(f"pytorch_block_sparse.BlockSparseMatrix.reverse_matmul: DEGRADED performance, dense_a is not contiguous {dense_a.stride()}")
547+ # warnings.warn(f"pytorch_block_sparse.BlockSparseMatrix.reverse_matmul:"
548+ # f" DEGRADED performance, dense_a is not contiguous {dense_a.stride()}")
493549 dense_a = dense_a .contiguous ()
494550
495551 verbose = False
@@ -572,12 +628,15 @@ def matmul_with_output_sparse_support_(self, dense_a, dense_b, overwrite_data=Fa
572628 else :
573629 data = torch .zeros_like (self .data )
574630
575- message = "pytorch_block_sparse.BlockSparseMatrix.matmul_with_output_sparse_support: DEGRADED performance, dense_%s is not contiguous"
631+ message = (
632+ "pytorch_block_sparse.BlockSparseMatrix.matmul_with_output_sparse_support:"
633+ " DEGRADED performance, dense_%s is not contiguous"
634+ )
576635 prepared_a , transpose_a = self .tensor_prepare (dense_a , message % "a" , True )
577636 prepared_b , transpose_b = self .tensor_prepare (dense_b , message % "b" , False )
578637
579- # We interpret a as transposed, so we pass shape_a[1], shape_a[0] as a shape,
580- # and transpose_a will be set correcly too (for a "normal" contiguous pytorch matrix a, transpose_a will be true)
638+ # We interpret a as transposed, so we pass shape_a[1], shape_a[0] as a shape, and transpose_a
639+ # will be set correcly too (for a "normal" contiguous pytorch matrix a, transpose_a will be true)
581640 block_sparse_native .blocksparse_matmul_back_cutlass (
582641 prepared_a ,
583642 transpose_a ,
@@ -626,18 +685,29 @@ class BlockSparseMatrixEmulator(BlockSparseMatrixBase):
626685 # Data is (len(cols), block_shape, block_shape)
627686 def __init__ (self , shape , block_mask , data , block_shape ):
628687 super (BlockSparseMatrixEmulator , self ).__init__ (shape , block_mask , data , block_shape )
688+ self .register_parameter ("_dense" , None )
689+ self .updated_data ()
629690
630691 def get_differentiable_data (self ):
631- return self .dense_
692+ return self ._dense
632693
633- def rebuild (self , block_mask , block_ptr = None ):
634- super (). rebuild ( block_mask , block_ptr )
635- self . _dense = self .to_dense ()
636- self . _mask = self . to_dense ( data_replace = torch . ones_like ( self . data )) == 1
694+ def to_dense (self , data_replace = None ):
695+ if data_replace is None :
696+ return self ._dense
697+ return data_replace * self . _mask
637698
638- def reverse_matmul (self , dense_a , transpose ):
699+ def _update_data_from_dense (self ):
700+ _ = self .from_dense (self ._dense , out = self )
701+
702+ def updated_data (self ):
703+ with torch .no_grad ():
704+ self ._dense = torch .nn .Parameter (super ().to_dense ())
705+ self ._mask = super ().to_dense (data_replace = torch .ones_like (self .data )) == 1
706+
707+ def reverse_matmul (self , dense_a , transpose = True ):
639708 m = self ._dense .t () if transpose else self ._dense
640- return dense_a .matmul (m * self ._mask ) # The self._mask multiplication is not really needed, but ...
709+ mask = self ._mask .t () if transpose else self ._mask
710+ return dense_a .matmul (m * mask ) # The self._mask multiplication is not really needed, but ...
641711
642712 def matmul_with_output_sparse_support (self , dense_a , dense_b , overwrite_data = False ):
643713 """Compute c = a.t().mm(b) where c is sparse (we just keep the results where c is non_zero)."""
0 commit comments