Skip to content

Commit 9815e32

Browse files
authored
Merge pull request #15 from mrava87/master
New implementation of MDC
2 parents 9475728 + 4a27da1 commit 9815e32

File tree

3 files changed

+112
-21
lines changed

3 files changed

+112
-21
lines changed

pylops_distributed/LinearOperator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def matvec(self, x):
7676
An array with shape ``(N, )`` or ``(N, 1)``
7777
7878
"""
79-
if self.todask[0]:
79+
if self.todask[0] and not isinstance(x, da.core.Array):
8080
x = da.asarray(x)
8181
if self.Op is None:
8282
y = self._matvec(x)
@@ -105,7 +105,7 @@ def rmatvec(self, x):
105105
An array with shape ``(M, )`` or ``(M, 1)``
106106
107107
"""
108-
if self.todask[1]:
108+
if self.todask[1] and not isinstance(x, da.core.Array):
109109
x = da.asarray(x)
110110
if self.Op is None:
111111
y = self._rmatvec(x)

pylops_distributed/waveeqprocessing/marchenko.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,6 @@ def apply_onepoint(self, trav, dist=None, G0=None, nfft=None,
256256
f1_inv_minus = f1_inv_tot[:self.nt2].T
257257
f1_inv_plus = f1_inv_tot[self.nt2:].T
258258
if greens:
259-
g_inv = np.real(g_inv) # cast to real as Gop is a complex operator
260259
g_inv_minus, g_inv_plus = -g_inv[:self.nt2].T, \
261260
np.fliplr(g_inv[self.nt2:].T)
262261

@@ -432,7 +431,6 @@ def apply_multiplepoints(self, trav, dist=None, G0=None, nfft=None,
432431
f1_inv_minus = f1_inv_tot[:self.nt2].transpose(1, 2, 0)
433432
f1_inv_plus = f1_inv_tot[self.nt2:].transpose(1, 2, 0)
434433
if greens:
435-
g_inv = np.real(g_inv) # cast to real as Gop is a complex operator
436434
g_inv_minus = -g_inv[:self.nt2].transpose(1, 2, 0)
437435
g_inv_plus = np.flip(g_inv[self.nt2:], axis=0).transpose(1, 2, 0)
438436

pylops_distributed/waveeqprocessing/mdd.py

Lines changed: 110 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import logging
22
import numpy as np
33
import dask.array as da
4+
from math import sqrt
45
from pylops.waveeqprocessing.mdd import _MDC
56

7+
from pylops_distributed import LinearOperator
68
from pylops_distributed.utils import dottest as Dottest
79
from pylops_distributed import Identity, Transpose
810
from pylops_distributed.signalprocessing import FFT, Fredholm1
@@ -11,9 +13,7 @@
1113
logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.WARNING)
1214

1315

14-
def MDC(G, nt, nv, dt=1., dr=1., twosided=True,
15-
saveGt=False, conj=False, prescaled=False,
16-
compute=(False, False), todask=(False, False)):
16+
class MDC(LinearOperator):
1717
r"""Multi-dimensional convolution.
1818
1919
Apply multi-dimensional convolution between two datasets.
@@ -56,28 +56,120 @@ def MDC(G, nt, nv, dt=1., dr=1., twosided=True,
5656
todask : :obj:`tuple`, optional
5757
Apply :func:`dask.array.from_array` to model and data before applying
5858
forward and adjoint respectively
59+
dtype : :obj:`str`, optional
60+
Type of elements in input array. If ``None``, automatically inferred
61+
from ``G``
5962
6063
Notes
6164
-----
6265
Refer to :class:`pylops.waveeqprocessing.MDC` for implementation
6366
details.
6467
6568
"""
66-
return _MDC(G, nt, nv, dt=dt, dr=dr, twosided=twosided,
67-
transpose=False, saveGt=saveGt, conj=conj, prescaled=prescaled,
68-
_Identity=Identity, _Transpose=Transpose,
69-
_FFT=FFT, _Fredholm1=Fredholm1,
70-
args_Fredholm1={'chunks': ((G.chunks[0], G.shape[2], nv),
71-
(G.chunks[0], G.shape[1], nv))},
72-
args_FFT={'chunks': ((nt, G.shape[2], nv),
73-
(nt, G.shape[2], nv)),
74-
'todask':(todask[0], False),
75-
'compute': (False, compute[1])},
76-
args_FFT1={'chunks': ((nt, G.shape[1], nv),
77-
(nt, G.shape[1], nv)),
78-
'todask': (todask[1], False),
79-
'compute':(False, compute[0])})
69+
def __init__(self, G, nt, nv, dt=1., dr=1., twosided=True,
70+
saveGt=False, conj=False, prescaled=False,
71+
chunks=(None, None), compute=(False, False),
72+
todask=(False, False), dtype=None):
8073

74+
if twosided and nt % 2 == 0:
75+
raise ValueError('nt must be odd number')
76+
77+
# store G
78+
self.G = G
79+
self.nfmax, self.ns, self.nr = self.G.shape
80+
self.saveGt = saveGt
81+
if self.saveGt:
82+
self.GT = (G.transpose((0, 2, 1)).conj()).persist()
83+
84+
# ensure that nfmax is not bigger than allowed
85+
self.nfft = int(np.ceil((nt + 1) / 2))
86+
if self.nfmax > self.nfft:
87+
self.nfmax = self.nfft
88+
logging.warning('nfmax set equal to ceil[(nt+1)/2=%d]' % self.nfmax)
89+
90+
# store other input parameters
91+
self.nt, self.nv = nt, nv
92+
self.dt, self.dr = dt, dr
93+
self.twosided = twosided
94+
self.conj = conj
95+
self.prescaled = prescaled
96+
self.dims = (self.nt, self.nr, self.nv)
97+
self.dimsd = (self.nt, self.ns, self.nv)
98+
self.dimsdf = (self.nfft, self.ns, self.nv)
99+
100+
# find out dtype of G
101+
self.cdtype = self.G[0, 0, 0].dtype
102+
if dtype is None:
103+
self.dtype = np.real(np.ones(1, dtype=self.cdtype)).dtype
104+
else:
105+
self.dtype = dtype
106+
107+
self.shape = (np.prod(self.dimsd), np.prod(self.dims))
108+
self.compute = compute
109+
self.chunks = chunks
110+
self.todask = todask
111+
self.Op = None
112+
self.explicit = False
113+
114+
def _matvec(self, x):
115+
# apply forward fft
116+
x = da.reshape(x, self.dims)
117+
if self.twosided:
118+
x = da.fft.ifftshift(x, axes=0)
119+
y = sqrt(1. / self.nt) * da.fft.rfft(x, n=self.nt, axis=0)
120+
y = y.astype(self.cdtype)
121+
y = y[:self.nfmax]
122+
123+
# apply batched matrix mult
124+
y = y.rechunk((self.G.chunks[0], self.nr, self.nv))
125+
if self.conj:
126+
y = y.conj()
127+
y = da.matmul(self.G, y)
128+
if self.conj:
129+
y = y.conj()
130+
if not self.prescaled:
131+
y *= self.dr * self.dt * np.sqrt(self.nt)
132+
133+
# apply inverse fft
134+
y = da.pad(y, ((0, self.nfft - self.nfmax), (0, 0), (0, 0)), mode='constant')
135+
y = y.rechunk(self.dimsdf)
136+
y = sqrt(self.nt) * da.fft.irfft(y, n=self.nt, axis=0)
137+
y = y.astype(self.dtype)
138+
y = da.real(y)
139+
return y.ravel()
140+
141+
def _rmatvec(self, x):
142+
# apply forward fft
143+
x = da.reshape(x, self.dimsd)
144+
y = sqrt(1. / self.nt) * da.fft.rfft(x, n=self.nt, axis=0)
145+
y = y.astype(self.cdtype)
146+
y = y[:self.nfmax]
147+
148+
# apply batched matrix mult
149+
y = y.rechunk((self.G.chunks[0], self.nr, self.nv))
150+
if self.saveGt:
151+
if self.conj:
152+
y = y.conj()
153+
y = da.matmul(self.GT, y)
154+
if self.conj:
155+
y = y.conj()
156+
else:
157+
if self.conj:
158+
y = da.matmul(y.transpose(0, 2, 1), self.G).transpose(0, 2, 1)
159+
else:
160+
y = da.matmul(y.transpose(0, 2, 1).conj(), self.G).transpose(0, 2, 1).conj()
161+
if not self.prescaled:
162+
y *= self.dr * self.dt * np.sqrt(self.nt)
163+
164+
# apply inverse fft
165+
y = da.pad(y, ((0, self.nfft - self.nfmax), (0, 0), (0, 0)), mode='constant')
166+
y = y.rechunk(self.dimsdf)
167+
y = sqrt(self.nt) * da.fft.irfft(y, n=self.nt, axis=0)
168+
if self.twosided:
169+
y = da.fft.fftshift(y, axes=0)
170+
y = y.astype(self.dtype)
171+
y = da.real(y)
172+
return y.ravel()
81173

82174

83175
def MDD(G, d, dt=0.004, dr=1., nfmax=None, wav=None,
@@ -178,6 +270,7 @@ def MDD(G, d, dt=0.004, dr=1., nfmax=None, wav=None,
178270
d = da.concatenate((da.zeros((nt - 1, ns)), d), axis=0)
179271
else:
180272
d = da.concatenate((da.zeros((nt - 1, ns, nv)), d), axis=0)
273+
d = d.rechunk(d.shape)
181274

182275
# Define MDC linear operator
183276
MDCop = MDC(G, nt2, nv=nv, dt=dt, dr=dr,

0 commit comments

Comments
 (0)