Skip to content

Commit d129290

Browse files
authored
Merge pull request #9 from mrava87/master
Avoid repeating matvec/rmatvec in compuond operators (sum, prod, etc.)
2 parents 27ed544 + fd45f17 commit d129290

File tree

3 files changed

+65
-20
lines changed

3 files changed

+65
-20
lines changed

pylops_distributed/LinearOperator.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,12 @@ def adjoint(self):
201201
def _adjoint(self):
202202
"""Default implementation of _adjoint; defers to rmatvec."""
203203
shape = (self.shape[1], self.shape[0])
204-
return _CustomLinearOperator(shape, matvec=self.rmatvec,
205-
rmatvec=self.matvec,
204+
return _CustomLinearOperator(shape, matvec=self._rmatvec,
205+
rmatvec=self._matvec,
206+
compute=(self.compute[1],
207+
self.compute[0]),
208+
todask=(self.todask[1],
209+
self.todask[0]),
206210
dtype=self.dtype)
207211

208212
def div1(self, y, niter=100):
@@ -286,7 +290,7 @@ def _adjoint(self):
286290
matvec=self.__rmatvec_impl,
287291
rmatvec=self.__matvec_impl,
288292
dtype=self.dtype,
289-
compute=self.compute,
293+
compute=(self.compute[1], self.compute[0]),
290294
todask=(self.todask[1], self.todask[0]))
291295

292296

@@ -319,13 +323,13 @@ def __init__(self, A, B):
319323
self.args = (Ac, Bc)
320324

321325
def _matvec(self, x):
322-
return self.args[0].matvec(x) + self.args[1].matvec(x)
326+
return self.args[0]._matvec(x) + self.args[1]._matvec(x)
323327

324328
def _rmatvec(self, x):
325-
return self.args[0].rmatvec(x) + self.args[1].rmatvec(x)
329+
return self.args[0]._rmatvec(x) + self.args[1]._rmatvec(x)
326330

327331
def _matmat(self, x):
328-
return self.args[0].matmat(x) + self.args[1].matmat(x)
332+
return self.args[0]._matmat(x) + self.args[1]._matmat(x)
329333

330334
def _adjoint(self):
331335
A, B = self.args
@@ -345,8 +349,8 @@ def __init__(self, A, B):
345349
dtype=A.dtype, Op=None,
346350
explicit=A.explicit and
347351
B.explicit,
348-
compute=(B.compute[0],
349-
A.compute[1]),
352+
compute=(A.compute[0],
353+
B.compute[1]),
350354
todask=(B.todask[0],
351355
A.todask[1]))
352356
# Force compute and todask not to be applied to individual operators
@@ -359,13 +363,13 @@ def __init__(self, A, B):
359363
self.args = (Ac, Bc)
360364

361365
def _matvec(self, x):
362-
return self.args[0].matvec(self.args[1].matvec(x))
366+
return self.args[0]._matvec(self.args[1]._matvec(x))
363367

364368
def _rmatvec(self, x):
365-
return self.args[1].rmatvec(self.args[0].rmatvec(x))
369+
return self.args[1]._rmatvec(self.args[0]._rmatvec(x))
366370

367371
def _matmat(self, x):
368-
return self.args[0].matmat(self.args[1].matmat(x))
372+
return self.args[0]._matmat(self.args[1]._matmat(x))
369373

370374
def _adjoint(self):
371375
A, B = self.args
@@ -390,13 +394,13 @@ def __init__(self, A, alpha):
390394
self.args = (Ac, alpha)
391395

392396
def _matvec(self, x):
393-
return self.args[1] * self.args[0].matvec(x)
397+
return self.args[1] * self.args[0]._matvec(x)
394398

395399
def _rmatvec(self, x):
396-
return np.conj(self.args[1]) * self.args[0].rmatvec(x)
400+
return np.conj(self.args[1]) * self.args[0]._rmatvec(x)
397401

398402
def _matmat(self, x):
399-
return self.args[1] * self.args[0].matmat(x)
403+
return self.args[1] * self.args[0]._matmat(x)
400404

401405
def _adjoint(self):
402406
A, alpha = self.args
@@ -428,13 +432,13 @@ def _power(self, fun, x, compute):
428432
return res
429433

430434
def _matvec(self, x):
431-
return self._power(self.args[0].matvec, x, self.compute[0])
435+
return self._power(self.args[0]._matvec, x, self.compute[0])
432436

433437
def _rmatvec(self, x):
434-
return self._power(self.args[0].rmatvec, x, self.compute[1])
438+
return self._power(self.args[0]._rmatvec, x, self.compute[1])
435439

436440
def _matmat(self, x):
437-
return self._power(self.args[0].matmat, x, self.compute[0])
441+
return self._power(self.args[0]._matmat, x, self.compute[0])
438442

439443
def _adjoint(self):
440444
A, p = self.args
@@ -489,4 +493,4 @@ def aslinearoperator(Op):
489493
return Op
490494
else:
491495
return LinearOperator(Op.shape, Op.dtype, Op, explicit=Op.explicit,
492-
compute=Op.compute, todask=Op.todask)
496+
compute=Op.compute, todask=Op.todask)

pylops_distributed/waveeqprocessing/mdd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,5 +74,4 @@ def MDC(G, nt, nv, dt=1., dr=1., twosided=True,
7474
args_FFT1={'chunks': ((nt, G.shape[1], nv),
7575
(nt, G.shape[1], nv)),
7676
'todask': (todask[1], False),
77-
'compute':(compute[0], False)})
78-
77+
'compute':(False, compute[0])})

pytests/test_waveeqprocessing.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,48 @@
6060
par8['nfmax'] = int(np.ceil((PAR['nt']+1.)/2))-30
6161

6262

63+
@pytest.mark.parametrize("par", [(par1)])
64+
def test_MDC_compute(par):
65+
"""Ensure that forward and adjoint of MDC return numpy array when
66+
compute=True
67+
"""
68+
par['nt2'] = par['nt']
69+
v = 1500
70+
it0_m = 25
71+
t0_m = it0_m * par['dt']
72+
theta_m = 0
73+
amp_m = 1.
74+
75+
it0_G = np.array([25, 50, 75])
76+
t0_G = it0_G * par['dt']
77+
theta_G = (0, 0, 0)
78+
phi_G = (0, 0, 0)
79+
amp_G = (1., 0.6, 2.)
80+
81+
# Create axis
82+
t, _, x, y = makeaxis(par)
83+
84+
# Create wavelet
85+
wav = ricker(t[:41], f0=par['f0'])[0]
86+
87+
# Generate model
88+
_, mwav = linear2d(x, t, v, t0_m, theta_m, amp_m, wav)
89+
# Generate operator
90+
_, Gwav = linear3d(x, y, t, v, t0_G, theta_G, phi_G, amp_G, wav)
91+
92+
# Define MDC linear operator
93+
Gwav_fft = np.fft.fft(Gwav, par['nt2'], axis=-1)
94+
Gwav_fft = Gwav_fft[..., :par['nfmax']]
95+
96+
dMDCop = dMDC(da.from_array(Gwav_fft.transpose(2, 0, 1)),
97+
nt=par['nt2'], nv=1, dt=par['dt'], dr=par['dx'],
98+
twosided=par['twosided'], todask=(True, True),
99+
compute=(True, True))
100+
101+
assert isinstance(dMDCop.matvec(np.ones(dMDCop.shape[1])), np.ndarray)
102+
assert isinstance(dMDCop.rmatvec(np.ones(dMDCop.shape[0])), np.ndarray)
103+
104+
63105
@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4),
64106
(par5), (par6), (par7), (par8)])
65107
def test_MDC_1virtualsource(par):

0 commit comments

Comments
 (0)