Skip to content

Refactor Tree::new #367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 2, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions src/trees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ use std::ptr::NonNull;
///
/// Wrapper around `tsk_tree_t`.
pub struct Tree {
pub(crate) inner: ll_bindings::tsk_tree_t,
pub(crate) inner: mbox::MBox<ll_bindings::tsk_tree_t>,
api: TreeInterface,
current_tree: i32,
advanced: bool,
}

impl Drop for Tree {
fn drop(&mut self) {
let rv = unsafe { tsk_tree_free(&mut self.inner) };
// SAFETY: Mbox<_> cannot hold a NULL ptr
let rv = unsafe { tsk_tree_free(self.inner.as_mut()) };
assert_eq!(rv, 0);
}
}
Expand All @@ -58,29 +59,38 @@ impl DerefMut for Tree {
impl Tree {
fn new<F: Into<TreeFlags>>(ts: &TreeSequence, flags: F) -> Result<Self, TskitError> {
let flags = flags.into();
let mut tree = MaybeUninit::<ll_bindings::tsk_tree_t>::uninit();

// SAFETY: this is the type we want :)
let temp = unsafe {
libc::malloc(std::mem::size_of::<ll_bindings::tsk_tree_t>())
as *mut ll_bindings::tsk_tree_t
};

// Get our pointer into MBox ASAP
let nonnull = NonNull::<ll_bindings::tsk_tree_t>::new(temp)
.ok_or_else(|| TskitError::LibraryError("failed to malloc tsk_tree_t".to_string()))?;

// SAFETY: if temp is NULL, we have returned Err already.
let mut tree = unsafe { mbox::MBox::from_non_null_raw(nonnull) };
let mut rv =
unsafe { ll_bindings::tsk_tree_init(tree.as_mut_ptr(), ts.as_ptr(), flags.bits()) };
unsafe { ll_bindings::tsk_tree_init(tree.as_mut(), ts.as_ptr(), flags.bits()) };
if rv < 0 {
return Err(TskitError::ErrorCode { code: rv });
}
// Gotta ask Jerome about this one--why isn't this handled in tsk_tree_init??
if !flags.contains(TreeFlags::NO_SAMPLE_COUNTS) {
// SAFETY: nobody is null here.
rv = unsafe {
ll_bindings::tsk_tree_set_tracked_samples(
tree.as_mut_ptr(),
tree.as_mut(),
ts.num_samples().into(),
(*tree.as_ptr()).samples,
(*tree.as_mut()).samples,
)
};
}

let mut tree = unsafe { tree.assume_init() };
let ptr = &mut tree as *mut ll_bindings::tsk_tree_t;

let num_nodes = unsafe { (*(*ts.as_ptr()).tables).nodes.num_rows };
let non_owned_pointer = unsafe { NonNull::new_unchecked(ptr) };
let api = TreeInterface::new(non_owned_pointer, num_nodes, num_nodes + 1, flags);
let api = TreeInterface::new(nonnull, num_nodes, num_nodes + 1, flags);
handle_tsk_return_value!(
rv,
Tree {
Expand Down