From 66ae1089bf14f77cc729bf5669cc0fbb079c2d11 Mon Sep 17 00:00:00 2001 From: Noah Stiltner Date: Sat, 8 Jun 2024 08:54:38 -0500 Subject: [PATCH] turned add64! into add_counter! that accepts a bool; now it passes tests --- chacha20/src/backends/neon.rs | 44 +++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/chacha20/src/backends/neon.rs b/chacha20/src/backends/neon.rs index e411a72..a4c9f62 100644 --- a/chacha20/src/backends/neon.rs +++ b/chacha20/src/backends/neon.rs @@ -99,12 +99,17 @@ impl ParBlocksSizeUser for Backend { type ParBlocksSize = U4; } -macro_rules! add64 { - ($a:expr, $b:expr) => { - vreinterpretq_u32_u64(vaddq_u64( - vreinterpretq_u64_u32($a), - vreinterpretq_u64_u32($b), - )) +/// Adds a counter row with either 32-bit or 64-bit addition +macro_rules! add_counter { + ($a:expr, $b:expr, $is_32_bit:expr) => { + if $is_32_bit { + vaddq_u32($a, $b) + } else { + vreinterpretq_u32_u64(vaddq_u64( + vreinterpretq_u64_u32($a), + vreinterpretq_u64_u32($b), + )) + } }; } @@ -124,7 +129,11 @@ impl StreamBackend for Backend { self.gen_par_ks_blocks(&mut par); *block = par[0]; unsafe { - self.state[3] = add64!(state3, vld1q_u32([1, 0, 0, 0].as_ptr())); + self.state[3] = add_counter!( + state3, + vld1q_u32([1, 0, 0, 0].as_ptr()), + V::USES_U32_COUNTER + ); } } @@ -137,19 +146,19 @@ impl StreamBackend for Backend { self.state[0], self.state[1], self.state[2], - add64!(self.state[3], self.ctrs[0]), + add_counter!(self.state[3], self.ctrs[0], V::USES_U32_COUNTER), ], [ self.state[0], self.state[1], self.state[2], - add64!(self.state[3], self.ctrs[1]), + add_counter!(self.state[3], self.ctrs[1], V::USES_U32_COUNTER), ], [ self.state[0], self.state[1], self.state[2], - add64!(self.state[3], self.ctrs[2]), + add_counter!(self.state[3], self.ctrs[2], V::USES_U32_COUNTER), ], ]; @@ -178,7 +187,7 @@ impl StreamBackend for Backend { } add_assign_vec!( blocks[block][3], - add64!(self.state[3], self.ctrs[block - 1]) + add_counter!(self.state[3], self.ctrs[block - 1], V::USES_U32_COUNTER) ); // write @@ -190,7 +199,7 @@ impl StreamBackend for Backend { } } //self.state[3] = vaddq_u32(self.state[3], self.ctrs[3]); - self.state[3] = add64!(self.state[3], self.ctrs[3]); + self.state[3] = add_counter!(self.state[3], self.ctrs[3], V::USES_U32_COUNTER); } } } @@ -233,19 +242,19 @@ impl Backend { self.state[0], self.state[1], self.state[2], - add64!(self.state[3], self.ctrs[0]), + add_counter!(self.state[3], self.ctrs[0], V::USES_U32_COUNTER), ], [ self.state[0], self.state[1], self.state[2], - add64!(self.state[3], self.ctrs[1]), + add_counter!(self.state[3], self.ctrs[1], V::USES_U32_COUNTER), ], [ self.state[0], self.state[1], self.state[2], - add64!(self.state[3], self.ctrs[2]), + add_counter!(self.state[3], self.ctrs[2], V::USES_U32_COUNTER), ], ]; @@ -260,7 +269,8 @@ impl Backend { add_assign_vec!(blocks[block][state_row], self.state[state_row]); } if block > 0 { - blocks[block][3] = add64!(blocks[block][3], self.ctrs[block - 1]); + blocks[block][3] = + add_counter!(blocks[block][3], self.ctrs[block - 1], V::USES_U32_COUNTER); } // write blocks to buffer for state_row in 0..4 { @@ -271,7 +281,7 @@ impl Backend { } dest_ptr = dest_ptr.add(64); } - self.state[3] = add64!(self.state[3], self.ctrs[3]); + self.state[3] = add_counter!(self.state[3], self.ctrs[3], V::USES_U32_COUNTER); } }