Skip to content

Commit

Permalink
Add Batched Select for devices and tensor_ops (#182)
Browse files Browse the repository at this point in the history
* device select now uses traits

* Rework devices/select. Adding Broadcasted select

* Making DeviceSelect take Input & Output and index is associated

* Temp commit of select function

* Improving modes of select

* Removing R from DeviceSelect

* tensor_ops select working

* Going back to Select1 trait for tensor_ops impls

* Adding SelectBatchAx0

* Adding test for select_batch

* Fixing select test

* Fixing doctest for select

* Filling out docstrings for select

* Adding comments
  • Loading branch information
coreylowman authored Sep 18, 2022
1 parent fad9123 commit 133e2d6
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 114 deletions.
201 changes: 122 additions & 79 deletions src/devices/select.rs
Original file line number Diff line number Diff line change
@@ -1,107 +1,140 @@
//! Implementations of selecting either 1 or Z elements from an axis of an nd array.
//!
//! # Implementation Details
//! There are three cases to handle:
//! There are four cases to handle:
//!
//! ## Selecting 1 element from the 0th axis
//! ## Selecting 1 element from the 0th axis [select_modes::Index]
//!
//! Just index into input using the single index and assign to output.
//!
//! ## Selecting Z elements from the 0th axis
//! ## Selecting Z elements from the 0th axis [select_modes::Index]
//!
//! Just index into input for each index and assing to `output[z]`
//!
//! ## Selecting either 1 or Z elements from a non-zero axis
//! ## Selecting either 1 or Z elements from a non-zero axis [select_modes::Recurse]
//!
//! Then all three arrays with have the same dimension as the 0th axis.
//! Do a for loop over the 0th axis and recurse!
//!
//! ## Broadcasted select [select_modes::Broadcast]
//!
//! In this case only the indices & output are indexed. The input is broadcasted by
//! not indexing into it.

use super::{Cpu, ForEachElement};
use crate::arrays::CountElements;

/// Select values from `T` using `Indices` and producing `R` along a single `AXIS`.
pub trait SelectAlongAxis<T: CountElements, Indices, R: CountElements, const AXIS: isize> {
/// Used to disambiguate trait implementations. Callees
/// must specify what kind of selection is occurring.
pub(crate) mod select_modes {
use std::marker::PhantomData;

/// Select the current axis.
pub struct Index;

/// Recurse the current axis.
pub struct Recurse<M>(PhantomData<*const M>);

/// Broadcast the current axis of input and recurse the indices.
pub struct Broadcast<M>(PhantomData<*const M>);
}

use select_modes::{Broadcast, Index, Recurse};

pub(crate) type SelectAx0 = select_modes::Index;
pub(crate) type SelectAx1 = select_modes::Recurse<SelectAx0>;
pub(crate) type SelectAx2 = select_modes::Recurse<SelectAx1>;
pub(crate) type SelectAx3 = select_modes::Recurse<SelectAx2>;
pub(crate) type BSelectAx1 = select_modes::Broadcast<SelectAx0>;

/// Select values from `T` using indices `I`. `Mode` is used to disambiguate the impl.
pub trait DeviceSelect<T, I, Mode> {
type Result;

/// Equivalent to psuedocode `out = inp[indices]`
fn select_axis(inp: &T, indices: &Indices, out: &mut R);
fn select_axis(inp: &T, indices: &I, out: &mut Self::Result);

/// `inp[indices] += out`
fn select_add(inp: &mut T, indices: &Indices, out: &R);
fn select_add(inp: &mut T, indices: &I, out: &Self::Result);
}

macro_rules! select_01 {
($Axis:expr, $SrcTy:tt, $DstTy:tt, {$($Dims:tt),*}) => {
impl<$(const $Dims: usize),*> SelectAlongAxis<$SrcTy, usize, $DstTy, $Axis> for Cpu {
fn select_axis(inp: &$SrcTy, indices: &usize, out: &mut $DstTy) {
// Select 1 element from 0th axis.
impl<T, const M: usize> DeviceSelect<[T; M], usize, Index> for Cpu
where
Self: ForEachElement<T>,
T: Copy + CountElements,
T::Dtype: for<'a> std::ops::AddAssign<&'a T::Dtype>,
{
type Result = T;

fn select_axis(inp: &[T; M], indices: &usize, out: &mut Self::Result) {
*out = inp[*indices];
}
fn select_add(inp: &mut $SrcTy, indices: &usize, out: &$DstTy) {

fn select_add(inp: &mut [T; M], indices: &usize, out: &Self::Result) {
Self::foreach_mr(&mut inp[*indices], out, &mut |a, b| *a += b);
}
}
};
}

macro_rules! select_0z {
($Axis:expr, $SrcTy:tt, $DstTy:tt, {$($Dims:tt),*}) => {
impl<$(const $Dims: usize),*> SelectAlongAxis<$SrcTy, [usize; Z], $DstTy, $Axis> for Cpu {
fn select_axis(inp: &$SrcTy, indices: &[usize; Z], out: &mut $DstTy) {
// Select Z elements from 0th axis.
impl<T, const M: usize, const Z: usize> DeviceSelect<[T; M], [usize; Z], Index> for Cpu
where
Self: ForEachElement<T>,
T: Copy + CountElements,
T::Dtype: for<'a> std::ops::AddAssign<&'a T::Dtype>,
{
type Result = [T; Z];

fn select_axis(inp: &[T; M], indices: &[usize; Z], out: &mut Self::Result) {
for z in 0..Z {
out[z] = inp[indices[z]];
}
}
fn select_add(inp: &mut $SrcTy, indices: &[usize; Z], out: &$DstTy) {
fn select_add(inp: &mut [T; M], indices: &[usize; Z], out: &Self::Result) {
for z in 0..Z {
Self::foreach_mr(&mut inp[indices[z]], &out[z], &mut |a, b| *a += b);
}
}
}
};
}

macro_rules! select_nz {
($Axis:expr, $SrcTy:tt, $IndTy:tt, $DstTy:tt, {$($Dims:tt),*}) => {
impl<$(const $Dims: usize),*> SelectAlongAxis<$SrcTy, $IndTy, $DstTy, $Axis> for Cpu {
fn select_axis(inp: &$SrcTy, indices: &$IndTy, out: &mut $DstTy) {
// Select elements from non-zero axis
impl<T, I, const M: usize, SubMode> DeviceSelect<[T; M], [I; M], Recurse<SubMode>> for Cpu
where
Self: DeviceSelect<T, I, SubMode>,
{
type Result = [<Self as DeviceSelect<T, I, SubMode>>::Result; M];

fn select_axis(inp: &[T; M], indices: &[I; M], out: &mut Self::Result) {
for m in 0..M {
Self::select_axis(&inp[m], &indices[m], &mut out[m]);
}
}
fn select_add(inp: &mut $SrcTy, indices: &$IndTy, out: &$DstTy) {

fn select_add(inp: &mut [T; M], indices: &[I; M], out: &Self::Result) {
for m in 0..M {
Self::select_add(&mut inp[m], &indices[m], &out[m]);
}
}
}
};
}

// 1d
select_01!(-1, [f32; M], f32, { M });
select_0z!(-1, [f32; M], [f32; Z], {M, Z});

// 2d
select_01!(0, [[f32; N]; M], [f32; N], {M, N});
select_0z!(0, [[f32; N]; M], [[f32; N]; Z], {M, N, Z});
select_nz!(-1, [[f32; N]; M], [usize; M], [f32; M], {M, N});
select_nz!(-1, [[f32; N]; M], [[usize; Z]; M], [[f32; Z]; M], {M, N, Z});

// 3d
select_01!(0, [[[f32; O]; N]; M], [[f32; O]; N], {M, N, O});
select_0z!(0, [[[f32; O]; N]; M], [[[f32; O]; N]; Z], {M, N, O, Z});
select_nz!(1, [[[f32; O]; N]; M], [usize; M], [[f32; O]; M], {M, N, O});
select_nz!(1, [[[f32; O]; N]; M], [[usize; Z]; M], [[[f32; O]; Z]; M], {M, N, O, Z});
select_nz!(-1, [[[f32; O]; N]; M], [[usize; N]; M], [[f32; N]; M], {M, N, O});
select_nz!(-1, [[[f32; O]; N]; M], [[[usize; Z]; N]; M], [[[f32; Z]; N]; M], {M, N, O, Z});

// 4d
select_01!(0, [[[[f32; P]; O]; N]; M], [[[f32; P]; O]; N], {M, N, O, P});
select_0z!(0, [[[[f32; P]; O]; N]; M], [[[[f32; P]; O]; N]; Z], {M, N, O, P, Z});
select_nz!(1, [[[[f32; P]; O]; N]; M], [usize; M], [[[f32; P]; O]; M], {M, N, O, P});
select_nz!(1, [[[[f32; P]; O]; N]; M], [[usize; Z]; M], [[[[f32; P]; O]; Z]; M], {M, N, O, P, Z});
select_nz!(2, [[[[f32; P]; O]; N]; M], [[usize; N]; M], [[[f32; P]; N]; M], {M, N, O, P});
select_nz!(2, [[[[f32; P]; O]; N]; M], [[[usize; Z]; N]; M], [[[[f32; P]; Z]; N]; M], {M, N, O, P, Z});
select_nz!(-1, [[[[f32; P]; O]; N]; M], [[[usize; O]; N]; M], [[[f32; O]; N]; M], {M, N, O, P});
select_nz!(-1, [[[[f32; P]; O]; N]; M], [[[[usize; Z]; O]; N]; M], [[[[f32; Z]; O]; N]; M], {M, N, O, P, Z});
// Broadcast select elements from non-zero axis.
impl<T, I, const M: usize, SubMode> DeviceSelect<T, [I; M], Broadcast<SubMode>> for Cpu
where
Self: DeviceSelect<T, I, SubMode>,
{
type Result = [<Self as DeviceSelect<T, I, SubMode>>::Result; M];

fn select_axis(inp: &T, indices: &[I; M], out: &mut Self::Result) {
for m in 0..M {
Self::select_axis(inp, &indices[m], &mut out[m]);
}
}
fn select_add(inp: &mut T, indices: &[I; M], out: &Self::Result) {
for m in 0..M {
Self::select_add(inp, &indices[m], &out[m]);
}
}
}

#[cfg(test)]
mod tests {
Expand All @@ -110,17 +143,17 @@ mod tests {

#[test]
fn test_select_1d_0() {
let a = [1.0, 2.0, 3.0];
let mut b = ZeroElements::ZEROS;
let a: [f32; 3] = [1.0, 2.0, 3.0];
let mut b: f32 = ZeroElements::ZEROS;
Cpu::select_axis(&a, &1, &mut b);
assert_eq!(b, 2.0);
}

#[test]
fn test_select_1d_0z() {
let a = [1.0, 2.0, 3.0];
let mut b = ZeroElements::ZEROS;
Cpu::select_axis(&a, &[0, 1, 2, 2, 1, 0], &mut b);
let a: [f32; 3] = [1.0f32, 2.0, 3.0];
let mut b: [f32; 6] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Index>>::select_axis(&a, &[0, 1, 2, 2, 1, 0], &mut b);
assert_eq!(b, [1.0, 2.0, 3.0, 3.0, 2.0, 1.0]);
}

Expand All @@ -129,41 +162,51 @@ mod tests {
#[test]
fn test_select_2d_0() {
let a = A_2D;
let mut b = ZeroElements::ZEROS;
let mut b: [f32; 3] = ZeroElements::ZEROS;
Cpu::select_axis(&a, &0, &mut b);
assert_eq!(b, [1.0, 2.0, 3.0]);
}

#[test]
fn test_select_2d_0z() {
let a = A_2D;
let mut b = ZeroElements::ZEROS;
Cpu::select_axis(&a, &[0, 0, 1], &mut b);
let mut b: [[f32; 3]; 3] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Index>>::select_axis(&a, &[0, 0, 1], &mut b);
assert_eq!(b, [a[0], a[0], a[1]]);
}

#[test]
fn test_select_2d_1() {
let a = A_2D;
let mut b = ZeroElements::ZEROS;
<Cpu as SelectAlongAxis<_, _, _, -1>>::select_axis(&a, &[0, 1], &mut b);
let mut b: [f32; 2] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Recurse<Index>>>::select_axis(&a, &[0, 1], &mut b);
assert_eq!(b, [1.0, 5.0]);
}

#[test]
fn test_select_2d_1z() {
let a = A_2D;
let mut b = ZeroElements::ZEROS;
Cpu::select_axis(&a, &[[0, 2], [1, 1]], &mut b);
let mut b: [[f32; 2]; 2] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Recurse<Index>>>::select_axis(&a, &[[0, 2], [1, 1]], &mut b);
assert_eq!(b, [[1.0, 3.0], [5.0, 5.0]]);
}

#[test]
fn test_select_broadcast_2d() {
let a = [[1.0], [2.0]];
let i: [[usize; 3]; 4] = [[0, 1, 0], [1, 1, 1], [0, 0, 0], [1, 0, 1]];
let mut b: [[[f32; 1]; 3]; 4] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Broadcast<Index>>>::select_axis(&a, &i, &mut b);
#[rustfmt::skip]
assert_eq!(b, [[[1.], [2.], [1.]], [[2.], [2.], [2.]], [[1.], [1.], [1.]], [[2.], [1.], [2.]]]);
}

#[test]
fn test_select_add_2d() {
let mut a = [[0.0; 3]; 2];
let b = [[1.0, 3.0], [5.0, 5.0]];
let i = [[0, 2], [1, 1]];
Cpu::select_add(&mut a, &i, &b);
<Cpu as DeviceSelect<_, _, Recurse<Index>>>::select_add(&mut a, &i, &b);
assert_eq!(a, [[1.0, 0.0, 3.0], [0.0, 10.0, 0.0]]);
}

Expand All @@ -177,40 +220,40 @@ mod tests {
#[test]
fn test_select_3d_0() {
let a = A_3D;
let mut b = ZeroElements::ZEROS;
let mut b: [[f32; 3]; 2] = ZeroElements::ZEROS;
Cpu::select_axis(&a, &0, &mut b);
assert_eq!(b, A_3D[0]);
}

#[test]
fn test_select_3d_0z() {
let a = A_3D;
let mut b = ZeroElements::ZEROS;
Cpu::select_axis(&a, &[0, 0, 1, 2, 3, 3], &mut b);
let mut b: [[[f32; 3]; 2]; 6] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Index>>::select_axis(&a, &[0, 0, 1, 2, 3, 3], &mut b);
assert_eq!(b, [A_3D[0], A_3D[0], A_3D[1], A_3D[2], A_3D[3], A_3D[3]]);
}

#[test]
fn test_select_3d_1() {
let a = A_3D;
let mut b = ZeroElements::ZEROS;
<Cpu as SelectAlongAxis<_, _, _, 1>>::select_axis(&a, &[0, 0, 1, 1], &mut b);
let mut b: [[f32; 3]; 4] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Recurse<Index>>>::select_axis(&a, &[0, 0, 1, 1], &mut b);
assert_eq!(b, [A_3D[0][0], A_3D[1][0], A_3D[2][1], A_3D[3][1]]);
}

#[test]
fn test_select_3d_1z() {
let a = A_3D;
let mut b = ZeroElements::ZEROS;
<Cpu as SelectAlongAxis<_, _, _, 1>>::select_axis(&a, &[[0], [0], [1], [1]], &mut b);
let mut b: [[[f32; 3]; 1]; 4] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Recurse<Index>>>::select_axis(&a, &[[0], [0], [1], [1]], &mut b);
assert_eq!(b, [[A_3D[0][0]], [A_3D[1][0]], [A_3D[2][1]], [A_3D[3][1]]]);
}

#[test]
fn test_select_3d_2() {
let a = A_3D;
let mut b = ZeroElements::ZEROS;
<Cpu as SelectAlongAxis<_, _, _, -1>>::select_axis(
let mut b: [[f32; 2]; 4] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Recurse<Recurse<Index>>>>::select_axis(
&a,
&[[1, 0], [0, 1], [0, 0], [1, 1]],
&mut b,
Expand All @@ -230,7 +273,7 @@ mod tests {
fn test_select_3d_2z() {
let a = A_3D;
let mut b: [[[f32; 1]; 2]; 4] = ZeroElements::ZEROS;
<Cpu as SelectAlongAxis<_, _, _, -1>>::select_axis(
<Cpu as DeviceSelect<_, _, Recurse<Recurse<Index>>>>::select_axis(
&a,
&[[[1], [0]], [[0], [1]], [[0], [0]], [[1], [1]]],
&mut b,
Expand Down
Loading

0 comments on commit 133e2d6

Please sign in to comment.