Skip to content

Commit f3139cb

Browse files
committed
better, faster, stronger
1 parent 071292b commit f3139cb

File tree

4 files changed

+90
-18
lines changed

4 files changed

+90
-18
lines changed

Cargo.toml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,16 @@ edition = "2021"
88
[dependencies]
99
lazy_static = "1.4.0"
1010
hashbrown = "0.12"
11+
12+
[[bench]]
13+
name = "bench_algos"
14+
harness = false
15+
16+
[profile.bench]
17+
debug = true
18+
19+
[profile.release]
20+
debug = true
21+
22+
[dev-dependencies]
23+
criterion = "0.3"

benches/bench_algos.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
use criterion::{black_box, criterion_group, criterion_main, Criterion};
2+
use simple_grad::*;
3+
4+
fn vec_ops(c: &mut Criterion) {
5+
let dims = 100;
6+
let mut embeddings = Vec::new();
7+
8+
for i in 0..1000 {
9+
let mut row = Vec::with_capacity(dims);
10+
for dim in 0..dims {
11+
row.push((i*dim) as f32);
12+
}
13+
embeddings.push(Variable::new(row));
14+
}
15+
//c.bench_function("bench vecs", |b| b.iter(|| embeddings.clone().sum_all()));
16+
17+
let results = embeddings.clone().sum_all();
18+
c.bench_function("bench backward", |b| b.iter(|| {
19+
let mut graph = Graph::new();
20+
graph.backward(&results);
21+
}));
22+
}
23+
24+
criterion_group!(benches, vec_ops);
25+
criterion_main!(benches);

src/graph.rs

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::rc::Rc;
22

33
use hashbrown::HashMap;
4+
use hashbrown::hash_map::Entry;
45
use crate::{DType,ANode,NodeIdx,Node};
56
use crate::vecops::iadd;
67
use crate::pool::{allocate_vec,MPVec};
@@ -19,19 +20,22 @@ impl Graph {
1920
}
2021
}
2122

23+
#[inline]
2224
pub fn debug_nan(&mut self, check: bool) {
2325
self.nan_check = check;
2426
}
25-
2627

28+
#[inline]
2729
pub fn get_grad(&self, node: &ANode) -> Option<&Vec<DType>> {
2830
self.gradients.get(&node.get_id()).map(|v| v.as_ref())
2931
}
3032

33+
#[inline]
3134
pub fn zero_grads(&mut self) {
3235
self.gradients.clear();
3336
}
3437

38+
#[inline]
3539
pub fn clear_memory(&mut self) {
3640
self.gradients.clear();
3741
}
@@ -45,39 +49,50 @@ impl Graph {
4549
}
4650
}
4751

52+
#[inline]
4853
fn get_temp_space(&mut self, size: usize) -> MPVec {
4954
allocate_vec(size)
5055
}
5156

57+
#[inline]
5258
fn add_grad(&mut self, node: &ANode, grad: MPVec) {
5359
self.gradients.insert(node.get_id(), grad);
5460
}
61+
62+
#[inline]
63+
fn add_or_update_grad(&mut self, node: &ANode, grad: MPVec) {
64+
match self.gradients.entry(node.get_id()) {
65+
Entry::Occupied(mut entry) => {
66+
iadd(entry.get_mut(), grad.as_slice());
67+
},
68+
Entry::Vacant(mut entry) => {
69+
entry.insert(grad);
70+
}
71+
}
72+
}
73+
5574

56-
pub fn backward(&mut self, end_node: &ANode) {
75+
pub fn backward(&mut self, end_node: &ANode) {
5776
let out = Run::new(end_node);
5877
// dz/dz of course is 1
5978
let mut z_grad = self.get_or_create_grad(&out);
6079
z_grad.fill(1f32);
6180

6281
// Allocate once
63-
let mut grads = Vec::new();
6482
let mut temp_grads = Vec::new();
6583
self.add_grad(&out, z_grad);
66-
self.recurse(&out, &mut grads, &mut temp_grads);
84+
self.recurse(&out, &mut temp_grads);
6785
}
6886

69-
fn recurse(&mut self, node: &ANode, grads: &mut Vec<MPVec>, temp_grads: &mut Vec<MPVec>) {
87+
fn recurse(&mut self, node: &ANode, temp_grads: &mut Vec<MPVec>) {
7088
if !node.is_leaf() {
7189
let node_grad = self.get_or_create_grad(node);
7290
if let Some(children) = node.get_children() {
73-
grads.clear();
7491
temp_grads.clear();
7592
// Grab gradients
76-
grads.extend(children.iter()
77-
.map(|c| self.get_or_create_grad(c)));
7893

79-
temp_grads.extend(grads.iter()
80-
.map(|g| self.get_temp_space(g.len())));
94+
temp_grads.extend(children.iter()
95+
.map(|c| self.get_temp_space(c.value().len())));
8196

8297
node.compute_grad(&node_grad, temp_grads);
8398

@@ -93,13 +108,10 @@ impl Graph {
93108
}
94109

95110
// Update grads
96-
grads.iter_mut().zip(temp_grads.into_iter()).for_each(|(g, tg)| {
97-
iadd(g, &tg);
98-
});
99111

100112
// Re-add gradients
101-
children.iter().zip(grads.drain(..)).for_each(|(c, g)| {
102-
self.add_grad(c, g);
113+
children.iter().zip(temp_grads.drain(..)).for_each(|(c, g)| {
114+
self.add_or_update_grad(c, g);
103115
});
104116

105117
if node.requires_grad() {
@@ -108,7 +120,7 @@ impl Graph {
108120

109121
// Run children
110122
for child in children.iter() {
111-
self.recurse(child, grads, temp_grads);
123+
self.recurse(child, temp_grads);
112124
}
113125

114126
} else {

src/ops.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ impl Variable {
5353
ANode::new(Rc::new(v))
5454
}
5555

56+
pub fn pooled(value: &[DType]) -> ANode {
57+
let mut mpv = allocate_vec(value.len());
58+
mpv.clone_from_slice(value);
59+
let v = Variable(NodeIdx::new(), Computation::pooled(mpv));
60+
ANode::new(Rc::new(v))
61+
}
62+
5663
}
5764

5865
impl Node for Variable {
@@ -230,11 +237,11 @@ impl Node for AddN {
230237
fn requires_grad(&self) -> bool { false }
231238

232239
fn compute_grad(&self, grad: &[DType], child_grads: &mut [MPVec]) {
233-
// f(x,y) = x - y
240+
// f(x,y) = x + y
234241
// df(x,y)/dx = 1
235242
// df(x,y)/dy = 1
236243
for out in child_grads.iter_mut() {
237-
let it = Broadcast::sized(grad, out.len());
244+
let it = Broadcast::sized(grad, out.len());
238245
let mut agg = Updater::new(out, grad.len());
239246
it.for_each(|gi| agg.add(*gi));
240247
}
@@ -749,6 +756,7 @@ impl Node for BulkSum {
749756

750757
fn is_leaf(&self) -> bool { false }
751758

759+
#[inline]
752760
fn value(&self) -> &[DType] {
753761
&self.2.get()
754762
}
@@ -890,6 +898,20 @@ mod tests {
890898
assert_eq!(res.value(), &[2., 4.]);
891899
}
892900

901+
#[test]
902+
fn test_add_simple() {
903+
let x = Variable::new(vec![0., 1.]);
904+
let res = AddN::new(x.clone(), x.clone()).sum();
905+
assert_eq!(res.value(), &[2.]);
906+
907+
908+
let mut graph = Graph::new();
909+
graph.backward(&res);
910+
911+
let res = graph.get_grad(&x).unwrap();
912+
assert_eq!(res, &[2., 2.]);
913+
}
914+
893915
#[test]
894916
fn test_add_scalar() {
895917
let x = Variable::new(vec![0., 1.]);

0 commit comments

Comments
 (0)