Skip to content

Commit

Permalink
Add SiteId and MutationId (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
molpopgen authored Jul 21, 2021
1 parent de006a1 commit 7713ef6
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 50 deletions.
6 changes: 5 additions & 1 deletion examples/mutation_metadata_bincode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ pub fn run() {
// 1. The first is to handle errors.
// 2. The second is b/c metadata are optional,
// so a row may return None.
let decoded = tables.mutations().metadata::<Mutation>(0).unwrap().unwrap();
let decoded = tables
.mutations()
.metadata::<Mutation>(0.into())
.unwrap()
.unwrap();

// Check that we've made the round trip:
assert_eq!(decoded.origin_time, 1);
Expand Down
6 changes: 5 additions & 1 deletion examples/mutation_metadata_std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ pub fn run() {
// 1. The first is to handle errors.
// 2. The second is b/c metadata are optional,
// so a row may return None.
let decoded = tables.mutations().metadata::<Mutation>(0).unwrap().unwrap();
let decoded = tables
.mutations()
.metadata::<Mutation>(0.into())
.unwrap()
.unwrap();

// Check that we've made the round trip:
assert_eq!(decoded.origin_time, 1);
Expand Down
20 changes: 20 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,29 @@ pub struct IndividualId(tsk_id_t);
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)]
pub struct PopulationId(tsk_id_t);

/// A site ID
///
/// This is an integer referring to a row of an [``SiteTable``].
///
/// The features for this type follow the same pattern as for [``NodeId``]
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)]
pub struct SiteId(tsk_id_t);

/// A mutation ID
///
/// This is an integer referring to a row of an [``MutationTable``].
///
/// The features for this type follow the same pattern as for [``NodeId``]
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)]
pub struct MutationId(tsk_id_t);

impl_id_traits!(NodeId);
impl_id_traits!(IndividualId);
impl_id_traits!(PopulationId);
impl_id_traits!(SiteId);
impl_id_traits!(MutationId);

// tskit defines this via a type cast
// in a macro. bindgen thus misses it.
Expand Down
7 changes: 6 additions & 1 deletion src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ use thiserror::Error;
/// // The two unwraps are:
/// // 1. Handle Errors vs Option.
/// // 2. Handle the option for the case of no error.
/// let decoded = tables.mutations().metadata::<MyMutation>(0).unwrap().unwrap();
/// //
/// // The .into() reflects the fact that metadata fetching
/// // functions only take a strong ID type, and tskit-rust
/// // adds Into<strong ID type> for i32 for all strong ID types.
///
/// let decoded = tables.mutations().metadata::<MyMutation>(0.into()).unwrap().unwrap();
/// assert_eq!(mutation.origin_time, decoded.origin_time);
/// match decoded.effect_size.partial_cmp(&mutation.effect_size) {
/// Some(std::cmp::Ordering::Greater) => assert!(false),
Expand Down
47 changes: 28 additions & 19 deletions src/mutation_table.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use crate::bindings as ll_bindings;
use crate::metadata;
use crate::NodeId;
use crate::{tsk_id_t, tsk_size_t, TskitError};
use crate::{MutationId, NodeId, SiteId};

/// Row of a [`MutationTable`]
pub struct MutationTableRow {
pub id: tsk_id_t,
pub site: tsk_id_t,
pub id: MutationId,
pub site: SiteId,
pub node: NodeId,
pub parent: tsk_id_t,
pub parent: MutationId,
pub time: f64,
pub derived_state: Option<Vec<u8>>,
pub metadata: Option<Vec<u8>>,
Expand All @@ -29,7 +29,7 @@ impl PartialEq for MutationTableRow {
fn make_mutation_table_row(table: &MutationTable, pos: tsk_id_t) -> Option<MutationTableRow> {
if pos < table.num_rows() as tsk_id_t {
let rv = MutationTableRow {
id: pos,
id: pos.into(),
site: table.site(pos).unwrap(),
node: table.node(pos).unwrap(),
parent: table.parent(pos).unwrap(),
Expand Down Expand Up @@ -90,8 +90,8 @@ impl<'a> MutationTable<'a> {
///
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn site(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.site)
pub fn site<M: Into<MutationId> + Copy>(&'a self, row: M) -> Result<SiteId, TskitError> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.site, SiteId)
}

/// Return the ``node`` value from row ``row`` of the table.
Expand All @@ -100,8 +100,8 @@ impl<'a> MutationTable<'a> {
///
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn node(&'a self, row: tsk_id_t) -> Result<NodeId, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.node, NodeId)
pub fn node<M: Into<MutationId> + Copy>(&'a self, row: M) -> Result<NodeId, TskitError> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.node, NodeId)
}

/// Return the ``parent`` value from row ``row`` of the table.
Expand All @@ -110,8 +110,14 @@ impl<'a> MutationTable<'a> {
///
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn parent(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.parent)
pub fn parent<M: Into<MutationId> + Copy>(&'a self, row: M) -> Result<MutationId, TskitError> {
unsafe_tsk_column_access!(
row.into().0,
0,
self.num_rows(),
self.table_.parent,
MutationId
)
}

/// Return the ``time`` value from row ``row`` of the table.
Expand All @@ -120,8 +126,8 @@ impl<'a> MutationTable<'a> {
///
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn time(&'a self, row: tsk_id_t) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.time)
pub fn time<M: Into<MutationId> + Copy>(&'a self, row: M) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.time)
}

/// Get the ``derived_state`` value from row ``row`` of the table.
Expand All @@ -134,21 +140,24 @@ impl<'a> MutationTable<'a> {
///
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn derived_state(&'a self, row: tsk_id_t) -> Result<Option<Vec<u8>>, TskitError> {
pub fn derived_state<M: Into<MutationId>>(
&'a self,
row: M,
) -> Result<Option<Vec<u8>>, TskitError> {
metadata::char_column_to_vector(
self.table_.derived_state,
self.table_.derived_state_offset,
row,
row.into().0,
self.table_.num_rows,
self.table_.derived_state_length,
)
}

pub fn metadata<T: metadata::MetadataRoundtrip>(
&'a self,
row: tsk_id_t,
row: MutationId,
) -> Result<Option<T>, TskitError> {
let buffer = metadata_to_vector!(self, row)?;
let buffer = metadata_to_vector!(self, row.0)?;
decode_metadata_row!(T, buffer)
}

Expand All @@ -167,7 +176,7 @@ impl<'a> MutationTable<'a> {
/// # Errors
///
/// [`TskitError::IndexError`] if `r` is out of range.
pub fn row(&self, r: tsk_id_t) -> Result<MutationTableRow, TskitError> {
table_row_access!(r, self, make_mutation_table_row)
pub fn row<M: Into<MutationId> + Copy>(&self, r: M) -> Result<MutationTableRow, TskitError> {
table_row_access!(r.into().0, self, make_mutation_table_row)
}
}
24 changes: 14 additions & 10 deletions src/site_table.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use crate::bindings as ll_bindings;
use crate::metadata;
use crate::SiteId;
use crate::TskitError;
use crate::{tsk_id_t, tsk_size_t};

/// Row of a [`SiteTable`]
pub struct SiteTableRow {
pub id: tsk_id_t,
pub id: SiteId,
pub position: f64,
pub ancestral_state: Option<Vec<u8>>,
pub metadata: Option<Vec<u8>>,
Expand All @@ -23,7 +24,7 @@ impl PartialEq for SiteTableRow {
fn make_site_table_row(table: &SiteTable, pos: tsk_id_t) -> Option<SiteTableRow> {
if pos < table.num_rows() as tsk_id_t {
let rv = SiteTableRow {
id: pos,
id: pos.into(),
position: table.position(pos).unwrap(),
ancestral_state: table.ancestral_state(pos).unwrap(),
metadata: table_row_decode_metadata!(table, pos),
Expand Down Expand Up @@ -82,8 +83,8 @@ impl<'a> SiteTable<'a> {
///
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn position(&'a self, row: tsk_id_t) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.position)
pub fn position<S: Into<SiteId> + Copy>(&'a self, row: S) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.position)
}

/// Get the ``ancestral_state`` value from row ``row`` of the table.
Expand All @@ -96,21 +97,24 @@ impl<'a> SiteTable<'a> {
///
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn ancestral_state(&'a self, row: tsk_id_t) -> Result<Option<Vec<u8>>, TskitError> {
pub fn ancestral_state<S: Into<SiteId>>(
&'a self,
row: S,
) -> Result<Option<Vec<u8>>, TskitError> {
crate::metadata::char_column_to_vector(
self.table_.ancestral_state,
self.table_.ancestral_state_offset,
row,
row.into().0,
self.table_.num_rows,
self.table_.ancestral_state_length,
)
}

pub fn metadata<T: metadata::MetadataRoundtrip>(
&'a self,
row: tsk_id_t,
row: SiteId,
) -> Result<Option<T>, TskitError> {
let buffer = metadata_to_vector!(self, row)?;
let buffer = metadata_to_vector!(self, row.0)?;
decode_metadata_row!(T, buffer)
}

Expand All @@ -129,7 +133,7 @@ impl<'a> SiteTable<'a> {
/// # Errors
///
/// [`TskitError::IndexError`] if `r` is out of range.
pub fn row(&self, r: tsk_id_t) -> Result<SiteTableRow, TskitError> {
table_row_access!(r, self, make_site_table_row)
pub fn row<S: Into<SiteId> + Copy>(&self, r: S) -> Result<SiteTableRow, TskitError> {
table_row_access!(r.into().0, self, make_site_table_row)
}
}
44 changes: 26 additions & 18 deletions src/table_collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::TreeSequenceFlags;
use crate::TskReturnValue;
use crate::TskitTypeAccess;
use crate::{tsk_flags_t, tsk_id_t, tsk_size_t, TSK_NULL};
use crate::{IndividualId, NodeId, PopulationId};
use crate::{IndividualId, MutationId, NodeId, PopulationId, SiteId};
use ll_bindings::tsk_table_collection_free;

/// A table collection.
Expand Down Expand Up @@ -323,7 +323,11 @@ impl TableCollection {
}

/// Add a row to the site table
pub fn add_site(&mut self, position: f64, ancestral_state: Option<&[u8]>) -> TskReturnValue {
pub fn add_site(
&mut self,
position: f64,
ancestral_state: Option<&[u8]>,
) -> Result<SiteId, TskitError> {
self.add_site_with_metadata(position, ancestral_state, None)
}

Expand All @@ -333,7 +337,7 @@ impl TableCollection {
position: f64,
ancestral_state: Option<&[u8]>,
metadata: Option<&dyn MetadataRoundtrip>,
) -> TskReturnValue {
) -> Result<SiteId, TskitError> {
let astate = process_state_input!(ancestral_state);
let md = EncodedMetadata::new(metadata)?;

Expand All @@ -348,40 +352,40 @@ impl TableCollection {
)
};

handle_tsk_return_value!(rv)
handle_tsk_return_value!(rv, SiteId::from(rv))
}

/// Add a row to the mutation table.
pub fn add_mutation<N: Into<NodeId>>(
pub fn add_mutation<N: Into<NodeId>, M: Into<MutationId>, S: Into<SiteId>>(
&mut self,
site: tsk_id_t,
site: S,
node: N,
parent: tsk_id_t,
parent: M,
time: f64,
derived_state: Option<&[u8]>,
) -> TskReturnValue {
) -> Result<MutationId, TskitError> {
self.add_mutation_with_metadata(site, node, parent, time, derived_state, None)
}

/// Add a row with metadata to the mutation table.
pub fn add_mutation_with_metadata<N: Into<NodeId>>(
pub fn add_mutation_with_metadata<N: Into<NodeId>, M: Into<MutationId>, S: Into<SiteId>>(
&mut self,
site: tsk_id_t,
site: S,
node: N,
parent: tsk_id_t,
parent: M,
time: f64,
derived_state: Option<&[u8]>,
metadata: Option<&dyn MetadataRoundtrip>,
) -> TskReturnValue {
) -> Result<MutationId, TskitError> {
let dstate = process_state_input!(derived_state);
let md = EncodedMetadata::new(metadata)?;

let rv = unsafe {
ll_bindings::tsk_mutation_table_add_row(
&mut (*self.as_mut_ptr()).mutations,
site,
site.into().0,
node.into().0,
parent,
parent.into().0,
time,
dstate.0,
dstate.1,
Expand All @@ -390,7 +394,7 @@ impl TableCollection {
)
};

handle_tsk_return_value!(rv)
handle_tsk_return_value!(rv, MutationId::from(rv))
}

/// Add a row to the population_table
Expand Down Expand Up @@ -983,7 +987,7 @@ mod test {
.unwrap();
// The double unwrap is to first check for error
// and then to process the Option.
let md = tables.mutations().metadata::<F>(0).unwrap().unwrap();
let md = tables.mutations().metadata::<F>(0.into()).unwrap().unwrap();
assert_eq!(md.x, -3);
assert_eq!(md.y, 666);

Expand Down Expand Up @@ -1016,7 +1020,11 @@ mod test {
let mut num_with_metadata = 0;
let mut num_without_metadata = 0;
for i in 0..tables.mutations().num_rows() {
match tables.mutations().metadata::<F>(i as tsk_id_t).unwrap() {
match tables
.mutations()
.metadata::<F>((i as tsk_id_t).into())
.unwrap()
{
Some(x) => {
num_with_metadata += 1;
assert_eq!(x.x, -3);
Expand Down Expand Up @@ -1166,7 +1174,7 @@ mod test_bad_metadata {
tables
.add_mutation_with_metadata(0, 0, crate::TSK_NULL, 0.0, None, Some(&md))
.unwrap();
if tables.mutations().metadata::<Ff>(0).is_ok() {
if tables.mutations().metadata::<Ff>(0.into()).is_ok() {
panic!("expected an error!!");
}
}
Expand Down

0 comments on commit 7713ef6

Please sign in to comment.