Skip to content

Commit 83579cb

Browse files
committed
Fixed bug in block sparse linear backward function.
1 parent 26e708e commit 83579cb

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

pytorch_block_sparse/block_sparse_linear.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,11 @@ def backward(ctx, grad_output):
110110
else:
111111
grad_weight1 = None
112112

113-
assert(not (grad_weight1 == 0).all())
114-
assert(grad_input1.shape == input.shape)
113+
if grad_weight1 != None:
114+
assert(not (grad_weight1 == 0).all())
115+
if grad_input1 != None:
116+
assert(grad_input1.shape == input.shape)
117+
115118
return grad_input1, grad_weight1, None
116119

117120
class BlockSparseLinear(nn.Module):

0 commit comments

Comments
 (0)