Skip to content

Commit b27d686

Browse files
authored
feature: Tree::total_branch_length now works. (#429)
* Replace unimplemented! macro with working code. * Remove Tree::node_table. This is not a breaking change due to use of unimplemented!.
1 parent 59054b6 commit b27d686

File tree

3 files changed

+56
-12
lines changed

3 files changed

+56
-12
lines changed

examples/haploid_wright_fisher.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,25 @@ proptest! {
114114
#[test]
115115
fn test_simulate_proptest(seed in any::<u64>(),
116116
num_generations in 50..100i32,
117-
simplify_interval in 1..100i32 ) {
118-
let _ = simulate(seed, 100, num_generations, simplify_interval).unwrap();
117+
simplify_interval in 1..100i32) {
118+
let ts = simulate(seed, 100, num_generations, simplify_interval).unwrap();
119+
120+
// stress test the branch length fn b/c it is not a trivial
121+
// wrapper around the C API.
122+
{
123+
use streaming_iterator::StreamingIterator;
124+
let mut x = f64::NAN;
125+
if let Ok(mut tree_iter) = ts.tree_iterator(0) {
126+
// We will only do the first tree to save time.
127+
if let Some(tree) = tree_iter.next() {
128+
let b = tree.total_branch_length(false).unwrap();
129+
let b2 = unsafe {
130+
tskit::bindings::tsk_tree_get_total_branch_length(tree.as_ptr(), -1, &mut x)
131+
};
132+
assert!(b2 >= 0, "{}", b2);
133+
assert!(f64::from(b) - x <= 1e-8);
134+
}
135+
}
136+
}
119137
}
120138
}

src/tree_interface.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ impl TreeInterface {
340340
/// Get the parent of node `u`.
341341
///
342342
/// Returns `None` if `u` is out of range.
343-
pub fn parent<N: Into<NodeId> + Copy>(&self, u: N) -> Option<NodeId> {
343+
pub fn parent<N: Into<NodeId> + Copy + std::fmt::Debug>(&self, u: N) -> Option<NodeId> {
344344
sys::tsk_column_access::<NodeId, _, _, _>(u.into(), self.as_ref().parent, self.array_len)
345345
}
346346

@@ -489,12 +489,9 @@ impl TreeInterface {
489489
/// (and the tree sequence from which it came).
490490
///
491491
/// This is a convenience function for accessing node times, etc..
492-
pub fn node_table(&self) -> crate::NodeTable {
493-
unimplemented!("this needs to return &NodeTable");
494-
// crate::NodeTable::new_from_table(unsafe {
495-
// &(*(*(*self.as_ptr()).tree_sequence).tables).nodes
496-
// })
497-
}
492+
// fn node_table(&self) -> &crate::NodeTable {
493+
// &self.nodes
494+
// }
498495

499496
/// Calculate the total length of the tree via a preorder traversal.
500497
///
@@ -506,13 +503,19 @@ impl TreeInterface {
506503
///
507504
/// [`TskitError`] may be returned if a node index is out of range.
508505
pub fn total_branch_length(&self, by_span: bool) -> Result<Time, TskitError> {
509-
let nt = self.node_table();
506+
let time: &[Time] = sys::generate_slice(
507+
unsafe {
508+
(*(*(*self.non_owned_pointer.as_ptr()).tree_sequence).tables)
509+
.nodes
510+
.time
511+
},
512+
self.num_nodes,
513+
);
510514
let mut b = Time::from(0.);
511515
for n in self.traverse_nodes(NodeTraversalOrder::Preorder) {
512516
let p = self.parent(n).ok_or(TskitError::IndexError {})?;
513517
if p != NodeId::NULL {
514-
b += nt.time(p).ok_or(TskitError::IndexError {})?
515-
- nt.time(n).ok_or(TskitError::IndexError {})?;
518+
b += time[p.as_usize()] - time[n.as_usize()]
516519
}
517520
}
518521

tests/book_trees.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,29 @@ fn initialize_from_table_collection() {
143143
}
144144
// ANCHOR_END: iterate_node_siblings_via_array_getters
145145

146+
let mut tree_iterator = treeseq.tree_iterator(TreeFlags::default()).unwrap();
147+
let mut total_branch_lengths = vec![];
148+
while let Some(tree) = tree_iterator.next() {
149+
total_branch_lengths.push(tree.total_branch_length(false).unwrap());
150+
}
151+
152+
let mut tree_iterator = treeseq.tree_iterator(TreeFlags::default()).unwrap();
153+
let mut total_branch_lengths_ll = vec![];
154+
let mut x = 0.0;
155+
while let Some(tree) = tree_iterator.next() {
156+
let l =
157+
unsafe { tskit::bindings::tsk_tree_get_total_branch_length(tree.as_ptr(), -1, &mut x) };
158+
assert!(l >= 0);
159+
total_branch_lengths_ll.push(x);
160+
}
161+
162+
for (i, j) in total_branch_lengths
163+
.iter()
164+
.zip(total_branch_lengths_ll.iter())
165+
{
166+
assert_eq!(i, j, "{} {}", i, j);
167+
}
168+
146169
// ANCHOR: iterate_edge_differences
147170
if let Ok(mut edge_diff_iterator) = treeseq.edge_differences_iter() {
148171
while let Some(diffs) = edge_diff_iterator.next() {

0 commit comments

Comments
 (0)