Skip to content

feat: column slice getters for tables #404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,38 @@ macro_rules! optional_container_comparison {
};
}

macro_rules! build_table_column_slice_getter {
($(#[$attr:meta])* => $column: ident, $name: ident, $cast: ty) => {
$(#[$attr])*
pub fn $name(&self) -> &[$cast] {
// Caveat: num_rows is u64 but we need usize
// The conversion is fallible but unlikely.
let num_rows =
usize::try_from(self.num_rows()).expect("conversion of num_rows to usize failed");
let ptr = self.as_ref().$column as *const $cast;
// SAFETY: tables are initialzed, num rows comes
// from the C back end.
unsafe { std::slice::from_raw_parts(ptr, num_rows) }
}
};
}

macro_rules! build_table_column_slice_mut_getter {
($(#[$attr:meta])* => $column: ident, $name: ident, $cast: ty) => {
$(#[$attr])*
pub fn $name(&mut self) -> &mut [$cast] {
// Caveat: num_rows is u64 but we need usize
// The conversion is fallible but unlikely.
let num_rows =
usize::try_from(self.num_rows()).expect("conversion of num_rows to usize failed");
let ptr = self.as_ref().$column as *mut $cast;
// SAFETY: tables are initialzed, num rows comes
// from the C back end.
unsafe { std::slice::from_raw_parts_mut(ptr, num_rows) }
}
};
}

#[cfg(test)]
mod test {
use crate::error::TskitError;
Expand Down
25 changes: 25 additions & 0 deletions src/edge_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,31 @@ impl EdgeTable {
};
Some(view)
}

build_table_column_slice_getter!(
/// Get the left column as a slice
=> left, left_slice, Position);
build_table_column_slice_getter!(
/// Get the left column as a slice of [`f64`]
=> left, left_slice_raw, f64);
build_table_column_slice_getter!(
/// Get the right column as a slice
=> right, right_slice, Position);
build_table_column_slice_getter!(
/// Get the left column as a slice of [`f64`]
=> right, right_slice_raw, f64);
build_table_column_slice_getter!(
/// Get the parent column as a slice
=> parent, parent_slice, NodeId);
build_table_column_slice_getter!(
/// Get the parent column as a slice of [`crate::bindings::tsk_id_t`]
=> parent, parent_slice_raw, ll_bindings::tsk_id_t);
build_table_column_slice_getter!(
/// Get the child column as a slice
=> child, child_slice, NodeId);
build_table_column_slice_getter!(
/// Get the child column as a slice of [`crate::bindings::tsk_id_t`]
=> child, child_slice_raw, ll_bindings::tsk_id_t);
}

build_owned_table_type!(
Expand Down
7 changes: 7 additions & 0 deletions src/individual_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,13 @@ match tables.individuals().metadata::<MutationMetadata>(0.into())
};
Some(view)
}

build_table_column_slice_getter!(
/// Get the flags column as a slice
=> flags, flags_slice, IndividualFlags);
build_table_column_slice_getter!(
/// Get the flags column as a slice
=> flags, flags_slice_raw, ll_bindings::tsk_flags_t);
}

build_owned_table_type!(
Expand Down
37 changes: 37 additions & 0 deletions src/migration_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,43 @@ impl MigrationTable {
};
Some(view)
}

build_table_column_slice_getter!(
/// Get the left column as a slice
=> left, left_slice, Position);
build_table_column_slice_getter!(
/// Get the left column as a slice
=> left, left_slice_raw, f64);
build_table_column_slice_getter!(
/// Get the right column as a slice
=> right, right_slice, Position);
build_table_column_slice_getter!(
/// Get the right column as a slice
=> right, right_slice_raw, f64);
build_table_column_slice_getter!(
/// Get the time column as a slice
=> time, time_slice, Time);
build_table_column_slice_getter!(
/// Get the time column as a slice
=> time, time_slice_raw, f64);
build_table_column_slice_getter!(
/// Get the node column as a slice
=> node, node_slice, NodeId);
build_table_column_slice_getter!(
/// Get the node column as a slice
=> node, node_slice_raw, ll_bindings::tsk_id_t);
build_table_column_slice_getter!(
/// Get the source column as a slice
=> source, source_slice, PopulationId);
build_table_column_slice_getter!(
/// Get the source column as a slice
=> source, source_slice_raw, ll_bindings::tsk_id_t);
build_table_column_slice_getter!(
/// Get the dest column as a slice
=> dest, dest_slice, PopulationId);
build_table_column_slice_getter!(
/// Get the dest column as a slice
=> dest, dest_slice_raw, ll_bindings::tsk_id_t);
}

build_owned_table_type!(
Expand Down
25 changes: 25 additions & 0 deletions src/mutation_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,31 @@ impl MutationTable {
};
Some(view)
}

build_table_column_slice_getter!(
/// Get the node column as a slice
=> node, node_slice, NodeId);
build_table_column_slice_getter!(
/// Get the node column as a slice
=> node, node_slice_raw, crate::tsk_id_t);
build_table_column_slice_getter!(
/// Get the site column as a slice
=> site, site_slice, SiteId);
build_table_column_slice_getter!(
/// Get the site column as a slice
=> site, site_slice_raw, crate::tsk_id_t);
build_table_column_slice_getter!(
/// Get the time column as a slice
=> time, time_slice, Time);
build_table_column_slice_getter!(
/// Get the time column as a slice
=> time, time_slice_raw, f64);
build_table_column_slice_getter!(
/// Get the parent column as a slice
=> parent, parent_slice, MutationId);
build_table_column_slice_getter!(
/// Get the parent column as a slice
=> parent, parent_slice_raw, crate::tsk_id_t);
}

build_owned_table_type!(
Expand Down
37 changes: 37 additions & 0 deletions src/node_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,43 @@ impl NodeTable {
.map(|row| row.id)
.collect::<Vec<_>>()
}

build_table_column_slice_getter!(
/// Get the time column as a slice
=> time, time_slice, Time);
build_table_column_slice_getter!(
/// Get the time column as a slice
=> time, time_slice_raw, f64);
build_table_column_slice_mut_getter!(
/// Get the time column as a mutable slice
=> time, time_slice_mut, Time);
build_table_column_slice_mut_getter!(
/// Get the time column as a mutable slice
=> time, time_slice_raw_mut, f64);
build_table_column_slice_getter!(
/// Get the flags column as a slice
=> flags, flags_slice, NodeFlags);
build_table_column_slice_getter!(
/// Get the flags column as a slice
=> flags, flags_slice_raw, ll_bindings::tsk_flags_t);
build_table_column_slice_mut_getter!(
/// Get the flags column as a mutable slice
=> flags, flags_slice_mut, NodeFlags);
build_table_column_slice_mut_getter!(
/// Get the flags column as a mutable slice
=> flags, flags_slice_raw_mut, ll_bindings::tsk_flags_t);
build_table_column_slice_getter!(
/// Get the individual column as a slice
=> individual, individual_slice, IndividualId);
build_table_column_slice_getter!(
/// Get the individual column as a slice
=> individual, individual_slice_raw, crate::tsk_id_t);
build_table_column_slice_getter!(
/// Get the population column as a slice
=> population, population_slice, PopulationId);
build_table_column_slice_getter!(
/// Get the population column as a slice
=> population, population_slice_raw, crate::tsk_id_t);
}

build_owned_table_type!(
Expand Down
7 changes: 7 additions & 0 deletions src/site_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,13 @@ impl SiteTable {
};
Some(view)
}

build_table_column_slice_getter!(
/// Get the position column as a slice
=> position, position_slice, Position);
build_table_column_slice_getter!(
/// Get the position column as a slice
=> position, position_slice_raw, f64);
}

build_owned_table_type!(
Expand Down
122 changes: 115 additions & 7 deletions tests/test_tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,58 @@ mod test_adding_rows_without_metadata {
Err(e) => panic!("Err from tables.{}: {:?}", stringify!(adder), e)
}
assert_eq!(tables.$table().iter().count(), 2);
tables
}
}};
}

macro_rules! compare_column_to_raw_column {
($table: expr, $col: ident, $raw: ident) => {
assert_eq!(
$table.$col().len(),
usize::try_from($table.num_rows()).unwrap()
);
assert_eq!(
$table.$raw().len(),
usize::try_from($table.num_rows()).unwrap()
);
assert!($table
.$col()
.iter()
.zip($table.$raw().iter())
.all(|(a, b)| a == b))
};
}

macro_rules! compare_column_to_row {
($table: expr, $col: ident, $target: ident) => {
assert!($table
.$col()
.iter()
.zip($table.iter())
.all(|(c, r)| c == &r.$target));
};
}

// NOTE: all functions arguments for adding rows are Into<T>
// where T is one of our new types.
// Further, functions taking multiple inputs of T are defined
// as X: Into<T>, X2: Into<T>, etc., allowing mix-and-match.

#[test]
fn test_adding_edge() {
add_row_without_metadata!(edges, add_edge, 0.1, 0.5, 0, 1); // left, right, parent, child
{
let tables = add_row_without_metadata!(edges, add_edge, 0.1, 0.5, 0, 1); // left, right, parent, child
compare_column_to_raw_column!(tables.edges(), left_slice, left_slice_raw);
compare_column_to_raw_column!(tables.edges(), right_slice, right_slice_raw);
compare_column_to_raw_column!(tables.edges(), parent_slice, parent_slice_raw);
compare_column_to_raw_column!(tables.edges(), child_slice, child_slice_raw);

compare_column_to_row!(tables.edges(), left_slice, left);
compare_column_to_row!(tables.edges(), right_slice, right);
compare_column_to_row!(tables.edges(), parent_slice, parent);
compare_column_to_row!(tables.edges(), child_slice, child);
}
add_row_without_metadata!(edges, add_edge, tskit::Position::from(0.1), 0.5, 0, 1); // left, right, parent, child
add_row_without_metadata!(edges, add_edge, 0.1, tskit::Position::from(0.5), 0, 1); // left, right, parent, child
add_row_without_metadata!(
Expand All @@ -105,8 +145,30 @@ mod test_adding_rows_without_metadata {

#[test]
fn test_adding_node() {
add_row_without_metadata!(nodes, add_node, 0, 0.1, -1, -1); // flags, time, population,
// individual
{
let tables =
add_row_without_metadata!(nodes, add_node, tskit::TSK_NODE_IS_SAMPLE, 0.1, -1, -1); // flags, time, population,
// individual
assert!(tables
.nodes()
.flags_slice()
.iter()
.zip(tables.nodes().flags_slice_raw().iter())
.all(|(a, b)| a.bits() == *b));
compare_column_to_raw_column!(tables.nodes(), time_slice, time_slice_raw);
compare_column_to_raw_column!(tables.nodes(), population_slice, population_slice_raw);
compare_column_to_raw_column!(tables.nodes(), individual_slice, individual_slice_raw);

assert!(tables
.nodes()
.flags_slice()
.iter()
.zip(tables.nodes().iter())
.all(|(c, r)| c == &r.flags));
compare_column_to_row!(tables.nodes(), time_slice, time);
compare_column_to_row!(tables.nodes(), population_slice, population);
compare_column_to_row!(tables.nodes(), individual_slice, individual);
}
add_row_without_metadata!(
nodes,
add_node,
Expand All @@ -120,7 +182,11 @@ mod test_adding_rows_without_metadata {
#[test]
fn test_adding_site() {
// No ancestral state
add_row_without_metadata!(sites, add_site, 2. / 3., None);
{
let tables = add_row_without_metadata!(sites, add_site, 2. / 3., None);
compare_column_to_raw_column!(tables.sites(), position_slice, position_slice_raw);
compare_column_to_row!(tables.sites(), position_slice, position);
}
add_row_without_metadata!(sites, add_site, tskit::Position::from(2. / 3.), None);
add_row_without_metadata!(sites, add_site, 2. / 3., Some(&[1_u8]));
add_row_without_metadata!(
Expand All @@ -136,14 +202,40 @@ mod test_adding_rows_without_metadata {
// site, node, parent mutation, time, derived_state
// Each value is a different Into<T> so we skip doing
// permutations
add_row_without_metadata!(mutations, add_mutation, 0, 0, -1, 0.0, None);
{
let tables = add_row_without_metadata!(mutations, add_mutation, 0, 0, -1, 0.0, None);
compare_column_to_raw_column!(tables.mutations(), node_slice, node_slice_raw);
compare_column_to_raw_column!(tables.mutations(), time_slice, time_slice_raw);
compare_column_to_raw_column!(tables.mutations(), site_slice, site_slice_raw);
compare_column_to_raw_column!(tables.mutations(), parent_slice, parent_slice_raw);

compare_column_to_row!(tables.mutations(), node_slice, node);
compare_column_to_row!(tables.mutations(), time_slice, time);
compare_column_to_row!(tables.mutations(), site_slice, site);
compare_column_to_row!(tables.mutations(), parent_slice, parent);
}

add_row_without_metadata!(mutations, add_mutation, 0, 0, -1, 0.0, Some(&[23_u8]));
}

#[test]
fn test_adding_individual() {
// flags, location, parents
add_row_without_metadata!(individuals, add_individual, 0, None, None);
{
let tables = add_row_without_metadata!(individuals, add_individual, 0, None, None);
assert!(tables
.individuals()
.flags_slice()
.iter()
.zip(tables.individuals().flags_slice_raw().iter())
.all(|(a, b)| a.bits() == *b));
assert!(tables
.individuals()
.flags_slice()
.iter()
.zip(tables.individuals().iter())
.all(|(c, r)| c == &r.flags));
}
add_row_without_metadata!(
individuals,
add_individual,
Expand Down Expand Up @@ -179,7 +271,23 @@ mod test_adding_rows_without_metadata {
fn test_adding_migration() {
// migration table
// (left, right), node, (source, dest), time
add_row_without_metadata!(migrations, add_migration, (0., 1.), 0, (0, 1), 0.0);
{
let tables =
add_row_without_metadata!(migrations, add_migration, (0., 1.), 0, (0, 1), 0.0);
compare_column_to_raw_column!(tables.migrations(), left_slice, left_slice_raw);
compare_column_to_raw_column!(tables.migrations(), right_slice, right_slice_raw);
compare_column_to_raw_column!(tables.migrations(), node_slice, node_slice_raw);
compare_column_to_raw_column!(tables.migrations(), time_slice, time_slice_raw);
compare_column_to_raw_column!(tables.migrations(), source_slice, source_slice_raw);
compare_column_to_raw_column!(tables.migrations(), dest_slice, dest_slice_raw);

compare_column_to_row!(tables.migrations(), left_slice, left);
compare_column_to_row!(tables.migrations(), right_slice, right);
compare_column_to_row!(tables.migrations(), node_slice, node);
compare_column_to_row!(tables.migrations(), time_slice, time);
compare_column_to_row!(tables.migrations(), source_slice, source);
compare_column_to_row!(tables.migrations(), dest_slice, dest);
}
add_row_without_metadata!(
migrations,
add_migration,
Expand Down