Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: improve Tree interface ergonomics #388

Merged
merged 1 commit into from
Nov 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 29 additions & 22 deletions src/tree_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,18 +295,18 @@ impl TreeInterface {

// error if we are not tracking samples,
// Ok(None) if u is out of range
fn left_sample(&self, u: NodeId) -> Option<NodeId> {
fn left_sample<N: Into<NodeId> + Copy>(&self, u: N) -> Option<NodeId> {
// SAFETY: internal pointer cannot be NULL
let ptr = unsafe { *self.as_ptr() };
unsafe_tsk_column_access!(u.0, 0, self.num_nodes, ptr, left_sample, NodeId)
unsafe_tsk_column_access!(u.into().0, 0, self.num_nodes, ptr, left_sample, NodeId)
}

// error if we are not tracking samples,
// Ok(None) if u is out of range
fn right_sample(&self, u: NodeId) -> Option<NodeId> {
fn right_sample<N: Into<NodeId> + Copy>(&self, u: N) -> Option<NodeId> {
// SAFETY: internal pointer cannot be NULL
let ptr = unsafe { *self.as_ptr() };
unsafe_tsk_column_access!(u.0, 0, self.num_nodes, ptr, right_sample, NodeId)
unsafe_tsk_column_access!(u.into().0, 0, self.num_nodes, ptr, right_sample, NodeId)
}

/// Return the `[left, right)` coordinates of the tree.
Expand All @@ -328,46 +328,46 @@ impl TreeInterface {
/// Get the parent of node `u`.
///
/// Returns `None` if `u` is out of range.
pub fn parent(&self, u: NodeId) -> Option<NodeId> {
pub fn parent<N: Into<NodeId> + Copy>(&self, u: N) -> Option<NodeId> {
// SAFETY: internal pointer cannot be NULL
let ptr = unsafe { *self.as_ptr() };
unsafe_tsk_column_access!(u.0, 0, self.array_len, ptr, parent, NodeId)
unsafe_tsk_column_access!(u.into().0, 0, self.array_len, ptr, parent, NodeId)
}

/// Get the left child of node `u`.
///
/// Returns `None` if `u` is out of range.
pub fn left_child(&self, u: NodeId) -> Option<NodeId> {
pub fn left_child<N: Into<NodeId> + Copy>(&self, u: N) -> Option<NodeId> {
// SAFETY: internal pointer cannot be NULL
let ptr = unsafe { *self.as_ptr() };
unsafe_tsk_column_access!(u.0, 0, self.array_len, ptr, left_child, NodeId)
unsafe_tsk_column_access!(u.into().0, 0, self.array_len, ptr, left_child, NodeId)
}

/// Get the right child of node `u`.
///
/// Returns `None` if `u` is out of range.
pub fn right_child(&self, u: NodeId) -> Option<NodeId> {
pub fn right_child<N: Into<NodeId> + Copy>(&self, u: N) -> Option<NodeId> {
// SAFETY: internal pointer cannot be NULL
let ptr = unsafe { *self.as_ptr() };
unsafe_tsk_column_access!(u.0, 0, self.array_len, ptr, right_child, NodeId)
unsafe_tsk_column_access!(u.into().0, 0, self.array_len, ptr, right_child, NodeId)
}

/// Get the left sib of node `u`.
///
/// Returns `None` if `u` is out of range.
pub fn left_sib(&self, u: NodeId) -> Option<NodeId> {
pub fn left_sib<N: Into<NodeId> + Copy>(&self, u: N) -> Option<NodeId> {
// SAFETY: internal pointer cannot be NULL
let ptr = unsafe { *self.as_ptr() };
unsafe_tsk_column_access!(u.0, 0, self.array_len, ptr, left_sib, NodeId)
unsafe_tsk_column_access!(u.into().0, 0, self.array_len, ptr, left_sib, NodeId)
}

/// Get the right sib of node `u`.
///
/// Returns `None` if `u` is out of range.
pub fn right_sib(&self, u: NodeId) -> Option<NodeId> {
pub fn right_sib<N: Into<NodeId> + Copy>(&self, u: N) -> Option<NodeId> {
// SAFETY: internal pointer cannot be NULL
let ptr = unsafe { *self.as_ptr() };
unsafe_tsk_column_access!(u.0, 0, self.array_len, ptr, right_sib, NodeId)
unsafe_tsk_column_access!(u.into().0, 0, self.array_len, ptr, right_sib, NodeId)
}

/// Obtain the list of samples for the current tree/tree sequence
Expand Down Expand Up @@ -406,17 +406,17 @@ impl TreeInterface {
///
/// * `Some(iterator)` if `u` is valid
/// * `None` otherwise
pub fn parents(&self, u: NodeId) -> impl Iterator<Item = NodeId> + '_ {
ParentsIterator::new(self, u)
pub fn parents<N: Into<NodeId> + Copy>(&self, u: N) -> impl Iterator<Item = NodeId> + '_ {
ParentsIterator::new(self, u.into())
}

/// Return an [`Iterator`] over the children of node `u`.
/// # Returns
///
/// * `Some(iterator)` if `u` is valid
/// * `None` otherwise
pub fn children(&self, u: NodeId) -> impl Iterator<Item = NodeId> + '_ {
ChildIterator::new(self, u)
pub fn children<N: Into<NodeId> + Copy>(&self, u: N) -> impl Iterator<Item = NodeId> + '_ {
ChildIterator::new(self, u.into())
}

/// Return an [`Iterator`] over the sample nodes descending from node `u`.
Expand All @@ -430,8 +430,11 @@ impl TreeInterface {
/// * Some(Ok(iterator)) if [`TreeFlags::SAMPLE_LISTS`] is in [`TreeInterface::flags`]
/// * Some(Err(_)) if [`TreeFlags::SAMPLE_LISTS`] is not in [`TreeInterface::flags`]
/// * None if `u` is not valid.
pub fn samples(&self, u: NodeId) -> Result<impl Iterator<Item = NodeId> + '_, TskitError> {
SamplesIterator::new(self, u)
pub fn samples<N: Into<NodeId> + Copy>(
&self,
u: N,
) -> Result<impl Iterator<Item = NodeId> + '_, TskitError> {
SamplesIterator::new(self, u.into())
}

/// Return an [`Iterator`] over the roots of the tree.
Expand Down Expand Up @@ -514,10 +517,14 @@ impl TreeInterface {
/// # Errors
///
/// * [`TskitError`] if [`TreeFlags::NO_SAMPLE_COUNTS`].
pub fn num_tracked_samples(&self, u: NodeId) -> Result<SizeType, TskitError> {
pub fn num_tracked_samples<N: Into<NodeId> + Copy>(
&self,
u: N,
) -> Result<SizeType, TskitError> {
let mut n = SizeType(tsk_size_t::MAX);
let np: *mut tsk_size_t = &mut n.0;
let code = unsafe { ll_bindings::tsk_tree_get_num_tracked_samples(self.as_ptr(), u.0, np) };
let code =
unsafe { ll_bindings::tsk_tree_get_num_tracked_samples(self.as_ptr(), u.into().0, np) };
handle_tsk_return_value!(code, n)
}

Expand Down
26 changes: 13 additions & 13 deletions src/trees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -607,14 +607,14 @@ pub(crate) mod test_trees {
// These nodes are all out of range
for i in 100..110 {
let mut nsteps = 0;
for _ in tree.parents(i.into()) {
for _ in tree.parents(i) {
nsteps += 1;
}
assert_eq!(nsteps, 0);
}

assert_eq!(tree.parents((-1_i32).into()).count(), 0);
assert_eq!(tree.children((-1_i32).into()).count(), 0);
assert_eq!(tree.parents(-1_i32).count(), 0);
assert_eq!(tree.children(-1_i32).count(), 0);

let roots = tree.roots_to_vec();
for r in roots.iter() {
Expand Down Expand Up @@ -665,9 +665,9 @@ pub(crate) mod test_trees {
assert_eq!(treeseq.num_samples(), 2);
let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
if let Some(tree) = tree_iter.next() {
assert_eq!(tree.num_tracked_samples(2.into()).unwrap(), 1);
assert_eq!(tree.num_tracked_samples(1.into()).unwrap(), 1);
assert_eq!(tree.num_tracked_samples(0.into()).unwrap(), 2);
assert_eq!(tree.num_tracked_samples(2).unwrap(), 1);
assert_eq!(tree.num_tracked_samples(1).unwrap(), 1);
assert_eq!(tree.num_tracked_samples(0).unwrap(), 2);
}
}

Expand All @@ -678,9 +678,9 @@ pub(crate) mod test_trees {
assert_eq!(treeseq.num_samples(), 2);
let mut tree_iter = treeseq.tree_iterator(TreeFlags::NO_SAMPLE_COUNTS).unwrap();
if let Some(tree) = tree_iter.next() {
assert_eq!(tree.num_tracked_samples(2.into()).unwrap(), 0);
assert_eq!(tree.num_tracked_samples(1.into()).unwrap(), 0);
assert_eq!(tree.num_tracked_samples(0.into()).unwrap(), 0);
assert_eq!(tree.num_tracked_samples(2).unwrap(), 0);
assert_eq!(tree.num_tracked_samples(1).unwrap(), 0);
assert_eq!(tree.num_tracked_samples(0).unwrap(), 0);
}
}

Expand All @@ -695,22 +695,22 @@ pub(crate) mod test_trees {
assert!(tree.flags().contains(TreeFlags::SAMPLE_LISTS));
let mut s = vec![];

if let Ok(iter) = tree.samples(0.into()) {
if let Ok(iter) = tree.samples(0) {
for i in iter {
s.push(i);
}
}
assert_eq!(s.len(), 2);
assert_eq!(
s.len(),
usize::try_from(tree.num_tracked_samples(0.into()).unwrap()).unwrap()
usize::try_from(tree.num_tracked_samples(0).unwrap()).unwrap()
);
assert_eq!(s[0], 1);
assert_eq!(s[1], 2);

for u in 1..3 {
let mut s = vec![];
if let Ok(iter) = tree.samples(u.into()) {
if let Ok(iter) = tree.samples(u) {
for i in iter {
s.push(i);
}
Expand All @@ -719,7 +719,7 @@ pub(crate) mod test_trees {
assert_eq!(s[0], u);
assert_eq!(
s.len(),
usize::try_from(tree.num_tracked_samples(u.into()).unwrap()).unwrap()
usize::try_from(tree.num_tracked_samples(u).unwrap()).unwrap()
);
}
} else {
Expand Down