Skip to content

Commit 4f3d3d5

Browse files
committed
adds pool on/off
1 parent babaa58 commit 4f3d3d5

File tree

4 files changed

+55
-10
lines changed

4 files changed

+55
-10
lines changed

src/graph.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,23 @@ use crate::vecops::iadd;
77
#[derive(Debug)]
88
pub struct Graph {
99
gradients: HashMap<NodeIdx, Vec<DType>>,
10-
freelist: HashMap<usize, Vec<Vec<DType>>>
10+
freelist: HashMap<usize, Vec<Vec<DType>>>,
11+
nan_check: bool
1112
}
1213

1314
impl Graph {
1415
pub fn new() -> Self {
1516
Graph {
1617
gradients: HashMap::new(),
17-
freelist: HashMap::new()
18+
freelist: HashMap::new(),
19+
nan_check: false
1820
}
1921
}
2022

23+
pub fn debug_nan(&mut self, check: bool) {
24+
self.nan_check = check;
25+
}
26+
2127
pub fn backward(&mut self, end_node: &ANode) {
2228
let out = Run::new(end_node);
2329
// dz/dz of course is 1
@@ -86,6 +92,17 @@ impl Graph {
8692

8793
node.compute_grad(&node_grad, &mut temp_grads);
8894

95+
if self.nan_check {
96+
for (i, grad) in temp_grads.iter().enumerate() {
97+
for gi in grad.iter() {
98+
if gi.is_nan() {
99+
eprintln!("Nan detected with id {:?}, child {}", node.get_id(), i);
100+
panic!()
101+
}
102+
}
103+
}
104+
}
105+
89106
// Update grads
90107
grads.iter_mut().zip(temp_grads.into_iter()).for_each(|(g, tg)| {
91108
iadd(g, &tg);

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ mod pool;
99

1010
pub use graph::Graph;
1111
pub use ops::{Variable,Constant};
12-
pub use pool::clear_pool;
12+
pub use pool::{clear_pool, use_shared_pool};
1313

1414
use std::sync::atomic::{AtomicUsize, Ordering};
1515
use std::sync::Arc;

src/ops.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ impl Computation {
2020
}
2121

2222
fn shared(value: Arc<Vec<DType>>) -> Self {
23-
Computation { value: Data::Shared(value) }
23+
Computation { value: Data::Shared(value) }
2424
}
2525

2626
fn pooled(value: MPVec) -> Self {
@@ -427,6 +427,10 @@ impl Node for Power {
427427
grad.iter().zip(lx.zip(ly)).for_each(|(gi, (xi, yi))| {
428428
out.add(*gi * *yi * xi.powf(*yi - 1f32));
429429
});
430+
println!("x: {:?}", x);
431+
println!("y: {:?}", y);
432+
println!("grad: {:?}", grad);
433+
println!("dX: {:?}", child_grads[0]);
430434

431435
// df(x,y)/dy = ln(y) * x ^ y
432436
let (lx, ly) = Broadcast::from_pair(x, y);
@@ -1147,6 +1151,18 @@ mod tests {
11471151
assert_eq!(Some(&vec![0f32, 2f32, 4f32]), x_grad);
11481152
}
11491153

1154+
fn euclidean_distance(x: &ANode, y: &ANode) -> ANode {
1155+
let minus = x - y;
1156+
println!("{:?}", minus.get_id());
1157+
let pow = minus.pow(2f32);
1158+
println!("{:?}", pow.get_id());
1159+
let sum = pow.sum();
1160+
println!("{:?}", sum.get_id());
1161+
let sqrt = sum.pow(0.5);
1162+
println!("{:?}", sqrt.get_id());
1163+
sqrt
1164+
}
1165+
11501166
#[test]
11511167
fn test_backward_pass_complicated() {
11521168
// (x+2) ^ 2

src/pool.rs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ use lazy_static::lazy_static;
33
use std::convert::{AsRef, AsMut};
44
use std::collections::HashMap;
55
use std::sync::Mutex;
6+
use std::sync::atomic::{AtomicBool,Ordering};
67
use std::ops::{Drop,Deref,DerefMut};
78

89
use crate::DType;
910

11+
static USE_POOL: AtomicBool = AtomicBool::new(true);
1012
lazy_static! {
1113
static ref POOL: Mutex<MemoryPool> = {
1214
let m = MemoryPool::new();
@@ -43,10 +45,18 @@ impl MemoryPool {
4345
}
4446
}
4547

48+
pub fn use_shared_pool(use_pool: bool) {
49+
USE_POOL.store(use_pool, Ordering::SeqCst);
50+
}
51+
4652
pub fn allocate_vec(size: usize) -> MPVec {
47-
let mut pool = POOL.lock()
48-
.expect("Error accessing memory pool!");
49-
pool.get(size)
53+
if USE_POOL.load(Ordering::Relaxed) {
54+
let mut pool = POOL.lock()
55+
.expect("Error accessing memory pool!");
56+
pool.get(size)
57+
} else {
58+
MPVec(vec![0.; size])
59+
}
5060
}
5161

5262
pub fn clear_pool() {
@@ -56,9 +66,11 @@ pub fn clear_pool() {
5666
}
5767

5868
fn return_vec(v: Vec<DType>) {
59-
let mut pool = POOL.lock()
60-
.expect("Error accessing memory pool!");
61-
pool.ret(v);
69+
if USE_POOL.load(Ordering::Relaxed) {
70+
let mut pool = POOL.lock()
71+
.expect("Error accessing memory pool!");
72+
pool.ret(v);
73+
}
6274
}
6375

6476
pub struct MPVec(Vec<DType>);

0 commit comments

Comments
 (0)