Skip to content

Add strong ID type infrastructure + NodeID #129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions examples/forward_simulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ impl SimParams {

#[derive(Copy, Clone)]
struct Diploid {
node0: tskit::tsk_id_t,
node1: tskit::tsk_id_t,
node0: tskit::NodeId,
node1: tskit::NodeId,
}

struct Parents {
Expand Down Expand Up @@ -201,7 +201,7 @@ fn death_and_parents(
}
}

fn mendel(pnodes: &mut (tskit::tsk_id_t, tskit::tsk_id_t), rng: &mut StdRng) {
fn mendel(pnodes: &mut (tskit::NodeId, tskit::NodeId), rng: &mut StdRng) {
let x: f64 = rng.gen();
match x.partial_cmp(&0.5) {
Some(std::cmp::Ordering::Less) => {
Expand All @@ -214,7 +214,7 @@ fn mendel(pnodes: &mut (tskit::tsk_id_t, tskit::tsk_id_t), rng: &mut StdRng) {

fn crossover_and_record_edges_details(
parent: Diploid,
offspring_node: tskit::tsk_id_t,
offspring_node: tskit::NodeId,
params: &SimParams,
tables: &mut tskit::TableCollection,
rng: &mut StdRng,
Expand Down Expand Up @@ -269,7 +269,7 @@ fn crossover_and_record_edges_details(

fn crossover_and_record_edges(
parents: &Parents,
offspring_nodes: (tskit::tsk_id_t, tskit::tsk_id_t),
offspring_nodes: (tskit::NodeId, tskit::NodeId),
params: &SimParams,
tables: &mut tskit::TableCollection,
rng: &mut StdRng,
Expand Down Expand Up @@ -332,9 +332,9 @@ fn simplify(alive: &mut [Diploid], tables: &mut tskit::TableCollection) {
Ok(x) => match x {
Some(idmap) => {
for a in alive.iter_mut() {
a.node0 = idmap[a.node0 as usize];
a.node0 = idmap[usize::from(a.node0)];
assert!(a.node0 != tskit::TSK_NULL);
a.node1 = idmap[a.node1 as usize];
a.node1 = idmap[usize::from(a.node1)];
assert!(a.node1 != tskit::TSK_NULL);
}
}
Expand Down
59 changes: 59 additions & 0 deletions src/_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ macro_rules! unsafe_tsk_column_access {
Ok(unsafe { *$array.offset($i as isize) })
}
}};
($i: expr, $lo: expr, $hi: expr, $array: expr, $output_id_type: expr) => {{
if $i < $lo || ($i as $crate::tsk_size_t) >= $hi {
Err($crate::error::TskitError::IndexError {})
} else {
Ok($output_id_type(unsafe { *$array.offset($i as isize) }))
}
}};
}

macro_rules! unsafe_tsk_ragged_column_access {
Expand Down Expand Up @@ -220,6 +227,58 @@ macro_rules! tree_array_slice {
};
}

macro_rules! impl_id_traits {
($idtype: ty) => {
impl From<$crate::tsk_id_t> for $idtype {
fn from(value: $crate::tsk_id_t) -> Self {
Self(value)
}
}

impl $crate::IdIsNull for $idtype {
fn is_null(&self) -> bool {
self.0 == $crate::TSK_NULL
}
}

impl From<$idtype> for usize {
fn from(value: $idtype) -> Self {
value.0 as usize
}
}

impl From<$idtype> for $crate::tsk_id_t {
fn from(value: $idtype) -> Self {
value.0
}
}

impl PartialEq<$crate::tsk_id_t> for $idtype {
fn eq(&self, other: &$crate::tsk_id_t) -> bool {
self.0 == *other
}
}

impl PartialEq<$idtype> for $crate::tsk_id_t {
fn eq(&self, other: &$idtype) -> bool {
*self == other.0
}
}

impl PartialOrd<$crate::tsk_id_t> for $idtype {
fn partial_cmp(&self, other: &$crate::tsk_id_t) -> Option<std::cmp::Ordering> {
self.0.partial_cmp(other)
}
}

impl PartialOrd<$idtype> for $crate::tsk_id_t {
fn partial_cmp(&self, other: &$idtype) -> Option<std::cmp::Ordering> {
self.partial_cmp(&other.0)
}
}
};
}

/// Convenience macro to handle implementing
/// [`crate::metadata::MetadataRoundtrip`]
#[macro_export]
Expand Down
51 changes: 51 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,56 @@ pub use bindings::tsk_flags_t;
pub use bindings::tsk_id_t;
pub use bindings::tsk_size_t;

/// A node ID
///
/// This is an integer referring to a row of a [``NodeTable``].
/// The underlying type is [``tsk_id_t``].
///
/// # Examples
///
/// These examples illustrate using this type as something "integer-like".
///
/// ```
/// use tskit::NodeId;
/// use tskit::tsk_id_t;
///
/// let x: tsk_id_t = 1;
/// let y: NodeId = NodeId::from(x);
/// assert_eq!(x, y);
/// assert_eq!(y, x);
///
/// assert!(y < x + 1);
/// assert!(y <= x);
/// assert!(x + 1 > y);
/// assert!(x + 1 >= y);
///
/// let z: NodeId = NodeId::from(x);
/// assert_eq!(y, z);
/// ```
///
/// It is also possible to write functions accepting both the `NodeId`
/// and `tsk_id_t`:
///
/// ```
/// use tskit::NodeId;
/// use tskit::tsk_id_t;
///
/// fn interesting<N: Into<NodeId>>(x: N) -> NodeId {
/// x.into()
/// }
///
/// let x: tsk_id_t = 0;
/// assert_eq!(interesting(x), x);
/// let x: NodeId = NodeId::from(0);
/// assert_eq!(interesting(x), x);
/// ```
///
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)]
pub struct NodeId(tsk_id_t);

impl_id_traits!(NodeId);

// tskit defines this via a type cast
// in a macro. bindgen thus misses it.
// See bindgen issue 316.
Expand All @@ -110,6 +160,7 @@ pub use node_table::{NodeTable, NodeTableRow};
pub use population_table::{PopulationTable, PopulationTableRow};
pub use site_table::{SiteTable, SiteTableRow};
pub use table_collection::TableCollection;
pub use traits::IdIsNull;
pub use traits::NodeListGenerator;
pub use traits::TableAccess;
pub use traits::TskitTypeAccess;
Expand Down
7 changes: 4 additions & 3 deletions src/mutation_table.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::bindings as ll_bindings;
use crate::metadata;
use crate::NodeId;
use crate::{tsk_id_t, tsk_size_t, TskitError};

/// Row of a [`MutationTable`]
pub struct MutationTableRow {
pub id: tsk_id_t,
pub site: tsk_id_t,
pub node: tsk_id_t,
pub node: NodeId,
pub parent: tsk_id_t,
pub time: f64,
pub derived_state: Option<Vec<u8>>,
Expand Down Expand Up @@ -99,8 +100,8 @@ impl<'a> MutationTable<'a> {
///
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn node(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.node)
pub fn node(&'a self, row: tsk_id_t) -> Result<NodeId, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.node, NodeId)
}

/// Return the ``parent`` value from row ``row`` of the table.
Expand Down
35 changes: 18 additions & 17 deletions src/node_table.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use crate::bindings as ll_bindings;
use crate::metadata;
use crate::NodeId;
use crate::{tsk_flags_t, tsk_id_t, TskitError};

/// Row of a [`NodeTable`]
pub struct NodeTableRow {
pub id: tsk_id_t,
pub id: NodeId,
pub time: f64,
pub flags: tsk_flags_t,
pub population: tsk_id_t,
Expand All @@ -26,7 +27,7 @@ impl PartialEq for NodeTableRow {
fn make_node_table_row(table: &NodeTable, pos: tsk_id_t) -> Option<NodeTableRow> {
if pos < table.num_rows() as tsk_id_t {
Some(NodeTableRow {
id: pos,
id: pos.into(),
time: table.time(pos).unwrap(),
flags: table.flags(pos).unwrap(),
population: table.population(pos).unwrap(),
Expand Down Expand Up @@ -86,8 +87,8 @@ impl<'a> NodeTable<'a> {
///
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn time(&'a self, row: tsk_id_t) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.time)
pub fn time<N: Into<NodeId> + Copy>(&'a self, row: N) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.time)
}

/// Return the ``flags`` value from row ``row`` of the table.
Expand All @@ -96,8 +97,8 @@ impl<'a> NodeTable<'a> {
///
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn flags(&'a self, row: tsk_id_t) -> Result<tsk_flags_t, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.flags)
pub fn flags<N: Into<NodeId> + Copy>(&'a self, row: N) -> Result<tsk_flags_t, TskitError> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.flags)
}

/// Mutable access to node flags.
Expand All @@ -116,8 +117,8 @@ impl<'a> NodeTable<'a> {
///
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn population(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.population)
pub fn population<N: Into<NodeId> + Copy>(&'a self, row: N) -> Result<tsk_id_t, TskitError> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.population)
}

/// Return the ``population`` value from row ``row`` of the table.
Expand All @@ -126,7 +127,7 @@ impl<'a> NodeTable<'a> {
///
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn deme(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
pub fn deme<N: Into<NodeId> + Copy>(&'a self, row: N) -> Result<tsk_id_t, TskitError> {
self.population(row)
}

Expand All @@ -136,8 +137,8 @@ impl<'a> NodeTable<'a> {
///
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn individual(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.individual)
pub fn individual<N: Into<NodeId> + Copy>(&'a self, row: N) -> Result<tsk_id_t, TskitError> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.individual)
}

pub fn metadata<T: metadata::MetadataRoundtrip>(
Expand All @@ -163,15 +164,15 @@ impl<'a> NodeTable<'a> {
/// # Errors
///
/// [`TskitError::IndexError`] if `r` is out of range.
pub fn row(&self, r: tsk_id_t) -> Result<NodeTableRow, TskitError> {
table_row_access!(r, self, make_node_table_row)
pub fn row<N: Into<NodeId> + Copy>(&self, r: N) -> Result<NodeTableRow, TskitError> {
table_row_access!(r.into().0, self, make_node_table_row)
}

/// Obtain a vector containing the indexes ("ids")
/// of all nodes for which [`crate::TSK_NODE_IS_SAMPLE`]
/// is `true`.
pub fn samples_as_vector(&self) -> Vec<tsk_id_t> {
let mut samples: Vec<tsk_id_t> = vec![];
pub fn samples_as_vector(&self) -> Vec<NodeId> {
let mut samples: Vec<NodeId> = vec![];
for row in self.iter() {
if row.flags & crate::TSK_NODE_IS_SAMPLE > 0 {
samples.push(row.id);
Expand All @@ -185,8 +186,8 @@ impl<'a> NodeTable<'a> {
pub fn create_node_id_vector(
&self,
mut f: impl FnMut(&crate::NodeTableRow) -> bool,
) -> Vec<tsk_id_t> {
let mut samples: Vec<tsk_id_t> = vec![];
) -> Vec<NodeId> {
let mut samples: Vec<NodeId> = vec![];
for row in self.iter() {
if f(&row) {
samples.push(row.id);
Expand Down
Loading