Skip to content

add migration provenance ids #137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,21 @@ pub struct SiteId(tsk_id_t);
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)]
pub struct MutationId(tsk_id_t);

/// A migration ID
///
/// This is an integer referring to a row of an [``MigrationTable``].
///
/// The features for this type follow the same pattern as for [``NodeId``]
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)]
pub struct MigrationId(tsk_id_t);

impl_id_traits!(NodeId);
impl_id_traits!(IndividualId);
impl_id_traits!(PopulationId);
impl_id_traits!(SiteId);
impl_id_traits!(MutationId);
impl_id_traits!(MigrationId);

// tskit defines this via a type cast
// in a macro. bindgen thus misses it.
Expand Down Expand Up @@ -210,6 +220,19 @@ pub use trees::{NodeTraversalOrder, Tree, TreeSequence};
#[cfg(any(doc, feature = "provenance"))]
pub mod provenance;

/// A provenance ID
///
/// This is an integer referring to a row of an [``ProvenanceTable``].
///
/// The features for this type follow the same pattern as for [``NodeId``]
#[cfg(any(doc, feature = "provenance"))]
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)]
pub struct ProvenanceId(tsk_id_t);

#[cfg(any(doc, feature = "provenance"))]
impl_id_traits!(ProvenanceId);

/// Handles return codes from low-level tskit functions.
///
/// When an error from the tskit C API is detected,
Expand Down
58 changes: 37 additions & 21 deletions src/migration_table.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use crate::bindings as ll_bindings;
use crate::metadata;
use crate::{tsk_id_t, TskitError};
use crate::{MigrationId, NodeId, PopulationId};

/// Row of a [`MigrationTable`]
pub struct MigrationTableRow {
pub id: tsk_id_t,
pub id: MigrationId,
pub left: f64,
pub right: f64,
pub node: tsk_id_t,
pub source: tsk_id_t,
pub dest: tsk_id_t,
pub node: NodeId,
pub source: PopulationId,
pub dest: PopulationId,
pub time: f64,
pub metadata: Option<Vec<u8>>,
}
Expand All @@ -30,7 +31,7 @@ impl PartialEq for MigrationTableRow {
fn make_migration_table_row(table: &MigrationTable, pos: tsk_id_t) -> Option<MigrationTableRow> {
if pos < table.num_rows() as tsk_id_t {
Some(MigrationTableRow {
id: pos,
id: pos.into(),
left: table.left(pos).unwrap(),
right: table.right(pos).unwrap(),
node: table.node(pos).unwrap(),
Expand Down Expand Up @@ -92,53 +93,68 @@ impl<'a> MigrationTable<'a> {
/// # Errors
///
/// * [`TskitError::IndexError`] if `row` is out of range.
pub fn left(&'a self, row: tsk_id_t) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.left)
pub fn left<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.left)
}

/// Return the right coordinate for a given row.
///
/// # Errors
///
/// * [`TskitError::IndexError`] if `row` is out of range.
pub fn right(&'a self, row: tsk_id_t) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.right)
pub fn right<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.right)
}

/// Return the node for a given row.
///
/// # Errors
///
/// * [`TskitError::IndexError`] if `row` is out of range.
pub fn node(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.source)
pub fn node<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Result<NodeId, TskitError> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.source, NodeId)
}

/// Return the source population for a given row.
///
/// # Errors
///
/// * [`TskitError::IndexError`] if `row` is out of range.
pub fn source(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.node)
pub fn source<M: Into<MigrationId> + Copy>(
&'a self,
row: M,
) -> Result<PopulationId, TskitError> {
unsafe_tsk_column_access!(
row.into().0,
0,
self.num_rows(),
self.table_.node,
PopulationId
)
}

/// Return the destination population for a given row.
///
/// # Errors
///
/// * [`TskitError::IndexError`] if `row` is out of range.
pub fn dest(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.dest)
pub fn dest<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Result<PopulationId, TskitError> {
unsafe_tsk_column_access!(
row.into().0,
0,
self.num_rows(),
self.table_.dest,
PopulationId
)
}

/// Return the time of the migration event for a given row.
///
/// # Errors
///
/// * [`TskitError::IndexError`] if `row` is out of range.
pub fn time(&'a self, row: tsk_id_t) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.time)
pub fn time<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.time)
}

/// Return the metadata for a given row.
Expand All @@ -148,9 +164,9 @@ impl<'a> MigrationTable<'a> {
/// * [`TskitError::IndexError`] if `row` is out of range.
pub fn metadata<T: metadata::MetadataRoundtrip>(
&'a self,
row: tsk_id_t,
row: MigrationId,
) -> Result<Option<T>, TskitError> {
let buffer = metadata_to_vector!(self, row)?;
let buffer = metadata_to_vector!(self, row.0)?;
decode_metadata_row!(T, buffer)
}

Expand All @@ -169,7 +185,7 @@ impl<'a> MigrationTable<'a> {
/// # Errors
///
/// [`TskitError::IndexError`] if `r` is out of range.
pub fn row(&self, r: tsk_id_t) -> Result<MigrationTableRow, TskitError> {
table_row_access!(r, self, make_migration_table_row)
pub fn row<M: Into<MigrationId> + Copy>(&self, r: M) -> Result<MigrationTableRow, TskitError> {
table_row_access!(r.into().0, self, make_migration_table_row)
}
}
31 changes: 20 additions & 11 deletions src/provenance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
//! See [`Provenance`] for examples.

use crate::bindings as ll_bindings;
use crate::{tsk_id_t, tsk_size_t, TskitError};
use crate::{tsk_id_t, tsk_size_t, ProvenanceId, TskitError};

/// Enable provenance table access.
///
Expand Down Expand Up @@ -100,7 +100,7 @@ pub trait Provenance: crate::TableAccess {
/// # Parameters
///
/// * `record`: the provenance record
fn add_provenance(&mut self, record: &str) -> crate::TskReturnValue;
fn add_provenance(&mut self, record: &str) -> Result<ProvenanceId, TskitError>;
/// Return an immutable reference to the table, type [`ProvenanceTable`]
fn provenances(&self) -> ProvenanceTable;
/// Return an iterator over the rows of the [`ProvenanceTable`].
Expand All @@ -114,7 +114,7 @@ pub trait Provenance: crate::TableAccess {
/// Row of a [`ProvenanceTable`].
pub struct ProvenanceTableRow {
/// The row id
pub id: tsk_id_t,
pub id: ProvenanceId,
/// ISO-formatted time stamp
pub timestamp: String,
/// The provenance record
Expand All @@ -127,6 +127,12 @@ impl PartialEq for ProvenanceTableRow {
}
}

impl std::fmt::Display for ProvenanceId {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "ProvenanceId({})", self.0)
}
}

impl std::fmt::Display for ProvenanceTableRow {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
Expand All @@ -140,7 +146,7 @@ impl std::fmt::Display for ProvenanceTableRow {
fn make_provenance_table_row(table: &ProvenanceTable, pos: tsk_id_t) -> Option<ProvenanceTableRow> {
if pos < table.num_rows() as tsk_id_t {
Some(ProvenanceTableRow {
id: pos,
id: pos.into(),
timestamp: table.timestamp(pos).unwrap(),
record: table.record(pos).unwrap(),
})
Expand Down Expand Up @@ -203,9 +209,9 @@ impl<'a> ProvenanceTable<'a> {
/// # Errors
///
/// [`TskitError::IndexError`] if `r` is out of range.
pub fn timestamp(&'a self, row: tsk_id_t) -> Result<String, TskitError> {
pub fn timestamp<P: Into<ProvenanceId> + Copy>(&'a self, row: P) -> Result<String, TskitError> {
match unsafe_tsk_ragged_char_column_access!(
row,
row.into().0,
0,
self.num_rows(),
self.table_.timestamp,
Expand All @@ -226,9 +232,9 @@ impl<'a> ProvenanceTable<'a> {
/// # Errors
///
/// [`TskitError::IndexError`] if `r` is out of range.
pub fn record(&'a self, row: tsk_id_t) -> Result<String, TskitError> {
pub fn record<P: Into<ProvenanceId> + Copy>(&'a self, row: P) -> Result<String, TskitError> {
match unsafe_tsk_ragged_char_column_access!(
row,
row.into().0,
0,
self.num_rows(),
self.table_.record,
Expand All @@ -249,11 +255,14 @@ impl<'a> ProvenanceTable<'a> {
/// # Errors
///
/// [`TskitError::IndexError`] if `r` is out of range.
pub fn row(&'a self, row: tsk_id_t) -> Result<ProvenanceTableRow, TskitError> {
if row < 0 {
pub fn row<P: Into<ProvenanceId> + Copy>(
&'a self,
row: P,
) -> Result<ProvenanceTableRow, TskitError> {
if row.into() < 0 {
Err(TskitError::IndexError)
} else {
match make_provenance_table_row(self, row) {
match make_provenance_table_row(self, row.into().0) {
Some(x) => Ok(x),
None => Err(TskitError::IndexError),
}
Expand Down
4 changes: 2 additions & 2 deletions src/table_collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ impl crate::traits::NodeListGenerator for TableCollection {}

#[cfg(any(doc, feature = "provenance"))]
impl crate::provenance::Provenance for TableCollection {
fn add_provenance(&mut self, record: &str) -> TskReturnValue {
fn add_provenance(&mut self, record: &str) -> Result<crate::ProvenanceId, TskitError> {
if record.is_empty() {
return Err(TskitError::ValueError {
got: String::from("empty string slice"),
Expand All @@ -650,7 +650,7 @@ impl crate::provenance::Provenance for TableCollection {
record.len() as tsk_size_t,
)
};
handle_tsk_return_value!(rv)
handle_tsk_return_value!(rv, crate::ProvenanceId::from(rv))
}

fn provenances(&self) -> crate::provenance::ProvenanceTable {
Expand Down
4 changes: 2 additions & 2 deletions src/trees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ impl crate::traits::NodeListGenerator for TreeSequence {}

#[cfg(any(doc, feature = "provenance"))]
impl crate::provenance::Provenance for TreeSequence {
fn add_provenance(&mut self, record: &str) -> TskReturnValue {
fn add_provenance(&mut self, record: &str) -> Result<crate::ProvenanceId, TskitError> {
if record.is_empty() {
return Err(TskitError::ValueError {
got: String::from("empty string slice"),
Expand All @@ -1152,7 +1152,7 @@ impl crate::provenance::Provenance for TreeSequence {
record.len() as tsk_size_t,
)
};
handle_tsk_return_value!(rv)
handle_tsk_return_value!(rv, crate::ProvenanceId::from(rv))
}

fn provenances(&self) -> crate::provenance::ProvenanceTable {
Expand Down