Skip to content

Commit ef5ab88

Browse files
committed
Ensure that dtype chosen dtype is enforced to all variables created in Marchenko class
1 parent 6fc2e33 commit ef5ab88

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

pylops_distributed/waveeqprocessing/marchenko.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,14 @@ def apply_onepoint(self, trav, dist=None, G0=None, nfft=None,
160160
# Create window
161161
trav_off = trav - self.toff
162162
trav_off = np.round(trav_off / self.dt).astype(np.int)
163-
w = np.zeros((self.nr, self.nt))
163+
w = np.zeros((self.nr, self.nt), dtype=self.dtype)
164164
for ir in range(self.nr):
165165
w[ir, :trav_off[ir]] = 1
166166
w = np.hstack((np.fliplr(w), w[:, 1:]))
167167
if self.nsmooth > 0:
168168
smooth = np.ones(self.nsmooth) / self.nsmooth
169169
w = filtfilt(smooth, 1, w)
170+
w = w.astype(self.dtype)
170171

171172
# Create operators
172173
Rop = MDC(self.Rtwosided_fft, self.nt2, nv=1, dt=self.dt, dr=self.dr,
@@ -202,16 +203,19 @@ def apply_onepoint(self, trav, dist=None, G0=None, nfft=None,
202203
if G0 is None:
203204
if self.wav is not None and nfft is not None:
204205
G0 = (directwave(self.wav, trav, self.nt,
205-
self.dt, nfft=nfft, dist=dist,
206+
self.dt, nfft=nfft,
207+
derivative=True, dist=dist,
206208
kind='2d' if dist is None else '3d')).T
207209
else:
208210
logging.error('wav and/or nfft are not provided. '
209211
'Provide either G0 or wav and nfft...')
210212
raise ValueError('wav and/or nfft are not provided. '
211213
'Provide either G0 or wav and nfft...')
214+
G0 = G0.astype(self.dtype)
212215

213216
fd_plus = np.concatenate((np.fliplr(G0).T,
214-
np.zeros((self.nt - 1, self.nr))))
217+
np.zeros((self.nt - 1, self.nr),
218+
dtype=self.dtype)))
215219
fd_plus = da.from_array(fd_plus)
216220

217221
# Run standard redatuming as benchmark
@@ -222,12 +226,14 @@ def apply_onepoint(self, trav, dist=None, G0=None, nfft=None,
222226
# Create data and inverse focusing functions
223227
d = Wop * Rop * fd_plus.flatten()
224228
d = da.concatenate((d.reshape(self.nt2, self.ns),
225-
da.zeros((self.nt2, self.ns))))
229+
da.zeros((self.nt2, self.ns),
230+
dtype = self.dtype)))
226231

227232
# Invert for focusing functions
228233
f1_inv = cgls(Mop, d.flatten(), **kwargs_cgls)[0]
229234
f1_inv = f1_inv.reshape(2 * self.nt2, self.nr)
230-
f1_inv_tot = f1_inv + da.concatenate((da.zeros((self.nt2, self.nr)),
235+
f1_inv_tot = f1_inv + da.concatenate((da.zeros((self.nt2, self.nr),
236+
dtype=self.dtype),
231237
fd_plus))
232238
# Create Green's functions
233239
if greens:
@@ -325,14 +331,15 @@ def apply_multiplepoints(self, trav, dist=None, G0=None, nfft=None,
325331
trav_off = trav - self.toff
326332
trav_off = np.round(trav_off / self.dt).astype(np.int)
327333

328-
w = np.zeros((self.nr, nvs, self.nt))
334+
w = np.zeros((self.nr, nvs, self.nt), dtype=self.dtype)
329335
for ir in range(self.nr):
330336
for ivs in range(nvs):
331337
w[ir, ivs, :trav_off[ir, ivs]] = 1
332338
w = np.concatenate((np.flip(w, axis=-1), w[:, :, 1:]), axis=-1)
333339
if self.nsmooth > 0:
334340
smooth = np.ones(self.nsmooth) / self.nsmooth
335341
w = filtfilt(smooth, 1, w)
342+
w = w.astype(self.dtype)
336343

337344
# Create operators
338345
Rop = MDC(self.Rtwosided_fft, self.nt2, nv=nvs, dt=self.dt,
@@ -367,20 +374,22 @@ def apply_multiplepoints(self, trav, dist=None, G0=None, nfft=None,
367374
# Create input focusing function
368375
if G0 is None:
369376
if self.wav is not None and nfft is not None:
370-
G0 = np.zeros((self.nr, nvs, self.nt))
377+
G0 = np.zeros((self.nr, nvs, self.nt), dtype=self.dtype)
371378
for ivs in range(nvs):
372379
G0[:, ivs] = (directwave(self.wav, trav[:, ivs],
373-
self.nt, self.dt, nfft=nfft)).T
374-
# dist=dist,
375-
# kind='2d' if dist is None else '3d')).T
380+
self.nt, self.dt, nfft=nfft,
381+
derivative=True, dist=dist,
382+
kind='2d' if dist is None else '3d')).T
376383
else:
377384
logging.error('wav and/or nfft are not provided. '
378385
'Provide either G0 or wav and nfft...')
379386
raise ValueError('wav and/or nfft are not provided. '
380387
'Provide either G0 or wav and nfft...')
388+
G0 = G0.astype(self.dtype)
381389

382390
fd_plus = np.concatenate((np.flip(G0, axis=-1).transpose(2, 0, 1),
383-
np.zeros((self.nt - 1, self.nr, nvs))))
391+
np.zeros((self.nt - 1, self.nr, nvs),
392+
dtype=self.dtype)))
384393
fd_plus = da.from_array(fd_plus).rechunk(fd_plus.shape)
385394

386395
# Run standard redatuming as benchmark
@@ -392,14 +401,15 @@ def apply_multiplepoints(self, trav, dist=None, G0=None, nfft=None,
392401
# Create data and inverse focusing functions
393402
d = Wop * Rop * fd_plus.flatten()
394403
d = da.concatenate((d.reshape(self.nt2, self.ns, nvs),
395-
da.zeros((self.nt2, self.ns, nvs))))
404+
da.zeros((self.nt2, self.ns, nvs),
405+
dtype=self.dtype)))
396406

397407
# Invert for focusing functions
398408
f1_inv = cgls(Mop, d.flatten(), **kwargs_cgls)[0]
399409
f1_inv = f1_inv.reshape(2 * self.nt2, self.nr, nvs)
400410
f1_inv_tot = \
401-
f1_inv + da.concatenate((np.zeros((self.nt2, self.nr, nvs)),
402-
fd_plus))
411+
f1_inv + da.concatenate((da.zeros((self.nt2, self.nr, nvs),
412+
dtype=self.dtype), fd_plus))
403413
if greens:
404414
# Create Green's functions
405415
g_inv = Gop * f1_inv_tot.flatten()

0 commit comments

Comments
 (0)