Skip to content

Commit 3ac3d50

Browse files
authored
* Establish a pattern for stronger ID types (#129)
* Implement the pattern via a macro * Add NodeId as the first concrete example This PR breaks API, although the changes required are small. See examples/forward_simulation.rs.
1 parent e19095d commit 3ac3d50

File tree

8 files changed

+186
-66
lines changed

8 files changed

+186
-66
lines changed

examples/forward_simulation.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ impl SimParams {
166166

167167
#[derive(Copy, Clone)]
168168
struct Diploid {
169-
node0: tskit::tsk_id_t,
170-
node1: tskit::tsk_id_t,
169+
node0: tskit::NodeId,
170+
node1: tskit::NodeId,
171171
}
172172

173173
struct Parents {
@@ -201,7 +201,7 @@ fn death_and_parents(
201201
}
202202
}
203203

204-
fn mendel(pnodes: &mut (tskit::tsk_id_t, tskit::tsk_id_t), rng: &mut StdRng) {
204+
fn mendel(pnodes: &mut (tskit::NodeId, tskit::NodeId), rng: &mut StdRng) {
205205
let x: f64 = rng.gen();
206206
match x.partial_cmp(&0.5) {
207207
Some(std::cmp::Ordering::Less) => {
@@ -214,7 +214,7 @@ fn mendel(pnodes: &mut (tskit::tsk_id_t, tskit::tsk_id_t), rng: &mut StdRng) {
214214

215215
fn crossover_and_record_edges_details(
216216
parent: Diploid,
217-
offspring_node: tskit::tsk_id_t,
217+
offspring_node: tskit::NodeId,
218218
params: &SimParams,
219219
tables: &mut tskit::TableCollection,
220220
rng: &mut StdRng,
@@ -269,7 +269,7 @@ fn crossover_and_record_edges_details(
269269

270270
fn crossover_and_record_edges(
271271
parents: &Parents,
272-
offspring_nodes: (tskit::tsk_id_t, tskit::tsk_id_t),
272+
offspring_nodes: (tskit::NodeId, tskit::NodeId),
273273
params: &SimParams,
274274
tables: &mut tskit::TableCollection,
275275
rng: &mut StdRng,
@@ -332,9 +332,9 @@ fn simplify(alive: &mut [Diploid], tables: &mut tskit::TableCollection) {
332332
Ok(x) => match x {
333333
Some(idmap) => {
334334
for a in alive.iter_mut() {
335-
a.node0 = idmap[a.node0 as usize];
335+
a.node0 = idmap[usize::from(a.node0)];
336336
assert!(a.node0 != tskit::TSK_NULL);
337-
a.node1 = idmap[a.node1 as usize];
337+
a.node1 = idmap[usize::from(a.node1)];
338338
assert!(a.node1 != tskit::TSK_NULL);
339339
}
340340
}

src/_macros.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ macro_rules! unsafe_tsk_column_access {
3535
Ok(unsafe { *$array.offset($i as isize) })
3636
}
3737
}};
38+
($i: expr, $lo: expr, $hi: expr, $array: expr, $output_id_type: expr) => {{
39+
if $i < $lo || ($i as $crate::tsk_size_t) >= $hi {
40+
Err($crate::error::TskitError::IndexError {})
41+
} else {
42+
Ok($output_id_type(unsafe { *$array.offset($i as isize) }))
43+
}
44+
}};
3845
}
3946

4047
macro_rules! unsafe_tsk_ragged_column_access {
@@ -220,6 +227,58 @@ macro_rules! tree_array_slice {
220227
};
221228
}
222229

230+
macro_rules! impl_id_traits {
231+
($idtype: ty) => {
232+
impl From<$crate::tsk_id_t> for $idtype {
233+
fn from(value: $crate::tsk_id_t) -> Self {
234+
Self(value)
235+
}
236+
}
237+
238+
impl $crate::IdIsNull for $idtype {
239+
fn is_null(&self) -> bool {
240+
self.0 == $crate::TSK_NULL
241+
}
242+
}
243+
244+
impl From<$idtype> for usize {
245+
fn from(value: $idtype) -> Self {
246+
value.0 as usize
247+
}
248+
}
249+
250+
impl From<$idtype> for $crate::tsk_id_t {
251+
fn from(value: $idtype) -> Self {
252+
value.0
253+
}
254+
}
255+
256+
impl PartialEq<$crate::tsk_id_t> for $idtype {
257+
fn eq(&self, other: &$crate::tsk_id_t) -> bool {
258+
self.0 == *other
259+
}
260+
}
261+
262+
impl PartialEq<$idtype> for $crate::tsk_id_t {
263+
fn eq(&self, other: &$idtype) -> bool {
264+
*self == other.0
265+
}
266+
}
267+
268+
impl PartialOrd<$crate::tsk_id_t> for $idtype {
269+
fn partial_cmp(&self, other: &$crate::tsk_id_t) -> Option<std::cmp::Ordering> {
270+
self.0.partial_cmp(other)
271+
}
272+
}
273+
274+
impl PartialOrd<$idtype> for $crate::tsk_id_t {
275+
fn partial_cmp(&self, other: &$idtype) -> Option<std::cmp::Ordering> {
276+
self.partial_cmp(&other.0)
277+
}
278+
}
279+
};
280+
}
281+
223282
/// Convenience macro to handle implementing
224283
/// [`crate::metadata::MetadataRoundtrip`]
225284
#[macro_export]

src/lib.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,56 @@ pub use bindings::tsk_flags_t;
9494
pub use bindings::tsk_id_t;
9595
pub use bindings::tsk_size_t;
9696

97+
/// A node ID
98+
///
99+
/// This is an integer referring to a row of a [``NodeTable``].
100+
/// The underlying type is [``tsk_id_t``].
101+
///
102+
/// # Examples
103+
///
104+
/// These examples illustrate using this type as something "integer-like".
105+
///
106+
/// ```
107+
/// use tskit::NodeId;
108+
/// use tskit::tsk_id_t;
109+
///
110+
/// let x: tsk_id_t = 1;
111+
/// let y: NodeId = NodeId::from(x);
112+
/// assert_eq!(x, y);
113+
/// assert_eq!(y, x);
114+
///
115+
/// assert!(y < x + 1);
116+
/// assert!(y <= x);
117+
/// assert!(x + 1 > y);
118+
/// assert!(x + 1 >= y);
119+
///
120+
/// let z: NodeId = NodeId::from(x);
121+
/// assert_eq!(y, z);
122+
/// ```
123+
///
124+
/// It is also possible to write functions accepting both the `NodeId`
125+
/// and `tsk_id_t`:
126+
///
127+
/// ```
128+
/// use tskit::NodeId;
129+
/// use tskit::tsk_id_t;
130+
///
131+
/// fn interesting<N: Into<NodeId>>(x: N) -> NodeId {
132+
/// x.into()
133+
/// }
134+
///
135+
/// let x: tsk_id_t = 0;
136+
/// assert_eq!(interesting(x), x);
137+
/// let x: NodeId = NodeId::from(0);
138+
/// assert_eq!(interesting(x), x);
139+
/// ```
140+
///
141+
#[repr(transparent)]
142+
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)]
143+
pub struct NodeId(tsk_id_t);
144+
145+
impl_id_traits!(NodeId);
146+
97147
// tskit defines this via a type cast
98148
// in a macro. bindgen thus misses it.
99149
// See bindgen issue 316.
@@ -110,6 +160,7 @@ pub use node_table::{NodeTable, NodeTableRow};
110160
pub use population_table::{PopulationTable, PopulationTableRow};
111161
pub use site_table::{SiteTable, SiteTableRow};
112162
pub use table_collection::TableCollection;
163+
pub use traits::IdIsNull;
113164
pub use traits::NodeListGenerator;
114165
pub use traits::TableAccess;
115166
pub use traits::TskitTypeAccess;

src/mutation_table.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
use crate::bindings as ll_bindings;
22
use crate::metadata;
3+
use crate::NodeId;
34
use crate::{tsk_id_t, tsk_size_t, TskitError};
45

56
/// Row of a [`MutationTable`]
67
pub struct MutationTableRow {
78
pub id: tsk_id_t,
89
pub site: tsk_id_t,
9-
pub node: tsk_id_t,
10+
pub node: NodeId,
1011
pub parent: tsk_id_t,
1112
pub time: f64,
1213
pub derived_state: Option<Vec<u8>>,
@@ -99,8 +100,8 @@ impl<'a> MutationTable<'a> {
99100
///
100101
/// Will return [``IndexError``](crate::TskitError::IndexError)
101102
/// if ``row`` is out of range.
102-
pub fn node(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
103-
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.node)
103+
pub fn node(&'a self, row: tsk_id_t) -> Result<NodeId, TskitError> {
104+
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.node, NodeId)
104105
}
105106

106107
/// Return the ``parent`` value from row ``row`` of the table.

src/node_table.rs

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use crate::bindings as ll_bindings;
22
use crate::metadata;
3+
use crate::NodeId;
34
use crate::{tsk_flags_t, tsk_id_t, TskitError};
45

56
/// Row of a [`NodeTable`]
67
pub struct NodeTableRow {
7-
pub id: tsk_id_t,
8+
pub id: NodeId,
89
pub time: f64,
910
pub flags: tsk_flags_t,
1011
pub population: tsk_id_t,
@@ -26,7 +27,7 @@ impl PartialEq for NodeTableRow {
2627
fn make_node_table_row(table: &NodeTable, pos: tsk_id_t) -> Option<NodeTableRow> {
2728
if pos < table.num_rows() as tsk_id_t {
2829
Some(NodeTableRow {
29-
id: pos,
30+
id: pos.into(),
3031
time: table.time(pos).unwrap(),
3132
flags: table.flags(pos).unwrap(),
3233
population: table.population(pos).unwrap(),
@@ -86,8 +87,8 @@ impl<'a> NodeTable<'a> {
8687
///
8788
/// Will return [``IndexError``](crate::TskitError::IndexError)
8889
/// if ``row`` is out of range.
89-
pub fn time(&'a self, row: tsk_id_t) -> Result<f64, TskitError> {
90-
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.time)
90+
pub fn time<N: Into<NodeId> + Copy>(&'a self, row: N) -> Result<f64, TskitError> {
91+
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.time)
9192
}
9293

9394
/// Return the ``flags`` value from row ``row`` of the table.
@@ -96,8 +97,8 @@ impl<'a> NodeTable<'a> {
9697
///
9798
/// Will return [``IndexError``](crate::TskitError::IndexError)
9899
/// if ``row`` is out of range.
99-
pub fn flags(&'a self, row: tsk_id_t) -> Result<tsk_flags_t, TskitError> {
100-
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.flags)
100+
pub fn flags<N: Into<NodeId> + Copy>(&'a self, row: N) -> Result<tsk_flags_t, TskitError> {
101+
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.flags)
101102
}
102103

103104
/// Mutable access to node flags.
@@ -116,8 +117,8 @@ impl<'a> NodeTable<'a> {
116117
///
117118
/// Will return [``IndexError``](crate::TskitError::IndexError)
118119
/// if ``row`` is out of range.
119-
pub fn population(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
120-
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.population)
120+
pub fn population<N: Into<NodeId> + Copy>(&'a self, row: N) -> Result<tsk_id_t, TskitError> {
121+
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.population)
121122
}
122123

123124
/// Return the ``population`` value from row ``row`` of the table.
@@ -126,7 +127,7 @@ impl<'a> NodeTable<'a> {
126127
///
127128
/// Will return [``IndexError``](crate::TskitError::IndexError)
128129
/// if ``row`` is out of range.
129-
pub fn deme(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
130+
pub fn deme<N: Into<NodeId> + Copy>(&'a self, row: N) -> Result<tsk_id_t, TskitError> {
130131
self.population(row)
131132
}
132133

@@ -136,8 +137,8 @@ impl<'a> NodeTable<'a> {
136137
///
137138
/// Will return [``IndexError``](crate::TskitError::IndexError)
138139
/// if ``row`` is out of range.
139-
pub fn individual(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
140-
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.individual)
140+
pub fn individual<N: Into<NodeId> + Copy>(&'a self, row: N) -> Result<tsk_id_t, TskitError> {
141+
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.individual)
141142
}
142143

143144
pub fn metadata<T: metadata::MetadataRoundtrip>(
@@ -163,15 +164,15 @@ impl<'a> NodeTable<'a> {
163164
/// # Errors
164165
///
165166
/// [`TskitError::IndexError`] if `r` is out of range.
166-
pub fn row(&self, r: tsk_id_t) -> Result<NodeTableRow, TskitError> {
167-
table_row_access!(r, self, make_node_table_row)
167+
pub fn row<N: Into<NodeId> + Copy>(&self, r: N) -> Result<NodeTableRow, TskitError> {
168+
table_row_access!(r.into().0, self, make_node_table_row)
168169
}
169170

170171
/// Obtain a vector containing the indexes ("ids")
171172
/// of all nodes for which [`crate::TSK_NODE_IS_SAMPLE`]
172173
/// is `true`.
173-
pub fn samples_as_vector(&self) -> Vec<tsk_id_t> {
174-
let mut samples: Vec<tsk_id_t> = vec![];
174+
pub fn samples_as_vector(&self) -> Vec<NodeId> {
175+
let mut samples: Vec<NodeId> = vec![];
175176
for row in self.iter() {
176177
if row.flags & crate::TSK_NODE_IS_SAMPLE > 0 {
177178
samples.push(row.id);
@@ -185,8 +186,8 @@ impl<'a> NodeTable<'a> {
185186
pub fn create_node_id_vector(
186187
&self,
187188
mut f: impl FnMut(&crate::NodeTableRow) -> bool,
188-
) -> Vec<tsk_id_t> {
189-
let mut samples: Vec<tsk_id_t> = vec![];
189+
) -> Vec<NodeId> {
190+
let mut samples: Vec<NodeId> = vec![];
190191
for row in self.iter() {
191192
if f(&row) {
192193
samples.push(row.id);

0 commit comments

Comments
 (0)