11#include < iostream>
22
33#include " network.h"
4+ #include " render/neuron_render.hpp"
5+ #include " render/network_render.h"
46
5- void Network::addLayer (const NeuronLayer&& layer) {
7+
8+ void Network::addLayer (const NeuronLayer&& layer, const std::vector<size_t >& area_size) {
69 layers.push_back (layer);
7- neurons.assign (layer.begin (), layer.end ());
10+
11+ uint32_t last_id = neurons.back ()->idx ;
12+ for (auto it = layer.begin (); it != layer.end (); ++it)
13+ {
14+ (*it)->idx = ++last_id;
15+ neurons.emplace_back (std::move (*it));
16+ }
17+
18+ // neurons.assign(layer.begin(), layer.end());
819 if (render && std::dynamic_pointer_cast<NetworkRenderStrategy>(render) != nullptr ) {
9- std::dynamic_pointer_cast<NetworkRenderStrategy>(render)->addLayer (layer);
20+ std::dynamic_pointer_cast<NetworkRenderStrategy>(render)->addLayer (layer, area_size );
1021 }
1122}
1223
13- // void setSize(int N) {
14- // neurons.resize(N);
15- // // synapses.resize(N, std::vector<Synapse>(N));
16- // }
24+ void Network::addNeuron (const std::vector<size_t >& pos, const NeuronLayer& connected_to) {
25+ uint32_t layer = pos[2 ];
26+ if (layer >= layers.size ()) {
27+ std::cerr << " wrong layer position\n " ;
28+ return ;
29+ }
30+
31+ std::shared_ptr<Neuron> neuron = std::make_shared<Neuron>();
32+ neuron->idx = create_id (nullptr );
33+ auto ctx = std::make_shared<NeuronVisualContext>(neuron);
34+ ctx->position = { float (pos[0 ]), float (pos[1 ]), float (pos[2 ]) };
35+ neuron->render = std::make_shared<NeuronRenderStrategy>(ctx);
36+
37+ layers[layer].push_back (neuron);
38+ neurons.push_back (neuron);
39+ }
40+
41+ void Network::addConnection (const std::shared_ptr<Neuron>& n1, const std::shared_ptr<Neuron>& n2) {
42+ synapses.emplace (idsToLocation (n1->idx , n2->idx ), std::make_shared<Synapse>());
43+ }
44+
45+ std::pair<uint32_t , uint32_t > Network::locationToIds (uint64_t loc) const {
46+ uint32_t pre_idx = loc & 0xffff ;
47+ uint32_t post_idx = (loc >> 8 ) & 0xffff ;
48+ return {pre_idx, post_idx};
49+ }
50+
51+ uint64_t Network::idsToLocation (uint32_t pre , uint32_t post ) const {
52+ uint64_t loc = 0 ;
53+ loc = pre | (post << 8 );
54+ return loc;
55+ }
1756
1857std::vector<float > Network::get_current_voltage_state () const {
1958 std::vector<float > state (neurons.size (), 0.0 );
@@ -37,16 +76,15 @@ void Network::step(std::vector<uint8_t> inputs) {
3776
3877 // Synaptic input
3978 for (auto & [loc, syn] : synapses) {
40- uint32_t pre_idx = loc & 0xffff ;
41- uint32_t post_idx = (loc >> 8 ) & 0xffff ;
79+ auto [pre_idx, post_idx] = locationToIds (loc);
4280
43- syn. update_pre (dt);
44- syn. update_post (dt);
81+ syn-> update_pre (dt);
82+ syn-> update_post (dt);
4583
4684 // STDP
4785 if (neurons[pre_idx]->spiked ) {
48- syn. on_pre_spike ();
49- dv[post_idx] += syn. weight ;
86+ syn-> on_pre_spike ();
87+ dv[post_idx] += syn-> weight ;
5088 }
5189 }
5290
@@ -56,10 +94,9 @@ void Network::step(std::vector<uint8_t> inputs) {
5694
5795 // STDP
5896 for (auto & [loc, syn] : synapses) {
59- uint32_t pre_idx = loc & 0xffff ;
60- uint32_t post_idx = (loc >> 8 ) & 0xffff ;
61- syn.apply_stdp (neurons[pre_idx]->spiked , neurons[post_idx]->spiked );
62- if (neurons[post_idx]->spiked ) syn.on_post_spike ();
97+ auto [pre_idx, post_idx] = locationToIds (loc);
98+ syn->apply_stdp (neurons[pre_idx]->spiked , neurons[post_idx]->spiked );
99+ if (neurons[post_idx]->spiked ) syn->on_post_spike ();
63100 }
64101
65102 time += dt;
@@ -85,4 +122,8 @@ void Network::update(float dt) {
85122
86123void Network::destroy () const {
87124 if (render) render->destroy ();
125+ }
126+
127+ uint32_t Network::create_id (std::shared_ptr<Neuron> neuron) const {
128+ return neurons.size ();
88129}
0 commit comments