Skip to content

Commit 611f693

Browse files
authored
refactor: improve Tree interface ergonomics (#388)
* pub fns taking NodeId now take N: Into<NodeId> + Copy BREAKING CHANGE: code using, e.g., 1_i32.into() must be rewritten as 1_i32
1 parent 937ce0e commit 611f693

File tree

2 files changed

+42
-35
lines changed

2 files changed

+42
-35
lines changed

src/tree_interface.rs

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -295,18 +295,18 @@ impl TreeInterface {
295295

296296
// error if we are not tracking samples,
297297
// Ok(None) if u is out of range
298-
fn left_sample(&self, u: NodeId) -> Option<NodeId> {
298+
fn left_sample<N: Into<NodeId> + Copy>(&self, u: N) -> Option<NodeId> {
299299
// SAFETY: internal pointer cannot be NULL
300300
let ptr = unsafe { *self.as_ptr() };
301-
unsafe_tsk_column_access!(u.0, 0, self.num_nodes, ptr, left_sample, NodeId)
301+
unsafe_tsk_column_access!(u.into().0, 0, self.num_nodes, ptr, left_sample, NodeId)
302302
}
303303

304304
// error if we are not tracking samples,
305305
// Ok(None) if u is out of range
306-
fn right_sample(&self, u: NodeId) -> Option<NodeId> {
306+
fn right_sample<N: Into<NodeId> + Copy>(&self, u: N) -> Option<NodeId> {
307307
// SAFETY: internal pointer cannot be NULL
308308
let ptr = unsafe { *self.as_ptr() };
309-
unsafe_tsk_column_access!(u.0, 0, self.num_nodes, ptr, right_sample, NodeId)
309+
unsafe_tsk_column_access!(u.into().0, 0, self.num_nodes, ptr, right_sample, NodeId)
310310
}
311311

312312
/// Return the `[left, right)` coordinates of the tree.
@@ -328,46 +328,46 @@ impl TreeInterface {
328328
/// Get the parent of node `u`.
329329
///
330330
/// Returns `None` if `u` is out of range.
331-
pub fn parent(&self, u: NodeId) -> Option<NodeId> {
331+
pub fn parent<N: Into<NodeId> + Copy>(&self, u: N) -> Option<NodeId> {
332332
// SAFETY: internal pointer cannot be NULL
333333
let ptr = unsafe { *self.as_ptr() };
334-
unsafe_tsk_column_access!(u.0, 0, self.array_len, ptr, parent, NodeId)
334+
unsafe_tsk_column_access!(u.into().0, 0, self.array_len, ptr, parent, NodeId)
335335
}
336336

337337
/// Get the left child of node `u`.
338338
///
339339
/// Returns `None` if `u` is out of range.
340-
pub fn left_child(&self, u: NodeId) -> Option<NodeId> {
340+
pub fn left_child<N: Into<NodeId> + Copy>(&self, u: N) -> Option<NodeId> {
341341
// SAFETY: internal pointer cannot be NULL
342342
let ptr = unsafe { *self.as_ptr() };
343-
unsafe_tsk_column_access!(u.0, 0, self.array_len, ptr, left_child, NodeId)
343+
unsafe_tsk_column_access!(u.into().0, 0, self.array_len, ptr, left_child, NodeId)
344344
}
345345

346346
/// Get the right child of node `u`.
347347
///
348348
/// Returns `None` if `u` is out of range.
349-
pub fn right_child(&self, u: NodeId) -> Option<NodeId> {
349+
pub fn right_child<N: Into<NodeId> + Copy>(&self, u: N) -> Option<NodeId> {
350350
// SAFETY: internal pointer cannot be NULL
351351
let ptr = unsafe { *self.as_ptr() };
352-
unsafe_tsk_column_access!(u.0, 0, self.array_len, ptr, right_child, NodeId)
352+
unsafe_tsk_column_access!(u.into().0, 0, self.array_len, ptr, right_child, NodeId)
353353
}
354354

355355
/// Get the left sib of node `u`.
356356
///
357357
/// Returns `None` if `u` is out of range.
358-
pub fn left_sib(&self, u: NodeId) -> Option<NodeId> {
358+
pub fn left_sib<N: Into<NodeId> + Copy>(&self, u: N) -> Option<NodeId> {
359359
// SAFETY: internal pointer cannot be NULL
360360
let ptr = unsafe { *self.as_ptr() };
361-
unsafe_tsk_column_access!(u.0, 0, self.array_len, ptr, left_sib, NodeId)
361+
unsafe_tsk_column_access!(u.into().0, 0, self.array_len, ptr, left_sib, NodeId)
362362
}
363363

364364
/// Get the right sib of node `u`.
365365
///
366366
/// Returns `None` if `u` is out of range.
367-
pub fn right_sib(&self, u: NodeId) -> Option<NodeId> {
367+
pub fn right_sib<N: Into<NodeId> + Copy>(&self, u: N) -> Option<NodeId> {
368368
// SAFETY: internal pointer cannot be NULL
369369
let ptr = unsafe { *self.as_ptr() };
370-
unsafe_tsk_column_access!(u.0, 0, self.array_len, ptr, right_sib, NodeId)
370+
unsafe_tsk_column_access!(u.into().0, 0, self.array_len, ptr, right_sib, NodeId)
371371
}
372372

373373
/// Obtain the list of samples for the current tree/tree sequence
@@ -406,17 +406,17 @@ impl TreeInterface {
406406
///
407407
/// * `Some(iterator)` if `u` is valid
408408
/// * `None` otherwise
409-
pub fn parents(&self, u: NodeId) -> impl Iterator<Item = NodeId> + '_ {
410-
ParentsIterator::new(self, u)
409+
pub fn parents<N: Into<NodeId> + Copy>(&self, u: N) -> impl Iterator<Item = NodeId> + '_ {
410+
ParentsIterator::new(self, u.into())
411411
}
412412

413413
/// Return an [`Iterator`] over the children of node `u`.
414414
/// # Returns
415415
///
416416
/// * `Some(iterator)` if `u` is valid
417417
/// * `None` otherwise
418-
pub fn children(&self, u: NodeId) -> impl Iterator<Item = NodeId> + '_ {
419-
ChildIterator::new(self, u)
418+
pub fn children<N: Into<NodeId> + Copy>(&self, u: N) -> impl Iterator<Item = NodeId> + '_ {
419+
ChildIterator::new(self, u.into())
420420
}
421421

422422
/// Return an [`Iterator`] over the sample nodes descending from node `u`.
@@ -430,8 +430,11 @@ impl TreeInterface {
430430
/// * Some(Ok(iterator)) if [`TreeFlags::SAMPLE_LISTS`] is in [`TreeInterface::flags`]
431431
/// * Some(Err(_)) if [`TreeFlags::SAMPLE_LISTS`] is not in [`TreeInterface::flags`]
432432
/// * None if `u` is not valid.
433-
pub fn samples(&self, u: NodeId) -> Result<impl Iterator<Item = NodeId> + '_, TskitError> {
434-
SamplesIterator::new(self, u)
433+
pub fn samples<N: Into<NodeId> + Copy>(
434+
&self,
435+
u: N,
436+
) -> Result<impl Iterator<Item = NodeId> + '_, TskitError> {
437+
SamplesIterator::new(self, u.into())
435438
}
436439

437440
/// Return an [`Iterator`] over the roots of the tree.
@@ -514,10 +517,14 @@ impl TreeInterface {
514517
/// # Errors
515518
///
516519
/// * [`TskitError`] if [`TreeFlags::NO_SAMPLE_COUNTS`].
517-
pub fn num_tracked_samples(&self, u: NodeId) -> Result<SizeType, TskitError> {
520+
pub fn num_tracked_samples<N: Into<NodeId> + Copy>(
521+
&self,
522+
u: N,
523+
) -> Result<SizeType, TskitError> {
518524
let mut n = SizeType(tsk_size_t::MAX);
519525
let np: *mut tsk_size_t = &mut n.0;
520-
let code = unsafe { ll_bindings::tsk_tree_get_num_tracked_samples(self.as_ptr(), u.0, np) };
526+
let code =
527+
unsafe { ll_bindings::tsk_tree_get_num_tracked_samples(self.as_ptr(), u.into().0, np) };
521528
handle_tsk_return_value!(code, n)
522529
}
523530

src/trees.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -607,14 +607,14 @@ pub(crate) mod test_trees {
607607
// These nodes are all out of range
608608
for i in 100..110 {
609609
let mut nsteps = 0;
610-
for _ in tree.parents(i.into()) {
610+
for _ in tree.parents(i) {
611611
nsteps += 1;
612612
}
613613
assert_eq!(nsteps, 0);
614614
}
615615

616-
assert_eq!(tree.parents((-1_i32).into()).count(), 0);
617-
assert_eq!(tree.children((-1_i32).into()).count(), 0);
616+
assert_eq!(tree.parents(-1_i32).count(), 0);
617+
assert_eq!(tree.children(-1_i32).count(), 0);
618618

619619
let roots = tree.roots_to_vec();
620620
for r in roots.iter() {
@@ -665,9 +665,9 @@ pub(crate) mod test_trees {
665665
assert_eq!(treeseq.num_samples(), 2);
666666
let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
667667
if let Some(tree) = tree_iter.next() {
668-
assert_eq!(tree.num_tracked_samples(2.into()).unwrap(), 1);
669-
assert_eq!(tree.num_tracked_samples(1.into()).unwrap(), 1);
670-
assert_eq!(tree.num_tracked_samples(0.into()).unwrap(), 2);
668+
assert_eq!(tree.num_tracked_samples(2).unwrap(), 1);
669+
assert_eq!(tree.num_tracked_samples(1).unwrap(), 1);
670+
assert_eq!(tree.num_tracked_samples(0).unwrap(), 2);
671671
}
672672
}
673673

@@ -678,9 +678,9 @@ pub(crate) mod test_trees {
678678
assert_eq!(treeseq.num_samples(), 2);
679679
let mut tree_iter = treeseq.tree_iterator(TreeFlags::NO_SAMPLE_COUNTS).unwrap();
680680
if let Some(tree) = tree_iter.next() {
681-
assert_eq!(tree.num_tracked_samples(2.into()).unwrap(), 0);
682-
assert_eq!(tree.num_tracked_samples(1.into()).unwrap(), 0);
683-
assert_eq!(tree.num_tracked_samples(0.into()).unwrap(), 0);
681+
assert_eq!(tree.num_tracked_samples(2).unwrap(), 0);
682+
assert_eq!(tree.num_tracked_samples(1).unwrap(), 0);
683+
assert_eq!(tree.num_tracked_samples(0).unwrap(), 0);
684684
}
685685
}
686686

@@ -695,22 +695,22 @@ pub(crate) mod test_trees {
695695
assert!(tree.flags().contains(TreeFlags::SAMPLE_LISTS));
696696
let mut s = vec![];
697697

698-
if let Ok(iter) = tree.samples(0.into()) {
698+
if let Ok(iter) = tree.samples(0) {
699699
for i in iter {
700700
s.push(i);
701701
}
702702
}
703703
assert_eq!(s.len(), 2);
704704
assert_eq!(
705705
s.len(),
706-
usize::try_from(tree.num_tracked_samples(0.into()).unwrap()).unwrap()
706+
usize::try_from(tree.num_tracked_samples(0).unwrap()).unwrap()
707707
);
708708
assert_eq!(s[0], 1);
709709
assert_eq!(s[1], 2);
710710

711711
for u in 1..3 {
712712
let mut s = vec![];
713-
if let Ok(iter) = tree.samples(u.into()) {
713+
if let Ok(iter) = tree.samples(u) {
714714
for i in iter {
715715
s.push(i);
716716
}
@@ -719,7 +719,7 @@ pub(crate) mod test_trees {
719719
assert_eq!(s[0], u);
720720
assert_eq!(
721721
s.len(),
722-
usize::try_from(tree.num_tracked_samples(u.into()).unwrap()).unwrap()
722+
usize::try_from(tree.num_tracked_samples(u).unwrap()).unwrap()
723723
);
724724
}
725725
} else {

0 commit comments

Comments
 (0)