From b2312fa6965159803b06bb8eb38209f5e7fdeb39 Mon Sep 17 00:00:00 2001 From: Artyom Pavlov Date: Tue, 27 Aug 2024 12:56:21 +0300 Subject: [PATCH] sha2: improve RISC-V Zknh backends (#617) Annoyingly, RISC-V is really inconvenient when we have to deal with misaligned loads/stores. LLVM by default generates very [inefficient code](https://rust.godbolt.org/z/3Yaj4fq5o) which loads every byte separately and combines them into a 32/64 bit integer. The `ld` instruction "may" support misaligned loads and for Linux user-space it's even [guaranteed](https://www.kernel.org/doc/html/v6.10/arch/riscv/uabi.html#misaligned-accesses), but it can be (and IIUC often in practice is) "extremely slow", so we should not rely on it while writing performant code. After asking around, it looks like this mess is here to stay, so we have no choice but to work around it. To do that this PR introduces two separate paths for loading block data: aligned and misaligned. The aligned path should be the most common one. In the misaligned path we have to rely on inline assembly since we have to load some bits outside of the block. Additionally, this PR makes inlining in the `riscv-zknh` backend less aggressive, which makes generated binary code 3-4 times smaller at the cost of one additional branch. Generated assembly for RV64: - SHA-256, unrolled: https://rust.godbolt.org/z/GxPM8PE3P (2278 bytes) - SHA-256, compact: https://rust.godbolt.org/z/4KWrcve9E (538 bytes) - SHA-512, unrolled: https://rust.godbolt.org/z/Th8ro8Tbo (2278 bytes) - SHA-512: compact: https://rust.godbolt.org/z/dqrv48ax3 (530 bytes) --- .github/workflows/sha2.yml | 12 +- sha2/src/lib.rs | 1 + sha2/src/sha256.rs | 2 + sha2/src/sha256/riscv_zknh.rs | 152 ++++++++++++----------- sha2/src/sha256/riscv_zknh_compact.rs | 26 ++-- sha2/src/sha256/riscv_zknh_utils.rs | 80 ++++++++++++ sha2/src/sha512.rs | 2 + sha2/src/sha512/riscv_zknh.rs | 172 ++++++++++++-------------- sha2/src/sha512/riscv_zknh_compact.rs | 26 ++-- sha2/src/sha512/riscv_zknh_utils.rs | 129 +++++++++++++++++++ 10 files changed, 405 insertions(+), 197 deletions(-) create mode 100644 sha2/src/sha256/riscv_zknh_utils.rs create mode 100644 sha2/src/sha512/riscv_zknh_utils.rs diff --git a/.github/workflows/sha2.yml b/.github/workflows/sha2.yml index 3efb039e..fdb9e9a0 100644 --- a/.github/workflows/sha2.yml +++ b/.github/workflows/sha2.yml @@ -169,13 +169,13 @@ jobs: - run: cargo install cross --git https://github.com/cross-rs/cross - run: cross test --package sha2 --all-features --target riscv64gc-unknown-linux-gnu env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh,+zbkb - run: cross test --package sha2 --all-features --target riscv64gc-unknown-linux-gnu env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh,+zbkb - run: cross test --package sha2 --all-features --target riscv64gc-unknown-linux-gnu env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh,+zbkb riscv32-zknh: runs-on: ubuntu-latest @@ -188,13 +188,13 @@ jobs: components: rust-src - run: cargo build --all-features --target riscv32gc-unknown-linux-gnu -Z build-std env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh,+zbkb - run: cargo build --all-features --target riscv32gc-unknown-linux-gnu -Z build-std env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh,+zbkb - run: cargo build --all-features --target riscv32gc-unknown-linux-gnu -Z build-std env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh,+zbkb minimal-versions: uses: RustCrypto/actions/.github/workflows/minimal-versions.yml@master diff --git a/sha2/src/lib.rs b/sha2/src/lib.rs index 1efc78b6..9d9c9f44 100644 --- a/sha2/src/lib.rs +++ b/sha2/src/lib.rs @@ -10,6 +10,7 @@ any(sha2_backend = "riscv-zknh", sha2_backend = "riscv-zknh-compact"), feature(riscv_ext_intrinsics) )] +#![allow(clippy::needless_range_loop)] #[cfg(all( any(sha2_backend = "riscv-zknh", sha2_backend = "riscv-zknh-compact"), diff --git a/sha2/src/sha256.rs b/sha2/src/sha256.rs index 4e084902..6d5896f1 100644 --- a/sha2/src/sha256.rs +++ b/sha2/src/sha256.rs @@ -11,12 +11,14 @@ cfg_if::cfg_if! { sha2_backend = "riscv-zknh" ))] { mod riscv_zknh; + mod riscv_zknh_utils; use riscv_zknh::compress; } else if #[cfg(all( any(target_arch = "riscv32", target_arch = "riscv64"), sha2_backend = "riscv-zknh-compact" ))] { mod riscv_zknh_compact; + mod riscv_zknh_utils; use riscv_zknh_compact::compress; } else if #[cfg(target_arch = "aarch64")] { mod soft; diff --git a/sha2/src/sha256/riscv_zknh.rs b/sha2/src/sha256/riscv_zknh.rs index fe950bdc..7477c640 100644 --- a/sha2/src/sha256/riscv_zknh.rs +++ b/sha2/src/sha256/riscv_zknh.rs @@ -5,8 +5,11 @@ use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] use core::arch::riscv64::*; -#[cfg(not(target_feature = "zknh"))] -compile_error!("riscv-zknh backend requires enabled zknh target feature"); +#[cfg(not(all( + target_feature = "zknh", + any(target_feature = "zbb", target_feature = "zbkb") +)))] +compile_error!("riscv-zknh backend requires zknh and zbkb (or zbb) target features"); #[inline(always)] fn ch(x: u32, y: u32, z: u32) -> u32 { @@ -18,8 +21,34 @@ fn maj(x: u32, y: u32, z: u32) -> u32 { (x & y) ^ (x & z) ^ (y & z) } -#[allow(clippy::identity_op)] -fn round(state: &mut [u32; 8], block: &[u32; 16]) { +/// This function returns `k[R]`, but prevents compiler from inlining the indexed value +pub(super) fn opaque_load(k: &[u32]) -> u32 { + assert!(R < k.len()); + let dst; + #[cfg(target_arch = "riscv64")] + unsafe { + core::arch::asm!( + "lwu {dst}, 4*{R}({k})", + R = const R, + k = in(reg) k.as_ptr(), + dst = out(reg) dst, + options(pure, readonly, nostack, preserves_flags), + ); + } + #[cfg(target_arch = "riscv32")] + unsafe { + core::arch::asm!( + "lw {dst}, 4*{R}({k})", + R = const R, + k = in(reg) k.as_ptr(), + dst = out(reg) dst, + options(pure, readonly, nostack, preserves_flags), + ); + } + dst +} + +fn round(state: &mut [u32; 8], block: &[u32; 16], k: &[u32]) { let n = K32.len() - R; #[allow(clippy::identity_op)] let a = (n + 0) % 8; @@ -34,92 +63,65 @@ fn round(state: &mut [u32; 8], block: &[u32; 16]) { state[h] = state[h] .wrapping_add(unsafe { sha256sum1(state[e]) }) .wrapping_add(ch(state[e], state[f], state[g])) - // Force reading of constants from the static to prevent bad codegen - .wrapping_add(unsafe { core::ptr::read_volatile(&K32[R]) }) - .wrapping_add(block[R % 16]); + .wrapping_add(opaque_load::(k)) + .wrapping_add(block[R]); state[d] = state[d].wrapping_add(state[h]); state[h] = state[h] .wrapping_add(unsafe { sha256sum0(state[a]) }) .wrapping_add(maj(state[a], state[b], state[c])) } -fn round_schedule(state: &mut [u32; 8], block: &mut [u32; 16]) { - round::(state, block); +fn round_schedule(state: &mut [u32; 8], block: &mut [u32; 16], k: &[u32]) { + round::(state, block, k); - block[R % 16] = block[R % 16] + block[R] = block[R] .wrapping_add(unsafe { sha256sig1(block[(R + 14) % 16]) }) .wrapping_add(block[(R + 9) % 16]) .wrapping_add(unsafe { sha256sig0(block[(R + 1) % 16]) }); } +#[inline(always)] fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { let s = &mut state.clone(); let b = &mut block; - round_schedule::<0>(s, b); - round_schedule::<1>(s, b); - round_schedule::<2>(s, b); - round_schedule::<3>(s, b); - round_schedule::<4>(s, b); - round_schedule::<5>(s, b); - round_schedule::<6>(s, b); - round_schedule::<7>(s, b); - round_schedule::<8>(s, b); - round_schedule::<9>(s, b); - round_schedule::<10>(s, b); - round_schedule::<11>(s, b); - round_schedule::<12>(s, b); - round_schedule::<13>(s, b); - round_schedule::<14>(s, b); - round_schedule::<15>(s, b); - round_schedule::<16>(s, b); - round_schedule::<17>(s, b); - round_schedule::<18>(s, b); - round_schedule::<19>(s, b); - round_schedule::<20>(s, b); - round_schedule::<21>(s, b); - round_schedule::<22>(s, b); - round_schedule::<23>(s, b); - round_schedule::<24>(s, b); - round_schedule::<25>(s, b); - round_schedule::<26>(s, b); - round_schedule::<27>(s, b); - round_schedule::<28>(s, b); - round_schedule::<29>(s, b); - round_schedule::<30>(s, b); - round_schedule::<31>(s, b); - round_schedule::<32>(s, b); - round_schedule::<33>(s, b); - round_schedule::<34>(s, b); - round_schedule::<35>(s, b); - round_schedule::<36>(s, b); - round_schedule::<37>(s, b); - round_schedule::<38>(s, b); - round_schedule::<39>(s, b); - round_schedule::<40>(s, b); - round_schedule::<41>(s, b); - round_schedule::<42>(s, b); - round_schedule::<43>(s, b); - round_schedule::<44>(s, b); - round_schedule::<45>(s, b); - round_schedule::<46>(s, b); - round_schedule::<47>(s, b); - round::<48>(s, b); - round::<49>(s, b); - round::<50>(s, b); - round::<51>(s, b); - round::<52>(s, b); - round::<53>(s, b); - round::<54>(s, b); - round::<55>(s, b); - round::<56>(s, b); - round::<57>(s, b); - round::<58>(s, b); - round::<59>(s, b); - round::<60>(s, b); - round::<61>(s, b); - round::<62>(s, b); - round::<63>(s, b); + for i in 0..3 { + let k = &K32[16 * i..]; + round_schedule::<0>(s, b, k); + round_schedule::<1>(s, b, k); + round_schedule::<2>(s, b, k); + round_schedule::<3>(s, b, k); + round_schedule::<4>(s, b, k); + round_schedule::<5>(s, b, k); + round_schedule::<6>(s, b, k); + round_schedule::<7>(s, b, k); + round_schedule::<8>(s, b, k); + round_schedule::<9>(s, b, k); + round_schedule::<10>(s, b, k); + round_schedule::<11>(s, b, k); + round_schedule::<12>(s, b, k); + round_schedule::<13>(s, b, k); + round_schedule::<14>(s, b, k); + round_schedule::<15>(s, b, k); + } + + let k = &K32[48..]; + round::<0>(s, b, k); + round::<1>(s, b, k); + round::<2>(s, b, k); + round::<3>(s, b, k); + round::<4>(s, b, k); + round::<5>(s, b, k); + round::<6>(s, b, k); + round::<7>(s, b, k); + round::<8>(s, b, k); + round::<9>(s, b, k); + round::<10>(s, b, k); + round::<11>(s, b, k); + round::<12>(s, b, k); + round::<13>(s, b, k); + round::<14>(s, b, k); + round::<15>(s, b, k); for i in 0..8 { state[i] = state[i].wrapping_add(s[i]); @@ -127,7 +129,7 @@ fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { } pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { - for block in blocks.iter().map(super::to_u32s) { + for block in blocks.iter().map(super::riscv_zknh_utils::load_block) { compress_block(state, block); } } diff --git a/sha2/src/sha256/riscv_zknh_compact.rs b/sha2/src/sha256/riscv_zknh_compact.rs index 98375cce..bba510a3 100644 --- a/sha2/src/sha256/riscv_zknh_compact.rs +++ b/sha2/src/sha256/riscv_zknh_compact.rs @@ -5,8 +5,11 @@ use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] use core::arch::riscv64::*; -#[cfg(not(target_feature = "zknh"))] -compile_error!("riscv-zknh backend requires enabled zknh target feature"); +#[cfg(not(all( + target_feature = "zknh", + any(target_feature = "zbb", target_feature = "zbkb") +)))] +compile_error!("riscv-zknh-compact backend requires zknh and zbkb (or zbb) target features"); #[inline(always)] fn ch(x: u32, y: u32, z: u32) -> u32 { @@ -43,9 +46,7 @@ fn round(state: &mut [u32; 8], block: &[u32; 16], r: usize) { } #[inline(always)] -fn round_schedule(state: &mut [u32; 8], block: &mut [u32; 16], r: usize) { - round(state, block, r); - +fn schedule(block: &mut [u32; 16], r: usize) { block[r % 16] = block[r % 16] .wrapping_add(unsafe { sha256sig1(block[(r + 14) % 16]) }) .wrapping_add(block[(r + 9) % 16]) @@ -54,14 +55,13 @@ fn round_schedule(state: &mut [u32; 8], block: &mut [u32; 16], r: usize) { #[inline(always)] fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { - let s = &mut state.clone(); - let b = &mut block; + let mut s = *state; - for i in 0..48 { - round_schedule(s, b, i); - } - for i in 48..64 { - round(s, b, i); + for r in 0..64 { + round(&mut s, &block, r); + if r < 48 { + schedule(&mut block, r) + } } for i in 0..8 { @@ -70,7 +70,7 @@ fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { } pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { - for block in blocks.iter().map(super::to_u32s) { + for block in blocks.iter().map(super::riscv_zknh_utils::load_block) { compress_block(state, block); } } diff --git a/sha2/src/sha256/riscv_zknh_utils.rs b/sha2/src/sha256/riscv_zknh_utils.rs new file mode 100644 index 00000000..d75a0b1c --- /dev/null +++ b/sha2/src/sha256/riscv_zknh_utils.rs @@ -0,0 +1,80 @@ +use core::{arch::asm, ptr}; + +#[inline(always)] +pub(super) fn load_block(block: &[u8; 64]) -> [u32; 16] { + if block.as_ptr().cast::().is_aligned() { + load_aligned_block(block) + } else { + load_unaligned_block(block) + } +} + +#[inline(always)] +fn load_aligned_block(block: &[u8; 64]) -> [u32; 16] { + let p: *const u32 = block.as_ptr().cast(); + debug_assert!(p.is_aligned()); + let mut res = [0u32; 16]; + for i in 0..16 { + let val = unsafe { ptr::read(p.add(i)) }; + res[i] = val.to_be(); + } + res +} + +#[inline(always)] +fn load_unaligned_block(block: &[u8; 64]) -> [u32; 16] { + let offset = (block.as_ptr() as usize) % align_of::(); + debug_assert_ne!(offset, 0); + let off1 = (8 * offset) % 32; + let off2 = (32 - off1) % 32; + let bp: *const u32 = block.as_ptr().wrapping_sub(offset).cast(); + + let mut left: u32; + let mut res = [0u32; 16]; + + /// Use LW instruction on RV32 and LWU on RV64 + #[cfg(target_arch = "riscv32")] + macro_rules! lw { + ($r:literal) => { + concat!("lw ", $r) + }; + } + #[cfg(target_arch = "riscv64")] + macro_rules! lw { + ($r:literal) => { + concat!("lwu ", $r) + }; + } + + unsafe { + asm!( + lw!("{left}, 0({bp})"), // left = unsafe { ptr::read(bp) }; + "srl {left}, {left}, {off1}", // left >>= off1; + bp = in(reg) bp, + off1 = in(reg) off1, + left = out(reg) left, + options(pure, nostack, readonly, preserves_flags), + ); + } + + for i in 0..15 { + let right = unsafe { ptr::read(bp.add(1 + i)) }; + res[i] = (left | (right << off2)).to_be(); + left = right >> off1; + } + + let right: u32; + unsafe { + asm!( + lw!("{right}, 16 * 4({bp})"), // right = ptr::read(bp.add(16)); + "sll {right}, {right}, {off2}", // right <<= off2; + bp = in(reg) bp, + off2 = in(reg) off2, + right = out(reg) right, + options(pure, nostack, readonly, preserves_flags), + ); + } + res[15] = (left | right).to_be(); + + res +} diff --git a/sha2/src/sha512.rs b/sha2/src/sha512.rs index 08e802fc..20679266 100644 --- a/sha2/src/sha512.rs +++ b/sha2/src/sha512.rs @@ -11,12 +11,14 @@ cfg_if::cfg_if! { sha2_backend = "riscv-zknh" ))] { mod riscv_zknh; + mod riscv_zknh_utils; use riscv_zknh::compress; } else if #[cfg(all( any(target_arch = "riscv32", target_arch = "riscv64"), sha2_backend = "riscv-zknh-compact" ))] { mod riscv_zknh_compact; + mod riscv_zknh_utils; use riscv_zknh_compact::compress; } else if #[cfg(target_arch = "aarch64")] { mod soft; diff --git a/sha2/src/sha512/riscv_zknh.rs b/sha2/src/sha512/riscv_zknh.rs index 31a327eb..7be35ee8 100644 --- a/sha2/src/sha512/riscv_zknh.rs +++ b/sha2/src/sha512/riscv_zknh.rs @@ -5,8 +5,11 @@ use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] use core::arch::riscv64::*; -#[cfg(not(target_feature = "zknh"))] -compile_error!("riscv-zknh backend requires enabled zknh target feature"); +#[cfg(not(all( + target_feature = "zknh", + any(target_feature = "zbb", target_feature = "zbkb") +)))] +compile_error!("riscv-zknh backend requires zknh and zbkb (or zbb) target features"); #[cfg(target_arch = "riscv32")] unsafe fn sha512sum0(x: u64) -> u64 { @@ -46,7 +49,40 @@ fn maj(x: u64, y: u64, z: u64) -> u64 { (x & y) ^ (x & z) ^ (y & z) } -fn round(state: &mut [u64; 8], block: &[u64; 16]) { +/// This function returns `k[R]`, but prevents compiler from inlining the indexed value +pub(super) fn opaque_load(k: &[u64]) -> u64 { + use core::arch::asm; + assert!(R < k.len()); + #[cfg(target_arch = "riscv64")] + unsafe { + let dst; + asm!( + "ld {dst}, {N}({k})", + N = const 8 * R, + k = in(reg) k.as_ptr(), + dst = out(reg) dst, + options(pure, readonly, nostack, preserves_flags), + ); + dst + } + #[cfg(target_arch = "riscv32")] + unsafe { + let [hi, lo]: [u32; 2]; + asm!( + "lw {lo}, {N1}({k})", + "lw {hi}, {N2}({k})", + N1 = const 8 * R, + N2 = const 8 * R + 4, + k = in(reg) k.as_ptr(), + lo = out(reg) lo, + hi = out(reg) hi, + options(pure, readonly, nostack, preserves_flags), + ); + ((hi as u64) << 32) | (lo as u64) + } +} + +fn round(state: &mut [u64; 8], block: &[u64; 16], k: &[u64]) { let n = K64.len() - R; #[allow(clippy::identity_op)] let a = (n + 0) % 8; @@ -61,19 +97,18 @@ fn round(state: &mut [u64; 8], block: &[u64; 16]) { state[h] = state[h] .wrapping_add(unsafe { sha512sum1(state[e]) }) .wrapping_add(ch(state[e], state[f], state[g])) - // Force reading of constants from the static to prevent bad codegen - .wrapping_add(unsafe { core::ptr::read_volatile(&K64[R]) }) - .wrapping_add(block[R % 16]); + .wrapping_add(opaque_load::(k)) + .wrapping_add(block[R]); state[d] = state[d].wrapping_add(state[h]); state[h] = state[h] .wrapping_add(unsafe { sha512sum0(state[a]) }) .wrapping_add(maj(state[a], state[b], state[c])) } -fn round_schedule(state: &mut [u64; 8], block: &mut [u64; 16]) { - round::(state, block); +fn round_schedule(state: &mut [u64; 8], block: &mut [u64; 16], k: &[u64]) { + round::(state, block, k); - block[R % 16] = block[R % 16] + block[R] = block[R] .wrapping_add(unsafe { sha512sig1(block[(R + 14) % 16]) }) .wrapping_add(block[(R + 9) % 16]) .wrapping_add(unsafe { sha512sig0(block[(R + 1) % 16]) }); @@ -83,86 +118,43 @@ fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { let s = &mut state.clone(); let b = &mut block; - round_schedule::<0>(s, b); - round_schedule::<1>(s, b); - round_schedule::<2>(s, b); - round_schedule::<3>(s, b); - round_schedule::<4>(s, b); - round_schedule::<5>(s, b); - round_schedule::<6>(s, b); - round_schedule::<7>(s, b); - round_schedule::<8>(s, b); - round_schedule::<9>(s, b); - round_schedule::<10>(s, b); - round_schedule::<11>(s, b); - round_schedule::<12>(s, b); - round_schedule::<13>(s, b); - round_schedule::<14>(s, b); - round_schedule::<15>(s, b); - round_schedule::<16>(s, b); - round_schedule::<17>(s, b); - round_schedule::<18>(s, b); - round_schedule::<19>(s, b); - round_schedule::<20>(s, b); - round_schedule::<21>(s, b); - round_schedule::<22>(s, b); - round_schedule::<23>(s, b); - round_schedule::<24>(s, b); - round_schedule::<25>(s, b); - round_schedule::<26>(s, b); - round_schedule::<27>(s, b); - round_schedule::<28>(s, b); - round_schedule::<29>(s, b); - round_schedule::<30>(s, b); - round_schedule::<31>(s, b); - round_schedule::<32>(s, b); - round_schedule::<33>(s, b); - round_schedule::<34>(s, b); - round_schedule::<35>(s, b); - round_schedule::<36>(s, b); - round_schedule::<37>(s, b); - round_schedule::<38>(s, b); - round_schedule::<39>(s, b); - round_schedule::<40>(s, b); - round_schedule::<41>(s, b); - round_schedule::<42>(s, b); - round_schedule::<43>(s, b); - round_schedule::<44>(s, b); - round_schedule::<45>(s, b); - round_schedule::<46>(s, b); - round_schedule::<47>(s, b); - round_schedule::<48>(s, b); - round_schedule::<49>(s, b); - round_schedule::<50>(s, b); - round_schedule::<51>(s, b); - round_schedule::<52>(s, b); - round_schedule::<53>(s, b); - round_schedule::<54>(s, b); - round_schedule::<55>(s, b); - round_schedule::<56>(s, b); - round_schedule::<57>(s, b); - round_schedule::<58>(s, b); - round_schedule::<59>(s, b); - round_schedule::<60>(s, b); - round_schedule::<61>(s, b); - round_schedule::<62>(s, b); - round_schedule::<63>(s, b); - round::<64>(s, b); - round::<65>(s, b); - round::<66>(s, b); - round::<67>(s, b); - round::<68>(s, b); - round::<69>(s, b); - round::<70>(s, b); - round::<71>(s, b); - round::<72>(s, b); - round::<73>(s, b); - round::<74>(s, b); - round::<75>(s, b); - round::<76>(s, b); - round::<77>(s, b); - round::<78>(s, b); - round::<79>(s, b); + for i in 0..4 { + let k = &K64[16 * i..]; + round_schedule::<0>(s, b, k); + round_schedule::<1>(s, b, k); + round_schedule::<2>(s, b, k); + round_schedule::<3>(s, b, k); + round_schedule::<4>(s, b, k); + round_schedule::<5>(s, b, k); + round_schedule::<6>(s, b, k); + round_schedule::<7>(s, b, k); + round_schedule::<8>(s, b, k); + round_schedule::<9>(s, b, k); + round_schedule::<10>(s, b, k); + round_schedule::<11>(s, b, k); + round_schedule::<12>(s, b, k); + round_schedule::<13>(s, b, k); + round_schedule::<14>(s, b, k); + round_schedule::<15>(s, b, k); + } + + let k = &K64[64..]; + round::<0>(s, b, k); + round::<1>(s, b, k); + round::<2>(s, b, k); + round::<3>(s, b, k); + round::<4>(s, b, k); + round::<5>(s, b, k); + round::<6>(s, b, k); + round::<7>(s, b, k); + round::<8>(s, b, k); + round::<9>(s, b, k); + round::<10>(s, b, k); + round::<11>(s, b, k); + round::<12>(s, b, k); + round::<13>(s, b, k); + round::<14>(s, b, k); + round::<15>(s, b, k); for i in 0..8 { state[i] = state[i].wrapping_add(s[i]); @@ -170,7 +162,7 @@ fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { } pub fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { - for block in blocks.iter().map(super::to_u64s) { + for block in blocks.iter().map(super::riscv_zknh_utils::load_block) { compress_block(state, block); } } diff --git a/sha2/src/sha512/riscv_zknh_compact.rs b/sha2/src/sha512/riscv_zknh_compact.rs index 92e984c5..729157c4 100644 --- a/sha2/src/sha512/riscv_zknh_compact.rs +++ b/sha2/src/sha512/riscv_zknh_compact.rs @@ -5,8 +5,11 @@ use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] use core::arch::riscv64::*; -#[cfg(not(target_feature = "zknh"))] -compile_error!("riscv-zknh backend requires enabled zknh target feature"); +#[cfg(not(all( + target_feature = "zknh", + any(target_feature = "zbb", target_feature = "zbkb") +)))] +compile_error!("riscv-zknh-compact backend requires zknh and zbkb (or zbb) target features"); #[cfg(target_arch = "riscv32")] unsafe fn sha512sum0(x: u64) -> u64 { @@ -71,9 +74,7 @@ fn round(state: &mut [u64; 8], block: &[u64; 16], r: usize) { } #[inline(always)] -fn round_schedule(state: &mut [u64; 8], block: &mut [u64; 16], r: usize) { - round(state, block, r); - +fn schedule(block: &mut [u64; 16], r: usize) { block[r % 16] = block[r % 16] .wrapping_add(unsafe { sha512sig1(block[(r + 14) % 16]) }) .wrapping_add(block[(r + 9) % 16]) @@ -82,14 +83,13 @@ fn round_schedule(state: &mut [u64; 8], block: &mut [u64; 16], r: usize) { #[inline(always)] fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { - let s = &mut state.clone(); - let b = &mut block; + let mut s = *state; - for i in 0..64 { - round_schedule(s, b, i); - } - for i in 64..80 { - round(s, b, i); + for r in 0..80 { + round(&mut s, &block, r); + if r < 64 { + schedule(&mut block, r) + } } for i in 0..8 { @@ -98,7 +98,7 @@ fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { } pub fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { - for block in blocks.iter().map(super::to_u64s) { + for block in blocks.iter().map(super::riscv_zknh_utils::load_block) { compress_block(state, block); } } diff --git a/sha2/src/sha512/riscv_zknh_utils.rs b/sha2/src/sha512/riscv_zknh_utils.rs new file mode 100644 index 00000000..0b474606 --- /dev/null +++ b/sha2/src/sha512/riscv_zknh_utils.rs @@ -0,0 +1,129 @@ +use core::{arch::asm, ptr}; + +#[inline(always)] +pub(super) fn load_block(block: &[u8; 128]) -> [u64; 16] { + if block.as_ptr().cast::().is_aligned() { + load_aligned_block(block) + } else { + load_unaligned_block(block) + } +} + +#[cfg(target_arch = "riscv32")] +fn load_aligned_block(block: &[u8; 128]) -> [u64; 16] { + let p: *const [u32; 32] = block.as_ptr().cast(); + debug_assert!(p.is_aligned()); + let block = unsafe { &*p }; + let mut res = [0u64; 16]; + for i in 0..16 { + let a = block[2 * i].to_be() as u64; + let b = block[2 * i + 1].to_be() as u64; + res[i] = (a << 32) | b; + } + res +} + +#[cfg(target_arch = "riscv64")] +fn load_aligned_block(block: &[u8; 128]) -> [u64; 16] { + let block_ptr: *const u64 = block.as_ptr().cast(); + debug_assert!(block_ptr.is_aligned()); + let mut res = [0u64; 16]; + for i in 0..16 { + let val = unsafe { ptr::read(block_ptr.add(i)) }; + res[i] = val.to_be(); + } + res +} + +#[cfg(target_arch = "riscv32")] +fn load_unaligned_block(block: &[u8; 128]) -> [u64; 16] { + let offset = (block.as_ptr() as usize) % align_of::(); + debug_assert_ne!(offset, 0); + let off1 = (8 * offset) % 32; + let off2 = (32 - off1) % 32; + let bp: *const u32 = block.as_ptr().wrapping_sub(offset).cast(); + + let mut left: u32; + let mut block32 = [0u32; 32]; + + unsafe { + asm!( + "lw {left}, 0({bp})", // left = unsafe { ptr::read(bp) }; + "srl {left}, {left}, {off1}", // left >>= off1; + bp = in(reg) bp, + off1 = in(reg) off1, + left = out(reg) left, + options(pure, nostack, readonly, preserves_flags), + ); + } + + for i in 0..31 { + let right = unsafe { ptr::read(bp.add(1 + i)) }; + block32[i] = left | (right << off2); + left = right >> off1; + } + + let right: u32; + unsafe { + asm!( + "lw {right}, 32 * 4({bp})", // right = ptr::read(bp.add(32)); + "sll {right}, {right}, {off2}", // right <<= off2; + bp = in(reg) bp, + off2 = in(reg) off2, + right = out(reg) right, + options(pure, nostack, readonly, preserves_flags), + ); + } + block32[31] = left | right; + + let mut block64 = [0u64; 16]; + for i in 0..16 { + let a = block32[2 * i].to_be() as u64; + let b = block32[2 * i + 1].to_be() as u64; + block64[i] = (a << 32) | b; + } + block64 +} + +#[cfg(target_arch = "riscv64")] +fn load_unaligned_block(block: &[u8; 128]) -> [u64; 16] { + let offset = (block.as_ptr() as usize) % align_of::(); + debug_assert_ne!(offset, 0); + let off1 = (8 * offset) % 64; + let off2 = (64 - off1) % 64; + let bp: *const u64 = block.as_ptr().wrapping_sub(offset).cast(); + + let mut left: u64; + let mut res = [0u64; 16]; + + unsafe { + asm!( + "ld {left}, 0({bp})", // left = unsafe { ptr::read(bp) }; + "srl {left}, {left}, {off1}", // left >>= off1; + bp = in(reg) bp, + off1 = in(reg) off1, + left = out(reg) left, + options(pure, nostack, readonly, preserves_flags), + ); + } + for i in 0..15 { + let right = unsafe { ptr::read(bp.add(1 + i)) }; + res[i] = (left | (right << off2)).to_be(); + left = right >> off1; + } + + let right: u64; + unsafe { + asm!( + "ld {right}, 16 * 8({bp})", // right = ptr::read(bp.add(16)); + "sll {right}, {right}, {off2}", // right <<= off2; + bp = in(reg) bp, + off2 = in(reg) off2, + right = out(reg) right, + options(pure, nostack, readonly, preserves_flags), + ); + } + res[15] = (left | right).to_be(); + + res +}