Skip to content

Commit 539fe31

Browse files
Will FengJoelMarcey
Will Feng
authored andcommitted
Fix numpy_extensions_tutorial.py
1 parent bbc6235 commit 539fe31

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

advanced_source/numpy_extensions_tutorial.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def backward(ctx, grad_output):
103103
# the previous line can be expressed equivalently as:
104104
# grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full')
105105
grad_filter = correlate2d(input.numpy(), grad_output, mode='valid')
106-
return torch.as_tensor(grad_input, dtype=input.dtype), torch.as_tensor(grad_filter, dtype=filter.dtype), torch.as_tensor(grad_bias, dtype=bias.dtype)
106+
return torch.from_numpy(grad_input), torch.from_numpy(grad_filter).to(torch.float), torch.from_numpy(grad_bias).to(torch.float)
107107

108108

109109
class ScipyConv2d(Module):

0 commit comments

Comments
 (0)