Skip to content

Commit 6653f70

Browse files
committed
Major changes, but code style only.
1 parent 47a84d4 commit 6653f70

16 files changed

+908
-539
lines changed

pytorch_block_sparse/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from .block_sparse import BlockSparseMatrix
22
from .block_sparse_linear import BlockSparseLinear
3-
from .util import BlockSparseModelPatcher
43
from .sparse_optimizer import SparseOptimizer
4+
from .util import BlockSparseModelPatcher
5+
6+
__all__ = [BlockSparseMatrix, BlockSparseLinear, BlockSparseModelPatcher, SparseOptimizer]

pytorch_block_sparse/block_sparse.py

Lines changed: 199 additions & 110 deletions
Large diffs are not rendered by default.

pytorch_block_sparse/block_sparse_linear.py

Lines changed: 80 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import torch
22
import torch.autograd
33
import torch.nn as nn
4+
45
from .block_sparse import BlockSparseMatrix
5-
import typing
6+
67

78
class BlockSparseLinearFunction(torch.autograd.Function):
89
@staticmethod
@@ -16,14 +17,20 @@ def forward(ctx, input, weight_data, weight):
1617
if verbose:
1718
stride = 8
1819
print("BlockSparseLinearFunction.forward input\n", input[::stride, ::stride])
19-
print("BlockSparseLinearFunction.forward dense_weight\n", dense_weight[::stride, ::stride])
20-
print("BlockSparseLinearFunction.forward weight\n", weight.data[::stride, ::stride])
20+
print(
21+
"BlockSparseLinearFunction.forward dense_weight\n",
22+
dense_weight[::stride, ::stride],
23+
)
24+
print(
25+
"BlockSparseLinearFunction.forward weight\n",
26+
weight.data[::stride, ::stride],
27+
)
2128

22-
assert(isinstance(weight, BlockSparseMatrix))
29+
assert isinstance(weight, BlockSparseMatrix)
2330

2431
ctx.save_for_backward(input, weight_data)
2532
ctx.weight = weight
26-
output = weight.reverse_matmul(input, transpose = True)
33+
output = weight.reverse_matmul(input, transpose=True)
2734
if check:
2835
dense = weight.to_dense()
2936
output1 = input.matmul(dense.t())
@@ -42,17 +49,23 @@ def forward(ctx, input, weight_data, weight):
4249
def backward(ctx, grad_output):
4350
check = False
4451
verbose = False
45-
input, weight_data= ctx.saved_tensors
52+
input, weight_data = ctx.saved_tensors
4653
weight = ctx.weight
47-
assert (isinstance(weight, BlockSparseMatrix))
54+
assert isinstance(weight, BlockSparseMatrix)
4855

4956
if verbose or check:
5057
dense_weight = weight.to_dense()
5158

5259
if verbose:
5360
stride = 8
5461
print("input\n", input[::stride, ::stride])
55-
print("grad_output\n", grad_output.stride(), grad_output.storage, grad_output.layout, grad_output[::stride, ::stride])
62+
print(
63+
"grad_output\n",
64+
grad_output.stride(),
65+
grad_output.storage,
66+
grad_output.layout,
67+
grad_output[::stride, ::stride],
68+
)
5669
print("dense_weight\n", dense_weight[::stride, ::stride])
5770
print("weight\n", weight.data[::stride, ::stride])
5871

@@ -61,15 +74,27 @@ def backward(ctx, grad_output):
6174

6275
if verbose or check:
6376
grad_input0 = grad_output.matmul(dense_weight)
77+
atol = 1e-4
6478

6579
if check:
6680
if not grad_input0.isclose(grad_input1).all():
6781
print(f"grad_output.shape={grad_output.shape}, grad_output.stride={grad_output.stride()}")
68-
print("grad_input0/1 comparison\n", (grad_input0 - grad_input1)[1::32,1::32,1::32])
69-
print("grad_input0/1 comparison\n", (grad_input0 - grad_input1).abs().max())
70-
print("grad_input0/1 comparison: count of differences\n", ((grad_input0 - grad_input1).abs() > atol).sum())
71-
print("grad_input0/1 comparison: position of differences\n",
72-
((grad_input0 - grad_input1).abs() > atol).nonzero())
82+
print(
83+
"grad_input0/1 comparison\n",
84+
(grad_input0 - grad_input1)[1::32, 1::32, 1::32],
85+
)
86+
print(
87+
"grad_input0/1 comparison\n",
88+
(grad_input0 - grad_input1).abs().max(),
89+
)
90+
print(
91+
"grad_input0/1 comparison: count of differences\n",
92+
((grad_input0 - grad_input1).abs() > atol).sum(),
93+
)
94+
print(
95+
"grad_input0/1 comparison: position of differences\n",
96+
((grad_input0 - grad_input1).abs() > atol).nonzero(),
97+
)
7398

7499
print("grad_input0 max\n", grad_input0.abs().max())
75100
print("grad_input1 max\n", grad_input1.abs().max())
@@ -81,7 +106,7 @@ def backward(ctx, grad_output):
81106

82107
if verbose:
83108
grad_input2 = weight.reverse_matmul(torch.ones_like(grad_output), transpose=False)
84-
print("grad_input0\n", grad_input0[::stride,::stride])
109+
print("grad_input0\n", grad_input0[::stride, ::stride])
85110
print("grad_input1\n", grad_input1[::stride, ::stride])
86111
print("grad_input2\n", grad_input2[::stride, ::stride])
87112
else:
@@ -90,7 +115,11 @@ def backward(ctx, grad_output):
90115
if ctx.needs_input_grad[1]:
91116
grad_weight1 = weight.matmul_with_output_sparse_support(grad_output, input)
92117
if verbose or check:
93-
grad_weight0 = grad_output.reshape(-1, grad_output.shape[-1]).transpose(-1,-2).matmul(input.reshape(-1, input.shape[-1]))
118+
grad_weight0 = (
119+
grad_output.reshape(-1, grad_output.shape[-1])
120+
.transpose(-1, -2)
121+
.matmul(input.reshape(-1, input.shape[-1]))
122+
)
94123
if check:
95124
grad_weight1b = weight.to_dense(data_replace=grad_weight1)
96125
grad_weight1mask = weight.to_dense(data_replace=torch.ones_like(grad_weight1))
@@ -110,35 +139,43 @@ def backward(ctx, grad_output):
110139
else:
111140
grad_weight1 = None
112141

113-
if grad_weight1 != None:
114-
assert(not (grad_weight1 == 0).all())
115-
if grad_input1 != None:
116-
assert(grad_input1.shape == input.shape)
142+
if grad_weight1 is not None:
143+
assert not (grad_weight1 == 0).all()
144+
if grad_input1 is not None:
145+
assert grad_input1.shape == input.shape
117146

118147
return grad_input1, grad_weight1, None
119148

149+
120150
class BlockSparseLinear(nn.Module):
121-
BLOCK_SIZE=32
122-
def __init__(self,
123-
in_features: int,
124-
out_features: int,
125-
bias: bool = True,
126-
density:float = 0.5,
127-
torch_nn_linear = None,
128-
verbose = False):
151+
BLOCK_SIZE = 32
152+
153+
def __init__(
154+
self,
155+
in_features: int,
156+
out_features: int,
157+
bias: bool = True,
158+
density: float = 0.5,
159+
torch_nn_linear=None,
160+
verbose=False,
161+
):
129162
super(BlockSparseLinear, self).__init__()
130163
self.fn = BlockSparseLinearFunction.apply
131164
self.verbose = verbose
132165

133-
if torch_nn_linear != None:
166+
if torch_nn_linear is not None:
134167
in_features = torch_nn_linear.in_features
135168
out_features = torch_nn_linear.out_features
136169
bias = torch_nn_linear.bias is not None
137170

138171
if in_features % self.BLOCK_SIZE != 0:
139-
raise Exception(f"BlockSparseLinear invalid in_features={in_features}, should be multiple of {self.BLOCK_SIZE}")
172+
raise Exception(
173+
f"BlockSparseLinear invalid in_features={in_features}, should be multiple of {self.BLOCK_SIZE}"
174+
)
140175
if out_features % self.BLOCK_SIZE != 0:
141-
raise Exception(f"BlockSparseLinear invalid in_features={in_features}, should be multiple of {self.BLOCK_SIZE}")
176+
raise Exception(
177+
f"BlockSparseLinear invalid in_features={in_features}, should be multiple of {self.BLOCK_SIZE}"
178+
)
142179

143180
if density < 0 or density > 1:
144181
raise Exception(f"BlockSparseLinear invalid density={density}")
@@ -153,20 +190,22 @@ def __init__(self,
153190
with torch.no_grad():
154191
weight = BlockSparseMatrix.from_dense(torch_nn_linear.weight, block_shape, self.block_count)
155192
else:
156-
weight = BlockSparseMatrix.randn((out_features, in_features),
157-
self.block_count,
158-
blocks=None,
159-
block_shape=block_shape,
160-
device="cuda")
193+
weight = BlockSparseMatrix.randn(
194+
(out_features, in_features),
195+
self.block_count,
196+
blocks=None,
197+
block_shape=block_shape,
198+
device="cuda",
199+
)
161200
self.weight = weight
162201

163202
if bias:
164-
self.bias = nn.Parameter(torch.zeros(out_features, device = "cuda"))
203+
self.bias = nn.Parameter(torch.zeros(out_features, device="cuda"))
165204
if torch_nn_linear is not None:
166205
with torch.no_grad():
167206
self.bias.copy_(torch_nn_linear.bias)
168207
else:
169-
self.register_parameter('bias', None)
208+
self.register_parameter("bias", None)
170209

171210
def forward(self, x):
172211
x = self.fn(x, self.weight.get_differentiable_data(), self.weight)
@@ -177,6 +216,7 @@ def forward(self, x):
177216

178217
class PseudoBlockSparseLinear(torch.nn.Module):
179218
"""For debugging purposes mostly: emulate a BlockSparseLinear with only PyTorch primitives."""
219+
180220
def __init__(self, block_sparse_linear):
181221
super(PseudoBlockSparseLinear, self).__init__()
182222

@@ -186,9 +226,9 @@ def __init__(self, block_sparse_linear):
186226
if block_sparse_linear.bias is not None:
187227
self.bias = torch.nn.Parameter(block_sparse_linear.bias)
188228
else:
189-
self.register_parameter('bias', None)
229+
self.register_parameter("bias", None)
190230

191-
self.register_buffer('mask', mask)
231+
self.register_buffer("mask", mask)
192232
self.in_features = block_sparse_linear.in_features
193233
self.out_features = block_sparse_linear.out_features
194234
self.density = mask.sum().item() / (mask.shape[0] * mask.shape[1])
@@ -198,7 +238,6 @@ def forward(self, input):
198238
return torch.nn.functional.linear(input, weight, self.bias)
199239

200240
def extra_repr(self):
201-
return 'in_features={}, out_features={}, bias={}, fill_ratio={}'.format(
241+
return "in_features={}, out_features={}, bias={}, fill_ratio={}".format(
202242
self.in_features, self.out_features, self.bias is not None, self.density
203243
)
204-

0 commit comments

Comments
 (0)