Skip to content

Using ReLu as an Option instead of Sigmoid #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
2b43c07
Implemented GUI for snake
Nereuxofficial Aug 8, 2021
d115927
Merge remote-tracking branch 'origin/main' into snake
Nereuxofficial Aug 15, 2021
7de0a29
Removed GUI & Increased Generations & the number of snakes
Nereuxofficial Aug 18, 2021
73a0039
Snakes can no longer rotate forever
Nereuxofficial Aug 18, 2021
1e652d4
Removed unnecessary dependency
Nereuxofficial Aug 18, 2021
0d4e38e
Simplified Snake example
Nereuxofficial Aug 19, 2021
8332522
Added Benchmark library
Nereuxofficial Aug 20, 2021
86f38b3
Remove unused imports
Nereuxofficial Aug 20, 2021
00f0316
Cleanups
Nereuxofficial Aug 20, 2021
ee9861b
Derive Hash instead of implementing it
Nereuxofficial Aug 20, 2021
04a9a22
Cleanups
Nereuxofficial Aug 20, 2021
e559d99
Further Cleanups & Refactoring
Nereuxofficial Aug 21, 2021
5ceaf25
Update README.md
Nereuxofficial Aug 21, 2021
70d905c
Refactoring & Docs
Nereuxofficial Aug 21, 2021
81ef41c
Better function name
Nereuxofficial Aug 21, 2021
39c9bcd
Cleanups
Nereuxofficial Aug 22, 2021
a3f2428
Added comments
Nereuxofficial Aug 22, 2021
0e2c83d
Reworded comments
Nereuxofficial Aug 22, 2021
459c4a3
Removed unnecessary .gitignore
Nereuxofficial Aug 22, 2021
9f5555a
Fancier Badges
Nereuxofficial Aug 23, 2021
2b68ac6
Fixed accidental Crtl+V
Nereuxofficial Aug 23, 2021
144bb77
Fixed accidental Paste
Nereuxofficial Aug 23, 2021
7208831
Split Github Workflows
Aug 25, 2021
b30052e
Removed unnecessary code
Nereuxofficial Aug 25, 2021
b7541d4
WIP: Fixing serialization
Nereuxofficial Aug 28, 2021
4946362
Revert changes to Topology::to_serde_string
Nereuxofficial Aug 28, 2021
f9bf105
Fixed Serialization and Benchmarks
Nereuxofficial Aug 28, 2021
d756682
Fixed Benchmark and included snakes.json
Nereuxofficial Aug 28, 2021
47f6733
Topology::to_string(&self) now uses to_string_pretty
Nereuxofficial Aug 28, 2021
b0f54b0
Removed duplicate functions & Created separate json for benchmarking
Nereuxofficial Aug 28, 2021
f4a9a2e
Restructured some Snake code
Nereuxofficial Aug 28, 2021
157395e
Snake Example simplified and with par_iter
Nereuxofficial Aug 29, 2021
b041c17
Updated dependencies
Nereuxofficial Aug 29, 2021
7b8e69a
Refactoring
Nereuxofficial Aug 29, 2021
72252be
Refactoring
Nereuxofficial Aug 29, 2021
64d8cf8
Added wasm32 build to Tests
Nereuxofficial Aug 29, 2021
0ef3e2c
Added wasm32 build to Tests
Nereuxofficial Aug 29, 2021
89114e9
Fix Github Actions wasm32 build
Nereuxofficial Aug 29, 2021
f834f57
Added benchmarks for math functions.
Sep 11, 2021
9d44b1a
Merge remote-tracking branch 'sakex/main'
Nereuxofficial Sep 12, 2021
cf0ab0c
WIP: Trying out the relu function
Nereuxofficial Oct 3, 2021
db104fd
Extended benchmarks
Nereuxofficial Oct 3, 2021
8c4f130
Merge remote-tracking branch 'upstream/main'
Nereuxofficial Jan 9, 2022
2a362e9
Merge branch 'sakex:main' into main
Nereuxofficial Nov 22, 2022
bc712bd
Merge branch 'sakex:main' into main
Nereuxofficial Mar 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ criterion = "0.3.5"
name = "benchmark"
harness = false

[[bench]]
name="math_functions"
harness=false

[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2.1", features = ["js"] }

Expand Down
1 change: 1 addition & 0 deletions benches/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ fn benchmark(c: &mut Criterion) {
b.iter(|| network.compute(black_box(&[0.0, 0.0])))
});
}

criterion_group!(benches, benchmark);
criterion_main!(benches);
105 changes: 105 additions & 0 deletions benches/math_functions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
//! Contains benchmarks of the functions stored in neural_network::functions.
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use neat_gru::neural_network::functions::*;

extern crate neat_gru;

fn bench_sigmoid(c: &mut Criterion) {
let size: f32 = 0.3518392;
let mut group = c.benchmark_group("Sigmoid Function");
for size in [
size * 0.0,
size,
size * 2.0,
size * 4.0,
size * 6.0,
size * 8.0,
size * 10.0,
size * 12.0,
size * 14.0,
]
.iter()
{
group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, size| {
b.iter(|| fast_sigmoid(*size))
});
}
group.finish();
}

fn bench_tanh(c: &mut Criterion) {
let size: f32 = 0.3518392;
let mut group = c.benchmark_group("tanh Function");
for size in [
size * 0.0,
size,
size * 2.0,
size * 4.0,
size * 6.0,
size * 8.0,
size * 10.0,
size * 12.0,
size * 14.0,
]
.iter()
{
group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, size| {
b.iter(|| fast_tanh(*size))
});
}
group.finish();
}

fn bench_relu(c: &mut Criterion) {
let size: f32 = 0.3518392;
let mut group = c.benchmark_group("relu Function");
for size in [
size * 0.0,
size,
size * 2.0,
size * 4.0,
size * 6.0,
size * 8.0,
size * 10.0,
size * 12.0,
size * 14.0,
]
.iter()
{
group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, size| {
b.iter(|| re_lu(*size))
});
}
group.finish();
}

fn comparison(c: &mut Criterion) {
let size: f32 = 0.3518392;
let mut group = c.benchmark_group("relu vs sigmoid");
for size in [
size * 0.0,
size,
size * 2.0,
size * 4.0,
size * 6.0,
size * 8.0,
size * 10.0,
size * 12.0,
size * 14.0,
]
.iter()
{
group.bench_with_input(BenchmarkId::new("Sigmoid", size), size,
|b, size| b.iter(|| fast_sigmoid(*size)));
group.bench_with_input(BenchmarkId::new("Relu", size), size,
|b, size| b.iter(|| fast_sigmoid(*size)));
}
group.finish();
}

criterion_group! {
name = benches;
config = Criterion::default();
targets = bench_tanh, bench_sigmoid, bench_relu, comparison
}
criterion_main!(benches);
5 changes: 5 additions & 0 deletions src/neural_network/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,8 @@ pub fn fast_tanh<T: Float>(x: T) -> T {
let b = 135135 + x2 * (62370 + x2 * (3150 + x2 * 28));
a / b
}

#[inline]
pub fn re_lu<T: Float>(x: T) -> T{
x.max(T::zero())
}
2 changes: 1 addition & 1 deletion src/neural_network/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod connection_gru;
mod connection_relu;
mod connection_sigmoid;
mod functions;
pub mod functions;
mod neuron;
mod nn;

Expand Down
10 changes: 5 additions & 5 deletions src/neural_network/neuron.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::neural_network::connection_gru::ConnectionGru;
use crate::neural_network::connection_relu::ConnectionRelu;
use crate::neural_network::connection_sigmoid::ConnectionSigmoid;
use crate::neural_network::functions::{fast_sigmoid, fast_tanh};
use crate::neural_network::functions::{fast_sigmoid, fast_tanh, re_lu};
use crate::topology::bias::Bias;
use crate::utils::floats_almost_equal;
use num::Float;
Expand Down Expand Up @@ -95,8 +95,8 @@ where
#[replace_numeric_literals(T::from(literal).unwrap())]
#[inline]
pub fn get_value(&mut self) -> T {
let update_gate = fast_sigmoid(self.update);
let reset_gate = fast_sigmoid(self.reset);
let update_gate = re_lu(self.update);
let reset_gate = re_lu(self.reset);
let current_memory = fast_tanh(self.input + self.memory * reset_gate);
let value = update_gate * self.memory + (1 - update_gate) * current_memory;

Expand All @@ -107,8 +107,8 @@ where
#[replace_numeric_literals(T::from(literal).unwrap())]
#[inline]
pub fn feed_forward(&mut self) {
let update_gate = fast_sigmoid(self.update);
let reset_gate = fast_sigmoid(self.reset);
let update_gate = re_lu(self.update);
let reset_gate = re_lu(self.reset);
let current_memory = self.input + self.memory * reset_gate;
let value = update_gate * self.memory + (1 - update_gate) * current_memory;
for connection in self.connections_gru.iter_mut() {
Expand Down