Skip to content

Commit 1e60d7f

Browse files
authored
Merge pull request #11 from mrava87/master
Ensure correct dtype in FFT and Marchenko operators
2 parents 270bfa2 + f51e0ba commit 1e60d7f

File tree

3 files changed

+57
-33
lines changed

3 files changed

+57
-33
lines changed

pylops_distributed/signalprocessing/FFT.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class FFT(LinearOperator):
7979
"""
8080
def __init__(self, dims, dir=0, nfft=None, sampling=1.,
8181
real=False, fftshift=False, compute=(False, False),
82-
chunks=(None, None), todask=(None, None), dtype='complex128'):
82+
chunks=(None, None), todask=(None, None), dtype='float64'):
8383
if isinstance(dims, int):
8484
dims = (dims,)
8585
if dir > len(dims) - 1:
@@ -106,7 +106,12 @@ def __init__(self, dims, dir=0, nfft=None, sampling=1.,
106106
self.shape = (int(np.prod(dims) * (self.nfft // 2 + 1 if self.real
107107
else self.nfft) / self.dims[dir]),
108108
int(np.prod(dims)))
109+
# Find types to enforce to forward and adjoint outputs. This is
110+
# required as np.fft.fft always returns complex128 even if input is
111+
# float32 or less
109112
self.dtype = np.dtype(dtype)
113+
self.cdtype = (np.ones(1, dtype=self.dtype) +
114+
1j*np.ones(1, dtype=self.dtype)).dtype
110115
self.compute = compute
111116
self.chunks = chunks
112117
self.todask = todask
@@ -138,6 +143,7 @@ def _matvec(self, x):
138143
y = sqrt(1. / self.nfft) * da.fft.fft(x, n=self.nfft,
139144
axis=self.dir)
140145
y = y.ravel()
146+
y = y.astype(self.cdtype)
141147
return y
142148

143149
def _rmatvec(self, x):
@@ -169,4 +175,5 @@ def _rmatvec(self, x):
169175
if self.fftshift:
170176
y = da.fft.fftshift(y, axes=self.dir)
171177
y = y.ravel()
172-
return y
178+
y = y.astype(self.dtype)
179+
return y

pylops_distributed/waveeqprocessing/marchenko.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,14 @@ def apply_onepoint(self, trav, dist=None, G0=None, nfft=None,
160160
# Create window
161161
trav_off = trav - self.toff
162162
trav_off = np.round(trav_off / self.dt).astype(np.int)
163-
w = np.zeros((self.nr, self.nt))
163+
w = np.zeros((self.nr, self.nt), dtype=self.dtype)
164164
for ir in range(self.nr):
165165
w[ir, :trav_off[ir]] = 1
166166
w = np.hstack((np.fliplr(w), w[:, 1:]))
167167
if self.nsmooth > 0:
168168
smooth = np.ones(self.nsmooth) / self.nsmooth
169169
w = filtfilt(smooth, 1, w)
170+
w = w.astype(self.dtype)
170171

171172
# Create operators
172173
Rop = MDC(self.Rtwosided_fft, self.nt2, nv=1, dt=self.dt, dr=self.dr,
@@ -202,16 +203,19 @@ def apply_onepoint(self, trav, dist=None, G0=None, nfft=None,
202203
if G0 is None:
203204
if self.wav is not None and nfft is not None:
204205
G0 = (directwave(self.wav, trav, self.nt,
205-
self.dt, nfft=nfft, dist=dist,
206+
self.dt, nfft=nfft,
207+
derivative=True, dist=dist,
206208
kind='2d' if dist is None else '3d')).T
207209
else:
208210
logging.error('wav and/or nfft are not provided. '
209211
'Provide either G0 or wav and nfft...')
210212
raise ValueError('wav and/or nfft are not provided. '
211213
'Provide either G0 or wav and nfft...')
214+
G0 = G0.astype(self.dtype)
212215

213216
fd_plus = np.concatenate((np.fliplr(G0).T,
214-
np.zeros((self.nt - 1, self.nr))))
217+
np.zeros((self.nt - 1, self.nr),
218+
dtype=self.dtype)))
215219
fd_plus = da.from_array(fd_plus)
216220

217221
# Run standard redatuming as benchmark
@@ -222,12 +226,14 @@ def apply_onepoint(self, trav, dist=None, G0=None, nfft=None,
222226
# Create data and inverse focusing functions
223227
d = Wop * Rop * fd_plus.flatten()
224228
d = da.concatenate((d.reshape(self.nt2, self.ns),
225-
da.zeros((self.nt2, self.ns))))
229+
da.zeros((self.nt2, self.ns),
230+
dtype = self.dtype)))
226231

227232
# Invert for focusing functions
228233
f1_inv = cgls(Mop, d.flatten(), **kwargs_cgls)[0]
229234
f1_inv = f1_inv.reshape(2 * self.nt2, self.nr)
230-
f1_inv_tot = f1_inv + da.concatenate((da.zeros((self.nt2, self.nr)),
235+
f1_inv_tot = f1_inv + da.concatenate((da.zeros((self.nt2, self.nr),
236+
dtype=self.dtype),
231237
fd_plus))
232238
# Create Green's functions
233239
if greens:
@@ -325,14 +331,15 @@ def apply_multiplepoints(self, trav, dist=None, G0=None, nfft=None,
325331
trav_off = trav - self.toff
326332
trav_off = np.round(trav_off / self.dt).astype(np.int)
327333

328-
w = np.zeros((self.nr, nvs, self.nt))
334+
w = np.zeros((self.nr, nvs, self.nt), dtype=self.dtype)
329335
for ir in range(self.nr):
330336
for ivs in range(nvs):
331337
w[ir, ivs, :trav_off[ir, ivs]] = 1
332338
w = np.concatenate((np.flip(w, axis=-1), w[:, :, 1:]), axis=-1)
333339
if self.nsmooth > 0:
334340
smooth = np.ones(self.nsmooth) / self.nsmooth
335341
w = filtfilt(smooth, 1, w)
342+
w = w.astype(self.dtype)
336343

337344
# Create operators
338345
Rop = MDC(self.Rtwosided_fft, self.nt2, nv=nvs, dt=self.dt,
@@ -367,20 +374,22 @@ def apply_multiplepoints(self, trav, dist=None, G0=None, nfft=None,
367374
# Create input focusing function
368375
if G0 is None:
369376
if self.wav is not None and nfft is not None:
370-
G0 = np.zeros((self.nr, nvs, self.nt))
377+
G0 = np.zeros((self.nr, nvs, self.nt), dtype=self.dtype)
371378
for ivs in range(nvs):
372379
G0[:, ivs] = (directwave(self.wav, trav[:, ivs],
373-
self.nt, self.dt, nfft=nfft)).T
374-
# dist=dist,
375-
# kind='2d' if dist is None else '3d')).T
380+
self.nt, self.dt, nfft=nfft,
381+
derivative=True, dist=dist,
382+
kind='2d' if dist is None else '3d')).T
376383
else:
377384
logging.error('wav and/or nfft are not provided. '
378385
'Provide either G0 or wav and nfft...')
379386
raise ValueError('wav and/or nfft are not provided. '
380387
'Provide either G0 or wav and nfft...')
388+
G0 = G0.astype(self.dtype)
381389

382390
fd_plus = np.concatenate((np.flip(G0, axis=-1).transpose(2, 0, 1),
383-
np.zeros((self.nt - 1, self.nr, nvs))))
391+
np.zeros((self.nt - 1, self.nr, nvs),
392+
dtype=self.dtype)))
384393
fd_plus = da.from_array(fd_plus).rechunk(fd_plus.shape)
385394

386395
# Run standard redatuming as benchmark
@@ -392,14 +401,15 @@ def apply_multiplepoints(self, trav, dist=None, G0=None, nfft=None,
392401
# Create data and inverse focusing functions
393402
d = Wop * Rop * fd_plus.flatten()
394403
d = da.concatenate((d.reshape(self.nt2, self.ns, nvs),
395-
da.zeros((self.nt2, self.ns, nvs))))
404+
da.zeros((self.nt2, self.ns, nvs),
405+
dtype=self.dtype)))
396406

397407
# Invert for focusing functions
398408
f1_inv = cgls(Mop, d.flatten(), **kwargs_cgls)[0]
399409
f1_inv = f1_inv.reshape(2 * self.nt2, self.nr, nvs)
400410
f1_inv_tot = \
401-
f1_inv + da.concatenate((np.zeros((self.nt2, self.nr, nvs)),
402-
fd_plus))
411+
f1_inv + da.concatenate((da.zeros((self.nt2, self.nr, nvs),
412+
dtype=self.dtype), fd_plus))
403413
if greens:
404414
# Create Green's functions
405415
g_inv = Gop * f1_inv_tot.flatten()

pytests/test_ffts.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,19 @@
1010

1111
par1 = {'nt': 101, 'nx': 31, 'ny': 10,
1212
'nfft': None, 'real': False,
13-
'ffthshift': False} # nfft=nt, complex input
13+
'ffthshift': False, 'dtype':np.complex128} # nfft=nt, complex input
1414
par2 = {'nt': 101, 'nx': 31, 'ny': 10,
1515
'nfft': 256, 'real': False,
16-
'ffthshift': False} # nfft>nt, complex input
16+
'ffthshift': False, 'dtype':np.complex64} # nfft>nt, complex input
1717
par3 = {'nt': 101, 'nx': 31, 'ny': 10,
1818
'nfft': None, 'real': True,
19-
'ffthshift': False} # nfft=nt, real input
19+
'ffthshift': False, 'dtype':np.float64} # nfft=nt, real input
2020
par4 = {'nt': 101, 'nx': 31, 'ny': 10,
2121
'nfft': 256, 'real': True,
22-
'ffthshift': False} # nfft>nt, real input
22+
'ffthshift': False, 'dtype':np.float32} # nfft>nt, real input
2323
par5 = {'nt': 101, 'nx': 31, 'ny': 10,
2424
'nfft': 256, 'real': True,
25-
'ffthshift': True} # nfft>nt, real input and fftshift
25+
'ffthshift': True, 'dtype':np.float32} # nfft>nt, real input and fftshift
2626

2727
@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5)])
2828
def test_FFT_1dsignal(par):
@@ -33,8 +33,9 @@ def test_FFT_1dsignal(par):
3333
x = da.from_array(np.sin(2 * np.pi * f0 * t))
3434
nfft = par['nt'] if par['nfft'] is None else par['nfft']
3535
dFFTop = dFFT(dims=[par['nt']], nfft=nfft, sampling=dt, real=par['real'],
36-
chunks = (par['nt'], nfft))
37-
FFTop = FFT(dims=[par['nt']], nfft=nfft, sampling=dt, real=par['real'])
36+
chunks=(par['nt'], nfft), dtype=par['dtype'])
37+
FFTop = FFT(dims=[par['nt']], nfft=nfft, sampling=dt, real=par['real'],
38+
dtype=par['dtype'])
3839

3940
# FFT with real=True cannot pass dot-test neither be inverted correctly,
4041
# see FFT documentation for a detailed explanation. We thus test FFT.H*FFT
@@ -71,8 +72,9 @@ def test_FFT_2dsignal(par):
7172
nfft = par['nt'] if par['nfft'] is None else par['nfft']
7273
dFFTop = dFFT(dims=(nt, nx), dir=0, nfft=nfft,
7374
sampling=dt, real=par['real'],
74-
chunks=((nt, nx), (nfft, nx)))
75-
FFTop = FFT(dims=(nt, nx), dir=0, nfft=nfft, sampling=dt)
75+
chunks=((nt, nx), (nfft, nx)), dtype=par['dtype'])
76+
FFTop = FFT(dims=(nt, nx), dir=0, nfft=nfft, sampling=dt,
77+
dtype=par['dtype'])
7678

7779
# FFT with real=True cannot pass dot-test neither be inverted correctly,
7880
# see FFT documentation for a detailed explanation. We thus test FFT.H*FFT
@@ -96,8 +98,10 @@ def test_FFT_2dsignal(par):
9698
# 2nd dimension
9799
nfft = par['nx'] if par['nfft'] is None else par['nfft']
98100
dFFTop = dFFT(dims=(nt, nx), dir=1, nfft=nfft, sampling=dt,
99-
real=par['real'], chunks=((nt, nx), (nt, nfft)))
100-
FFTop = FFT(dims=(nt, nx), dir=1, nfft=nfft, sampling=dt)
101+
real=par['real'], chunks=((nt, nx), (nt, nfft)),
102+
dtype=par['dtype'])
103+
FFTop = FFT(dims=(nt, nx), dir=1, nfft=nfft, sampling=dt,
104+
dtype=par['dtype'])
101105

102106
# FFT with real=True cannot pass dot-test neither be inverted correctly,
103107
# see FFT documentation for a detailed explanation. We thus test FFT.H*FFT
@@ -134,9 +138,10 @@ def test_FFT_3dsignal(par):
134138
# 1st dimension
135139
nfft = par['nt'] if par['nfft'] is None else par['nfft']
136140
dFFTop = dFFT(dims=(nt, nx, ny), dir=0, nfft=nfft, sampling=dt,
137-
real=par['real'], chunks=((nt, nx, ny), (nfft, nx, ny)))
141+
real=par['real'], chunks=((nt, nx, ny), (nfft, nx, ny)),
142+
dtype=par['dtype'])
138143
FFTop = FFT(dims=(nt, nx, ny), dir=0, nfft=nfft, sampling=dt,
139-
real=par['real'])
144+
real=par['real'], dtype=par['dtype'])
140145

141146
# FFT with real=True cannot pass dot-test neither be inverted correctly,
142147
# see FFT documentation for a detailed explanation. We thus test FFT.H*FFT
@@ -160,9 +165,10 @@ def test_FFT_3dsignal(par):
160165
# 2nd dimension
161166
nfft = par['nx'] if par['nfft'] is None else par['nfft']
162167
dFFTop = dFFT(dims=(nt, nx, ny), dir=1, nfft=nfft, sampling=dt,
163-
real=par['real'], chunks=((nt, nx, ny), (nt, nfft, ny)))
168+
real=par['real'], chunks=((nt, nx, ny), (nt, nfft, ny)),
169+
dtype=par['dtype'])
164170
FFTop = FFT(dims=(nt, nx, ny), dir=1, nfft=nfft, sampling=dt,
165-
real=par['real'])
171+
real=par['real'], dtype=par['dtype'])
166172

167173
# FFT with real=True cannot pass dot-test neither be inverted correctly,
168174
# see FFT documentation for a detailed explanation. We thus test FFT.H*FFT
@@ -186,9 +192,10 @@ def test_FFT_3dsignal(par):
186192
# 3rd dimension
187193
nfft = par['ny'] if par['nfft'] is None else par['nfft']
188194
dFFTop = dFFT(dims=(nt, nx, ny), dir=2, nfft=nfft, sampling=dt,
189-
real=par['real'], chunks=((nt, nx, ny), (nt, ny, nfft)))
195+
real=par['real'], chunks=((nt, nx, ny), (nt, ny, nfft)),
196+
dtype=par['dtype'])
190197
FFTop = FFT(dims=(nt, nx, ny), dir=2, nfft=nfft, sampling=dt,
191-
real=par['real'])
198+
real=par['real'], dtype=par['dtype'])
192199

193200
# FFT with real=True cannot pass dot-test neither be inverted correctly,
194201
# see FFT documentation for a detailed explanation. We thus test FFT.H*FFT

0 commit comments

Comments
 (0)