Skip to content

Commit 9baff07

Browse files
authored
Add MigrationId and ProvenanceId (#137)
1 parent 7713ef6 commit 9baff07

File tree

5 files changed

+84
-36
lines changed

5 files changed

+84
-36
lines changed

src/lib.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,21 @@ pub struct SiteId(tsk_id_t);
178178
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)]
179179
pub struct MutationId(tsk_id_t);
180180

181+
/// A migration ID
182+
///
183+
/// This is an integer referring to a row of an [``MigrationTable``].
184+
///
185+
/// The features for this type follow the same pattern as for [``NodeId``]
186+
#[repr(transparent)]
187+
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)]
188+
pub struct MigrationId(tsk_id_t);
189+
181190
impl_id_traits!(NodeId);
182191
impl_id_traits!(IndividualId);
183192
impl_id_traits!(PopulationId);
184193
impl_id_traits!(SiteId);
185194
impl_id_traits!(MutationId);
195+
impl_id_traits!(MigrationId);
186196

187197
// tskit defines this via a type cast
188198
// in a macro. bindgen thus misses it.
@@ -210,6 +220,19 @@ pub use trees::{NodeTraversalOrder, Tree, TreeSequence};
210220
#[cfg(any(doc, feature = "provenance"))]
211221
pub mod provenance;
212222

223+
/// A provenance ID
224+
///
225+
/// This is an integer referring to a row of an [``ProvenanceTable``].
226+
///
227+
/// The features for this type follow the same pattern as for [``NodeId``]
228+
#[cfg(any(doc, feature = "provenance"))]
229+
#[repr(transparent)]
230+
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)]
231+
pub struct ProvenanceId(tsk_id_t);
232+
233+
#[cfg(any(doc, feature = "provenance"))]
234+
impl_id_traits!(ProvenanceId);
235+
213236
/// Handles return codes from low-level tskit functions.
214237
///
215238
/// When an error from the tskit C API is detected,

src/migration_table.rs

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
use crate::bindings as ll_bindings;
22
use crate::metadata;
33
use crate::{tsk_id_t, TskitError};
4+
use crate::{MigrationId, NodeId, PopulationId};
45

56
/// Row of a [`MigrationTable`]
67
pub struct MigrationTableRow {
7-
pub id: tsk_id_t,
8+
pub id: MigrationId,
89
pub left: f64,
910
pub right: f64,
10-
pub node: tsk_id_t,
11-
pub source: tsk_id_t,
12-
pub dest: tsk_id_t,
11+
pub node: NodeId,
12+
pub source: PopulationId,
13+
pub dest: PopulationId,
1314
pub time: f64,
1415
pub metadata: Option<Vec<u8>>,
1516
}
@@ -30,7 +31,7 @@ impl PartialEq for MigrationTableRow {
3031
fn make_migration_table_row(table: &MigrationTable, pos: tsk_id_t) -> Option<MigrationTableRow> {
3132
if pos < table.num_rows() as tsk_id_t {
3233
Some(MigrationTableRow {
33-
id: pos,
34+
id: pos.into(),
3435
left: table.left(pos).unwrap(),
3536
right: table.right(pos).unwrap(),
3637
node: table.node(pos).unwrap(),
@@ -92,53 +93,68 @@ impl<'a> MigrationTable<'a> {
9293
/// # Errors
9394
///
9495
/// * [`TskitError::IndexError`] if `row` is out of range.
95-
pub fn left(&'a self, row: tsk_id_t) -> Result<f64, TskitError> {
96-
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.left)
96+
pub fn left<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Result<f64, TskitError> {
97+
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.left)
9798
}
9899

99100
/// Return the right coordinate for a given row.
100101
///
101102
/// # Errors
102103
///
103104
/// * [`TskitError::IndexError`] if `row` is out of range.
104-
pub fn right(&'a self, row: tsk_id_t) -> Result<f64, TskitError> {
105-
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.right)
105+
pub fn right<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Result<f64, TskitError> {
106+
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.right)
106107
}
107108

108109
/// Return the node for a given row.
109110
///
110111
/// # Errors
111112
///
112113
/// * [`TskitError::IndexError`] if `row` is out of range.
113-
pub fn node(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
114-
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.source)
114+
pub fn node<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Result<NodeId, TskitError> {
115+
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.source, NodeId)
115116
}
116117

117118
/// Return the source population for a given row.
118119
///
119120
/// # Errors
120121
///
121122
/// * [`TskitError::IndexError`] if `row` is out of range.
122-
pub fn source(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
123-
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.node)
123+
pub fn source<M: Into<MigrationId> + Copy>(
124+
&'a self,
125+
row: M,
126+
) -> Result<PopulationId, TskitError> {
127+
unsafe_tsk_column_access!(
128+
row.into().0,
129+
0,
130+
self.num_rows(),
131+
self.table_.node,
132+
PopulationId
133+
)
124134
}
125135

126136
/// Return the destination population for a given row.
127137
///
128138
/// # Errors
129139
///
130140
/// * [`TskitError::IndexError`] if `row` is out of range.
131-
pub fn dest(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
132-
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.dest)
141+
pub fn dest<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Result<PopulationId, TskitError> {
142+
unsafe_tsk_column_access!(
143+
row.into().0,
144+
0,
145+
self.num_rows(),
146+
self.table_.dest,
147+
PopulationId
148+
)
133149
}
134150

135151
/// Return the time of the migration event for a given row.
136152
///
137153
/// # Errors
138154
///
139155
/// * [`TskitError::IndexError`] if `row` is out of range.
140-
pub fn time(&'a self, row: tsk_id_t) -> Result<f64, TskitError> {
141-
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.time)
156+
pub fn time<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Result<f64, TskitError> {
157+
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.time)
142158
}
143159

144160
/// Return the metadata for a given row.
@@ -148,9 +164,9 @@ impl<'a> MigrationTable<'a> {
148164
/// * [`TskitError::IndexError`] if `row` is out of range.
149165
pub fn metadata<T: metadata::MetadataRoundtrip>(
150166
&'a self,
151-
row: tsk_id_t,
167+
row: MigrationId,
152168
) -> Result<Option<T>, TskitError> {
153-
let buffer = metadata_to_vector!(self, row)?;
169+
let buffer = metadata_to_vector!(self, row.0)?;
154170
decode_metadata_row!(T, buffer)
155171
}
156172

@@ -169,7 +185,7 @@ impl<'a> MigrationTable<'a> {
169185
/// # Errors
170186
///
171187
/// [`TskitError::IndexError`] if `r` is out of range.
172-
pub fn row(&self, r: tsk_id_t) -> Result<MigrationTableRow, TskitError> {
173-
table_row_access!(r, self, make_migration_table_row)
188+
pub fn row<M: Into<MigrationId> + Copy>(&self, r: M) -> Result<MigrationTableRow, TskitError> {
189+
table_row_access!(r.into().0, self, make_migration_table_row)
174190
}
175191
}

src/provenance.rs

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
//! See [`Provenance`] for examples.
1212
1313
use crate::bindings as ll_bindings;
14-
use crate::{tsk_id_t, tsk_size_t, TskitError};
14+
use crate::{tsk_id_t, tsk_size_t, ProvenanceId, TskitError};
1515

1616
/// Enable provenance table access.
1717
///
@@ -100,7 +100,7 @@ pub trait Provenance: crate::TableAccess {
100100
/// # Parameters
101101
///
102102
/// * `record`: the provenance record
103-
fn add_provenance(&mut self, record: &str) -> crate::TskReturnValue;
103+
fn add_provenance(&mut self, record: &str) -> Result<ProvenanceId, TskitError>;
104104
/// Return an immutable reference to the table, type [`ProvenanceTable`]
105105
fn provenances(&self) -> ProvenanceTable;
106106
/// Return an iterator over the rows of the [`ProvenanceTable`].
@@ -114,7 +114,7 @@ pub trait Provenance: crate::TableAccess {
114114
/// Row of a [`ProvenanceTable`].
115115
pub struct ProvenanceTableRow {
116116
/// The row id
117-
pub id: tsk_id_t,
117+
pub id: ProvenanceId,
118118
/// ISO-formatted time stamp
119119
pub timestamp: String,
120120
/// The provenance record
@@ -127,6 +127,12 @@ impl PartialEq for ProvenanceTableRow {
127127
}
128128
}
129129

130+
impl std::fmt::Display for ProvenanceId {
131+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
132+
write!(f, "ProvenanceId({})", self.0)
133+
}
134+
}
135+
130136
impl std::fmt::Display for ProvenanceTableRow {
131137
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
132138
write!(
@@ -140,7 +146,7 @@ impl std::fmt::Display for ProvenanceTableRow {
140146
fn make_provenance_table_row(table: &ProvenanceTable, pos: tsk_id_t) -> Option<ProvenanceTableRow> {
141147
if pos < table.num_rows() as tsk_id_t {
142148
Some(ProvenanceTableRow {
143-
id: pos,
149+
id: pos.into(),
144150
timestamp: table.timestamp(pos).unwrap(),
145151
record: table.record(pos).unwrap(),
146152
})
@@ -203,9 +209,9 @@ impl<'a> ProvenanceTable<'a> {
203209
/// # Errors
204210
///
205211
/// [`TskitError::IndexError`] if `r` is out of range.
206-
pub fn timestamp(&'a self, row: tsk_id_t) -> Result<String, TskitError> {
212+
pub fn timestamp<P: Into<ProvenanceId> + Copy>(&'a self, row: P) -> Result<String, TskitError> {
207213
match unsafe_tsk_ragged_char_column_access!(
208-
row,
214+
row.into().0,
209215
0,
210216
self.num_rows(),
211217
self.table_.timestamp,
@@ -226,9 +232,9 @@ impl<'a> ProvenanceTable<'a> {
226232
/// # Errors
227233
///
228234
/// [`TskitError::IndexError`] if `r` is out of range.
229-
pub fn record(&'a self, row: tsk_id_t) -> Result<String, TskitError> {
235+
pub fn record<P: Into<ProvenanceId> + Copy>(&'a self, row: P) -> Result<String, TskitError> {
230236
match unsafe_tsk_ragged_char_column_access!(
231-
row,
237+
row.into().0,
232238
0,
233239
self.num_rows(),
234240
self.table_.record,
@@ -249,11 +255,14 @@ impl<'a> ProvenanceTable<'a> {
249255
/// # Errors
250256
///
251257
/// [`TskitError::IndexError`] if `r` is out of range.
252-
pub fn row(&'a self, row: tsk_id_t) -> Result<ProvenanceTableRow, TskitError> {
253-
if row < 0 {
258+
pub fn row<P: Into<ProvenanceId> + Copy>(
259+
&'a self,
260+
row: P,
261+
) -> Result<ProvenanceTableRow, TskitError> {
262+
if row.into() < 0 {
254263
Err(TskitError::IndexError)
255264
} else {
256-
match make_provenance_table_row(self, row) {
265+
match make_provenance_table_row(self, row.into().0) {
257266
Some(x) => Ok(x),
258267
None => Err(TskitError::IndexError),
259268
}

src/table_collection.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ impl crate::traits::NodeListGenerator for TableCollection {}
633633

634634
#[cfg(any(doc, feature = "provenance"))]
635635
impl crate::provenance::Provenance for TableCollection {
636-
fn add_provenance(&mut self, record: &str) -> TskReturnValue {
636+
fn add_provenance(&mut self, record: &str) -> Result<crate::ProvenanceId, TskitError> {
637637
if record.is_empty() {
638638
return Err(TskitError::ValueError {
639639
got: String::from("empty string slice"),
@@ -650,7 +650,7 @@ impl crate::provenance::Provenance for TableCollection {
650650
record.len() as tsk_size_t,
651651
)
652652
};
653-
handle_tsk_return_value!(rv)
653+
handle_tsk_return_value!(rv, crate::ProvenanceId::from(rv))
654654
}
655655

656656
fn provenances(&self) -> crate::provenance::ProvenanceTable {

src/trees.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,7 +1135,7 @@ impl crate::traits::NodeListGenerator for TreeSequence {}
11351135

11361136
#[cfg(any(doc, feature = "provenance"))]
11371137
impl crate::provenance::Provenance for TreeSequence {
1138-
fn add_provenance(&mut self, record: &str) -> TskReturnValue {
1138+
fn add_provenance(&mut self, record: &str) -> Result<crate::ProvenanceId, TskitError> {
11391139
if record.is_empty() {
11401140
return Err(TskitError::ValueError {
11411141
got: String::from("empty string slice"),
@@ -1152,7 +1152,7 @@ impl crate::provenance::Provenance for TreeSequence {
11521152
record.len() as tsk_size_t,
11531153
)
11541154
};
1155-
handle_tsk_return_value!(rv)
1155+
handle_tsk_return_value!(rv, crate::ProvenanceId::from(rv))
11561156
}
11571157

11581158
fn provenances(&self) -> crate::provenance::ProvenanceTable {

0 commit comments

Comments
 (0)