Skip to content

Commit 14bf5ad

Browse files
committed
warnings, better defaults for exp
1 parent 2c83a2f commit 14bf5ad

File tree

4 files changed

+125
-29
lines changed

4 files changed

+125
-29
lines changed

src/graph.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,33 @@ impl Graph {
170170
}
171171
}
172172

173+
pub fn print_graph(node: &ANode) {
174+
let mut s = String::new();
175+
Graph::print_graph_level(node, 0, &mut s);
176+
}
173177

178+
fn print_graph_level(node: &ANode, depth: usize, buff: &mut String) {
179+
buff.clear();
180+
for _ in 0..depth {
181+
buff.push(' ');
182+
}
183+
if node.is_leaf() {
184+
eprintln!("{}- {}({:?}", buff, node.op_name(), node.value());
185+
} else {
186+
eprintln!("{}- {}({:?} {{", buff, node.op_name(), node.value());
187+
if let Some(children) = node.get_children() {
188+
for c in children {
189+
Graph::print_graph_level(c, depth + 1, buff);
190+
}
191+
}
192+
let spaces = &buff[..depth];
193+
eprintln!("{}}}", spaces);
194+
}
195+
}
174196
}
175197

198+
199+
176200
pub(crate) struct Run(NodeIdx, Vec<ANode>);
177201

178202
impl Run {
@@ -183,6 +207,8 @@ impl Run {
183207
}
184208

185209
impl Node for Run {
210+
fn op_name(&self) -> &str { "Run" }
211+
186212
fn get_id(&self) -> NodeIdx { self.0.clone() }
187213

188214
fn get_children(&self) -> Option<&[ANode]> {

src/lib.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ impl NodeIdx {
3030
}
3131
}
3232

33-
3433
pub trait Node {
34+
fn op_name(&self) -> &str;
35+
3536
fn get_id(&self) -> NodeIdx;
3637

3738
fn is_leaf(&self) -> bool;
@@ -42,7 +43,6 @@ pub trait Node {
4243

4344
fn requires_grad(&self) -> bool;
4445

45-
//fn compute_grad(&self, _grad: &[DType], _results: &mut [MPVec]) { }
4646
fn compute_grad(&self, _grad: &[DType], _results: &mut [&mut [DType]]) { }
4747

4848
}
@@ -76,11 +76,11 @@ impl ANode {
7676
}
7777

7878
pub fn exp(&self) -> ANode {
79-
Exp::new(self.clone(), true)
79+
Exp::new(self.clone(), false)
8080
}
8181

82-
pub fn exp_exact(&self) -> ANode {
83-
Exp::new(self.clone(), false)
82+
pub fn exp_approx(&self) -> ANode {
83+
Exp::new(self.clone(), true)
8484
}
8585

8686
pub fn sum(&self) -> ANode {
@@ -91,6 +91,10 @@ impl ANode {
9191
Slice::new(self.clone(), start, len)
9292
}
9393

94+
pub fn name(&self, name: String) -> ANode {
95+
Named::new(self.clone(), name)
96+
}
97+
9498
fn require_grad(self) -> ANode {
9599
ANode(Rc::new(RequiresGrad::new(self.0)))
96100
}

0 commit comments

Comments
 (0)