-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil_mwno.py
427 lines (347 loc) · 15.9 KB
/
util_mwno.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
"""
Reference
----------
author: gaurav71531
source: https://github.com/gaurav71531/mwt-operator
reminder: slightly modified, e.g., file path, better output format, etc.
"""
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from functools import partial
from scipy.special import eval_legendre
from sympy import Poly, legendre, Symbol, chebyshevt
from typing import List, Tuple
import math
def legendreDer(k, x):
def _legendre(k, x):
return (2*k+1) * eval_legendre(k, x)
out = 0
for i in np.arange(k-1,-1,-2):
out += _legendre(i, x)
return out
def phi_(phi_c, x, lb = 0, ub = 1):
mask = np.logical_or(x<lb, x>ub) * 1.0
return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1-mask)
def get_phi_psi(k, base):
x = Symbol('x')
phi_coeff = np.zeros((k,k))
phi_2x_coeff = np.zeros((k,k))
if base == 'legendre':
for ki in range(k):
coeff_ = Poly(legendre(ki, 2*x-1), x).all_coeffs()
phi_coeff[ki,:ki+1] = np.flip(np.sqrt(2*ki+1) * np.array(coeff_).astype(np.float64))
coeff_ = Poly(legendre(ki, 4*x-1), x).all_coeffs()
phi_2x_coeff[ki,:ki+1] = np.flip(np.sqrt(2) * np.sqrt(2*ki+1) * np.array(coeff_).astype(np.float64))
psi1_coeff = np.zeros((k, k))
psi2_coeff = np.zeros((k, k))
for ki in range(k):
psi1_coeff[ki,:] = phi_2x_coeff[ki,:]
for i in range(k):
a = phi_2x_coeff[ki,:ki+1]
b = phi_coeff[i, :i+1]
prod_ = np.convolve(a, b)
prod_[np.abs(prod_)<1e-8] = 0
proj_ = (prod_ * 1/(np.arange(len(prod_))+1) * np.power(0.5, 1+np.arange(len(prod_)))).sum()
psi1_coeff[ki,:] -= proj_ * phi_coeff[i,:]
psi2_coeff[ki,:] -= proj_ * phi_coeff[i,:]
for j in range(ki):
a = phi_2x_coeff[ki,:ki+1]
b = psi1_coeff[j, :]
prod_ = np.convolve(a, b)
prod_[np.abs(prod_)<1e-8] = 0
proj_ = (prod_ * 1/(np.arange(len(prod_))+1) * np.power(0.5, 1+np.arange(len(prod_)))).sum()
psi1_coeff[ki,:] -= proj_ * psi1_coeff[j,:]
psi2_coeff[ki,:] -= proj_ * psi2_coeff[j,:]
a = psi1_coeff[ki,:]
prod_ = np.convolve(a, a)
prod_[np.abs(prod_)<1e-8] = 0
norm1 = (prod_ * 1/(np.arange(len(prod_))+1) * np.power(0.5, 1+np.arange(len(prod_)))).sum()
a = psi2_coeff[ki,:]
prod_ = np.convolve(a, a)
prod_[np.abs(prod_)<1e-8] = 0
norm2 = (prod_ * 1/(np.arange(len(prod_))+1) * (1-np.power(0.5, 1+np.arange(len(prod_))))).sum()
norm_ = np.sqrt(norm1 + norm2)
psi1_coeff[ki,:] /= norm_
psi2_coeff[ki,:] /= norm_
psi1_coeff[np.abs(psi1_coeff)<1e-8] = 0
psi2_coeff[np.abs(psi2_coeff)<1e-8] = 0
phi = [np.poly1d(np.flip(phi_coeff[i,:])) for i in range(k)]
psi1 = [np.poly1d(np.flip(psi1_coeff[i,:])) for i in range(k)]
psi2 = [np.poly1d(np.flip(psi2_coeff[i,:])) for i in range(k)]
elif base == 'chebyshev':
for ki in range(k):
if ki == 0:
phi_coeff[ki,:ki+1] = np.sqrt(2/np.pi)
phi_2x_coeff[ki,:ki+1] = np.sqrt(2/np.pi) * np.sqrt(2)
else:
coeff_ = Poly(chebyshevt(ki, 2*x-1), x).all_coeffs()
phi_coeff[ki,:ki+1] = np.flip(2/np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
coeff_ = Poly(chebyshevt(ki, 4*x-1), x).all_coeffs()
phi_2x_coeff[ki,:ki+1] = np.flip(np.sqrt(2) * 2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
phi = [partial(phi_, phi_coeff[i,:]) for i in range(k)]
x = Symbol('x')
kUse = 2*k
roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots()
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
# not needed for our purpose here, we use even k always to avoid
wm = np.pi / kUse / 2
psi1_coeff = np.zeros((k, k))
psi2_coeff = np.zeros((k, k))
psi1 = [[] for _ in range(k)]
psi2 = [[] for _ in range(k)]
for ki in range(k):
psi1_coeff[ki,:] = phi_2x_coeff[ki,:]
for i in range(k):
proj_ = (wm * phi[i](x_m) * np.sqrt(2)* phi[ki](2*x_m)).sum()
psi1_coeff[ki,:] -= proj_ * phi_coeff[i,:]
psi2_coeff[ki,:] -= proj_ * phi_coeff[i,:]
for j in range(ki):
proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2*x_m)).sum()
psi1_coeff[ki,:] -= proj_ * psi1_coeff[j,:]
psi2_coeff[ki,:] -= proj_ * psi2_coeff[j,:]
psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5)
psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5, ub = 1)
norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum()
norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum()
norm_ = np.sqrt(norm1 + norm2)
psi1_coeff[ki,:] /= norm_
psi2_coeff[ki,:] /= norm_
psi1_coeff[np.abs(psi1_coeff)<1e-8] = 0
psi2_coeff[np.abs(psi2_coeff)<1e-8] = 0
psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5+1e-16)
psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5+1e-16, ub = 1)
return phi, psi1, psi2
def get_filter(base, k):
def psi(psi1, psi2, i, inp):
mask = (inp<=0.5) * 1.0
return psi1[i](inp) * mask + psi2[i](inp) * (1-mask)
if base not in ['legendre', 'chebyshev']:
raise Exception('Base not supported')
x = Symbol('x')
H0 = np.zeros((k,k))
H1 = np.zeros((k,k))
G0 = np.zeros((k,k))
G1 = np.zeros((k,k))
PHI0 = np.zeros((k,k))
PHI1 = np.zeros((k,k))
phi, psi1, psi2 = get_phi_psi(k, base)
if base == 'legendre':
roots = Poly(legendre(k, 2*x-1)).all_roots()
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
wm = 1/k/legendreDer(k,2*x_m-1)/eval_legendre(k-1,2*x_m-1)
for ki in range(k):
for kpi in range(k):
H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum()
G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum()
H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum()
G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum()
PHI0 = np.eye(k)
PHI1 = np.eye(k)
elif base == 'chebyshev':
x = Symbol('x')
kUse = 2*k
roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots()
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
# not needed for our purpose here, we use even k always to avoid
wm = np.pi / kUse / 2
for ki in range(k):
for kpi in range(k):
H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum()
G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum()
H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum()
G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum()
PHI0[ki, kpi] = (wm * phi[ki](2*x_m) * phi[kpi](2*x_m)).sum() * 2
PHI1[ki, kpi] = (wm * phi[ki](2*x_m-1) * phi[kpi](2*x_m-1)).sum() * 2
PHI0[np.abs(PHI0)<1e-8] = 0
PHI1[np.abs(PHI1)<1e-8] = 0
H0[np.abs(H0)<1e-8] = 0
H1[np.abs(H1)<1e-8] = 0
G0[np.abs(G0)<1e-8] = 0
G1[np.abs(G1)<1e-8] = 0
return H0, H1, G0, G1, PHI0, PHI1
def get_initializer(name):
if name == 'xavier_normal':
init_ = partial(nn.init.xavier_normal_)
elif name == 'kaiming_uniform':
init_ = partial(nn.init.kaiming_uniform_)
elif name == 'kaiming_normal':
init_ = partial(nn.init.kaiming_normal_)
return init_
class sparseKernel3d(nn.Module):
def __init__(self,
k, alpha, c=1,
nl = 1,
initializer = None,
**kwargs):
super(sparseKernel3d,self).__init__()
self.k = k
self.conv = self.convBlock(alpha*k**2, alpha*k**2)
self.Lo = nn.Linear(alpha*k**2, c*k**2)
def forward(self, x):
B, Nx, Ny, T, c, ich = x.shape # (B, Nx, Ny, T, c, k**2)
x = x.view(B, Nx, Ny, T, -1)
x = x.permute(0, 4, 1, 2, 3)
x = self.conv(x)
x = x.permute(0, 2, 3, 4, 1)
x = self.Lo(x)
x = x.view(B, Nx, Ny, T, c, ich)
return x
def convBlock(self, ich, och):
net = nn.Sequential(
nn.Conv3d(och, och, 3, 1, 1),
nn.ReLU(inplace=True),
)
return net
class sparseKernelFT3d(nn.Module):
def __init__(self,
k, alpha, c=1,
nl = 1,
initializer = None,
**kwargs):
super(sparseKernelFT3d, self).__init__()
self.modes = alpha
self.weights1 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, dtype=torch.cfloat))
self.weights2 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, dtype=torch.cfloat))
self.weights3 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, dtype=torch.cfloat))
self.weights4 = nn.Parameter(torch.zeros(c*k**2, c*k**2, self.modes, self.modes, self.modes, dtype=torch.cfloat))
nn.init.xavier_normal_(self.weights1)
nn.init.xavier_normal_(self.weights2)
nn.init.xavier_normal_(self.weights3)
nn.init.xavier_normal_(self.weights4)
self.Lo = nn.Linear(c*k**2, c*k**2)
self.k = k
def forward(self, x):
B, Nx, Ny, T, c, ich = x.shape # (B, N, N, T, c, k^2)
x = x.view(B, Nx, Ny, T, -1)
x = x.permute(0, 4, 1, 2, 3)
x_fft = torch.fft.rfftn(x, dim = [-3, -2, -1])
# Multiply relevant Fourier modes
l1 = min(self.modes, Nx//2+1)
l2 = min(self.modes, Ny//2+1)
out_ft = torch.zeros(B, c*ich, Nx, Ny, T//2 +1, device=x.device, dtype=torch.cfloat)
out_ft[:, :, :l1, :l2, :self.modes] = self.compl_mul3d(
x_fft[:, :, :l1, :l2, :self.modes], self.weights1[:, :, :l1, :l2, :])
out_ft[:, :, -l1:, :l2, :self.modes] = self.compl_mul3d(
x_fft[:, :, -l1:, :l2, :self.modes], self.weights2[:, :, :l1, :l2, :])
out_ft[:, :, :l1, -l2:, :self.modes] = self.compl_mul3d(
x_fft[:, :, :l1, -l2:, :self.modes], self.weights3[:, :, :l1, :l2, :])
out_ft[:, :, -l1:, -l2:, :self.modes] = self.compl_mul3d(
x_fft[:, :, -l1:, -l2:, :self.modes], self.weights4[:, :, :l1, :l2, :])
#Return to physical space
x = torch.fft.irfftn(out_ft, s = (Nx, Ny, T))
x = x.permute(0, 2, 3, 4, 1)
x = F.relu(x)
x = self.Lo(x)
x = x.view(B, Nx, Ny, T, c, ich)
return x
def compl_mul3d(self, input, weights):
# (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t)
return torch.einsum("bixyz,ioxyz->boxyz", input, weights)
class MWT_CZ3d(nn.Module):
def __init__(self,
k = 3, alpha = 5,
L = 0, c = 1,
base = 'legendre',
initializer = None,
**kwargs):
super(MWT_CZ3d, self).__init__()
self.k = k
self.L = L
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
H0r = H0@PHI0
G0r = G0@PHI0
H1r = H1@PHI1
G1r = G1@PHI1
H0r[np.abs(H0r)<1e-8]=0
H1r[np.abs(H1r)<1e-8]=0
G0r[np.abs(G0r)<1e-8]=0
G1r[np.abs(G1r)<1e-8]=0
self.A = sparseKernelFT3d(k, alpha, c)
self.B = sparseKernel3d(k, c, c)
self.C = sparseKernel3d(k, c, c)
self.T0 = nn.Linear(c*k**2, c*k**2)
if initializer is not None:
self.reset_parameters(initializer)
self.register_buffer('ec_s', torch.Tensor(
np.concatenate((np.kron(H0, H0).T,
np.kron(H0, H1).T,
np.kron(H1, H0).T,
np.kron(H1, H1).T,
), axis=0)))
self.register_buffer('ec_d', torch.Tensor(
np.concatenate((np.kron(G0, G0).T,
np.kron(G0, G1).T,
np.kron(G1, G0).T,
np.kron(G1, G1).T,
), axis=0)))
self.register_buffer('rc_ee', torch.Tensor(
np.concatenate((np.kron(H0r, H0r),
np.kron(G0r, G0r),
), axis=0)))
self.register_buffer('rc_eo', torch.Tensor(
np.concatenate((np.kron(H0r, H1r),
np.kron(G0r, G1r),
), axis=0)))
self.register_buffer('rc_oe', torch.Tensor(
np.concatenate((np.kron(H1r, H0r),
np.kron(G1r, G0r),
), axis=0)))
self.register_buffer('rc_oo', torch.Tensor(
np.concatenate((np.kron(H1r, H1r),
np.kron(G1r, G1r),
), axis=0)))
def forward(self, x):
B, Nx, Ny, T, c, ich = x.shape # (B, Nx, Ny, T, c, k**2)
# Padded zeros are considered
ns = max(math.ceil(np.log2(Nx)), math.ceil(np.log2(Ny)))
Ud = torch.jit.annotate(List[Tensor], [])
Us = torch.jit.annotate(List[Tensor], [])
# decompose
for i in range(ns-self.L):
d, x = self.wavelet_transform(x)
Ud += [self.A(d) + self.B(x)]
Us += [self.C(d)]
x = self.T0(x.view(B, 2**self.L, 2**self.L, T, -1)).view(
B, 2**self.L, 2**self.L, T, c, ich) # coarsest scale transform
# reconstruct
for i in range(ns-1-self.L,-1,-1):
# De-padding
x = x[:, :Us[i].shape[1], :Us[i].shape[2], ...]
x = x + Us[i]
x = torch.cat((x, Ud[i]), -1)
x = self.evenOdd(x)
return x
def wavelet_transform(self, x):
# Padding to even length
x = F.pad(x, [0, 0, 0, 0, 0, 0,
0, x.shape[-4]%2, 0, x.shape[-5]%2])
xa = torch.cat([x[:, ::2 , ::2 , :, :, :],
x[:, ::2 , 1::2, :, :, :],
x[:, 1::2, ::2 , :, :, :],
x[:, 1::2, 1::2, :, :, :]
], -1)
d = torch.matmul(xa, self.ec_d)
s = torch.matmul(xa, self.ec_s)
return d, s
def evenOdd(self, x):
B, Nx, Ny, T, c, ich = x.shape # (B, Nx, Ny, c, k**2)
assert ich == 2*self.k**2
x_ee = torch.matmul(x, self.rc_ee)
x_eo = torch.matmul(x, self.rc_eo)
x_oe = torch.matmul(x, self.rc_oe)
x_oo = torch.matmul(x, self.rc_oo)
x = torch.zeros(B, Nx*2, Ny*2, T, c, self.k**2,
device = x.device)
x[:, ::2 , ::2 , :, :, :] = x_ee
x[:, ::2 , 1::2, :, :, :] = x_eo
x[:, 1::2, ::2 , :, :, :] = x_oe
x[:, 1::2, 1::2, :, :, :] = x_oo
return x
def reset_parameters(self, initializer):
initializer(self.T0.weight)