-
Notifications
You must be signed in to change notification settings - Fork 58
/
Copy pathSOINN.cpp
366 lines (310 loc) · 10.4 KB
/
SOINN.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
#include "SOINN.h"
SOINN::SOINN(int node_age, double edge_age)
: is_first_learning(true)
, node_erase_age(node_age)
, edge_erase_age(edge_age)
{
srand((unsigned)time(NULL));
}
void SOINN::learn(vector<Node> inputs)
{
input_nodes = inputs;
initialize();
// 入力パターンを学習
for(int i = 0; i < (int)input_nodes.size(); ++i)
{
Node input = input_nodes.at(i);
// 勝者ノードを探す
vector<Node> winner_nodes = getWinnerNodes(input);
Node first_winner = winner_nodes.at(0); // 勝者ノード
Node second_winner = winner_nodes.at(1); // 第二勝者ノード
double first_simular = getSimularThreshold(first_winner);
double second_simular = getSimularThreshold(second_winner);
double input_first_dist = getNodeDistance(input, first_winner);
double input_second_dist = getNodeDistance(input, second_winner);
// 勝者ノードもしくは第二勝者ノードの類似度閾値の外側なら
if(input_first_dist > first_simular ||
input_second_dist > second_simular)
{
// 新規ノードとして追加
soinn.addNode(input.position);
}
// 勝者・第二勝者ノードいずれかの類似度閾値の内側の場合
if(input_first_dist <= first_simular &&
input_second_dist <= second_simular)
{
// 勝者ノードと第二勝者ノードの間にエッジが無いなら
if(!soinn.hasEdge(first_winner.id, second_winner.id))
{
// エッジを追加
soinn.addEdge(first_winner.id, second_winner.id);
}
// エッジの年齢を0にリセット
setEdgeAge(first_winner.id, second_winner.id, 0);
// 勝者ノードにつながる全エッジの年齢をインクリメント
incAllEdgeAge(first_winner.id);
// 位置ベクトルの更新
updatePosition(first_winner, input, 1.0, first_winner.win_times);
// 勝者ノードの関連ノードの位置ベクトルの更新
updatePositionRelated(first_winner, input, 1.0 / 100.0);
// 年老いたエッジを削除
eraseOldEdges();
// 消されたエッジでつながれていたノードを調べる
for(int j = 0; j < (int)erased_edges.size(); ++j)
{
Edge e = erased_edges.at(j);
// 関連ノードがなければ削除
eraseIndependentNode(e.node_ids.first);
eraseIndependentNode(e.node_ids.second);
}
}
// 入力パターン数がnode_erase_ageの倍数
if((i + 1) % node_erase_age == 0 ||
(i + 1) == (int)input_nodes.size())
{
eraseNoizyNode();
}
}
drawGraph();
}
vector<Node> SOINN::getLearnedNode()
{
return soinn.getAllNodes();
}
void SOINN::initialize()
{
// 初回の学習である
if(is_first_learning)
{
// 学習データから2つ適当に選んで追加
for(int i = 0; i < 2; ++i)
{
int random_num = rand() % input_nodes.size();
soinn.addNode(input_nodes.at(random_num).position);
cout << "random" << i << " : " << random_num << endl;
}
is_first_learning = false;
}
}
vector<Node> SOINN::getWinnerNodes(const Node &input_node)
{
// 距離でソートされたノード番号のセット
vector<Node> nodes = soinn.getAllNodes();
multimap<double, int> node_distances = getSortedDistancesNodeNumbers(input_node, nodes);
// 勝者ノードと第二勝者ノードを返す
vector<Node> winners;
multimap<double, int>::iterator it = node_distances.begin();
// 勝者ノード
Node win1 = soinn.getNode(it->second);
++win1.win_times;
soinn.setNode(win1);
winners.push_back(win1);
++it;
// 第二勝者ノード
Node win2 = soinn.getNode((*it).second);
winners.push_back(win2);
return winners;
}
double SOINN::getSimularThreshold(const Node &target_node)
{
// target_nodeの類似度閾値を計算
// 連結ノードがある場合は
if(soinn.getEdgeCount(target_node.id) != 0)
{
// 連結ノード集合
vector<Node> related_nodes = soinn.getRealtedNodes(target_node.id);
// 連結ノード集合のうち最も遠いノードまでの距離を返す
if(related_nodes.size() == 1)
{
return getNodeDistance(related_nodes.at(0), target_node);
}
else
{
multimap<double, int> node_distances = getSortedDistancesNodeNumbers(target_node, related_nodes);
return (*node_distances.end()).first;
}
}
// 連結ノードがない場合は
else
{
// target_node以外のノードの集合
vector<Node> nodes = soinn.getAllNodes();
vector<Node> except_nodes;
for(int i = 0; i < (int)nodes.size(); ++i)
{
Node n = nodes.at(i);
if(n.id != target_node.id)
{
except_nodes.push_back(n);
}
}
// target_node以外のノードのうち最も近いノードまでの距離を返す
multimap<double, int> node_distances = getSortedDistancesNodeNumbers(target_node, except_nodes);
return (*node_distances.begin()).first;
}
}
double SOINN::getNodeDistance(const Node &first_node, const Node &second_node)
{
return sqrt(pow(first_node.position[0] - second_node.position[0], 2.0) +
pow(first_node.position[1] - second_node.position[1], 2.0));
}
multimap<double, int> SOINN::getSortedDistancesNodeNumbers(const Node &input_node, vector<Node> &compared_nodes)
{
// 距離でソートされたノード番号のセット
multimap<double, int> node_distances;
vector<Node>::iterator it = compared_nodes.begin();
while(it != compared_nodes.end())
{
double d = getNodeDistance((*it), input_node);
node_distances.insert(pair<double, int>(d, it->id));
++it;
}
return node_distances;
}
void SOINN::setEdgeAge(int first, int second, int age)
{
Edge e = soinn.getEdgeFromTo(first, second);
e.age = age;
soinn.setEdge(e);
}
void SOINN::incAllEdgeAge(int node_id)
{
// node_idにつながる全エッジの年齢に1を足す
vector<Edge> from = soinn.getEdgeFrom(node_id);
vector<Edge> to = soinn.getEdgeTo(node_id);
for(int i = 0; i < (int)from.size(); ++i)
{
Edge e = from.at(i);
++e.age;
soinn.setEdge(e);
}
for(int i = 0; i < (int)to.size(); ++i)
{
Edge e = to.at(i);
++e.age;
soinn.setEdge(e);
}
}
void SOINN::updatePosition(Node &node, const Node &input, double weight, int win_times)
{
// nodeの位置ベクトル更新
double epsilon = 1.0 / (double)win_times;
VectorXd vec = input.position - node.position;
node.position += epsilon * weight * vec;
soinn.setNode(node);
}
void SOINN::updatePositionRelated(const Node &node, const Node &input, double weight)
{
// nodeに関連する全てのノード
vector<Node> related = soinn.getRealtedNodes(node.id);
for(int i = 0; i < (int)related.size(); ++i)
{
Node n = related.at(i);
updatePosition(n, input, weight, node.win_times);
}
}
void SOINN::eraseOldEdges()
{
erased_edges.clear();
vector<Edge> edges = soinn.getAllEdges();
vector<Edge>::iterator it = edges.begin();
while(it != edges.end())
{
if((*it).age > edge_erase_age)
{
erased_edges.push_back((*it));
it = edges.erase(it);
}
else
{
++it;
}
}
soinn.setAllEdges(edges);
}
void SOINN::eraseIndependentNode(int node_id)
{
// 関連ノードを持っていなければ消す
if(soinn.getEdgeCount(node_id) == 0)
{
vector<Node> nodes = soinn.getAllNodes();
vector<Node>::iterator it = nodes.begin();
it = nodes.erase(it + node_id);
soinn.setAllNodes(nodes);
// そのノードが持っていたエッジを消す
vector<Edge> edges = soinn.getAllEdges();
vector<Edge>::iterator eit = edges.begin();
while(eit != edges.end())
{
if((*eit).node_ids.first == node_id ||
(*eit).node_ids.second == node_id)
{
eit = edges.erase(eit);
}
else
{
++eit;
}
}
soinn.setAllEdges(edges);
}
}
void SOINN::eraseNoizyNode()
{
vector<Node> nodes = soinn.getAllNodes();
vector<Node>::iterator node_it = nodes.begin();
while(node_it != nodes.end())
{
// 関連ノード数が1未満なら
if(soinn.getEdgeCount(node_it->id) <= 1)
{
// そのエッジを消す
vector<Edge> edges = soinn.getAllEdges();
vector<Edge>::iterator edge_it = edges.begin();
while(edge_it != edges.end())
{
// 現在のノードが端点にあるなら
if(edge_it->node_ids.first == node_it->id ||
edge_it->node_ids.second == node_it->id)
{
edge_it = edges.erase(edge_it);
}
else
{
++edge_it;
}
}
soinn.setAllEdges(edges);
// ノードを消す
node_it = nodes.erase(node_it);
}
else
{
++node_it;
}
}
soinn.setAllNodes(nodes);
}
void SOINN::drawGraph()
{
Mat graph(300, 300, CV_8UC3);
graph = Mat::zeros(graph.size(), graph.type());
vector<Node> nodes = soinn.getAllNodes();
for(int i = 0; i < (int)nodes.size(); ++i)
{
Node n = nodes.at(i);
circle(graph, Point((int)n.position[1], (int)n.position[0]), 1, Scalar(255, 255, 255), -1);
}
vector<Edge> edges = soinn.getAllEdges();
for(int i = 0; i < (int)edges.size(); ++i)
{
Edge e = edges.at(i);
Node sn = soinn.getNode(e.node_ids.first);
Node en = soinn.getNode(e.node_ids.second);
line(graph,
Point((int)sn.position[1], (int)sn.position[0]),
Point((int)en.position[1], (int)en.position[0]),
Scalar(180, 180, i % 5 * 50), 1, CV_AA);
}
imwrite("graph.png", graph);
}