1+ #include < iostream>
2+
3+ #include " network.h"
4+
5+ void Network::addLayer (const NeuronLayer&& layer) {
6+ layers.push_back (layer);
7+ neurons.assign (layer.begin (), layer.end ());
8+ if (render && std::dynamic_pointer_cast<NetworkRenderStrategy>(render) != nullptr ) {
9+ std::dynamic_pointer_cast<NetworkRenderStrategy>(render)->addLayer (layer);
10+ }
11+ }
12+
13+ // void setSize(int N) {
14+ // neurons.resize(N);
15+ // // synapses.resize(N, std::vector<Synapse>(N));
16+ // }
17+
18+ std::vector<float > Network::get_current_voltage_state () const {
19+ std::vector<float > state (neurons.size (), 0.0 );
20+ for (size_t i = 0 ; i < neurons.size (); ++i) {
21+ state[i] = neurons[i]->v ;
22+ }
23+ return state;
24+ }
25+
26+ void Network::step (std::vector<uint8_t > inputs) {
27+ std::cout << " step (" << time << " )\n " ;
28+
29+ for (size_t i = 0 ; i < inputs.size (); i++)
30+ {
31+ if (inputs[i] > 0 ) {
32+ neurons[i]->v += 1 .5f * (inputs[i] / 255 .0f ) * 0 .2f ; // inject current to neuron 0
33+ }
34+ }
35+
36+ std::vector<float > dv (neurons.size (), 0 .0f );
37+
38+ // Synaptic input
39+ for (auto & [loc, syn] : synapses) {
40+ uint32_t pre_idx = loc & 0xffff ;
41+ uint32_t post_idx = (loc >> 8 ) & 0xffff ;
42+
43+ syn.update_pre (dt);
44+ syn.update_post (dt);
45+
46+ // STDP
47+ if (neurons[pre_idx]->spiked ) {
48+ syn.on_pre_spike ();
49+ dv[post_idx] += syn.weight ;
50+ }
51+ }
52+
53+ // Neuron updates
54+ for (size_t i = 0 ; i < neurons.size (); ++i)
55+ neurons[i]->update (dv[i], dt, time);
56+
57+ // STDP
58+ for (auto & [loc, syn] : synapses) {
59+ uint32_t pre_idx = loc & 0xffff ;
60+ uint32_t post_idx = (loc >> 8 ) & 0xffff ;
61+ syn.apply_stdp (neurons[pre_idx]->spiked , neurons[post_idx]->spiked );
62+ if (neurons[post_idx]->spiked ) syn.on_post_spike ();
63+ }
64+
65+ time += dt;
66+ }
67+
68+ void Network::init () {
69+ if (render) render->init ();
70+ }
71+
72+ void Network::draw (float time) const {
73+ if (render) render->draw (time);
74+ }
75+
76+ // void draw() const {
77+ // for (auto& n : neurons) {
78+ // n.draw();
79+ // }
80+ // }
81+
82+ void Network::update (float dt) {
83+ if (render) render->update (dt);
84+ }
85+
86+ void Network::destroy () const {
87+ if (render) render->destroy ();
88+ }
0 commit comments