-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathensemble.py
68 lines (53 loc) · 2.6 KB
/
ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
from ts_lsh.multiple import MultipleLSH, MultipleRandomSampledLSH
from ts_lsh.common import owa, get_owa_weights
class EnsembleLSH(MultipleLSH):
def __init__(self, **kwargs):
super(EnsembleLSH, self).__init__(**kwargs)
self.output_length = 1
self.aggregation = kwargs.get("aggregation", "srp")
self.aggregation_weights = kwargs.get("aggregation_weights",None)
if self.aggregation == "srp":
self.scale = kwargs.get("scale", 1.0)
self.dist = kwargs.get('dist','normal')
if self.dist == 'normal':
self.aggregation_weights = np.random.randn(self.num_components) * self.scale
elif self.dist == 'unif':
self.aggregation_weights = (np.random.rand(self.num_components) * 2 * self.scale) - self.scale
def _hashfunction(self, input : np.array, **kwargs):
hashes = super(EnsembleLSH, self)._hashfunction(input)
if self.aggregation == "owa":
if self.aggregation_weights is None:
raise Exception("Ordered Weight Aggregation weights are not set!")
return owa(hashes, self.aggregation_weights)
elif self.aggregation == "srp":
return np.dot(self.aggregation_weights, hashes)
else:
if self.aggregation_weights is None:
self.aggregation_weights = get_owa_weights(self.aggregation,input)
return owa(hashes, self.aggregation_weights)
class RandomSampleEnsembleLSH(MultipleRandomSampledLSH):
def __init__(self, **kwargs):
super(RandomSampleEnsembleLSH, self).__init__(**kwargs)
self.output_length = 1
self.aggregation = kwargs.get("aggregation", "srp")
self.aggregation_weights = kwargs.get("aggregation_weights",None)
if self.aggregation == "srp":
self.scale = kwargs.get("scale", 1.0)
self.dist = kwargs.get('dist','normal')
if self.dist == 'normal':
self.aggregation_weights = np.random.randn(self.num_components) * self.scale
elif self.dist == 'unif':
self.aggregation_weights = (np.random.rand(self.num_components) * 2 * self.scale) - self.scale
def _hashfunction(self, input : np.array, **kwargs):
hashes = super(RandomSampleEnsembleLSH, self)._hashfunction(input)
if self.aggregation == "owa":
if self.aggregation_weights is None:
raise Exception("Ordered Weight Aggregation weights are not set!")
return owa(hashes, self.aggregation_weights)
elif self.aggregation == "srp":
return np.dot(self.aggregation_weights, hashes)
else:
if self.aggregation_weights is None:
self.aggregation_weights = get_owa_weights(self.aggregation,input)
return owa(hashes, self.aggregation_weights)