forked from dgaddy/silent_speech
-
Notifications
You must be signed in to change notification settings - Fork 4
/
s4.py
1561 lines (1328 loc) · 58.1 KB
/
s4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Standalone version of Structured (Sequence) State Space (S4) model.
HazyResearch/state-spaces is licensed under the
Apache License 2.0
https://github.com/HazyResearch/state-spaces/blob/ede0b53fe4bcfccf185c32b99880463b2a2cd085/src/models/s4/s4.py
"""
import sys
import logging
from functools import partial
import math
import numpy as np
from scipy import special as ss
import torch
import torch.nn as nn
import torch.nn.functional as F
#from pytorch_lightning.utilities import rank_zero_only # this line causes recognition_model.py to crash
from einops import rearrange, repeat
import opt_einsum as oe
contract = oe.contract
contract_expression = oe.contract_expression
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
"""Initializes multi-GPU-friendly python logger."""
logger = logging.getLogger(name)
logger.setLevel(level)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
#for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"):
# setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger
log = get_logger(__name__)
""" Cauchy and Vandermonde kernels """
try: # Try CUDA extension
# Guy: looks like import path can't be found unless you add repo to your path (even after install)?
sys.path.append('/oak/stanford/groups/shenoy/ghwilson/repos/state-spaces/extensions/cauchy/')
from cauchy import cauchy_mult
#from extensions.cauchy.cauchy import cauchy_mult
has_cauchy_extension = True
except:
log.warning(
"CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%"
)
has_cauchy_extension = False
try: # Try pykeops
import pykeops
from pykeops.torch import Genred
has_pykeops = True
log.info("Pykeops installation found.")
def _broadcast_dims(*tensors):
max_dim = max([len(tensor.shape) for tensor in tensors])
tensors = [tensor.view((1,)*(max_dim-len(tensor.shape))+tensor.shape) for tensor in tensors]
return tensors
def cauchy_conj(v, z, w):
""" Pykeops version """
expr_num = 'z * ComplexReal(v) - Real2Complex(Sum(v * w))'
expr_denom = 'ComplexMult(z-w, z-Conj(w))'
cauchy_mult = Genred(
f'ComplexDivide({expr_num}, {expr_denom})',
[
'v = Vj(2)',
'z = Vi(2)',
'w = Vj(2)',
],
reduction_op='Sum',
axis=1,
)
v, z, w = _broadcast_dims(v, z, w)
v = _c2r(v)
z = _c2r(z)
w = _c2r(w)
r = 2*cauchy_mult(v, z, w, backend='GPU')
return _r2c(r)
def log_vandermonde(v, x, L):
expr = 'ComplexMult(v, ComplexExp(ComplexMult(x, l)))'
vandermonde_mult = Genred(
expr,
[
'v = Vj(2)',
'x = Vj(2)',
'l = Vi(2)',
],
reduction_op='Sum',
axis=1,
)
l = torch.arange(L).to(x)
v, x, l = _broadcast_dims(v, x, l)
v = _c2r(v)
x = _c2r(x)
l = _c2r(l)
r = vandermonde_mult(v, x, l, backend='GPU')
return 2*_r2c(r).real
def log_vandermonde_transpose(u, v, x, L):
"""
u: ... H L
v: ... H N
x: ... H N
Returns: ... H N
V = Vandermonde(a, L) : (H N L)
contract_L(V * u * v)
"""
expr = 'ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))'
vandermonde_mult = Genred(
expr,
[
'u = Vj(2)',
'v = Vi(2)',
'x = Vi(2)',
'l = Vj(2)',
],
reduction_op='Sum',
axis=1,
)
l = torch.arange(L).to(x)
u, v, x, l = _broadcast_dims(u, v, x, l)
u = _c2r(u)
v = _c2r(v)
x = _c2r(x)
l = _c2r(l)
r = vandermonde_mult(u, v, x, l, backend='GPU')
return _r2c(r)
except ImportError:
has_pykeops = False
if not has_cauchy_extension:
log.warning(
"Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency."
)
def cauchy_naive(v, z, w):
"""
v, w: (..., N)
z: (..., L)
returns: (..., L)
"""
cauchy_matrix = v.unsqueeze(-1) / (z.unsqueeze(-2) - w.unsqueeze(-1)) # (... N L)
return torch.sum(cauchy_matrix, dim=-2)
# Vandermonde functions
log.warning(
"Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency."
)
def log_vandermonde(v, x, L):
"""
v: (..., N)
x: (..., N)
returns: (..., L) \sum v x^l
"""
vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L)
vandermonde_prod = contract('... n, ... n l -> ... l', v, vandermonde_matrix) # (... L)
return 2*vandermonde_prod.real
def log_vandermonde_transpose(u, v, x, L):
vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L)
vandermonde_prod = contract('... l, ... n, ... n l -> ... n', u.to(x), v.to(x), vandermonde_matrix) # (... L)
return vandermonde_prod
_conj = lambda x: torch.cat([x, x.conj()], dim=-1)
_c2r = torch.view_as_real
_r2c = torch.view_as_complex
if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10):
_resolve_conj = lambda x: x.conj().resolve_conj()
else:
_resolve_conj = lambda x: x.conj()
""" Simple nn.Module components """
def Activation(activation=None, dim=-1):
if activation in [ None, 'id', 'identity', 'linear' ]:
return nn.Identity()
elif activation == 'tanh':
return nn.Tanh()
elif activation == 'relu':
return nn.ReLU()
elif activation == 'gelu':
return nn.GELU()
elif activation in ['swish', 'silu']:
return nn.SiLU()
elif activation == 'glu':
return nn.GLU(dim=dim)
elif activation == 'sigmoid':
return nn.Sigmoid()
else:
raise NotImplementedError("hidden activation '{}' is not implemented".format(activation))
def LinearActivation(
d_input, d_output, bias=True,
transposed=False,
activation=None,
activate=False, # Apply activation as part of this module
**kwargs,
):
""" Returns a linear nn.Module with control over axes order, initialization, and activation """
# Construct core module
linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear
if activation == 'glu': d_output *= 2
linear = linear_cls(d_input, d_output, bias=bias, **kwargs)
if activate and activation is not None:
activation = Activation(activation, dim=-2 if transposed else -1)
linear = nn.Sequential(linear, activation)
return linear
class DropoutNd(nn.Module):
def __init__(self, p: float = 0.5, tie=True, transposed=True):
"""
tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
"""
super().__init__()
if p < 0 or p >= 1:
raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p))
self.p = p
self.tie = tie
self.transposed = transposed
self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p)
def forward(self, X):
""" X: (batch, dim, lengths...) """
if self.training:
if not self.transposed: X = rearrange(X, 'b d ... -> b ... d')
mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape
mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p
X = X * mask * (1.0/(1-self.p))
if not self.transposed: X = rearrange(X, 'b ... d -> b d ...')
return X
return X
""" Misc functional utilities """
def power(L, A, v=None):
""" Compute A^L and the scan sum_i A^i v_i
A: (..., N, N)
v: (..., N, L)
"""
I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device)
powers = [A]
l = 1
while True:
if L % 2 == 1: I = powers[-1] @ I
L //= 2
if L == 0: break
l *= 2
powers.append(powers[-1] @ powers[-1])
if v is None: return I
# Invariants:
# powers[-1] := A^l
# l := largest po2 at most L
# Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A
# We do this reverse divide-and-conquer for efficiency reasons:
# 1) it involves fewer padding steps for non-po2 L
# 2) it involves more contiguous arrays
# Take care of edge case for non-po2 arrays
# Note that this initial step is a no-op for the case of power of 2 (l == L)
k = v.size(-1) - l
v_ = powers.pop() @ v[..., l:]
v = v[..., :l]
v[..., :k] = v[..., :k] + v_
# Handle reduction for power of 2
while v.size(-1) > 1:
v = rearrange(v, '... (z l) -> ... z l', z=2)
v = v[..., 0, :] + powers.pop() @ v[..., 1, :]
return I, v.squeeze(-1)
""" HiPPO utilities """
def transition(measure, N):
""" A, B transition matrices for different measures """
# Legendre (translated)
if measure == 'legt':
Q = np.arange(N, dtype=np.float64)
R = (2*Q + 1) ** .5
j, i = np.meshgrid(Q, Q)
A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :]
B = R[:, None]
A = -A
# Halve again for timescale correctness
A *= 0.5
B *= 0.5
# Legendre (scaled)
elif measure == 'legs':
q = np.arange(N, dtype=np.float64)
col, row = np.meshgrid(q, q)
r = 2 * q + 1
M = -(np.where(row >= col, r, 0) - np.diag(q))
T = np.sqrt(np.diag(2 * q + 1))
A = T @ M @ np.linalg.inv(T)
B = np.diag(T)[:, None]
B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
elif measure == 'legsd':
# Essentially equivalent to S4D-LegS
q = np.arange(N, dtype=np.float64)
col, row = np.meshgrid(q, q)
r = 2 * q + 1
M = -(np.where(row >= col, r, 0) - np.diag(q))
T = np.sqrt(np.diag(2 * q + 1))
A = T @ M @ np.linalg.inv(T)
B = np.diag(T)[:, None]
B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
A += .5 * B*B[None, :, 0]
B = B / 2.0
elif measure in ['fourier_diag', 'foud']:
# Essentially equivalent to S4D-Lin
freqs = np.arange(N//2)
d = np.stack([freqs, np.zeros(N//2)], axis=-1).reshape(-1)[:-1]
A = 2*np.pi*(-np.diag(d, 1) + np.diag(d, -1))
A = A - .5 * np.eye(N)
B = np.zeros(N)
B[0::2] = 2**.5
B[0] = 1
B = B[:, None]
elif measure in ['fourier', 'fout']:
freqs = np.arange(N//2)
d = np.stack([np.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
A = np.pi*(-np.diag(d, 1) + np.diag(d, -1))
B = np.zeros(N)
B[0::2] = 2**.5
B[0] = 1
# Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
A = A - B[:, None] * B[None, :]
B = B[:, None]
else:
raise NotImplementedError
return A, B
def rank_correction(measure, N, rank=1, dtype=torch.float):
""" Return low-rank matrix L such that A + L is normal """
if measure == 'legs':
assert rank >= 1
P = torch.sqrt(.5+torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N)
elif measure == 'legt':
assert rank >= 2
P = torch.sqrt(1+2*torch.arange(N, dtype=dtype)) # (N)
P0 = P.clone()
P0[0::2] = 0.
P1 = P.clone()
P1[1::2] = 0.
P = torch.stack([P0, P1], dim=0) # (2 N)
P *= 2**(-0.5) # Halve the rank correct just like the original matrix was halved
elif measure in ['fourier', 'fout']:
P = torch.zeros(N)
P[0::2] = 2**.5
P[0] = 1
P = P.unsqueeze(0)
elif measure in ['fourier_diag', 'foud', 'legsd']:
P = torch.zeros(1, N, dtype=dtype)
else: raise NotImplementedError
d = P.size(0)
if rank > d:
P = torch.cat([P, torch.zeros(rank-d, N, dtype=dtype)], dim=0) # (rank N)
return P
def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True):
""" Return w, p, q, V, B such that
(w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V
i.e. A = V[w - p q^*]V^*, B = V B
"""
assert dtype == torch.float or dtype == torch.double
cdtype = torch.cfloat if dtype == torch.float else torch.cdouble
A, B = transition(measure, N)
A = torch.as_tensor(A, dtype=dtype) # (N, N)
B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,)
P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N)
AP = A + torch.sum(P.unsqueeze(-2)*P.unsqueeze(-1), dim=-3)
# We require AP to be nearly skew-symmetric
_A = AP + AP.transpose(-1, -2)
if (err := torch.sum((_A - _A[0,0]*torch.eye(N))**2) / N) > 1e-5: # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5):
print("WARNING: HiPPO matrix not skew symmetric", err)
# Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately
# Imaginary part can use eigh instead of eig
w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True)
# Diagonalize in double precision
if diagonalize_precision: AP = AP.to(torch.double)
w_im, V = torch.linalg.eigh(AP*-1j) # (..., N) (..., N, N)
if diagonalize_precision: w_im, V = w_im.to(cdtype), V.to(cdtype)
w = w_re + 1j * w_im
# Check: V w V^{-1} = A
# print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))
# Only keep half of each conjugate pair
_, idx = torch.sort(w.imag)
w_sorted = w[idx]
V_sorted = V[:, idx]
# There is an edge case when eigenvalues can be 0, which requires some machinery to handle
# We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case)
V = V_sorted[:, :N//2]
w = w_sorted[:N//2]
assert w[-2].abs() > 1e-4, "Only 1 zero eigenvalue allowed in diagonal part of A"
if w[-1].abs() < 1e-4:
V[:, -1] = 0.
V[0, -1] = 2**-0.5
V[1, -1] = 2**-0.5 * 1j
_AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)
if ((err := torch.sum((2*_AP.real-AP)**2)/N) > 1e-5):
print("Warning: Diagonalization of A matrix not numerically precise - error", err)
# print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2))
V_inv = V.conj().transpose(-1, -2)
B = contract('ij, j -> i', V_inv, B.to(V)) # V^* B
P = contract('ij, ...j -> ...i', V_inv, P.to(V)) # V^* P
return w, P, B, V
def dplr(scaling, N, rank=1, H=1, dtype=torch.float, real_scale=1.0, imag_scale=1.0, random_real=False, random_imag=False, normalize=False, diagonal=True, random_B=False):
assert dtype == torch.float or dtype == torch.double
dtype = torch.cfloat if dtype == torch.float else torch.cdouble
pi = torch.tensor(math.pi)
if random_real:
real_part = torch.rand(H, N//2)
else:
real_part = .5 * torch.ones(H, N//2)
if random_imag:
imag_part = N//2 * torch.rand(H, N//2)
else:
imag_part = repeat(torch.arange(N//2), 'n -> h n', h=H)
real_part = real_scale * real_part
if scaling == 'random':
imag_part = torch.randn(H, N//2)
elif scaling == 'real':
imag_part = 0 * imag_part
real_part = 1 + repeat(torch.arange(N//2), 'n -> h n', h=H)
elif scaling in ['linear', 'lin']:
imag_part = pi * imag_part
elif scaling in ['inverse', 'inv']: # Based on asymptotics of the default HiPPO matrix
imag_part = 1/pi * N * (N/(1+2*imag_part)-1)
elif scaling in ['inverse2', 'inv2']:
imag_part = 1/pi * N * (N/(1+imag_part)-1)
elif scaling in ['quadratic', 'quad']:
imag_part = 1/pi * (1+2*imag_part)**2
elif scaling in ['legs', 'hippo']:
w, _, _, _ = nplr('legsd', N)
imag_part = w.imag
else: raise NotImplementedError
imag_part = imag_scale * imag_part
w = -real_part + 1j * imag_part
# Initialize B
if random_B:
B = torch.randn(H, N//2, dtype=dtype)
else:
B = torch.ones(H, N//2, dtype=dtype)
if normalize:
norm = -B/w # (H, N) # Result if you integrate the kernel with constant 1 function
zeta = 2*torch.sum(torch.abs(norm)**2, dim=-1, keepdim=True) # Variance with a random C vector
B = B / zeta**.5
P = torch.randn(rank, H, N//2, dtype=dtype)
if diagonal: P = P * 0.0
V = torch.eye(N, dtype=dtype)[: :N//2] # Only used in testing
V = repeat(V, 'n m -> h n m', h=H)
return w, P, B, V
def ssm(measure, N, R, H, **ssm_args):
"""Dispatcher to create single SSM initialization
N: state size
R: rank (for DPLR parameterization)
H: number of independent SSM copies
"""
if measure == "dplr":
w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args)
elif measure.startswith("diag"):
args = measure.split("-")
assert args[0] == "diag" and len(args) > 1
scaling = args[1]
w, P, B, V = dplr(scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args)
else:
w, P, B, V = nplr(measure, N, R, **ssm_args)
w = repeat(w, 'n -> s n', s=H)
P = repeat(P, 'r n -> r s n', s=H)
B = repeat(B, 'n -> s n', s=H)
V = repeat(V, 'n m -> s n m', s=H)
return w, P, B, V
combinations = {
'hippo': ['legs', 'fourier'],
'diag': ['diag-inv', 'diag-lin'],
'all': ['legs', 'fourier', 'diag-inv', 'diag-lin'],
}
def combination(measures, N, R, S, **ssm_args):
if isinstance(measures, str):
measures = combinations[measures] if measures in combinations else [measures]
assert S % len(measures) == 0, f"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures"
w, P, B, V = zip(
*[ssm(measure, N, R, S // len(measures), **ssm_args) for measure in measures]
)
w = torch.cat(w, dim=0) # (S N)
P = torch.cat(P, dim=1) # (R S N)
B = torch.cat(B, dim=0) # (S N)
V = torch.cat(V, dim=0) # (S N N)
return w, P, B, V
class OptimModule(nn.Module):
""" Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """
def register(self, name, tensor, lr=None):
"""Register a tensor with a configurable learning rate and 0 weight decay"""
if lr == 0.0:
self.register_buffer(name, tensor)
else:
self.register_parameter(name, nn.Parameter(tensor))
optim = {"weight_decay": 0.0}
if lr is not None: optim["lr"] = lr
setattr(getattr(self, name), "_optim", optim)
class SSKernelNPLR(OptimModule):
""" Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR)
"""
@torch.no_grad()
def _setup_C(self, L):
""" Construct C~ from C
Two modes are supported: go directly to length L if self.L is 1, or length is doubled
"""
if self.L.item() == 0:
if self.verbose: log.info(f"S4: Initializing kernel to length {L}")
double_length = False
elif L > self.L.item(): # 2*int(self.L) == L:
if self.verbose: log.info(f"S4: Doubling length from L = {self.L.item()} to {2*self.L.item()}")
double_length = True
L = self.L.item() # Convenience for the math below
else: return
C = _r2c(self.C)
dA, _ = self._setup_state()
dA_L = power(L, dA)
# Multiply C by I - dA_L
C_ = _conj(C)
prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_)
if double_length: prod = -prod # Multiply by I + dA_L instead
C_ = C_ - prod
C_ = C_[..., :self.N] # Take conjugate pairs again
self.C.copy_(_c2r(C_))
self.L = 2*self.L if double_length else self.L+L # Preserve type/device
def _omega(self, L, dtype, device, cache=True):
""" Calculate (and cache) FFT nodes and their "unprocessed" version with the bilinear transform
This should be called everytime the internal length self.L changes """
# Use cached if available
if cache and hasattr(self, 'omega') and self.omega.size(-1) == L//2+1:
return self.omega, self.z
omega = torch.tensor(
np.exp(-2j * np.pi / (L)), dtype=dtype, device=device
) # \omega_{2L}
omega = omega ** torch.arange(0, L // 2 + 1, device=device)
z = 2 * (1 - omega) / (1 + omega)
# Cache if necessary
if cache:
self.omega = omega
self.z = z
return omega, z
def __init__(
self,
w, P, B, C, log_dt,
L=None, # starting/maximum length of kernel
lr=None,
verbose=False,
keops=False,
real_type='exp', # ['none' | 'exp' | 'relu' | sigmoid']
real_tolerance=1e-3,
bandlimit=None,
):
"""
L: Maximum length; this module computes an SSM kernel of length L
A is represented by diag(w) - PP^*
w: (S, N) diagonal part
P: (R, S, N) low-rank part
B: (S, N)
C: (C, H, N)
dt: (H) timescale per feature
lr: [dict | float | None] hook to set lr of special parameters (A, B, dt)
Dimensions:
N (or d_state): state size
H (or d_model): total SSM copies
S (or n_ssm): number of trainable copies of (A, B, dt); must divide H
R (or rank): rank of low-rank part
C (or channels): system is 1-dim to C-dim
The forward pass of this Module returns a tensor of shape (C, H, L)
Note: tensor shape N here denotes half the true state size, because of conjugate symmetry
"""
super().__init__()
self.verbose = verbose
self.keops = keops
self.bandlimit = bandlimit
self.real_type = real_type
self.real_tolerance = real_tolerance
# Rank of low-rank correction
self.rank = P.shape[-3]
assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1)
self.H = log_dt.size(-1)
self.N = w.size(-1)
# Check different SSM inits
assert w.size(-2) == P.size(-2) == B.size(-2) # n_ssm
assert self.H % w.size(0) == 0
self.n_ssm = w.size(0)
self.repeat = self.H // w.size(0) # Each trainable SSM needs to be duplicated this many times
# Broadcast everything to correct shapes
C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (C, H, N)
B = B.unsqueeze(0) # (1, 1, N)
# Register parameters
self.C = nn.Parameter(_c2r(_resolve_conj(C)))
if lr is None or isinstance(lr, float): lr_dict = {}
else: lr_dict, lr = lr, None
self.register("log_dt", log_dt, lr_dict.get('dt', lr))
self.register("B", _c2r(B), lr_dict.get('B', lr))
self.register("P", _c2r(P), lr_dict.get('A', lr))
self.register("inv_w_real", self._w_init(w.real), lr_dict.get('A', lr))
self.register("w_imag", w.imag, lr_dict.get('A', lr))
self.l_max = L
self.register_buffer('L', torch.tensor(0)) # Internal length
def _w_init(self, w_real):
w_real = torch.clamp(w_real, max=-self.real_tolerance)
if self.real_type == 'none':
return -w_real
elif self.real_type == 'exp':
return torch.log(-w_real) # Some of the HiPPO methods have real part 0
elif self.real_type == 'relu':
return -w_real
elif self.real_type == 'sigmoid':
return torch.logit(-w_real)
elif self.real_type == 'softplus':
return torch.log(torch.exp(-w_real)-1)
else: raise NotImplementedError
def _w(self):
# Get the internal w (diagonal) parameter
if self.real_type == 'none':
w_real = -self.inv_w_real
elif self.real_type == 'exp':
w_real = -torch.exp(self.inv_w_real)
elif self.real_type == 'relu':
w_real = -F.relu(self.inv_w_real)
elif self.real_type == 'sigmoid':
w_real = -F.sigmoid(self.inv_w_real)
elif self.real_type == 'softplus':
w_real = -F.softplus(self.inv_w_real)
else: raise NotImplementedError
w = w_real + 1j * self.w_imag
return w
def forward(self, state=None, rate=1.0, L=None):
"""
state: (B, H, N) initial state
rate: sampling rate factor
L: target length
returns:
(C, H, L) convolution kernel (generally C=1)
(B, H, L) output from initial state
"""
# Initialize C~ if necessary (done in forward pass so it's on the correct device)
if self.L.item() == 0 and self.l_max is not None and self.l_max > 0:
self._setup_C(self.l_max)
# Handle sampling rate logic
# The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) frequency rate
if L is None:
L = round(self.L.item() / rate)
# Increase the internal length if needed
continuous_L = round(rate*L)
while continuous_L > self.L.item():
self._setup_C(continuous_L)
discrete_L = round(self.L.item()/rate)
dt = torch.exp(self.log_dt) * rate
B = _r2c(self.B)
C = _r2c(self.C)
P = _r2c(self.P)
Q = P.conj()
w = self._w() # (n_ssm, N)
# Address bandlimiting
if self.bandlimit is not None:
freqs = w.imag.abs() / (2*math.pi) # (H, N)
freqs = dt[:, None] / rate * freqs # (H, N)
mask = torch.where(freqs < self.bandlimit * .5, 1, 0)
C = C * mask
# Get FFT nodes of right length
omega, z = self._omega(discrete_L, dtype=w.dtype, device=w.device, cache=(rate==1.0))
# Broadcast parameters to same hidden features H
B = repeat(B, '1 t n -> 1 (v t) n', v=self.repeat)
P = repeat(P, 'r t n -> r (v t) n', v=self.repeat)
Q = repeat(Q, 'r t n -> r (v t) n', v=self.repeat)
w = repeat(w, 't n -> (v t) n', v=self.repeat)
# Augment B
if state is not None:
# Have to "unbilinear" the state to put it into the same "type" as B
# Compute 1/dt * (I + dt/2 A) @ state
# Can do this without expanding (maybe minor speedup using conj symmetry in theory), but it's easier to read this way
s = _conj(state) if state.size(-1) == self.N else state # (B H N)
sA = (
s * _conj(w) # (B H N)
- contract('bhm, rhm, rhn -> bhn', s, _conj(Q), _conj(P))
)
s = s / dt.unsqueeze(-1) + sA / 2
s = s[..., :self.N]
B = torch.cat([s, B], dim=-3) # (B+1, H, N)
# Incorporate dt into A
w = w * dt.unsqueeze(-1) # (H N)
# Stack B and p, C and q for convenient batching
B = torch.cat([B, P], dim=-3) # (B+1+R, H, N)
C = torch.cat([C, Q], dim=-3) # (C+R, H, N)
# Incorporate B and C batch dimensions
v = B.unsqueeze(-3) * C.unsqueeze(-4) # (B+1+R, C+R, H, N)
# Calculate resolvent at omega
if has_cauchy_extension and z.dtype == torch.cfloat and not self.keops:
r = cauchy_mult(v, z, w, symmetric=True)
elif has_pykeops:
r = cauchy_conj(v, z, w)
else:
r = cauchy_naive(v, z, w)
r = r * dt[None, None, :, None] # (B+1+R, C+R, H, L)
# Low-rank Woodbury correction
if self.rank == 1:
k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / (1 + r[-1:, -1:, :, :])
elif self.rank == 2:
r00 = r[: -self.rank, : -self.rank, :, :]
r01 = r[: -self.rank, -self.rank :, :, :]
r10 = r[-self.rank :, : -self.rank, :, :]
r11 = r[-self.rank :, -self.rank :, :, :]
det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[:1, 1:, :, :] * r11[1:, :1, :, :]
s = (
r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :]
+ r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :]
- r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :]
- r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :]
)
s = s / det
k_f = r00 - s
else:
r00 = r[:-self.rank, :-self.rank, :, :]
r01 = r[:-self.rank, -self.rank:, :, :]
r10 = r[-self.rank:, :-self.rank, :, :]
r11 = r[-self.rank:, -self.rank:, :, :]
r11 = rearrange(r11, "a b h n -> h n a b")
r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11)
r11 = rearrange(r11, "h n a b -> a b h n")
k_f = r00 - torch.einsum("i j h n, j k h n, k l h n -> i l h n", r01, r11, r10)
# Final correction for the bilinear transform
k_f = k_f * 2 / (1 + omega)
# Move from frequency to coefficients
k = torch.fft.irfft(k_f, n=discrete_L) # (B+1, C, H, L)
# # Truncate to target length
k = k[..., :L]
if state is not None:
k_state = k[:-1, :, :, :] # (B, C, H, L)
else:
k_state = None
k_B = k[-1, :, :, :] # (C H L)
return k_B, k_state
@torch.no_grad()
def _setup_linear(self):
""" Create parameters that allow fast linear stepping of state """
w = self._w()
B = _r2c(self.B) # (H N)
P = _r2c(self.P)
Q = P.conj()
# Repeat w shape properly
B = repeat(B, '1 t n -> 1 (v t) n', v=self.repeat)
P = repeat(P, 'r t n -> r (v t) n', v=self.repeat)
Q = repeat(Q, 'r t n -> r (v t) n', v=self.repeat)
w = repeat(w, 't n -> (v t) n', v=self.repeat)
# Prepare Linear stepping
dt = torch.exp(self.log_dt)
D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N)
R = (torch.eye(self.rank, dtype=w.dtype, device=w.device) + 2*contract('r h n, h n, s h n -> h r s', Q, D, P).real) # (H R R)
Q_D = rearrange(Q*D, 'r h n -> h r n')
try:
R = torch.linalg.solve(R, Q_D) # (H R N)
except:
R = torch.tensor(np.linalg.solve(R.to(Q_D).contiguous().detach().cpu(), Q_D.contiguous().detach().cpu())).to(Q_D)
R = rearrange(R, 'h r n -> r h n')
self.step_params = {
"D": D, # (H N)
"R": R, # (R H N)
"P": P, # (R H N)
"Q": Q, # (R H N)
"B": B, # (1 H N)
"E": 2.0 / dt.unsqueeze(-1) + w, # (H N)
}
def _step_state_linear(self, u=None, state=None):
"""
Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization.
Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster
u: (H) input
state: (H, N/2) state with conjugate pairs
Optionally, the state can have last dimension N
Returns: same shape as state
"""
C = _r2c(self.C) # View used for dtype/device
if u is None: # Special case used to find dA
u = torch.zeros(self.H, dtype=C.dtype, device=C.device)
if state is None: # Special case used to find dB
state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device)
step_params = self.step_params.copy()
if state.size(-1) == self.N: # Only store half of the conjugate pairs; should be true by default
# There should be a slightly faster way using conjugate symmetry
contract_fn = lambda p, x, y: contract('r h n, r h m, ... h m -> ... h n', _conj(p), _conj(x), _conj(y))[..., :self.N] # inner outer product
else:
assert state.size(-1) == 2*self.N
step_params = {k: _conj(v) for k, v in step_params.items()}
# TODO worth setting up a contract_expression in default_state if we want to use this at inference time for stepping
contract_fn = lambda p, x, y: contract('r h n, r h m, ... h m -> ... h n', p, x, y) # inner outer product
D = step_params["D"] # (H N)
E = step_params["E"] # (H N)
R = step_params["R"] # (R H N)
P = step_params["P"] # (R H N)
Q = step_params["Q"] # (R H N)
B = step_params["B"] # (1 H N)
new_state = E * state - contract_fn(P, Q, state) # (B H N)
new_state = new_state + 2.0 * B * u.unsqueeze(-1) # (B H N)
new_state = D * (new_state - contract_fn(P, R, new_state))
return new_state
def _setup_state(self):
""" Construct dA and dB for discretized state equation """
# Construct dA and dB by using the stepping
self._setup_linear()
C = _r2c(self.C) # Just returns a view that we use for finding dtype/device
state = torch.eye(2*self.N, dtype=C.dtype, device=C.device).unsqueeze(-2) # (N 1 N)
dA = self._step_state_linear(state=state)
dA = rearrange(dA, "n h m -> h m n")
u = C.new_ones(self.H)
dB = self._step_state_linear(u=u)
dB = _conj(dB)
dB = rearrange(dB, '1 h n -> h n') # (H N)
return dA, dB
def _step_state(self, u, state):
""" Must be called after self.default_state() is used to construct an initial state! """
next_state = self.state_contraction(self.dA, state) + self.input_contraction(self.dB, u)
return next_state
def _setup_step(self, mode='dense'):
""" Set up dA, dB, dC discretized parameters for stepping """
self.dA, self.dB = self._setup_state()
# Calculate original C
C = _conj(_r2c(self.C)) # (H C N)
if self.L.item() == 0:
dC = C
else:
# self.C represents C_tilde
dA_L = power(self.L.item(), self.dA)
I = torch.eye(self.dA.size(-1)).to(dA_L)
dC = torch.linalg.solve(
I - dA_L.transpose(-1, -2),
C.unsqueeze(-1),
).squeeze(-1)
self.dC = dC
# Do special preprocessing for different step modes
self._step_mode = mode
if mode == 'linear':
# Linear case: special step function for the state, we need to handle output
# use conjugate symmetry by default, which affects the output projection
self.dC = 2*self.dC[:, :, :self.N]
elif mode == 'diagonal':
# Eigendecomposition of the A matrix
L, V = torch.linalg.eig(self.dA)
V_inv = torch.linalg.inv(V)
# Check that the eigendedecomposition is correct
if self.verbose:
print("Diagonalization error:", torch.dist(V @ torch.diag_embed(L) @ V_inv, self.dA))
# Change the parameterization to diagonalize
self.dA = L
self.dB = contract('h n m, h m -> h n', V_inv, self.dB)
self.dC = contract('h n m, c h n -> c h m', V, self.dC)
elif mode == 'dense':
pass
else: raise NotImplementedError("NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}")
def default_state(self, *batch_shape):
C = _r2c(self.C)
N = C.size(-1)
H = C.size(-2)
# Cache the tensor contractions we will later do, for efficiency
# These are put in this function because they depend on the batch size
step_mode = getattr(self, "_step_mode", "dense") # Used in default_state, which is called without _setup_step() in forward_state()
if step_mode != 'linear':
N *= 2
if step_mode == 'diagonal':
self.state_contraction = contract_expression(
"h n, ... h n -> ... h n",
(H, N),
batch_shape + (H, N),
)
else:
# Dense (quadratic) case: expand all terms
self.state_contraction = contract_expression(
"h m n, ... h n -> ... h m",
(H, N, N),
batch_shape + (H, N),
)
self.input_contraction = contract_expression(
"h n, ... h -> ... h n",
(H, N), # self.dB.shape