11import torch
22import torch .autograd
33import torch .nn as nn
4+
45from .block_sparse import BlockSparseMatrix
5- import typing
6+
67
78class BlockSparseLinearFunction (torch .autograd .Function ):
89 @staticmethod
@@ -16,14 +17,20 @@ def forward(ctx, input, weight_data, weight):
1617 if verbose :
1718 stride = 8
1819 print ("BlockSparseLinearFunction.forward input\n " , input [::stride , ::stride ])
19- print ("BlockSparseLinearFunction.forward dense_weight\n " , dense_weight [::stride , ::stride ])
20- print ("BlockSparseLinearFunction.forward weight\n " , weight .data [::stride , ::stride ])
20+ print (
21+ "BlockSparseLinearFunction.forward dense_weight\n " ,
22+ dense_weight [::stride , ::stride ],
23+ )
24+ print (
25+ "BlockSparseLinearFunction.forward weight\n " ,
26+ weight .data [::stride , ::stride ],
27+ )
2128
22- assert ( isinstance (weight , BlockSparseMatrix ) )
29+ assert isinstance (weight , BlockSparseMatrix )
2330
2431 ctx .save_for_backward (input , weight_data )
2532 ctx .weight = weight
26- output = weight .reverse_matmul (input , transpose = True )
33+ output = weight .reverse_matmul (input , transpose = True )
2734 if check :
2835 dense = weight .to_dense ()
2936 output1 = input .matmul (dense .t ())
@@ -42,17 +49,23 @@ def forward(ctx, input, weight_data, weight):
4249 def backward (ctx , grad_output ):
4350 check = False
4451 verbose = False
45- input , weight_data = ctx .saved_tensors
52+ input , weight_data = ctx .saved_tensors
4653 weight = ctx .weight
47- assert ( isinstance (weight , BlockSparseMatrix ) )
54+ assert isinstance (weight , BlockSparseMatrix )
4855
4956 if verbose or check :
5057 dense_weight = weight .to_dense ()
5158
5259 if verbose :
5360 stride = 8
5461 print ("input\n " , input [::stride , ::stride ])
55- print ("grad_output\n " , grad_output .stride (), grad_output .storage , grad_output .layout , grad_output [::stride , ::stride ])
62+ print (
63+ "grad_output\n " ,
64+ grad_output .stride (),
65+ grad_output .storage ,
66+ grad_output .layout ,
67+ grad_output [::stride , ::stride ],
68+ )
5669 print ("dense_weight\n " , dense_weight [::stride , ::stride ])
5770 print ("weight\n " , weight .data [::stride , ::stride ])
5871
@@ -61,15 +74,27 @@ def backward(ctx, grad_output):
6174
6275 if verbose or check :
6376 grad_input0 = grad_output .matmul (dense_weight )
77+ atol = 1e-4
6478
6579 if check :
6680 if not grad_input0 .isclose (grad_input1 ).all ():
6781 print (f"grad_output.shape={ grad_output .shape } , grad_output.stride={ grad_output .stride ()} " )
68- print ("grad_input0/1 comparison\n " , (grad_input0 - grad_input1 )[1 ::32 ,1 ::32 ,1 ::32 ])
69- print ("grad_input0/1 comparison\n " , (grad_input0 - grad_input1 ).abs ().max ())
70- print ("grad_input0/1 comparison: count of differences\n " , ((grad_input0 - grad_input1 ).abs () > atol ).sum ())
71- print ("grad_input0/1 comparison: position of differences\n " ,
72- ((grad_input0 - grad_input1 ).abs () > atol ).nonzero ())
82+ print (
83+ "grad_input0/1 comparison\n " ,
84+ (grad_input0 - grad_input1 )[1 ::32 , 1 ::32 , 1 ::32 ],
85+ )
86+ print (
87+ "grad_input0/1 comparison\n " ,
88+ (grad_input0 - grad_input1 ).abs ().max (),
89+ )
90+ print (
91+ "grad_input0/1 comparison: count of differences\n " ,
92+ ((grad_input0 - grad_input1 ).abs () > atol ).sum (),
93+ )
94+ print (
95+ "grad_input0/1 comparison: position of differences\n " ,
96+ ((grad_input0 - grad_input1 ).abs () > atol ).nonzero (),
97+ )
7398
7499 print ("grad_input0 max\n " , grad_input0 .abs ().max ())
75100 print ("grad_input1 max\n " , grad_input1 .abs ().max ())
@@ -81,7 +106,7 @@ def backward(ctx, grad_output):
81106
82107 if verbose :
83108 grad_input2 = weight .reverse_matmul (torch .ones_like (grad_output ), transpose = False )
84- print ("grad_input0\n " , grad_input0 [::stride ,::stride ])
109+ print ("grad_input0\n " , grad_input0 [::stride , ::stride ])
85110 print ("grad_input1\n " , grad_input1 [::stride , ::stride ])
86111 print ("grad_input2\n " , grad_input2 [::stride , ::stride ])
87112 else :
@@ -90,7 +115,11 @@ def backward(ctx, grad_output):
90115 if ctx .needs_input_grad [1 ]:
91116 grad_weight1 = weight .matmul_with_output_sparse_support (grad_output , input )
92117 if verbose or check :
93- grad_weight0 = grad_output .reshape (- 1 , grad_output .shape [- 1 ]).transpose (- 1 ,- 2 ).matmul (input .reshape (- 1 , input .shape [- 1 ]))
118+ grad_weight0 = (
119+ grad_output .reshape (- 1 , grad_output .shape [- 1 ])
120+ .transpose (- 1 , - 2 )
121+ .matmul (input .reshape (- 1 , input .shape [- 1 ]))
122+ )
94123 if check :
95124 grad_weight1b = weight .to_dense (data_replace = grad_weight1 )
96125 grad_weight1mask = weight .to_dense (data_replace = torch .ones_like (grad_weight1 ))
@@ -110,35 +139,43 @@ def backward(ctx, grad_output):
110139 else :
111140 grad_weight1 = None
112141
113- if grad_weight1 != None :
114- assert ( not (grad_weight1 == 0 ).all () )
115- if grad_input1 != None :
116- assert ( grad_input1 .shape == input .shape )
142+ if grad_weight1 is not None :
143+ assert not (grad_weight1 == 0 ).all ()
144+ if grad_input1 is not None :
145+ assert grad_input1 .shape == input .shape
117146
118147 return grad_input1 , grad_weight1 , None
119148
149+
120150class BlockSparseLinear (nn .Module ):
121- BLOCK_SIZE = 32
122- def __init__ (self ,
123- in_features : int ,
124- out_features : int ,
125- bias : bool = True ,
126- density :float = 0.5 ,
127- torch_nn_linear = None ,
128- verbose = False ):
151+ BLOCK_SIZE = 32
152+
153+ def __init__ (
154+ self ,
155+ in_features : int ,
156+ out_features : int ,
157+ bias : bool = True ,
158+ density : float = 0.5 ,
159+ torch_nn_linear = None ,
160+ verbose = False ,
161+ ):
129162 super (BlockSparseLinear , self ).__init__ ()
130163 self .fn = BlockSparseLinearFunction .apply
131164 self .verbose = verbose
132165
133- if torch_nn_linear != None :
166+ if torch_nn_linear is not None :
134167 in_features = torch_nn_linear .in_features
135168 out_features = torch_nn_linear .out_features
136169 bias = torch_nn_linear .bias is not None
137170
138171 if in_features % self .BLOCK_SIZE != 0 :
139- raise Exception (f"BlockSparseLinear invalid in_features={ in_features } , should be multiple of { self .BLOCK_SIZE } " )
172+ raise Exception (
173+ f"BlockSparseLinear invalid in_features={ in_features } , should be multiple of { self .BLOCK_SIZE } "
174+ )
140175 if out_features % self .BLOCK_SIZE != 0 :
141- raise Exception (f"BlockSparseLinear invalid in_features={ in_features } , should be multiple of { self .BLOCK_SIZE } " )
176+ raise Exception (
177+ f"BlockSparseLinear invalid in_features={ in_features } , should be multiple of { self .BLOCK_SIZE } "
178+ )
142179
143180 if density < 0 or density > 1 :
144181 raise Exception (f"BlockSparseLinear invalid density={ density } " )
@@ -153,20 +190,22 @@ def __init__(self,
153190 with torch .no_grad ():
154191 weight = BlockSparseMatrix .from_dense (torch_nn_linear .weight , block_shape , self .block_count )
155192 else :
156- weight = BlockSparseMatrix .randn ((out_features , in_features ),
157- self .block_count ,
158- blocks = None ,
159- block_shape = block_shape ,
160- device = "cuda" )
193+ weight = BlockSparseMatrix .randn (
194+ (out_features , in_features ),
195+ self .block_count ,
196+ blocks = None ,
197+ block_shape = block_shape ,
198+ device = "cuda" ,
199+ )
161200 self .weight = weight
162201
163202 if bias :
164- self .bias = nn .Parameter (torch .zeros (out_features , device = "cuda" ))
203+ self .bias = nn .Parameter (torch .zeros (out_features , device = "cuda" ))
165204 if torch_nn_linear is not None :
166205 with torch .no_grad ():
167206 self .bias .copy_ (torch_nn_linear .bias )
168207 else :
169- self .register_parameter (' bias' , None )
208+ self .register_parameter (" bias" , None )
170209
171210 def forward (self , x ):
172211 x = self .fn (x , self .weight .get_differentiable_data (), self .weight )
@@ -177,6 +216,7 @@ def forward(self, x):
177216
178217class PseudoBlockSparseLinear (torch .nn .Module ):
179218 """For debugging purposes mostly: emulate a BlockSparseLinear with only PyTorch primitives."""
219+
180220 def __init__ (self , block_sparse_linear ):
181221 super (PseudoBlockSparseLinear , self ).__init__ ()
182222
@@ -186,9 +226,9 @@ def __init__(self, block_sparse_linear):
186226 if block_sparse_linear .bias is not None :
187227 self .bias = torch .nn .Parameter (block_sparse_linear .bias )
188228 else :
189- self .register_parameter (' bias' , None )
229+ self .register_parameter (" bias" , None )
190230
191- self .register_buffer (' mask' , mask )
231+ self .register_buffer (" mask" , mask )
192232 self .in_features = block_sparse_linear .in_features
193233 self .out_features = block_sparse_linear .out_features
194234 self .density = mask .sum ().item () / (mask .shape [0 ] * mask .shape [1 ])
@@ -198,7 +238,6 @@ def forward(self, input):
198238 return torch .nn .functional .linear (input , weight , self .bias )
199239
200240 def extra_repr (self ):
201- return ' in_features={}, out_features={}, bias={}, fill_ratio={}' .format (
241+ return " in_features={}, out_features={}, bias={}, fill_ratio={}" .format (
202242 self .in_features , self .out_features , self .bias is not None , self .density
203243 )
204-
0 commit comments