Skip to content

Commit 8aad839

Browse files
committed
stats
1 parent 4379eb0 commit 8aad839

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

src/graph.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::rc::Rc;
2+
use std::ops::Add;
23

34
use hashbrown::HashMap;
45
use hashbrown::hash_map::Entry;
@@ -40,6 +41,17 @@ impl Graph {
4041
self.gradients.clear();
4142
}
4243

44+
pub fn stats(&self, node: &ANode) -> GraphStats {
45+
let stats = GraphStats::new(1, node.value().len());
46+
if let Some(children) = node.get_children() {
47+
children.iter()
48+
.map(|cn| self.stats(cn))
49+
.fold(stats, |acc, x| acc + x)
50+
} else {
51+
stats
52+
}
53+
}
54+
4355
fn get_or_create_grad(&mut self, node: &ANode) -> MPVec {
4456
let n_idx = node.get_id();
4557
if self.gradients.contains_key(&n_idx) {
@@ -162,3 +174,47 @@ impl Node for Run {
162174
}
163175
}
164176

177+
#[derive(Clone,Copy,Debug)]
178+
pub struct GraphStats {
179+
ops: usize,
180+
memory: usize
181+
}
182+
183+
impl GraphStats {
184+
fn new(ops: usize, memory: usize) -> Self {
185+
GraphStats {ops, memory}
186+
}
187+
188+
fn zero() -> Self {
189+
GraphStats::new(0, 0)
190+
}
191+
}
192+
193+
impl Add for GraphStats {
194+
type Output = Self;
195+
196+
fn add(self, other: Self) -> Self {
197+
Self {
198+
ops: self.ops + other.ops,
199+
memory: self.memory + other.memory,
200+
}
201+
}
202+
}
203+
204+
205+
#[cfg(test)]
206+
mod graph_tests {
207+
use super::*;
208+
use crate::*;
209+
210+
#[test]
211+
fn test_add() {
212+
let x = Variable::new(vec![0., 1.]);
213+
let y = Variable::new(vec![2., 3.]);
214+
let res = x + y;
215+
let graph = Graph::new();
216+
let stats = graph.stats(&res);
217+
assert_eq!(stats.ops, 3);
218+
assert_eq!(stats.memory, 6);
219+
}
220+
}

0 commit comments

Comments
 (0)