Skip to content

Commit e71b542

Browse files
committed
Optimized 'from_dense' function.
1 parent 6653f70 commit e71b542

File tree

7 files changed

+231
-91
lines changed

7 files changed

+231
-91
lines changed

pytorch_block_sparse/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from .block_sparse import BlockSparseMatrix
1+
from .block_sparse import BlockSparseMatrix, BlockSparseMatrixEmulator
22
from .block_sparse_linear import BlockSparseLinear
33
from .sparse_optimizer import SparseOptimizer
44
from .util import BlockSparseModelPatcher
55

6-
__all__ = [BlockSparseMatrix, BlockSparseLinear, BlockSparseModelPatcher, SparseOptimizer]
6+
__all__ = [BlockSparseMatrix, BlockSparseMatrixEmulator, BlockSparseLinear, BlockSparseModelPatcher, SparseOptimizer]

pytorch_block_sparse/block_sparse.py

Lines changed: 123 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import math
2-
31
import numpy
42
import torch
53
import torch.nn
@@ -24,12 +22,15 @@ def __init__(self, shape, block_mask, data, block_shape=(32, 32)):
2422

2523
self.data = torch.nn.Parameter(data)
2624

27-
self.rebuild(block_mask)
25+
self.rebuild(block_mask, callback=False)
26+
27+
def updated_data(self):
28+
pass
2829

2930
def get_differentiable_data(self):
3031
return self.data
3132

32-
def rebuild(self, block_mask, block_ptr=None):
33+
def rebuild(self, block_mask, block_ptr=None, callback=True):
3334
data = self.data
3435
block_shape = self.block_shape
3536

@@ -71,6 +72,8 @@ def rebuild(self, block_mask, block_ptr=None):
7172
(self.block_shape[1], self.block_shape[0]),
7273
)
7374
self.check_ = False
75+
if callback:
76+
self.updated_data()
7477

7578
@staticmethod
7679
def blocks_count_(shape, block_shape):
@@ -219,25 +222,26 @@ def block_replace(self, block_replacements):
219222
def zeros(cls, shape, n_blocks=None, blocks=None, block_shape=(32, 32), device="cuda"):
220223
for i in range(2):
221224
if shape[i] % block_shape[i] != 0:
222-
raise Exception(
223-
f"Invalid shape: shape[{i}]({shape[i]}) %% block_shape[{i}]({block_shape[i]}) is not 0."
224-
)
225+
raise Exception(f"Invalid shape: shape[{i}]={shape[i]} %% block_shape[{i}]={block_shape[i]} is not 0.")
226+
227+
X, Y = cls.blocks_count_(shape, block_shape)
228+
225229
if n_blocks is None:
226-
assert blocks is not None
227-
for b in blocks:
228-
for i in range(2):
229-
if b[i] * block_shape[i] >= shape[i]:
230-
raise Exception(
231-
f"Invalid block definition: block[{i}] = {b[i]} : should be < {shape[i] // block_shape[i]}"
232-
)
233-
n_blocks = len(blocks)
230+
if blocks is not None:
231+
for b in blocks:
232+
for i in range(2):
233+
if b[i] * block_shape[i] >= shape[i]:
234+
raise Exception(
235+
f"Invalid block definition: block[{i}] = {b[i]} : should be < {shape[i] // block_shape[i]}"
236+
)
237+
n_blocks = len(blocks)
238+
else:
239+
n_blocks = X * Y
234240
else:
235241
assert blocks is None
236242
if len(shape) != 2 or shape[0] % block_shape[0] != 0 or shape[1] % block_shape[1] != 0:
237243
raise Exception("shape should be a tuple of 2 multiples of block_shape")
238244

239-
X, Y = cls.blocks_count_(shape, block_shape)
240-
241245
if n_blocks > X * Y:
242246
raise Exception("Too many blocks : %d > %d * %d = %d" % (n_blocks, X, Y, X * Y))
243247
if blocks is not None:
@@ -257,11 +261,27 @@ def zeros(cls, shape, n_blocks=None, blocks=None, block_shape=(32, 32), device="
257261

258262
return cls(shape, block_mask, data, block_shape)
259263

264+
@classmethod
265+
def ones(
266+
cls,
267+
shape,
268+
n_blocks=None,
269+
blocks=None,
270+
block_shape=(32, 32),
271+
device="cuda",
272+
positive=False,
273+
):
274+
ret = cls.zeros(shape, n_blocks, blocks, block_shape, device)
275+
with torch.no_grad():
276+
ret.data += 1
277+
ret.updated_data()
278+
return ret
279+
260280
@classmethod
261281
def randn(
262282
cls,
263283
shape,
264-
n_blocks,
284+
n_blocks=None,
265285
blocks=None,
266286
block_shape=(32, 32),
267287
device="cuda",
@@ -273,35 +293,63 @@ def randn(
273293
ret.data.normal_().abs_()
274294
else:
275295
ret.data.normal_()
296+
ret.updated_data()
276297
return ret
277298

278299
@classmethod
279-
def from_dense(cls, dense, block_shape=(32, 32), block_count=None):
280-
dense_block_count = (dense.shape[0] * dense.shape[1]) // (block_shape[0] * block_shape[1])
281-
if block_count is None:
282-
block_count = dense_block_count
283-
284-
ret = cls.zeros(
285-
dense.shape,
286-
n_blocks=block_count,
287-
block_shape=block_shape,
288-
device=dense.device,
289-
)
300+
def from_dense(cls, dense, block_shape=(32, 32), block_count=None, blocks=None, slow=False, out=None):
301+
if out is None:
302+
if blocks is None:
303+
dense_block_count = (dense.shape[0] * dense.shape[1]) // (block_shape[0] * block_shape[1])
304+
if block_count is None:
305+
block_count = dense_block_count
306+
else:
307+
block_count = None
308+
309+
ret = cls.zeros(
310+
dense.shape,
311+
n_blocks=block_count,
312+
block_shape=block_shape,
313+
blocks=blocks,
314+
device=dense.device,
315+
)
316+
else:
317+
ret = out
290318

291-
if block_count == dense_block_count:
292-
# TODO : use some pytorch dimensions transposition to speed up this block by block copy
319+
if out is not None or blocks is not None or block_count == dense_block_count:
320+
# In case we keep the full matrix (block_count == dense_block_count), we make sure the
321+
# order is the right one, mostly for testing purposes.
293322
coo = ret.build_coo_block_index().long()
294-
295-
for i in range(coo.shape[1]):
296-
r, c = coo[0][i], coo[1][i]
297-
bs = ret.block_shape
298-
ret.data[i * bs[0] : (i + 1) * bs[0], :] = dense[
299-
r * bs[0] : (r + 1) * bs[0], c * bs[1] : (c + 1) * bs[1]
300-
].t()
323+
if slow:
324+
# Legacy version, used for testing only
325+
for i in range(coo.shape[1]):
326+
r, c = coo[0][i], coo[1][i]
327+
bs = ret.block_shape
328+
part = dense[r * bs[0] : (r + 1) * bs[0], c * bs[1] : (c + 1) * bs[1]]
329+
part = part.t().reshape(block_shape[0], block_shape[1])
330+
with torch.no_grad():
331+
ret.data[i * bs[0] : (i + 1) * bs[0]] = part
332+
else:
333+
dense2 = dense.reshape(
334+
dense.shape[0] // block_shape[0], block_shape[0], dense.shape[1] // block_shape[1], block_shape[1]
335+
)
336+
dense2 = dense2.transpose(1, 2)
337+
dense2 = dense2.transpose(2, 3)
338+
dense2 = dense2.reshape(-1, block_shape[0], block_shape[1])
339+
indices = coo[0] * (dense.shape[1] // block_shape[1]) + coo[1]
340+
indices = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, block_shape[0], block_shape[1])
341+
new_data = torch.gather(dense2, 0, indices)
342+
new_data = new_data.reshape(-1, block_shape[1])
343+
with torch.no_grad():
344+
ret.data.copy_(new_data)
301345
else:
346+
# We just keep the first elements in the dense matrix
347+
# Of course this only captures the statistical distribution in the dense matrix
302348
param_count = ret.data.numel()
303-
density = block_count / dense_block_count
304-
ret.data.copy_(dense.flatten()[:param_count].reshape(ret.data.shape) / math.sqrt(density))
349+
with torch.no_grad():
350+
ret.data.copy_(dense.flatten()[:param_count].reshape(ret.data.shape))
351+
352+
ret.updated_data()
305353

306354
return ret
307355

@@ -315,6 +363,11 @@ def __repr__(self):
315363
self.block_shape,
316364
)
317365

366+
def multiply_(self, factor):
367+
with torch.no_grad():
368+
self.data.multiply_(factor)
369+
self.updated_data()
370+
318371
def build_coo_block_index(self):
319372
device = self.cols_a.device
320373
# Build a tensor to store the row indices.
@@ -356,11 +409,13 @@ def to_sparse(self, data_replace=None):
356409
data = data_replace
357410
else:
358411
data = self.data
359-
data = data.reshape(-1, *self.block_shape).transpose(1, 2)
412+
data = data.reshape(-1, self.block_shape[1], self.block_shape[0])
413+
data = data.transpose(1, 2)
360414
out = torch.sparse.FloatTensor(
361415
coo,
362416
data,
363-
(self.shape[0] // self.block_shape[0], self.shape[1] // self.block_shape[1]) + self.block_shape,
417+
(self.shape[0] // self.block_shape[0], self.shape[1] // self.block_shape[1])
418+
+ (self.block_shape[0], self.block_shape[1]),
364419
)
365420

366421
return out
@@ -489,7 +544,8 @@ def reverse_matmul_(self, dense_a, transpose=True):
489544
data_b = data.reshape(-1, block_shape[1]).contiguous()
490545

491546
if not dense_a.is_contiguous():
492-
# warnings.warn(f"pytorch_block_sparse.BlockSparseMatrix.reverse_matmul: DEGRADED performance, dense_a is not contiguous {dense_a.stride()}")
547+
# warnings.warn(f"pytorch_block_sparse.BlockSparseMatrix.reverse_matmul:"
548+
# f" DEGRADED performance, dense_a is not contiguous {dense_a.stride()}")
493549
dense_a = dense_a.contiguous()
494550

495551
verbose = False
@@ -572,12 +628,15 @@ def matmul_with_output_sparse_support_(self, dense_a, dense_b, overwrite_data=Fa
572628
else:
573629
data = torch.zeros_like(self.data)
574630

575-
message = "pytorch_block_sparse.BlockSparseMatrix.matmul_with_output_sparse_support: DEGRADED performance, dense_%s is not contiguous"
631+
message = (
632+
"pytorch_block_sparse.BlockSparseMatrix.matmul_with_output_sparse_support:"
633+
" DEGRADED performance, dense_%s is not contiguous"
634+
)
576635
prepared_a, transpose_a = self.tensor_prepare(dense_a, message % "a", True)
577636
prepared_b, transpose_b = self.tensor_prepare(dense_b, message % "b", False)
578637

579-
# We interpret a as transposed, so we pass shape_a[1], shape_a[0] as a shape,
580-
# and transpose_a will be set correcly too (for a "normal" contiguous pytorch matrix a, transpose_a will be true)
638+
# We interpret a as transposed, so we pass shape_a[1], shape_a[0] as a shape, and transpose_a
639+
# will be set correcly too (for a "normal" contiguous pytorch matrix a, transpose_a will be true)
581640
block_sparse_native.blocksparse_matmul_back_cutlass(
582641
prepared_a,
583642
transpose_a,
@@ -626,18 +685,29 @@ class BlockSparseMatrixEmulator(BlockSparseMatrixBase):
626685
# Data is (len(cols), block_shape, block_shape)
627686
def __init__(self, shape, block_mask, data, block_shape):
628687
super(BlockSparseMatrixEmulator, self).__init__(shape, block_mask, data, block_shape)
688+
self.register_parameter("_dense", None)
689+
self.updated_data()
629690

630691
def get_differentiable_data(self):
631-
return self.dense_
692+
return self._dense
632693

633-
def rebuild(self, block_mask, block_ptr=None):
634-
super().rebuild(block_mask, block_ptr)
635-
self._dense = self.to_dense()
636-
self._mask = self.to_dense(data_replace=torch.ones_like(self.data)) == 1
694+
def to_dense(self, data_replace=None):
695+
if data_replace is None:
696+
return self._dense
697+
return data_replace * self._mask
637698

638-
def reverse_matmul(self, dense_a, transpose):
699+
def _update_data_from_dense(self):
700+
_ = self.from_dense(self._dense, out=self)
701+
702+
def updated_data(self):
703+
with torch.no_grad():
704+
self._dense = torch.nn.Parameter(super().to_dense())
705+
self._mask = super().to_dense(data_replace=torch.ones_like(self.data)) == 1
706+
707+
def reverse_matmul(self, dense_a, transpose=True):
639708
m = self._dense.t() if transpose else self._dense
640-
return dense_a.matmul(m * self._mask) # The self._mask multiplication is not really needed, but ...
709+
mask = self._mask.t() if transpose else self._mask
710+
return dense_a.matmul(m * mask) # The self._mask multiplication is not really needed, but ...
641711

642712
def matmul_with_output_sparse_support(self, dense_a, dense_b, overwrite_data=False):
643713
"""Compute c = a.t().mm(b) where c is sparse (we just keep the results where c is non_zero)."""

pytorch_block_sparse/block_sparse_linear.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
import math
2+
from typing import Tuple
3+
14
import torch
25
import torch.autograd
36
import torch.nn as nn
47

5-
from .block_sparse import BlockSparseMatrix
8+
from .block_sparse import (
9+
BlockSparseMatrix,
10+
BlockSparseMatrixBase,
11+
BlockSparseMatrixEmulator,
12+
)
613

714

815
class BlockSparseLinearFunction(torch.autograd.Function):
@@ -26,7 +33,7 @@ def forward(ctx, input, weight_data, weight):
2633
weight.data[::stride, ::stride],
2734
)
2835

29-
assert isinstance(weight, BlockSparseMatrix)
36+
assert isinstance(weight, BlockSparseMatrixBase)
3037

3138
ctx.save_for_backward(input, weight_data)
3239
ctx.weight = weight
@@ -51,7 +58,7 @@ def backward(ctx, grad_output):
5158
verbose = False
5259
input, weight_data = ctx.saved_tensors
5360
weight = ctx.weight
54-
assert isinstance(weight, BlockSparseMatrix)
61+
assert isinstance(weight, BlockSparseMatrixBase)
5562

5663
if verbose or check:
5764
dense_weight = weight.to_dense()
@@ -148,7 +155,7 @@ def backward(ctx, grad_output):
148155

149156

150157
class BlockSparseLinear(nn.Module):
151-
BLOCK_SIZE = 32
158+
OPTIMIZED_BLOCK_SIZE = 32
152159

153160
def __init__(
154161
self,
@@ -157,40 +164,52 @@ def __init__(
157164
bias: bool = True,
158165
density: float = 0.5,
159166
torch_nn_linear=None,
160-
verbose=False,
167+
verbose: bool = False,
168+
block_shape: Tuple[int, int] = (32, 32),
161169
):
162170
super(BlockSparseLinear, self).__init__()
163171
self.fn = BlockSparseLinearFunction.apply
164172
self.verbose = verbose
173+
self.block_shape = block_shape
174+
self._optimized = (
175+
self.block_shape[0] == self.OPTIMIZED_BLOCK_SIZE and self.block_shape[1] == self.OPTIMIZED_BLOCK_SIZE
176+
)
165177

166178
if torch_nn_linear is not None:
167179
in_features = torch_nn_linear.in_features
168180
out_features = torch_nn_linear.out_features
169181
bias = torch_nn_linear.bias is not None
170182

171-
if in_features % self.BLOCK_SIZE != 0:
183+
if in_features % self.block_shape[1] != 0:
172184
raise Exception(
173-
f"BlockSparseLinear invalid in_features={in_features}, should be multiple of {self.BLOCK_SIZE}"
185+
f"BlockSparseLinear invalid in_features={in_features}, should be multiple of {self.block_shape[1]}"
174186
)
175-
if out_features % self.BLOCK_SIZE != 0:
187+
if out_features % self.block_shape[0] != 0:
176188
raise Exception(
177-
f"BlockSparseLinear invalid in_features={in_features}, should be multiple of {self.BLOCK_SIZE}"
189+
f"BlockSparseLinear invalid in_features={in_features}, should be multiple of {self.block_shape[0]}"
178190
)
179191

180192
if density < 0 or density > 1:
181193
raise Exception(f"BlockSparseLinear invalid density={density}")
182194

183-
self.block_count = int(density * (in_features * out_features / (self.BLOCK_SIZE * self.BLOCK_SIZE)))
195+
self.block_count = int(density * (in_features * out_features / (self.block_shape[0] * self.block_shape[1])))
184196

185197
self.in_features = in_features
186198
self.out_features = out_features
187199

188-
block_shape = (self.BLOCK_SIZE, self.BLOCK_SIZE)
200+
block_shape = self.block_shape
201+
202+
if self._optimized:
203+
BlockSparseMatrixConstructor = BlockSparseMatrix
204+
else:
205+
BlockSparseMatrixConstructor = BlockSparseMatrixEmulator
206+
189207
if torch_nn_linear is not None:
190208
with torch.no_grad():
191-
weight = BlockSparseMatrix.from_dense(torch_nn_linear.weight, block_shape, self.block_count)
209+
weight = BlockSparseMatrixConstructor.from_dense(torch_nn_linear.weight, block_shape, self.block_count)
210+
weight.multiply_(1.0 / math.sqrt(density))
192211
else:
193-
weight = BlockSparseMatrix.randn(
212+
weight = BlockSparseMatrixConstructor.randn(
194213
(out_features, in_features),
195214
self.block_count,
196215
blocks=None,

0 commit comments

Comments
 (0)