-
Notifications
You must be signed in to change notification settings - Fork 5
/
hydra.py
119 lines (75 loc) · 3.28 KB
/
hydra.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Angus Dempster, Daniel F Schmidt, Geoffrey I Webb
# HYDRA: Competing Convolutional Kernels for Fast and Accurate Time Series Classification
# https://arxiv.org/abs/2203.13652
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
class Hydra(nn.Module):
def __init__(self, input_length, k = 8, g = 64, seed = None):
super().__init__()
if seed is not None:
torch.manual_seed(seed)
self.k = k # num kernels per group
self.g = g # num groups
max_exponent = np.log2((input_length - 1) / (9 - 1)) # kernel length = 9
self.dilations = 2 ** torch.arange(int(max_exponent) + 1)
self.num_dilations = len(self.dilations)
self.paddings = torch.div((9 - 1) * self.dilations, 2, rounding_mode = "floor").int()
self.divisor = min(2, self.g)
self.h = self.g // self.divisor
self.W = torch.randn(self.num_dilations, self.divisor, self.k * self.h, 1, 9)
self.W = self.W - self.W.mean(-1, keepdims = True)
self.W = self.W / self.W.abs().sum(-1, keepdims = True)
# transform in batches of *batch_size*
def batch(self, X, batch_size = 256):
num_examples = X.shape[0]
if num_examples <= batch_size:
return self(X)
else:
Z = []
batches = torch.arange(num_examples).split(batch_size)
for batch in batches:
Z.append(self(X[batch]))
return torch.cat(Z)
def forward(self, X):
num_examples = X.shape[0]
if self.divisor > 1:
diff_X = torch.diff(X)
Z = []
for dilation_index in range(self.num_dilations):
d = self.dilations[dilation_index].item()
p = self.paddings[dilation_index].item()
for diff_index in range(self.divisor):
_Z = F.conv1d(X if diff_index == 0 else diff_X, self.W[dilation_index, diff_index], dilation = d, padding = p) \
.view(num_examples, self.h, self.k, -1)
max_values, max_indices = _Z.max(2)
count_max = torch.zeros(num_examples, self.h, self.k)
min_values, min_indices = _Z.min(2)
count_min = torch.zeros(num_examples, self.h, self.k)
count_max.scatter_add_(-1, max_indices, max_values)
count_min.scatter_add_(-1, min_indices, torch.ones_like(min_values))
Z.append(count_max)
Z.append(count_min)
Z = torch.cat(Z, 1).view(num_examples, -1)
return Z
class SparseScaler():
def __init__(self, mask = True, exponent = 4):
self.mask = mask
self.exponent = exponent
self.fitted = False
def fit(self, X):
assert not self.fitted, "Already fitted."
X = X.clamp(0).sqrt()
self.epsilon = (X == 0).float().mean(0) ** self.exponent + 1e-8
self.mu = X.mean(0)
self.sigma = X.std(0) + self.epsilon
self.fitted = True
def transform(self, X):
assert self.fitted, "Not fitted."
X = X.clamp(0).sqrt()
if self.mask:
return ((X - self.mu) * (X != 0)) / self.sigma
else:
return (X - self.mu) / self.sigma
def fit_transform(self, X):
self.fit(X)
return self.transform(X)