Skip to content

Commit 882b255

Browse files
committed
Fix: DataParallel does not support boolean tensors -> removed unused block_mask from object.
1 parent c023ab4 commit 882b255

File tree

2 files changed

+34
-15
lines changed

2 files changed

+34
-15
lines changed

pytorch_block_sparse/block_sparse.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class BlockSparseMatrix(torch.nn.Module):
1010
# Data is (len(cols), block_shape, block_shape)
1111
def __init__(self, shape, block_mask, data, block_shape=(16, 16)):
1212
super(BlockSparseMatrix, self).__init__()
13+
self.int_type = torch.int32
1314

1415
if block_mask.device != data.device:
1516
raise Exception("block_mask and data should have same device, got %s and %s" % (block_mask.device, data.device))
@@ -34,7 +35,7 @@ def __init__(self, shape, block_mask, data, block_shape=(16, 16)):
3435

3536

3637
self.data = torch.nn.Parameter(data)
37-
for name in ("block_mask", "cols_a", "row_start_ends_a", "rows_b", "col_start_ends_b", "blocks"):
38+
for name in ("cols_a", "row_start_ends_a", "rows_b", "col_start_ends_b", "blocks"):
3839
self.register_buffer(name, locals()[name])
3940

4041
self.sanity_check(self.cols_a, self.row_start_ends_a, self.shape, self.block_shape)
@@ -79,15 +80,15 @@ def build_indices_(self, block_mask, nnzt, transpose_indices):
7980
row_start_ends = torch.zeros((X + 1,), dtype=torch.long, device = device)
8081

8182
row_start_ends.index_add_(0, rows + 1, torch.ones(size=(cols.shape[0],), dtype=torch.long, device = device))
82-
row_start_ends = row_start_ends.cumsum(0).int()
83+
row_start_ends = row_start_ends.cumsum(0).to(dtype=self.int_type)
8384

84-
cols = torch.stack([cols, block_shuffle], 1).int()
85+
cols = torch.stack([cols, block_shuffle], 1).to(dtype=self.int_type)
8586

8687
return cols, row_start_ends
8788

8889
def build_indices(self, block_mask):
8990
nnz = block_mask.nonzero()
90-
blocks = nnz.flip(-1).flatten().to(dtype=torch.int32)
91+
blocks = nnz.flip(-1).flatten().to(dtype=self.int_type)
9192

9293
nnzt = nnz.transpose(0, 1)
9394
cols_a, row_start_ends_a = self.build_indices_(block_mask, nnzt, False)
@@ -129,7 +130,7 @@ def zeros(cls, shape, n_blocks = None, blocks = None, block_shape=(32, 32), devi
129130
positions = torch.tensor(positions, dtype=torch.int64, device = device).sort()[0]
130131

131132
block_mask = torch.zeros(X * Y, dtype=torch.bool, device = device)
132-
block_mask[positions] = True
133+
block_mask[positions] = 1
133134
block_mask = block_mask.view(X, Y)
134135
data = torch.zeros((n_blocks * block_shape[0], block_shape[1]), dtype=torch.float, device = device)
135136

@@ -177,16 +178,16 @@ def build_coo_block_index(self):
177178
device = self.cols_a.device
178179
# Build a tensor to store the row indices.
179180
# It's one element too long for the moment, we'll trim it later
180-
rows = torch.zeros((self.cols_a.shape[0] + 1), dtype=torch.int32, device=device)
181+
rows = torch.zeros((self.cols_a.shape[0] + 1), dtype=self.int_type, device=device)
181182

182183
# Change self.row_start_ends_a to the right type
183184
row_end_prepare = self.row_start_ends_a[1:].long()
184185

185186
# Add ones to the start position of each new row
186-
rows.index_add_(0, row_end_prepare, torch.ones(size=row_end_prepare.shape, dtype=torch.int32, device=device))
187+
rows.index_add_(0, row_end_prepare, torch.ones(size=row_end_prepare.shape, dtype=self.int_type, device=device))
187188

188189
# Accumulate those start positions to fill the remaining positions
189-
rows = rows.cumsum(0).int()
190+
rows = rows.cumsum(0).to(dtype=self.int_type)
190191

191192
# Trim the last element: it's just a left over
192193
rows = rows[:-1]
@@ -219,9 +220,8 @@ def sanity_check(self, cols, row_end, shape, block_shape):
219220
row_end = row_end[1:]
220221
if len(cols.shape) != 2:
221222
raise Exception("cols should be bidimensional, not of shape %s" % cols.shape)
222-
if cols.dtype != torch.int32:
223-
224-
raise Exception("cols should be int32, not of type %s" % cols.dtype)
223+
if cols.dtype != self.int_type:
224+
raise Exception("cols should be %s, not of type %s" % (self.int_type, cols.dtype))
225225
max_col = cols[:,0].max()
226226
if max_col > shape[1] / block_shape[1]:
227227
raise Exception("cols max element (%d) cannot be larger than shape[1]/block_shape[1] (%d)" % (max_col, shape[1] / block_shape[1]))
@@ -230,8 +230,8 @@ def sanity_check(self, cols, row_end, shape, block_shape):
230230
raise Exception("row_end should be unidimensional, not of shape %s" % row_end.shape)
231231
if row_end.shape[0] != shape[0] / block_shape[0]:
232232
raise Exception("row_end.shape[0] (%d) should be equal to shape[0]/block_shape[0] (%d)" % (row_end.shape[0], shape[0] / block_shape[0]))
233-
if row_end.dtype != torch.int32:
234-
raise Exception("row_end should be int32, not of type %s" % row_end.dtype)
233+
if row_end.dtype != self.int_type:
234+
raise Exception("row_end should be %s, not of type %s" % (self.int_type, row_end.dtype))
235235

236236
max_row_end = row_end.max()
237237
if max_row_end > cols.shape[0]:
@@ -305,9 +305,9 @@ def reverse_matmul_(self, dense_a, transpose = True):
305305
assert(out.is_contiguous())
306306

307307
assert(ptr_b.is_contiguous())
308-
assert(ptr_b.dtype == torch.int32)
308+
assert(ptr_b.dtype == self.int_type)
309309
assert(indices_b.is_contiguous())
310-
assert(indices_b.dtype == torch.int32)
310+
assert(indices_b.dtype == self.int_type)
311311

312312
assert(ptr_b.shape[0] == self.blocks_count()[dim] + 1)
313313

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import unittest
2+
from unittest import TestCase
3+
import torch
4+
import torch.nn
5+
from pytorch_block_sparse import BlockSparseLinear
6+
7+
class TestFun(TestCase):
8+
9+
def test1(self):
10+
linear = BlockSparseLinear(64, 128, False).to("cuda")
11+
model = torch.nn.DataParallel(linear)
12+
13+
input_tensor = torch.randn(64, 64).cuda()
14+
15+
output = model(input_tensor)
16+
17+
18+
if __name__ == '__main__':
19+
unittest.main()

0 commit comments

Comments
 (0)