@@ -79,7 +79,7 @@ class FFT(LinearOperator):
7979 """
8080 def __init__ (self , dims , dir = 0 , nfft = None , sampling = 1. ,
8181 real = False , fftshift = False , compute = (False , False ),
82- chunks = (None , None ), todask = (None , None ), dtype = 'complex128 ' ):
82+ chunks = (None , None ), todask = (None , None ), dtype = 'float64 ' ):
8383 if isinstance (dims , int ):
8484 dims = (dims ,)
8585 if dir > len (dims ) - 1 :
@@ -106,7 +106,12 @@ def __init__(self, dims, dir=0, nfft=None, sampling=1.,
106106 self .shape = (int (np .prod (dims ) * (self .nfft // 2 + 1 if self .real
107107 else self .nfft ) / self .dims [dir ]),
108108 int (np .prod (dims )))
109+ # Find types to enforce to forward and adjoint outputs. This is
110+ # required as np.fft.fft always returns complex128 even if input is
111+ # float32 or less
109112 self .dtype = np .dtype (dtype )
113+ self .cdtype = (np .ones (1 , dtype = self .dtype ) +
114+ 1j * np .ones (1 , dtype = self .dtype )).dtype
110115 self .compute = compute
111116 self .chunks = chunks
112117 self .todask = todask
@@ -138,6 +143,7 @@ def _matvec(self, x):
138143 y = sqrt (1. / self .nfft ) * da .fft .fft (x , n = self .nfft ,
139144 axis = self .dir )
140145 y = y .ravel ()
146+ y = y .astype (self .cdtype )
141147 return y
142148
143149 def _rmatvec (self , x ):
@@ -169,4 +175,5 @@ def _rmatvec(self, x):
169175 if self .fftshift :
170176 y = da .fft .fftshift (y , axes = self .dir )
171177 y = y .ravel ()
172- return y
178+ y = y .astype (self .dtype )
179+ return y
0 commit comments