Skip to content

Commit ab62128

Browse files
lucasb-eyerchsasank
authored andcommitted
Use foo.new instead of torch.FloatTensor for GPU. (pytorch#211)
This replaces the calls to `torch.FloatTensor` by a call to `.new` on the input tensor, such that GPU types are respected.
1 parent 9c1bc69 commit ab62128

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

advanced_source/numpy_extensions_tutorial.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ class BadFFTFunction(Function):
3838
def forward(self, input):
3939
numpy_input = input.numpy()
4040
result = abs(rfft2(numpy_input))
41-
return torch.FloatTensor(result)
41+
return input.new(result)
4242

4343
def backward(self, grad_output):
4444
numpy_go = grad_output.numpy()
4545
result = irfft2(numpy_go)
46-
return torch.FloatTensor(result)
46+
return grad_output.new(result)
4747

4848
# since this layer does not have any parameters, we can
4949
# simply declare this as a function, rather than as an nn.Module class
@@ -90,7 +90,7 @@ class ScipyConv2dFunction(Function):
9090
def forward(ctx, input, filter):
9191
result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
9292
ctx.save_for_backward(input, filter)
93-
return torch.FloatTensor(result)
93+
return input.new(result)
9494

9595
@staticmethod
9696
def backward(ctx, grad_output):
@@ -99,8 +99,8 @@ def backward(ctx, grad_output):
9999
grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
100100
grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')
101101

102-
return Variable(torch.FloatTensor(grad_input)), \
103-
Variable(torch.FloatTensor(grad_filter))
102+
return Variable(grad_output.new(grad_input)), \
103+
Variable(grad_output.new(grad_filter))
104104

105105

106106
class ScipyConv2d(Module):

0 commit comments

Comments
 (0)