Skip to content

Commit 8aeb057

Browse files
committed
Ensure that chosen dtype is enforced to the output of ifft in FFT operator
1 parent ef5ab88 commit 8aeb057

File tree

1 file changed

+9
-2
lines changed
  • pylops_distributed/signalprocessing

1 file changed

+9
-2
lines changed

pylops_distributed/signalprocessing/FFT.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class FFT(LinearOperator):
7979
"""
8080
def __init__(self, dims, dir=0, nfft=None, sampling=1.,
8181
real=False, fftshift=False, compute=(False, False),
82-
chunks=(None, None), todask=(None, None), dtype='complex128'):
82+
chunks=(None, None), todask=(None, None), dtype='float64'):
8383
if isinstance(dims, int):
8484
dims = (dims,)
8585
if dir > len(dims) - 1:
@@ -106,7 +106,12 @@ def __init__(self, dims, dir=0, nfft=None, sampling=1.,
106106
self.shape = (int(np.prod(dims) * (self.nfft // 2 + 1 if self.real
107107
else self.nfft) / self.dims[dir]),
108108
int(np.prod(dims)))
109+
# Find types to enforce to forward and adjoint outputs. This is
110+
# required as np.fft.fft always returns complex128 even if input is
111+
# float32 or less
109112
self.dtype = np.dtype(dtype)
113+
self.cdtype = (np.ones(1, dtype=self.dtype) +
114+
1j*np.ones(1, dtype=self.dtype)).dtype
110115
self.compute = compute
111116
self.chunks = chunks
112117
self.todask = todask
@@ -138,6 +143,7 @@ def _matvec(self, x):
138143
y = sqrt(1. / self.nfft) * da.fft.fft(x, n=self.nfft,
139144
axis=self.dir)
140145
y = y.ravel()
146+
y = y.astype(self.cdtype)
141147
return y
142148

143149
def _rmatvec(self, x):
@@ -169,4 +175,5 @@ def _rmatvec(self, x):
169175
if self.fftshift:
170176
y = da.fft.fftshift(y, axes=self.dir)
171177
y = y.ravel()
172-
return y
178+
y = y.astype(self.dtype)
179+
return y

0 commit comments

Comments
 (0)