Skip to content

Commit d15da76

Browse files
committed
New cpp files
1 parent a552108 commit d15da76

File tree

11 files changed

+418
-388
lines changed

11 files changed

+418
-388
lines changed

graphics/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ if (MSVC)
189189
endif ()
190190

191191
add_executable(neurons_cli
192-
neurons/network.hpp
192+
neurons/network.cpp
193193
neurons/neurons_cli.cpp
194194
)
195195

@@ -225,10 +225,10 @@ add_executable(neurons_gui
225225
thirdparty/imgui/backends/imgui_impl_opengl3.cpp
226226
thirdparty/imgui/backends/imgui_impl_sdl2.cpp
227227
neurons/renderer.cpp
228-
neurons/network.hpp
229-
neurons/render/neuron_render.hpp
228+
neurons/network.cpp
229+
neurons/mnist_data_processor.cpp
230230
neurons/render/neuron_render.cpp
231-
neurons/render/network_render.hpp
231+
neurons/render/network_render.cpp
232232
# neurons/camera.cpp
233233
# neurons/mouse.cpp
234234
# neurons/settings.cpp

graphics/neurons/mnist_data_processor.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
11
#include "mnist_data_processor.h"
22
#include "render/neuron_render.hpp"
33

4-
5-
const size_t input_size = 784;
6-
const uint32_t width = 28, height = 28;
7-
8-
struct MnistImage {
9-
uint8_t label;
10-
uint8_t pixels[input_size];
11-
};
12-
134
std::vector<MnistImage> load_mnist_bin(const std::string& path) {
145
std::ifstream in(path, std::ios::binary);
156
std::vector<MnistImage> images;
@@ -89,15 +80,18 @@ std::vector<uint8_t> MnistDataProcessor::convert_current_to_inputs() {
8980
}
9081

9182
Network::NeuronLayer MnistDataProcessor::prepare_neurons() const {
92-
std::vector<Neuron> neurons(input_size);
83+
Network::NeuronLayer neurons(input_size);
9384
bx::Vec3 area_size = get_area_size();
9485

9586
for (size_t i = 0; i < input_size; ++i) {
87+
neurons[i] = std::make_shared<Neuron>();
9688
auto ctx = std::make_shared<NeuronVisualContext>(neurons[i]);
9789
ctx->position = {
9890
float(i % int32_t(area_size.x)),
9991
i / area_size.x,
10092
0.0f };
101-
neurons[i].render = std::make_shared<NeuronRenderStrategy>(ctx);
93+
neurons[i]->render = std::make_shared<NeuronRenderStrategy>(ctx);
10294
}
95+
96+
return neurons;
10397
}

graphics/neurons/mnist_data_processor.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414
#include "neuron.hpp"
1515
#include "network.hpp"
1616

17-
struct MnistImage;
17+
const size_t input_size = 784;
18+
const uint32_t width = 28, height = 28;
19+
20+
struct MnistImage {
21+
uint8_t label;
22+
uint8_t pixels[input_size];
23+
};
1824

1925
class MnistDataProcessor : public DataProcessor {
2026
private:

graphics/neurons/network.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#include <iostream>
2+
3+
#include "network.h"
4+
5+
void Network::addLayer(const NeuronLayer&& layer) {
6+
layers.push_back(layer);
7+
neurons.assign(layer.begin(), layer.end());
8+
if (render && std::dynamic_pointer_cast<NetworkRenderStrategy>(render) != nullptr) {
9+
std::dynamic_pointer_cast<NetworkRenderStrategy>(render)->addLayer(layer);
10+
}
11+
}
12+
13+
// void setSize(int N) {
14+
// neurons.resize(N);
15+
// // synapses.resize(N, std::vector<Synapse>(N));
16+
// }
17+
18+
std::vector<float> Network::get_current_voltage_state() const {
19+
std::vector<float> state(neurons.size(), 0.0);
20+
for (size_t i = 0; i < neurons.size(); ++i) {
21+
state[i] = neurons[i]->v;
22+
}
23+
return state;
24+
}
25+
26+
void Network::step(std::vector<uint8_t> inputs) {
27+
std::cout << "step (" << time << ")\n";
28+
29+
for (size_t i = 0; i < inputs.size(); i++)
30+
{
31+
if (inputs[i] > 0) {
32+
neurons[i]->v += 1.5f * (inputs[i] / 255.0f) * 0.2f; // inject current to neuron 0
33+
}
34+
}
35+
36+
std::vector<float> dv(neurons.size(), 0.0f);
37+
38+
// Synaptic input
39+
for (auto& [loc, syn] : synapses) {
40+
uint32_t pre_idx = loc & 0xffff;
41+
uint32_t post_idx = (loc >> 8) & 0xffff;
42+
43+
syn.update_pre(dt);
44+
syn.update_post(dt);
45+
46+
// STDP
47+
if (neurons[pre_idx]->spiked) {
48+
syn.on_pre_spike();
49+
dv[post_idx] += syn.weight;
50+
}
51+
}
52+
53+
// Neuron updates
54+
for (size_t i = 0; i < neurons.size(); ++i)
55+
neurons[i]->update(dv[i], dt, time);
56+
57+
// STDP
58+
for (auto& [loc, syn] : synapses) {
59+
uint32_t pre_idx = loc & 0xffff;
60+
uint32_t post_idx = (loc >> 8) & 0xffff;
61+
syn.apply_stdp(neurons[pre_idx]->spiked, neurons[post_idx]->spiked);
62+
if (neurons[post_idx]->spiked) syn.on_post_spike();
63+
}
64+
65+
time += dt;
66+
}
67+
68+
void Network::init() {
69+
if (render) render->init();
70+
}
71+
72+
void Network::draw(float time) const {
73+
if (render) render->draw(time);
74+
}
75+
76+
// void draw() const {
77+
// for (auto& n : neurons) {
78+
// n.draw();
79+
// }
80+
// }
81+
82+
void Network::update(float dt) {
83+
if (render) render->update(dt);
84+
}
85+
86+
void Network::destroy() const {
87+
if (render) render->destroy();
88+
}

graphics/neurons/network.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#ifndef NETWORK_H
2+
#define NETWORK_H
3+
4+
#include <memory>
5+
#include <vector>
6+
#include <unordered_map>
7+
8+
#include "neuron.hpp"
9+
#include "synapse.hpp"
10+
#include "render_strategy.h"
11+
#include "render/network_render.h"
12+
13+
using NeuronLayer = std::vector<std::shared_ptr<Neuron>>;
14+
15+
struct Network {
16+
std::vector<std::shared_ptr<Neuron>> neurons;
17+
std::vector<NeuronLayer> layers;
18+
19+
using Location = uint64_t;
20+
std::unordered_map<Location, Synapse> synapses;
21+
// std::vector<std::vector<Synapse>> synapses;
22+
float dt = 1.0f;
23+
float time = 0.0f;
24+
25+
std::shared_ptr<RenderStrategy> render;
26+
27+
explicit Network() = default;
28+
29+
void addLayer(const NeuronLayer&& layer);
30+
std::vector<float> get_current_voltage_state() const;
31+
void step(std::vector<uint8_t> inputs);
32+
void init();
33+
void draw(float time) const;
34+
void update(float dt);
35+
void destroy() const;
36+
};
37+
38+
#endif

graphics/neurons/network.hpp

Lines changed: 0 additions & 112 deletions
This file was deleted.

graphics/neurons/neurons_gui.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include "neuron.hpp"
1919
#include "network.hpp"
2020
#include "render/neuron_render.hpp"
21-
#include "render/network_render.hpp"
21+
#include "render/network_render.h"
2222
#include "simulation_clock.hpp"
2323
#include "mnist_data_processor.h"
2424

@@ -153,11 +153,11 @@ namespace
153153
showExampleDialog(this);
154154

155155
ImGui::SetNextWindowPos(
156-
ImVec2(m_width - m_width / 5.0f - 10.0f, 10.0f)
156+
ImVec2(m_width - m_width / 3.0f - 10.0f, 10.0f)
157157
, ImGuiCond_FirstUseEver
158158
);
159159
ImGui::SetNextWindowSize(
160-
ImVec2(m_width / 5.0f, m_height / 2.0f)
160+
ImVec2(m_width / 3.0f, m_height / 1.2f)
161161
, ImGuiCond_FirstUseEver
162162
);
163163
ImGui::Begin("Settings"

0 commit comments

Comments
 (0)