|
1 | 1 | use std::rc::Rc; |
| 2 | +use std::ops::Add; |
2 | 3 |
|
3 | 4 | use hashbrown::HashMap; |
4 | 5 | use hashbrown::hash_map::Entry; |
@@ -40,6 +41,17 @@ impl Graph { |
40 | 41 | self.gradients.clear(); |
41 | 42 | } |
42 | 43 |
|
| 44 | + pub fn stats(&self, node: &ANode) -> GraphStats { |
| 45 | + let stats = GraphStats::new(1, node.value().len()); |
| 46 | + if let Some(children) = node.get_children() { |
| 47 | + children.iter() |
| 48 | + .map(|cn| self.stats(cn)) |
| 49 | + .fold(stats, |acc, x| acc + x) |
| 50 | + } else { |
| 51 | + stats |
| 52 | + } |
| 53 | + } |
| 54 | + |
43 | 55 | fn get_or_create_grad(&mut self, node: &ANode) -> MPVec { |
44 | 56 | let n_idx = node.get_id(); |
45 | 57 | if self.gradients.contains_key(&n_idx) { |
@@ -162,3 +174,47 @@ impl Node for Run { |
162 | 174 | } |
163 | 175 | } |
164 | 176 |
|
| 177 | +#[derive(Clone,Copy,Debug)] |
| 178 | +pub struct GraphStats { |
| 179 | + ops: usize, |
| 180 | + memory: usize |
| 181 | +} |
| 182 | + |
| 183 | +impl GraphStats { |
| 184 | + fn new(ops: usize, memory: usize) -> Self { |
| 185 | + GraphStats {ops, memory} |
| 186 | + } |
| 187 | + |
| 188 | + fn zero() -> Self { |
| 189 | + GraphStats::new(0, 0) |
| 190 | + } |
| 191 | +} |
| 192 | + |
| 193 | +impl Add for GraphStats { |
| 194 | + type Output = Self; |
| 195 | + |
| 196 | + fn add(self, other: Self) -> Self { |
| 197 | + Self { |
| 198 | + ops: self.ops + other.ops, |
| 199 | + memory: self.memory + other.memory, |
| 200 | + } |
| 201 | + } |
| 202 | +} |
| 203 | + |
| 204 | + |
| 205 | +#[cfg(test)] |
| 206 | +mod graph_tests { |
| 207 | + use super::*; |
| 208 | + use crate::*; |
| 209 | + |
| 210 | + #[test] |
| 211 | + fn test_add() { |
| 212 | + let x = Variable::new(vec![0., 1.]); |
| 213 | + let y = Variable::new(vec![2., 3.]); |
| 214 | + let res = x + y; |
| 215 | + let graph = Graph::new(); |
| 216 | + let stats = graph.stats(&res); |
| 217 | + assert_eq!(stats.ops, 3); |
| 218 | + assert_eq!(stats.memory, 6); |
| 219 | + } |
| 220 | +} |
0 commit comments