Skip to content

Spiking neural networks #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 131 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "neat-gru"
version = "1.1.0"
version = "1.1.1"
authors = ["sakex <alexandre@senges.ch>"]
edition = "2018"
description = "NEAT algorithm with GRU gates"
Expand All @@ -9,6 +9,9 @@ repository = "https://github.com/sakex/neat-gru-rust"
categories = ["science", "wasm"]
keywords = ["neat", "ai", "machine-learning", "genetic", "algorithm"]

[features]
default = []
snn = ["tokio", "futures"]

[lib]
crate-type = ["cdylib", "rlib"]
Expand All @@ -23,6 +26,8 @@ numeric_literals = "0.2.0"
rayon = "1.5.1"
itertools = "0.10.1"
async-trait = "0.1.51"
futures = { version = "0.3", optional = true }
tokio = { version = "1.0", optional = true, features = ["sync", "rt", "macros", "time"] }

[dev-dependencies]
criterion = "0.3.5"
Expand Down
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,21 @@
## Examples
[XOR](examples/example.rs)

[Snake](examples/snake-cli)


Right now this is the only working example. You can run it via:
```
```bash
cargo run --example example
```

[Snake](examples/snake-cli)

```bash
cargo run --example snake-cli
```

## How to use
In `Cargo.toml`:
```
[dependencies]
neat-gru = 1.0.0"
neat-gru = 1.1.0"
```
Create a struct that implements the `Game` trait
```rust
Expand Down
3 changes: 2 additions & 1 deletion benches/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ use neat_gru::neural_network::nn::NeuralNetwork;
use neat_gru::topology::Topology;
use std::fs::File;
use std::io::Read;
use neat_gru::neural_network::nn_trait::NN;

fn benchmark(c: &mut Criterion) {
let mut file = File::open("snakes_benchmark.json").expect("Can't open snakes_benchmark.json");
let file_string = &mut "".to_string();
file.read_to_string(file_string).unwrap();
let topology = Topology::from_string(file_string);
let mut network = unsafe { NeuralNetwork::new(&topology) };
let mut network = unsafe { NeuralNetwork::from_topology(&topology) };
c.bench_function("nn::compute", |b| {
b.iter(|| network.compute(black_box(&[0.0, 0.0])))
});
Expand Down
5 changes: 4 additions & 1 deletion src/neural_network/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ mod connection_relu;
mod connection_sigmoid;
mod functions;
mod neuron;
mod nn;
pub mod nn;
pub mod nn_trait;

pub use nn::*;
#[cfg(feature = "snn")]
pub mod spiking;
21 changes: 8 additions & 13 deletions src/neural_network/nn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use num::Float;
use std::fmt::Display;

use super::connection_relu::ConnectionRelu;
use super::nn_trait::NN;

#[derive(Debug)]
pub struct NeuralNetwork<T>
Expand All @@ -22,17 +23,11 @@ where
unsafe impl<T> Send for NeuralNetwork<T> where T: Float + std::ops::AddAssign + Display + Send {}
unsafe impl<T> Sync for NeuralNetwork<T> where T: Float + std::ops::AddAssign + Display + Send {}

impl<T> NeuralNetwork<T>
impl<T> NN<T> for NeuralNetwork<T>
where
T: Float + std::ops::AddAssign + Display + Send,
{
/// Instantiates a new Neural Network from a `Topology`
///
/// # Safety
///
/// If the Topology is ill-formed, it will result in pointer overflow.
/// Topologies generated by this crate are guaranteed to be safe.
pub unsafe fn new(topology: &Topology<T>) -> NeuralNetwork<T> {
unsafe fn from_topology(topology: &Topology<T>) -> NeuralNetwork<T> {
let layer_count = topology.layers_sizes.len();
let sizes = &topology.layers_sizes;
let mut layer_addresses = vec![0; layer_count];
Expand Down Expand Up @@ -113,7 +108,12 @@ where
net.reset_neurons_value();
net
}
}

impl<T> NeuralNetwork<T>
where
T: Float + std::ops::AddAssign + Display + Send,
{
#[inline]
fn reset_neurons_value(&mut self) {
for (neuron, bias) in self.neurons.iter_mut().zip(self.biases.iter()) {
Expand Down Expand Up @@ -151,11 +151,6 @@ where
neuron.reset_state();
}
}

pub fn from_string(serialized: &str) -> NeuralNetwork<T> {
let top = Topology::from_string(serialized);
unsafe { NeuralNetwork::new(&top) }
}
}

impl<T> PartialEq for NeuralNetwork<T>
Expand Down
24 changes: 24 additions & 0 deletions src/neural_network/nn_trait.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use std::fmt::Display;

use num::Float;

use crate::topology::Topology;

pub trait NN<T>: Sized
where
T: Float + std::ops::AddAssign + Display + Send,
{
/// Instantiates a new Neural Network from a `Topology`
///
/// # Safety
///
/// If the Topology is ill-formed, it will result in pointer overflow.
/// Topologies generated by this crate are guaranteed to be safe.
unsafe fn from_topology(topology: &Topology<T>) -> Self;

/// Deserializes a serde serialized Topolgy into a neural network
fn from_string(serialized: &str) -> Self {
let top = Topology::from_string(serialized);
unsafe { Self::from_topology(&top) }
}
}
4 changes: 4 additions & 0 deletions src/neural_network/spiking/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mod spiking_neuron;
mod spiking_nn;

pub use spiking_nn::*;
Loading