|
1 | 1 | import logging |
2 | 2 | import numpy as np |
3 | 3 | import dask.array as da |
| 4 | +from math import sqrt |
4 | 5 | from pylops.waveeqprocessing.mdd import _MDC |
5 | 6 |
|
| 7 | +from pylops_distributed import LinearOperator |
6 | 8 | from pylops_distributed.utils import dottest as Dottest |
7 | 9 | from pylops_distributed import Identity, Transpose |
8 | 10 | from pylops_distributed.signalprocessing import FFT, Fredholm1 |
|
11 | 13 | logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.WARNING) |
12 | 14 |
|
13 | 15 |
|
14 | | -def MDC(G, nt, nv, dt=1., dr=1., twosided=True, |
15 | | - saveGt=False, conj=False, prescaled=False, |
16 | | - compute=(False, False), todask=(False, False)): |
| 16 | +class MDC(LinearOperator): |
17 | 17 | r"""Multi-dimensional convolution. |
18 | 18 |
|
19 | 19 | Apply multi-dimensional convolution between two datasets. |
@@ -56,28 +56,120 @@ def MDC(G, nt, nv, dt=1., dr=1., twosided=True, |
56 | 56 | todask : :obj:`tuple`, optional |
57 | 57 | Apply :func:`dask.array.from_array` to model and data before applying |
58 | 58 | forward and adjoint respectively |
| 59 | + dtype : :obj:`str`, optional |
| 60 | + Type of elements in input array. If ``None``, automatically inferred |
| 61 | + from ``G`` |
59 | 62 |
|
60 | 63 | Notes |
61 | 64 | ----- |
62 | 65 | Refer to :class:`pylops.waveeqprocessing.MDC` for implementation |
63 | 66 | details. |
64 | 67 |
|
65 | 68 | """ |
66 | | - return _MDC(G, nt, nv, dt=dt, dr=dr, twosided=twosided, |
67 | | - transpose=False, saveGt=saveGt, conj=conj, prescaled=prescaled, |
68 | | - _Identity=Identity, _Transpose=Transpose, |
69 | | - _FFT=FFT, _Fredholm1=Fredholm1, |
70 | | - args_Fredholm1={'chunks': ((G.chunks[0], G.shape[2], nv), |
71 | | - (G.chunks[0], G.shape[1], nv))}, |
72 | | - args_FFT={'chunks': ((nt, G.shape[2], nv), |
73 | | - (nt, G.shape[2], nv)), |
74 | | - 'todask':(todask[0], False), |
75 | | - 'compute': (False, compute[1])}, |
76 | | - args_FFT1={'chunks': ((nt, G.shape[1], nv), |
77 | | - (nt, G.shape[1], nv)), |
78 | | - 'todask': (todask[1], False), |
79 | | - 'compute':(False, compute[0])}) |
| 69 | + def __init__(self, G, nt, nv, dt=1., dr=1., twosided=True, |
| 70 | + saveGt=False, conj=False, prescaled=False, |
| 71 | + chunks=(None, None), compute=(False, False), |
| 72 | + todask=(False, False), dtype=None): |
80 | 73 |
|
| 74 | + if twosided and nt % 2 == 0: |
| 75 | + raise ValueError('nt must be odd number') |
| 76 | + |
| 77 | + # store G |
| 78 | + self.G = G |
| 79 | + self.nfmax, self.ns, self.nr = self.G.shape |
| 80 | + self.saveGt = saveGt |
| 81 | + if self.saveGt: |
| 82 | + self.GT = (G.transpose((0, 2, 1)).conj()).persist() |
| 83 | + |
| 84 | + # ensure that nfmax is not bigger than allowed |
| 85 | + self.nfft = int(np.ceil((nt + 1) / 2)) |
| 86 | + if self.nfmax > self.nfft: |
| 87 | + self.nfmax = self.nfft |
| 88 | + logging.warning('nfmax set equal to ceil[(nt+1)/2=%d]' % self.nfmax) |
| 89 | + |
| 90 | + # store other input parameters |
| 91 | + self.nt, self.nv = nt, nv |
| 92 | + self.dt, self.dr = dt, dr |
| 93 | + self.twosided = twosided |
| 94 | + self.conj = conj |
| 95 | + self.prescaled = prescaled |
| 96 | + self.dims = (self.nt, self.nr, self.nv) |
| 97 | + self.dimsd = (self.nt, self.ns, self.nv) |
| 98 | + self.dimsdf = (self.nfft, self.ns, self.nv) |
| 99 | + |
| 100 | + # find out dtype of G |
| 101 | + self.cdtype = self.G[0, 0, 0].dtype |
| 102 | + if dtype is None: |
| 103 | + self.dtype = np.real(np.ones(1, dtype=self.cdtype)).dtype |
| 104 | + else: |
| 105 | + self.dtype = dtype |
| 106 | + |
| 107 | + self.shape = (np.prod(self.dimsd), np.prod(self.dims)) |
| 108 | + self.compute = compute |
| 109 | + self.chunks = chunks |
| 110 | + self.todask = todask |
| 111 | + self.Op = None |
| 112 | + self.explicit = False |
| 113 | + |
| 114 | + def _matvec(self, x): |
| 115 | + # apply forward fft |
| 116 | + x = da.reshape(x, self.dims) |
| 117 | + if self.twosided: |
| 118 | + x = da.fft.ifftshift(x, axes=0) |
| 119 | + y = sqrt(1. / self.nt) * da.fft.rfft(x, n=self.nt, axis=0) |
| 120 | + y = y.astype(self.cdtype) |
| 121 | + y = y[:self.nfmax] |
| 122 | + |
| 123 | + # apply batched matrix mult |
| 124 | + y = y.rechunk((self.G.chunks[0], self.nr, self.nv)) |
| 125 | + if self.conj: |
| 126 | + y = y.conj() |
| 127 | + y = da.matmul(self.G, y) |
| 128 | + if self.conj: |
| 129 | + y = y.conj() |
| 130 | + if not self.prescaled: |
| 131 | + y *= self.dr * self.dt * np.sqrt(self.nt) |
| 132 | + |
| 133 | + # apply inverse fft |
| 134 | + y = da.pad(y, ((0, self.nfft - self.nfmax), (0, 0), (0, 0)), mode='constant') |
| 135 | + y = y.rechunk(self.dimsdf) |
| 136 | + y = sqrt(self.nt) * da.fft.irfft(y, n=self.nt, axis=0) |
| 137 | + y = y.astype(self.dtype) |
| 138 | + y = da.real(y) |
| 139 | + return y.ravel() |
| 140 | + |
| 141 | + def _rmatvec(self, x): |
| 142 | + # apply forward fft |
| 143 | + x = da.reshape(x, self.dimsd) |
| 144 | + y = sqrt(1. / self.nt) * da.fft.rfft(x, n=self.nt, axis=0) |
| 145 | + y = y.astype(self.cdtype) |
| 146 | + y = y[:self.nfmax] |
| 147 | + |
| 148 | + # apply batched matrix mult |
| 149 | + y = y.rechunk((self.G.chunks[0], self.nr, self.nv)) |
| 150 | + if self.saveGt: |
| 151 | + if self.conj: |
| 152 | + y = y.conj() |
| 153 | + y = da.matmul(self.GT, y) |
| 154 | + if self.conj: |
| 155 | + y = y.conj() |
| 156 | + else: |
| 157 | + if self.conj: |
| 158 | + y = da.matmul(y.transpose(0, 2, 1), self.G).transpose(0, 2, 1) |
| 159 | + else: |
| 160 | + y = da.matmul(y.transpose(0, 2, 1).conj(), self.G).transpose(0, 2, 1).conj() |
| 161 | + if not self.prescaled: |
| 162 | + y *= self.dr * self.dt * np.sqrt(self.nt) |
| 163 | + |
| 164 | + # apply inverse fft |
| 165 | + y = da.pad(y, ((0, self.nfft - self.nfmax), (0, 0), (0, 0)), mode='constant') |
| 166 | + y = y.rechunk(self.dimsdf) |
| 167 | + y = sqrt(self.nt) * da.fft.irfft(y, n=self.nt, axis=0) |
| 168 | + if self.twosided: |
| 169 | + y = da.fft.fftshift(y, axes=0) |
| 170 | + y = y.astype(self.dtype) |
| 171 | + y = da.real(y) |
| 172 | + return y.ravel() |
81 | 173 |
|
82 | 174 |
|
83 | 175 | def MDD(G, d, dt=0.004, dr=1., nfmax=None, wav=None, |
@@ -178,6 +270,7 @@ def MDD(G, d, dt=0.004, dr=1., nfmax=None, wav=None, |
178 | 270 | d = da.concatenate((da.zeros((nt - 1, ns)), d), axis=0) |
179 | 271 | else: |
180 | 272 | d = da.concatenate((da.zeros((nt - 1, ns, nv)), d), axis=0) |
| 273 | + d = d.rechunk(d.shape) |
181 | 274 |
|
182 | 275 | # Define MDC linear operator |
183 | 276 | MDCop = MDC(G, nt2, nv=nv, dt=dt, dr=dr, |
|
0 commit comments