Skip to content

Commit a8375a7

Browse files
committed
add module
1 parent 0461f54 commit a8375a7

File tree

11 files changed

+317
-0
lines changed

11 files changed

+317
-0
lines changed

module/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from ._lvq import (
2+
Model as LVQ
3+
)
4+
from ._som import (
5+
Model as SOM
6+
)
7+
from ._wta import (
8+
Model as WTA
9+
)

module/_lvq.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import numpy as np
2+
from numpy import ndarray
3+
from .util.rand import PARAMETER_RAND_NORM
4+
from .util.distance import EUCLIDEAN
5+
from .util.neighborhood import GAUSSIAN
6+
7+
8+
class Node(object):
9+
def __init__(self, label, weight: ndarray):
10+
"""
11+
图节点
12+
:param label: 标签
13+
:param weight: 权重
14+
"""
15+
self.weight = weight
16+
self.position = np.array([0, 0], dtype=np.float)
17+
self.label = label
18+
19+
20+
class Model(object):
21+
def __init__(self, depth: int, width: int, height: int, labels: list,
22+
param_init=PARAMETER_RAND_NORM,
23+
distance=EUCLIDEAN,
24+
neighborhood=GAUSSIAN):
25+
"""
26+
学习向量量化 LVQ 模型
27+
需要额外设置
28+
:param depth: 位深,对应节点权值深度
29+
:param width: 图的宽度
30+
:param height: 图的高度
31+
:param labels: 标签初始节点标签列表,长度不能小于节点数量
32+
:param param_init: 标签初始化函数,默认为正态分布
33+
:param distance: 距离函数,默认为欧氏距离
34+
:param neighborhood: 邻域函数,默认为简化高斯函数
35+
"""
36+
self.width = width
37+
self.height = height
38+
self.length = width * height
39+
self.nodes = [Node(label, weight) for weight, label in
40+
[(param_init(depth), labels[i]) for i in range(width * height)]]
41+
for i in range(len(self.nodes)):
42+
self.nodes[i].position = np.array([int(i % width), int(i / width)], np.float)
43+
self.distance = distance
44+
self.neighborhood = neighborhood
45+
46+
def __len__(self):
47+
return self.length
48+
49+
def winner(self, x: ndarray) -> Node:
50+
"""
51+
获胜节点计算函数
52+
:param x: 单个输入
53+
:return: 获胜节点
54+
"""
55+
centre = self.nodes[0]
56+
min_d = self.distance(centre.weight, x)
57+
for node in self.nodes:
58+
d = self.distance(node.weight, x)
59+
if d < min_d:
60+
centre = node
61+
min_d = d
62+
return centre
63+
64+
def train(self, X: list, y, alpha: float, radius: float) -> None:
65+
"""
66+
模型训练
67+
:param X: 输入列表(要求属于同一个标签)
68+
:param y: 标签
69+
:param alpha: 学习率 [0.-1.]
70+
:param radius: 邻域半径
71+
:return: None
72+
"""
73+
winners = []
74+
for _, x in X:
75+
# 查找优胜节点
76+
centre = self.winner(x)
77+
winners.append(self.nodes.index(centre))
78+
# 利用邻域函数更新全部节点的权值
79+
for node in self.nodes:
80+
node.weight = node.weight + self.neighborhood(alpha, radius, node.position, centre.position) * (
81+
x - node.weight)
82+
# 更新选中节点的标签
83+
max_label = max(winners, key=winners.count)
84+
self.nodes[max_label].label = y
85+
86+
def validate(self, test_sets: list) -> float:
87+
"""
88+
评估模型正确率
89+
:param test_sets:输入列表(标签可以不相同)
90+
:return: 模型正确率[0.-1.]
91+
"""
92+
return sum([1 for (label, test) in test_sets if self.winner(test).label == label]) / len(test_sets)

module/_som.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import numpy as np
2+
from numpy import ndarray
3+
from .util.rand import PARAMETER_RAND_NORM
4+
from .util.distance import EUCLIDEAN
5+
from .util.neighborhood import GAUSSIAN
6+
7+
8+
class Node(object):
9+
def __init__(self, weight: ndarray):
10+
"""
11+
图节点
12+
:param weight: 权重
13+
"""
14+
self.weight = weight
15+
self.position = np.array([0, 0], dtype=np.float64)
16+
17+
18+
class Model(object):
19+
def __init__(self, depth: int, width: int, height: int,
20+
param_init=PARAMETER_RAND_NORM,
21+
distance=EUCLIDEAN,
22+
neighborhood=GAUSSIAN
23+
):
24+
"""
25+
SOM 模型
26+
:param depth: 深度
27+
:param width: 图 宽度
28+
:param height: 图 高度
29+
:param param_init: 标签初始化函数,默认为正态分布
30+
:param distance: 距离函数,默认为欧氏距离
31+
:param neighborhood: 邻域函数,默认为简化高斯函数
32+
"""
33+
self.width = width
34+
self.height = height
35+
self.length = width * height
36+
self.nodes = [Node(weight) for weight in [param_init(depth) for _ in range(width * height)]]
37+
for i in range(len(self.nodes)):
38+
self.nodes[i].position = np.array([int(i % width), int(i / width)], np.float64)
39+
self.distance = distance
40+
self.neighborhood = neighborhood
41+
42+
def __len__(self):
43+
return self.length
44+
45+
def winner(self, x) -> Node:
46+
"""
47+
获胜节点计算函数
48+
:param x: 单个输入
49+
:return: 获胜节点
50+
"""
51+
centre = self.nodes[0]
52+
min_d = self.distance(centre.weight, x)
53+
for node in self.nodes:
54+
d = self.distance(node.weight, x)
55+
if d < min_d:
56+
centre = node
57+
min_d = d
58+
return centre
59+
60+
def train(self, x: ndarray, alpha: float, radius: float):
61+
"""
62+
模型训练
63+
:param x: 单个输入
64+
:param alpha: 学习率 [0.-1.]
65+
:param radius: 邻域半径
66+
:return: None
67+
"""
68+
# 查找优胜节点
69+
centre = self.winner(x)
70+
# 利用邻域函数更新全部节点的权值
71+
for node in self.nodes:
72+
node.weight = node.weight + self.neighborhood(alpha, radius, node.position, centre.position) * (
73+
x - node.weight)

module/_wta.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import numpy as np
2+
from numpy import ndarray
3+
from .util.rand import PARAMETER_RAND_NORM
4+
from .util.distance import EUCLIDEAN
5+
from .util.neighborhood import GAUSSIAN
6+
7+
8+
class Node(object):
9+
def __init__(self, weight: ndarray):
10+
"""
11+
图节点
12+
:param weight: 权重
13+
"""
14+
self.weight = weight
15+
16+
17+
class Model(object):
18+
def __init__(self, depth: int, length: int,
19+
param_init=PARAMETER_RAND_NORM,
20+
distance=EUCLIDEAN,
21+
neighborhood=GAUSSIAN
22+
):
23+
"""
24+
WTA 模型
25+
:param depth: 深度
26+
:param length: 节点数量
27+
:param param_init: 标签初始化函数,默认为正态分布
28+
:param distance: 距离函数,默认为欧氏距离
29+
:param neighborhood: 邻域函数,默认为简化高斯函数
30+
"""
31+
self.length = length
32+
self.nodes = [Node(weight) for weight in [param_init(depth) for _ in range(length)]]
33+
self.distance = distance
34+
35+
def __len__(self):
36+
return self.length
37+
38+
def winner(self, x) -> Node:
39+
"""
40+
获胜节点计算函数
41+
:param x: 单个输入
42+
:return: 获胜节点
43+
"""
44+
centre = self.nodes[0]
45+
min_d = self.distance(centre.weight, x)
46+
for node in self.nodes:
47+
d = self.distance(node.weight, x)
48+
if d < min_d:
49+
centre = node
50+
min_d = d
51+
return centre
52+
53+
def train(self, x: ndarray, alpha: float):
54+
"""
55+
模型训练
56+
:param x: 单个输入
57+
:param alpha: 学习率 [0.-1.]
58+
:return: None
59+
"""
60+
# 查找优胜节点
61+
centre = self.winner(x)
62+
# 利用邻域函数更新全部节点的权值
63+
for node in self.nodes:
64+
node.weight = node.weight + alpha * (x - node.weight)

module/util/__init__.py

Whitespace-only changes.

module/util/distance/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from ._distance import (
2+
euclidean as EUCLIDEAN,
3+
pearson as PEARSON,
4+
fast_dtw as FAST_DTW,
5+
dtw as DTW
6+
)

module/util/distance/_distance.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import numpy as np
2+
3+
4+
def euclidean(v1: np.ndarray, v2: np.ndarray) -> float:
5+
"""
6+
计算两个向量的欧氏距离
7+
:return: float
8+
"""
9+
return np.linalg.norm(v1 - v2)
10+
11+
12+
def pearson(v1: np.ndarray, v2: np.ndarray) -> float:
13+
"""
14+
计算两个向量的皮尔逊相关系数
15+
:return: 相关系数的倒数
16+
"""
17+
from scipy.stats import pearsonr
18+
return 1 / pearsonr(v1, v2)[0]
19+
20+
21+
def fast_dtw(v1: np.ndarray, v2: np.ndarray) -> float:
22+
"""
23+
fast_dtw
24+
:return: 距离
25+
"""
26+
import fastdtw
27+
from scipy.spatial.distance import euclidean
28+
# noinspection PyTypeChecker,PyUnresolvedReferences
29+
return fastdtw.fastdtw(v1, v2, dist=euclidean)[0]
30+
31+
32+
def dtw(v1: np.ndarray, v2: np.ndarray) -> float:
33+
"""
34+
fast_dtw
35+
:return: 距离
36+
"""
37+
import fastdtw
38+
from scipy.spatial.distance import euclidean
39+
# noinspection PyTypeChecker,PyUnresolvedReferences
40+
return fastdtw.dtw(v1, v2, dist=euclidean)[0]

module/util/neighborhood/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ._gaussian import (
2+
gaussian as GAUSSIAN
3+
)

module/util/neighborhood/_gaussian.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import numpy as np
2+
from numpy import ndarray
3+
4+
5+
def gaussian(alpha: float, radius: float, i: ndarray, c: ndarray) -> float:
6+
"""
7+
简化高斯函数
8+
:param alpha: 学习率
9+
:param radius: 邻域半径
10+
:param i: 位置
11+
:param c: 中心位置
12+
:return: 距离
13+
"""
14+
if np.linalg.norm(i - c) <= radius:
15+
return alpha
16+
else:
17+
return 0.0

module/util/rand/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ._parameter import (
2+
rand as PARAMETER_RAND_NORM
3+
)

0 commit comments

Comments
 (0)