Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change syntax of azip! to be similar to for loops #626

Merged
merged 1 commit into from
Sep 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benches/bench1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ fn add_2d_zip_alloc(bench: &mut test::Bencher) {
let b = Array::<i32, _>::zeros((ADD2DSZ, ADD2DSZ));
bench.iter(|| unsafe {
let mut c = Array::uninitialized(a.dim());
azip!(a, b, mut c in { *c = a + b });
azip!((&a in &a, &b in &b, c in &mut c) *c = a + b);
c
});
}
Expand Down
4 changes: 2 additions & 2 deletions benches/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fn chunk2x2_iter_sum(bench: &mut Bencher) {
let chunksz = (2, 2);
let mut sum = Array::zeros(a.exact_chunks(chunksz).raw_dim());
bench.iter(|| {
azip!(ref a (a.exact_chunks(chunksz)), mut sum in {
azip!((a in a.exact_chunks(chunksz), sum in &mut sum) {
*sum = a.iter().sum::<f32>();
});
});
Expand All @@ -24,7 +24,7 @@ fn chunk2x2_sum(bench: &mut Bencher) {
let chunksz = (2, 2);
let mut sum = Array::zeros(a.exact_chunks(chunksz).raw_dim());
bench.iter(|| {
azip!(ref a (a.exact_chunks(chunksz)), mut sum in {
azip!((a in a.exact_chunks(chunksz), sum in &mut sum) {
*sum = a.sum();
});
});
Expand Down
4 changes: 2 additions & 2 deletions benches/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ fn sum_3_azip(bench: &mut Bencher) {
let c = vec![1; ZIPSZ];
bench.iter(|| {
let mut s = 0;
azip!(a, b, c in {
azip!((&a in &a, &b in &b, &c in &c) {
s += a + b + c;
});
s
Expand All @@ -182,7 +182,7 @@ fn vector_sum_3_azip(bench: &mut Bencher) {
let b = vec![1.; ZIPSZ];
let mut c = vec![1.; ZIPSZ];
bench.iter(|| {
azip!(a, b, mut c in {
azip!((&a in &a, &b in &b, c in &mut c) {
*c += a + b;
});
});
Expand Down
2 changes: 1 addition & 1 deletion benches/par_rayon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ fn add(bench: &mut Bencher) {
let c = Array2::<f64>::zeros((ADDN, ADDN));
let d = Array2::<f64>::zeros((ADDN, ADDN));
bench.iter(|| {
azip!(mut a, b, c, d in {
azip!((a in &mut a, &b in &b, &c in &c, &d in &d) {
*a += b.exp() + c.exp() + d.exp();
});
});
Expand Down
8 changes: 4 additions & 4 deletions examples/zip_many.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@ fn main() {

{
let a = a.view_mut().reversed_axes();
azip!(mut a (a), b (b.t()) in { *a = b });
azip!((a in a, &b in b.t()) *a = b);
}
assert_eq!(a, b);

azip!(mut a, b, c in { *a = b + c; });
azip!((a in &mut a, &b in &b, &c in c) *a = b + c);
assert_eq!(a, &b + &c);

// sum of each row
let ax = Axis(0);
let mut sums = Array::zeros(a.len_of(ax));
azip!(mut sums, ref a (a.axis_iter(ax)) in { *sums = a.sum() });
azip!((s in &mut sums, a in a.axis_iter(ax)) *s = a.sum());

// sum of each chunk
let chunk_sz = (2, 2);
let nchunks = (n / chunk_sz.0, n / chunk_sz.1);
let mut sums = Array::zeros(nchunks);
azip!(mut sums, ref a (a.exact_chunks(chunk_sz)) in { *sums = a.sum() });
azip!((s in &mut sums, a in a.exact_chunks(chunk_sz)) *s = a.sum());

// Let's imagine we split to parallelize
{
Expand Down
6 changes: 3 additions & 3 deletions src/doc/ndarray_for_numpy_users/coord_transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
//! let bunge = Array2::<f64>::ones((3, nelems));
//!
//! let mut rmat = Array::zeros((3, 3, nelems).f());
//! azip!(mut rmat (rmat.axis_iter_mut(Axis(2))), ref bunge (bunge.axis_iter(Axis(1))) in {
//! azip!((mut rmat in rmat.axis_iter_mut(Axis(2)), bunge in bunge.axis_iter(Axis(1))) {
//! let s1 = bunge[0].sin();
//! let c1 = bunge[0].cos();
//! let s2 = bunge[1].sin();
Expand All @@ -129,8 +129,8 @@
//! let eye2d = Array2::<f64>::eye(3);
//!
//! let mut rotated = Array3::<f64>::zeros((3, 3, nelems).f());
//! azip!(mut rotated (rotated.axis_iter_mut(Axis(2)), rmat (rmat.axis_iter(Axis(2)))) in {
//! rotated.assign({ &rmat.dot(&eye2d) });
//! azip!((mut rotated in rotated.axis_iter_mut(Axis(2)), rmat in rmat.axis_iter(Axis(2))) {
//! rotated.assign(&rmat.dot(&eye2d));
//! });
//! }
//! ```
2 changes: 1 addition & 1 deletion src/numeric/impl_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ where
let mut sum_sq = Array::<A, _>::zeros(self.dim.remove_axis(axis));
for (i, subview) in self.axis_iter(axis).enumerate() {
let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
azip!(mut mean, mut sum_sq, x (subview) in {
azip!((mean in &mut mean, sum_sq in &mut sum_sq, &x in &subview) {
let delta = x - *mean;
*mean = *mean + delta / count;
*sum_sq = (x - *mean).mul_add(delta, *sum_sq);
Expand Down
124 changes: 45 additions & 79 deletions src/zip/zipmacro.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#[macro_export]
/// Array zip macro: lock step function application across several arrays and
/// producers.
///
Expand All @@ -7,41 +6,32 @@
/// This example:
///
/// ```rust,ignore
/// azip!(mut a, b, c in { *a = b + c })
/// azip!((a in &mut a, &b in &b, &c in &c) *a = b + c);
/// ```
///
/// Is equivalent to:
///
/// ```rust,ignore
/// Zip::from(&mut a).and(&b).and(&c).apply(|a, &b, &c| {
/// *a = b + c;
/// *a = b + c
/// });
///
/// ```
///
/// Explanation of the shorthand for captures:
///
/// + `mut a`: the producer is `&mut a` and the variable pattern is `mut a`.
/// + `b`: the producer is `&b` and the variable pattern is `&b` (same for `c`).
/// The syntax is either
///
/// The syntax is `azip!(` *[* `index` *pattern* `,`*] capture [*`,` *capture [*`,` *...] ]* `in {` *expression* `})`
/// where the captures are a sequence of pattern-like items that indicate which
/// arrays are used for the zip. The *expression* is evaluated elementwise,
/// with the value of an element from each producer in their respective variable.
/// `azip!((` *pat* `in` *expr* `,` *[* *pat* `in` *expr* `,` ... *]* `)` *body_expr* `)`
///
/// More capture rules:
/// or, to use `Zip::indexed` instead of `Zip::from`,
///
/// + `ref c`: the producer is `&c` and the variable pattern is `c`.
/// + `mut a (expr)`: the producer is `expr` and the variable pattern is `mut a`.
/// + `b (expr)`: the producer is `expr` and the variable pattern is `&b`.
/// + `ref c (expr)`: the producer is `expr` and the variable pattern is `c`.
/// `azip!((index` *pat* `,` *pat* `in` *expr* `,` *[* *pat* `in` *expr* `,` ... *]* `)` *body_expr* `)`
///
/// Special rule:
///
/// + `index i`: Use `Zip::indexed` instead. `i` is a pattern -- it can be
/// a single variable name or something else that pattern matches the index.
/// This rule must be the first if it is used, and it must be followed by
/// at least one other rule.
/// The *expr* are expressions whose types must implement `IntoNdProducer`, the
/// *pat* are the patterns of the parameters to the closure called by
/// `Zip::apply`, and *body_expr* is the body of the closure called by
/// `Zip::apply`. You can think of each *pat* `in` *expr* as being analogous to
/// the `pat in expr` of a normal loop `for pat in expr { statements }`: a
/// pattern, followed by `in`, followed by an expression that implements
/// `IntoNdProducer` (analogous to `IntoIterator` for a `for` loop).
///
/// **Panics** if any of the arrays are not of the same shape.
///
Expand All @@ -68,12 +58,12 @@
///
/// // Example 1: Compute a simple ternary operation:
/// // elementwise addition of b and c, stored in a
/// azip!(mut a, b, c in { *a = b + c });
/// azip!((a in &mut a, &b in &b, &c in &c) *a = b + c);
///
/// assert_eq!(a, &b + &c);
///
/// // Example 2: azip!() with index
/// azip!(index (i, j), b, c in {
/// azip!((index (i, j), &b in &b, &c in &c) {
/// a[[i, j]] = b - c;
/// });
///
Expand All @@ -87,80 +77,56 @@
/// assert_eq!(a, &b * &c);
///
///
/// // Since this function borrows its inputs, captures must use the x (x) pattern
/// // to avoid the macro's default rule that autorefs the producer.
/// // Since this function borrows its inputs, the `IntoNdProducer`
/// // expressions don't need to explicitly include `&mut` or `&`.
/// fn borrow_multiply(a: &mut M, b: &M, c: &M) {
/// azip!(mut a (a), b (b), c (c) in { *a = b * c });
/// azip!((a in a, &b in b, &c in c) *a = b * c);
/// }
///
///
/// // Example 4: using azip!() with a `ref` rule
/// // Example 4: using azip!() without dereference in pattern.
/// //
/// // Create a new array `totals` with one entry per row of `a`.
/// // Use azip to traverse the rows of `a` and assign to the corresponding
/// // entry in `totals` with the sum across each row.
/// //
/// // The row is an array view; use the 'ref' rule on the row, to avoid the
/// // default which is to dereference the produced item.
/// let mut totals = Array1::zeros(a.nrows());
///
/// azip!(mut totals, ref row (a.genrows()) in {
/// *totals = row.sum();
/// });
/// // The row is an array view; it doesn't need to be dereferenced.
/// let mut totals = Array1::zeros(a.rows());
/// azip!((totals in &mut totals, row in a.genrows()) *totals = row.sum());
///
/// // Check the result against the built in `.sum_axis()` along axis 1.
/// assert_eq!(totals, a.sum_axis(Axis(1)));
/// }
///
/// ```
#[macro_export]
macro_rules! azip {
// Build Zip Rule (index)
(@parse [index => $a:expr, $($aa:expr,)*] $t1:tt in $t2:tt) => {
$crate::azip!(@finish ($crate::Zip::indexed($a)) [$($aa,)*] $t1 in $t2)
};
// Build Zip Rule (no index)
(@parse [$a:expr, $($aa:expr,)*] $t1:tt in $t2:tt) => {
$crate::azip!(@finish ($crate::Zip::from($a)) [$($aa,)*] $t1 in $t2)
};
// Build Finish Rule (both)
(@finish ($z:expr) [$($aa:expr,)*] [$($p:pat,)+] in { $($t:tt)*}) => {
#[allow(unused_mut)]
($z)
$(
.and($aa)
)*
.apply(|$($p),+| {
$($t)*
})
};
// parsing stack: [expressions] [patterns] (one per operand)
// index uses empty [] -- must be first
(@parse [] [] index $i:pat, $($t:tt)*) => {
$crate::azip!(@parse [index =>] [$i,] $($t)*);
};
(@parse [$($exprs:tt)*] [$($pats:tt)*] mut $x:ident ($e:expr) $($t:tt)*) => {
$crate::azip!(@parse [$($exprs)* $e,] [$($pats)* mut $x,] $($t)*);
};
(@parse [$($exprs:tt)*] [$($pats:tt)*] mut $x:ident $($t:tt)*) => {
$crate::azip!(@parse [$($exprs)* &mut $x,] [$($pats)* mut $x,] $($t)*);
// Indexed with a single producer and no trailing comma.
((index $index:pat, $first_pat:pat in $first_prod:expr) $body:expr) => {
$crate::Zip::indexed($first_prod).apply(|$index, $first_pat| $body)
};
(@parse [$($exprs:tt)*] [$($pats:tt)*] , $($t:tt)*) => {
$crate::azip!(@parse [$($exprs)*] [$($pats)*] $($t)*);
// Indexed with more than one producer and no trailing comma.
((index $index:pat, $first_pat:pat in $first_prod:expr, $($pat:pat in $prod:expr),*) $body:expr) => {
$crate::Zip::indexed($first_prod)
$(.and($prod))*
.apply(|$index, $first_pat, $($pat),*| $body)
};
(@parse [$($exprs:tt)*] [$($pats:tt)*] ref $x:ident ($e:expr) $($t:tt)*) => {
$crate::azip!(@parse [$($exprs)* $e,] [$($pats)* $x,] $($t)*);
// Indexed with trailing comma.
((index $index:pat, $($pat:pat in $prod:expr),+,) $body:expr) => {
azip!((index $index, $($pat in $prod),+) $body)
};
(@parse [$($exprs:tt)*] [$($pats:tt)*] ref $x:ident $($t:tt)*) => {
$crate::azip!(@parse [$($exprs)* &$x,] [$($pats)* $x,] $($t)*);
// Unindexed with a single producer and no trailing comma.
(($first_pat:pat in $first_prod:expr) $body:expr) => {
$crate::Zip::from($first_prod).apply(|$first_pat| $body)
};
(@parse [$($exprs:tt)*] [$($pats:tt)*] $x:ident ($e:expr) $($t:tt)*) => {
$crate::azip!(@parse [$($exprs)* $e,] [$($pats)* &$x,] $($t)*);
// Unindexed with more than one producer and no trailing comma.
(($first_pat:pat in $first_prod:expr, $($pat:pat in $prod:expr),*) $body:expr) => {
$crate::Zip::from($first_prod)
$(.and($prod))*
.apply(|$first_pat, $($pat),*| $body)
};
(@parse [$($exprs:tt)*] [$($pats:tt)*] $x:ident $($t:tt)*) => {
$crate::azip!(@parse [$($exprs)* &$x,] [$($pats)* &$x,] $($t)*);
// Unindexed with trailing comma.
(($($pat:pat in $prod:expr),+,) $body:expr) => {
azip!(($($pat in $prod),+) $body)
};
(@parse [$($exprs:tt)*] [$($pats:tt)*] $($t:tt)*) => { };
($($t:tt)*) => {
$crate::azip!(@parse [] [] $($t)*);
}
}
18 changes: 9 additions & 9 deletions tests/azip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ use std::mem::swap;
fn test_azip1() {
let mut a = Array::zeros(62);
let mut x = 0;
azip!(mut a in { *a = x; x += 1; });
azip!((a in &mut a) { *a = x; x += 1; });
assert_equal(cloned(&a), 0..a.len());
}

#[test]
fn test_azip2() {
let mut a = Array::zeros((5, 7));
let b = Array::from_shape_fn(a.dim(), |(i, j)| 1. / (i + 2 * j) as f32);
azip!(mut a, b in { *a = b; });
azip!((a in &mut a, &b in &b) *a = b);
assert_eq!(a, b);
}

Expand All @@ -35,7 +35,7 @@ fn test_azip2_1() {
let mut a = Array::zeros((5, 7));
let b = Array::from_shape_fn((5, 10), |(i, j)| 1. / (i + 2 * j) as f32);
let b = b.slice(s![..;-1, 3..]);
azip!(mut a, b in { *a = b; });
azip!((a in &mut a, &b in &b) *a = b);
assert_eq!(a, b);
}

Expand All @@ -44,7 +44,7 @@ fn test_azip2_3() {
let mut b = Array::from_shape_fn((5, 10), |(i, j)| 1. / (i + 2 * j) as f32);
let mut c = Array::from_shape_fn((5, 10), |(i, j)| f32::exp((i + j) as f32));
let a = b.clone();
azip!(mut b, mut c in { swap(b, c) });
azip!((b in &mut b, c in &mut c) swap(b, c));
assert_eq!(a, c);
assert!(a != b);
}
Expand All @@ -58,7 +58,7 @@ fn test_azip2_sum() {
for i in 0..2 {
let ax = Axis(i);
let mut b = Array::zeros(c.len_of(ax));
azip!(mut b, ref c (c.axis_iter(ax)) in { *b = c.sum() });
azip!((b in &mut b, c in c.axis_iter(ax)) *b = c.sum());
assert_abs_diff_eq!(b, c.sum_axis(Axis(1 - i)), epsilon = 1e-6);
}
}
Expand All @@ -75,7 +75,7 @@ fn test_azip3_slices() {
*elt = i as f32;
}

azip!(mut a (&mut a[..]), b (&b[..]), mut c (&mut c[..]) in {
azip!((a in &mut a[..], b in &b[..], c in &mut c[..]) {
*a += b / 10.;
*c = a.sin();
});
Expand Down Expand Up @@ -115,7 +115,7 @@ fn test_zip_dim_mismatch_1() {
let mut d = a.raw_dim();
d[0] += 1;
let b = Array::from_shape_fn(d, |(i, j)| 1. / (i + 2 * j) as f32);
azip!(mut a, b in { *a = b; });
azip!((a in &mut a, &b in &b) *a = b);
}

// Test that Zip handles memory layout correctly for
Expand All @@ -136,7 +136,7 @@ fn test_contiguous_but_not_c_or_f() {
let correct_012 = a[[0, 1, 2]] + b[[0, 1, 2]];

let mut ans = Array::zeros(a.dim().f());
azip!(mut ans, a, b in { *ans = a + b });
azip!((ans in &mut ans, &a in &a, &b in &b) *ans = a + b);
println!("{:?}", a);
println!("{:?}", b);
println!("{:?}", ans);
Expand Down Expand Up @@ -200,7 +200,7 @@ fn test_indices_2() {
}

let mut count = 0;
azip!(index i, a1 in {
azip!((index i, &a1 in &a1) {
count += 1;
assert_eq!(a1, i);
});
Expand Down