Skip to content

Refactor TreeSequence use of unsafe #430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions src/sys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use bindings::tsk_mutation_table_t;
use bindings::tsk_node_table_t;
use bindings::tsk_population_table_t;
use bindings::tsk_site_table_t;
use std::ffi::CString;
use std::ptr::NonNull;

#[cfg(feature = "provenance")]
Expand Down Expand Up @@ -47,6 +48,94 @@ basic_lltableref_impl!(LLIndividualTableRef, tsk_individual_table_t);
#[cfg(feature = "provenance")]
basic_lltableref_impl!(LLProvenanceTableRef, tsk_provenance_table_t);

#[repr(transparent)]
pub struct LLTreeSeq(bindings::tsk_treeseq_t);

impl LLTreeSeq {
pub fn new(
tables: *mut bindings::tsk_table_collection_t,
flags: bindings::tsk_flags_t,
) -> Result<Self, TskitError> {
let mut inner = std::mem::MaybeUninit::<bindings::tsk_treeseq_t>::uninit();
let mut flags = flags;
flags |= bindings::TSK_TAKE_OWNERSHIP;
let rv = unsafe { bindings::tsk_treeseq_init(inner.as_mut_ptr(), tables, flags) };
handle_tsk_return_value!(rv, Self(unsafe { inner.assume_init() }))
}

pub fn as_ref(&self) -> &bindings::tsk_treeseq_t {
&self.0
}

pub fn as_ptr(&self) -> *const bindings::tsk_treeseq_t {
&self.0
}

pub fn as_mut_ptr(&mut self) -> *mut bindings::tsk_treeseq_t {
&mut self.0
}

pub fn simplify(
&self,
samples: &[bindings::tsk_id_t],
options: bindings::tsk_flags_t,
idmap: *mut bindings::tsk_id_t,
) -> Result<Self, TskitError> {
// The output is an UNINITIALIZED treeseq,
// else we leak memory.
let mut ts = std::mem::MaybeUninit::<bindings::tsk_treeseq_t>::uninit();
// SAFETY: samples is not null, idmap is allowed to be.
// self.as_ptr() is not null
let rv = unsafe {
bindings::tsk_treeseq_simplify(
self.as_ptr(),
samples.as_ptr(),
samples.len() as bindings::tsk_size_t,
options,
ts.as_mut_ptr(),
idmap,
)
};
let init = unsafe { ts.assume_init() };
if rv < 0 {
// SAFETY: the ptr is not null
// and tsk_treeseq_free uses safe methods
// to clean up.
unsafe { bindings::tsk_treeseq_free(ts.as_mut_ptr()) };
}
handle_tsk_return_value!(rv, Self(init))
}

pub fn dump(
&self,
filename: CString,
options: bindings::tsk_flags_t,
) -> Result<i32, TskitError> {
// SAFETY: self pointer is not null
let rv = unsafe { bindings::tsk_treeseq_dump(self.as_ptr(), filename.as_ptr(), options) };
handle_tsk_return_value!(rv)
}

pub fn num_trees(&self) -> bindings::tsk_size_t {
// SAFETY: self pointer is not null
unsafe { bindings::tsk_treeseq_get_num_trees(self.as_ptr()) }
}

pub fn kc_distance(&self, other: &Self, lambda: f64) -> Result<f64, TskitError> {
let mut kc: f64 = f64::NAN;
let kcp: *mut f64 = &mut kc;
// SAFETY: self pointer is not null
let code = unsafe {
bindings::tsk_treeseq_kc_distance(self.as_ptr(), other.as_ptr(), lambda, kcp)
};
handle_tsk_return_value!(code, kc)
}

pub fn num_samples(&self) -> bindings::tsk_size_t {
unsafe { bindings::tsk_treeseq_get_num_samples(self.as_ptr()) }
}
}

fn tsk_column_access_detail<R: Into<bindings::tsk_id_t>, L: Into<bindings::tsk_size_t>, T: Copy>(
row: R,
column: *const T,
Expand Down
97 changes: 34 additions & 63 deletions src/trees.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::mem::MaybeUninit;
use std::ops::Deref;
use std::ops::DerefMut;

Expand All @@ -13,7 +12,7 @@ use crate::TreeFlags;
use crate::TreeInterface;
use crate::TreeSequenceFlags;
use crate::TskReturnValue;
use crate::{tsk_id_t, tsk_size_t, TableCollection};
use crate::{tsk_id_t, TableCollection};
use ll_bindings::tsk_tree_free;
use std::ptr::NonNull;

Expand Down Expand Up @@ -185,7 +184,7 @@ impl streaming_iterator::DoubleEndedStreamingIterator for Tree {
/// assert_eq!(treeseq.nodes_mut().num_rows(), 3);
/// ```
pub struct TreeSequence {
pub(crate) inner: ll_bindings::tsk_treeseq_t,
pub(crate) inner: sys::LLTreeSeq,
views: crate::table_views::TableViews,
}

Expand All @@ -194,7 +193,7 @@ unsafe impl Sync for TreeSequence {}

impl Drop for TreeSequence {
fn drop(&mut self) {
let rv = unsafe { ll_bindings::tsk_treeseq_free(&mut self.inner) };
let rv = unsafe { ll_bindings::tsk_treeseq_free(self.as_mut_ptr()) };
assert_eq!(rv, 0);
}
}
Expand Down Expand Up @@ -247,31 +246,24 @@ impl TreeSequence {
tables: TableCollection,
flags: F,
) -> Result<Self, TskitError> {
let mut inner = std::mem::MaybeUninit::<ll_bindings::tsk_treeseq_t>::uninit();
let mut flags: u32 = flags.into().bits();
flags |= ll_bindings::TSK_TAKE_OWNERSHIP;
let raw_tables_ptr = tables.into_raw()?;
let rv =
unsafe { ll_bindings::tsk_treeseq_init(inner.as_mut_ptr(), raw_tables_ptr, flags) };
let mut inner = sys::LLTreeSeq::new(raw_tables_ptr, flags.into().bits())?;
let views = crate::table_views::TableViews::new_from_tree_sequence(inner.as_mut_ptr())?;
handle_tsk_return_value!(rv, {
let inner = unsafe { inner.assume_init() };
Self { inner, views }
})
Ok(Self { inner, views })
}

fn as_ref(&self) -> &ll_bindings::tsk_treeseq_t {
&self.inner
self.inner.as_ref()
}

/// Pointer to the low-level C type.
pub fn as_ptr(&self) -> *const ll_bindings::tsk_treeseq_t {
&self.inner
self.inner.as_ptr()
}

/// Mutable pointer to the low-level C type.
pub fn as_mut_ptr(&mut self) -> *mut ll_bindings::tsk_treeseq_t {
&mut self.inner
self.inner.as_mut_ptr()
}

/// Dump the tree sequence to file.
Expand All @@ -290,11 +282,7 @@ impl TreeSequence {
let c_str = std::ffi::CString::new(filename).map_err(|_| {
TskitError::LibraryError("call to ffi::Cstring::new failed".to_string())
})?;
let rv = unsafe {
ll_bindings::tsk_treeseq_dump(self.as_ptr(), c_str.as_ptr(), options.into().bits())
};

handle_tsk_return_value!(rv)
self.inner.dump(c_str, options.into().bits())
}

/// Load from a file.
Expand Down Expand Up @@ -401,7 +389,7 @@ impl TreeSequence {

/// Get the number of trees.
pub fn num_trees(&self) -> SizeType {
unsafe { ll_bindings::tsk_treeseq_get_num_trees(self.as_ptr()) }.into()
self.inner.num_trees().into()
}

/// Calculate the average Kendall-Colijn (`K-C`) distance between
Expand All @@ -416,17 +404,12 @@ impl TreeSequence {
/// * `lambda` specifies the relative weight of topology and branch length.
/// See [`TreeInterface::kc_distance`] for more details.
pub fn kc_distance(&self, other: &TreeSequence, lambda: f64) -> Result<f64, TskitError> {
let mut kc: f64 = f64::NAN;
let kcp: *mut f64 = &mut kc;
let code = unsafe {
ll_bindings::tsk_treeseq_kc_distance(self.as_ptr(), other.as_ptr(), lambda, kcp)
};
handle_tsk_return_value!(code, kc)
self.inner.kc_distance(&other.inner, lambda)
}

// FIXME: document
pub fn num_samples(&self) -> SizeType {
unsafe { ll_bindings::tsk_treeseq_get_num_samples(self.as_ptr()) }.into()
self.inner.num_samples().into()
}

/// Simplify tables and return a new tree sequence.
Expand All @@ -448,42 +431,29 @@ impl TreeSequence {
options: O,
idmap: bool,
) -> Result<(Self, Option<Vec<NodeId>>), TskitError> {
// The output is an UNINITIALIZED treeseq,
// else we leak memory.
let mut ts = MaybeUninit::<ll_bindings::tsk_treeseq_t>::uninit();
let mut output_node_map: Vec<NodeId> = vec![];
if idmap {
output_node_map.resize(usize::try_from(self.nodes().num_rows())?, NodeId::NULL);
}
let rv = unsafe {
ll_bindings::tsk_treeseq_simplify(
self.as_ptr(),
// NOTE: casting away const-ness:
samples.as_ptr().cast::<tsk_id_t>(),
samples.len() as tsk_size_t,
options.into().bits(),
ts.as_mut_ptr(),
match idmap {
true => output_node_map.as_mut_ptr().cast::<tsk_id_t>(),
false => std::ptr::null_mut(),
},
)
let llsamples = unsafe {
std::slice::from_raw_parts(samples.as_ptr().cast::<tsk_id_t>(), samples.len())
};
// TODO: is it possible that this can leak somehow?
handle_tsk_return_value!(
rv,
(
{
let mut inner = unsafe { ts.assume_init() };
let views = crate::table_views::TableViews::new_from_tree_sequence(&mut inner)?;
Self { inner, views }
},
match idmap {
true => Some(output_node_map),
false => None,
}
)
)
let mut inner = self.inner.simplify(
llsamples,
options.into().bits(),
match idmap {
true => output_node_map.as_mut_ptr().cast::<tsk_id_t>(),
false => std::ptr::null_mut(),
},
)?;
let views = crate::table_views::TableViews::new_from_tree_sequence(inner.as_mut_ptr())?;
Ok((
Self { inner, views },
match idmap {
true => Some(output_node_map),
false => None,
},
))
}

#[cfg(feature = "provenance")]
Expand Down Expand Up @@ -532,11 +502,11 @@ impl TreeSequence {
let timestamp = humantime::format_rfc3339(std::time::SystemTime::now()).to_string();
let rv = unsafe {
ll_bindings::tsk_provenance_table_add_row(
&mut (*self.inner.tables).provenances,
&mut (*self.inner.as_ref().tables).provenances,
timestamp.as_ptr() as *mut i8,
timestamp.len() as tsk_size_t,
timestamp.len() as ll_bindings::tsk_size_t,
record.as_ptr() as *mut i8,
record.len() as tsk_size_t,
record.len() as ll_bindings::tsk_size_t,
)
};
handle_tsk_return_value!(rv, crate::ProvenanceId::from(rv))
Expand Down Expand Up @@ -739,6 +709,7 @@ pub(crate) mod test_trees {

#[test]
fn test_iterate_samples_two_trees() {
use super::ll_bindings::tsk_size_t;
let treeseq = treeseq_from_small_table_collection_two_trees();
assert_eq!(treeseq.num_trees(), 2);
let mut tree_iter = treeseq.tree_iterator(TreeFlags::SAMPLE_LISTS).unwrap();
Expand Down