@@ -38,12 +38,12 @@ class BadFFTFunction(Function):
38
38
def forward (self , input ):
39
39
numpy_input = input .numpy ()
40
40
result = abs (rfft2 (numpy_input ))
41
- return torch . FloatTensor (result )
41
+ return input . new (result )
42
42
43
43
def backward (self , grad_output ):
44
44
numpy_go = grad_output .numpy ()
45
45
result = irfft2 (numpy_go )
46
- return torch . FloatTensor (result )
46
+ return grad_output . new (result )
47
47
48
48
# since this layer does not have any parameters, we can
49
49
# simply declare this as a function, rather than as an nn.Module class
@@ -90,7 +90,7 @@ class ScipyConv2dFunction(Function):
90
90
def forward (ctx , input , filter ):
91
91
result = correlate2d (input .numpy (), filter .numpy (), mode = 'valid' )
92
92
ctx .save_for_backward (input , filter )
93
- return torch . FloatTensor (result )
93
+ return input . new (result )
94
94
95
95
@staticmethod
96
96
def backward (ctx , grad_output ):
@@ -99,8 +99,8 @@ def backward(ctx, grad_output):
99
99
grad_input = convolve2d (grad_output .numpy (), filter .t ().numpy (), mode = 'full' )
100
100
grad_filter = convolve2d (input .numpy (), grad_output .numpy (), mode = 'valid' )
101
101
102
- return Variable (torch . FloatTensor (grad_input )), \
103
- Variable (torch . FloatTensor (grad_filter ))
102
+ return Variable (grad_output . new (grad_input )), \
103
+ Variable (grad_output . new (grad_filter ))
104
104
105
105
106
106
class ScipyConv2d (Module ):
0 commit comments