Skip to content

Commit e30c752

Browse files
committed
started the error calculation
1 parent acd9d1d commit e30c752

File tree

6 files changed

+52
-14
lines changed

6 files changed

+52
-14
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,6 @@ $RECYCLE.BIN/
4545
Network Trash Folder
4646
Temporary Items
4747
.apdisk
48+
49+
vs2015
50+
bin

include/network.hh

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ private:
3636
void _push( double value );
3737
void _fire();
3838

39+
void _compute_error(double value);
40+
3941
double m_weight = 0.0;
4042
std::shared_ptr<Node> m_link_to;
4143
std::shared_ptr<Node> m_link_from;
@@ -55,15 +57,32 @@ private:
5557
friend class Network;
5658
friend class Link;
5759

60+
// pushes a value to the node. When all links have been pushed
61+
// the node will fire the values forward.
5862
void _push(double value);
5963
double _sigmoid( double num );
6064
void _fire();
6165

66+
// computes the error. similar to push but backwards.
67+
void _compute_error(double value);
68+
6269
bool m_end_node = false;
70+
71+
// The last value fired forward.
72+
double m_value = 0.0;
73+
// the sum of all the values of fired to this node.
6374
double m_synapse_sum = 0.0;
75+
76+
double m_error_sum = 0.0;
6477
double m_delta = 0.0;
6578
double m_bias = 0.0;
66-
int32_t m_load = 0;
79+
80+
// When the load is equal to the number of back links the node will fire.
81+
int32_t m_forward_load = 0;
82+
// When the load is equal to the number of forward the node will fire back.
83+
// Calculating the error made from the previous iteration.
84+
int32_t m_backward_load = 0;
85+
6786
std::vector<std::shared_ptr<Link>> m_f_links;
6887
std::vector<std::shared_ptr<Link>> m_b_links;
6988

src/link.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,8 @@ void Link::_push( double value ) {
1212

1313
void Link::_fire() {
1414
m_link_to->_fire();
15+
}
16+
17+
void Link::_compute_error(double value){
18+
m_link_from->_compute_error( value * m_weight );
1519
}

src/main.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ int main( int argc, char** argv ) {
1818

1919
std::shared_ptr<Network> network = std::make_shared<Network>( input_layer, output_layer );
2020

21-
auto r = network->activate( { 5,5 } );
22-
2321
// train the network - learn XOR
2422
double learning_rate = 0.3;
2523
for( uint32_t i = 0; i < 10000; ++i ) {

src/network.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,12 @@ void Network::propagate( double learning_rate, std::vector<double> results ) {
3535
assert( results.size() == onodes.size() &&
3636
"RESULTS MUST BE THE SIZE OF THE NODES IN THE OUTPUT LAYER" );
3737

38-
double error = 0.0;
3938

4039
for( size_t i = 0; i < onodes.size(); ++i ){
4140

4241
auto on = onodes[i];
43-
4442
double this_error = results[i] - on->m_synapse_sum;
45-
on->m_delta = (1.0 - on->m_synapse_sum) * results[i] * this_error;
46-
error += (0.5 * this_error * this_error );
47-
43+
on->_compute_error(this_error);
4844
}
4945

5046
}

src/node.cc

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ void Node::add_b_link( std::shared_ptr<Link> link ) {
1010

1111
void Node::_push( double value ) {
1212
m_synapse_sum += value;
13-
++m_load;
14-
if( m_load == m_b_links.size() ){
13+
++m_forward_load;
14+
if( m_forward_load == m_b_links.size() ){
1515
_fire();
1616
}
1717
}
@@ -22,16 +22,34 @@ double Node::_sigmoid( double num ) {
2222
}
2323

2424
void Node::_fire( ){
25-
if( m_end_node ){
26-
return;
27-
}
2825

2926
double v = _sigmoid( m_synapse_sum + m_bias);
27+
m_value = v;
28+
29+
if( m_end_node ) return;
3030

3131
for( auto link : m_f_links ){
3232
link->_push(v);
3333
}
3434

3535
m_synapse_sum = 0;
36-
m_load = 0;
36+
m_forward_load = 0;
37+
}
38+
39+
void Node::_compute_error(double value){
40+
41+
m_error_sum += value;
42+
43+
++m_backward_load;
44+
if( m_backward_load == m_f_links.size() ){
45+
46+
m_delta = ( 1.0 - m_value ) * m_value * value;
47+
48+
for( auto blinks : m_b_links ){
49+
blinks->_compute_error(m_delta);
50+
}
51+
52+
m_error_sum = 0.0;
53+
m_backward_load = 0;
54+
}
3755
}

0 commit comments

Comments
 (0)