Skip to content

Commit

Permalink
sha2: improve RISC-V Zknh backends (#617)
Browse files Browse the repository at this point in the history
Annoyingly, RISC-V is really inconvenient when we have to deal with
misaligned loads/stores. LLVM by default generates very [inefficient
code](https://rust.godbolt.org/z/3Yaj4fq5o) which loads every byte
separately and combines them into a 32/64 bit integer. The `ld`
instruction "may" support misaligned loads and for Linux user-space it's
even
[guaranteed](https://www.kernel.org/doc/html/v6.10/arch/riscv/uabi.html#misaligned-accesses),
but it can be (and IIUC often in practice is) "extremely slow", so we
should not rely on it while writing performant code.

After asking around, it looks like this mess is here to stay, so we have
no choice but to work around it. To do that this PR introduces two
separate paths for loading block data: aligned and misaligned. The
aligned path should be the most common one. In the misaligned path we
have to rely on inline assembly since we have to load some bits outside
of the block.

Additionally, this PR makes inlining in the `riscv-zknh` backend less
aggressive, which makes generated binary code 3-4 times smaller at the
cost of one additional branch.

Generated assembly for RV64:
- SHA-256, unrolled: https://rust.godbolt.org/z/GxPM8PE3P (2278 bytes)
- SHA-256, compact: https://rust.godbolt.org/z/4KWrcve9E (538 bytes)
- SHA-512, unrolled: https://rust.godbolt.org/z/Th8ro8Tbo (2278 bytes)
- SHA-512: compact: https://rust.godbolt.org/z/dqrv48ax3 (530 bytes)
  • Loading branch information
newpavlov authored Aug 27, 2024
1 parent b5c99d3 commit b2312fa
Show file tree
Hide file tree
Showing 10 changed files with 405 additions and 197 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/sha2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,13 @@ jobs:
- run: cargo install cross --git https://github.com/cross-rs/cross
- run: cross test --package sha2 --all-features --target riscv64gc-unknown-linux-gnu
env:
RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh
RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh,+zbkb
- run: cross test --package sha2 --all-features --target riscv64gc-unknown-linux-gnu
env:
RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh
RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh,+zbkb
- run: cross test --package sha2 --all-features --target riscv64gc-unknown-linux-gnu
env:
RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh
RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh,+zbkb

riscv32-zknh:
runs-on: ubuntu-latest
Expand All @@ -188,13 +188,13 @@ jobs:
components: rust-src
- run: cargo build --all-features --target riscv32gc-unknown-linux-gnu -Z build-std
env:
RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh
RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh,+zbkb
- run: cargo build --all-features --target riscv32gc-unknown-linux-gnu -Z build-std
env:
RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh
RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh,+zbkb
- run: cargo build --all-features --target riscv32gc-unknown-linux-gnu -Z build-std
env:
RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh
RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh,+zbkb

minimal-versions:
uses: RustCrypto/actions/.github/workflows/minimal-versions.yml@master
Expand Down
1 change: 1 addition & 0 deletions sha2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
any(sha2_backend = "riscv-zknh", sha2_backend = "riscv-zknh-compact"),
feature(riscv_ext_intrinsics)
)]
#![allow(clippy::needless_range_loop)]

#[cfg(all(
any(sha2_backend = "riscv-zknh", sha2_backend = "riscv-zknh-compact"),
Expand Down
2 changes: 2 additions & 0 deletions sha2/src/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ cfg_if::cfg_if! {
sha2_backend = "riscv-zknh"
))] {
mod riscv_zknh;
mod riscv_zknh_utils;
use riscv_zknh::compress;
} else if #[cfg(all(
any(target_arch = "riscv32", target_arch = "riscv64"),
sha2_backend = "riscv-zknh-compact"
))] {
mod riscv_zknh_compact;
mod riscv_zknh_utils;
use riscv_zknh_compact::compress;
} else if #[cfg(target_arch = "aarch64")] {
mod soft;
Expand Down
152 changes: 77 additions & 75 deletions sha2/src/sha256/riscv_zknh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ use core::arch::riscv32::*;
#[cfg(target_arch = "riscv64")]
use core::arch::riscv64::*;

#[cfg(not(target_feature = "zknh"))]
compile_error!("riscv-zknh backend requires enabled zknh target feature");
#[cfg(not(all(
target_feature = "zknh",
any(target_feature = "zbb", target_feature = "zbkb")
)))]
compile_error!("riscv-zknh backend requires zknh and zbkb (or zbb) target features");

#[inline(always)]
fn ch(x: u32, y: u32, z: u32) -> u32 {
Expand All @@ -18,8 +21,34 @@ fn maj(x: u32, y: u32, z: u32) -> u32 {
(x & y) ^ (x & z) ^ (y & z)
}

#[allow(clippy::identity_op)]
fn round<const R: usize>(state: &mut [u32; 8], block: &[u32; 16]) {
/// This function returns `k[R]`, but prevents compiler from inlining the indexed value
pub(super) fn opaque_load<const R: usize>(k: &[u32]) -> u32 {
assert!(R < k.len());
let dst;
#[cfg(target_arch = "riscv64")]
unsafe {
core::arch::asm!(
"lwu {dst}, 4*{R}({k})",
R = const R,
k = in(reg) k.as_ptr(),
dst = out(reg) dst,
options(pure, readonly, nostack, preserves_flags),
);
}
#[cfg(target_arch = "riscv32")]
unsafe {
core::arch::asm!(
"lw {dst}, 4*{R}({k})",
R = const R,
k = in(reg) k.as_ptr(),
dst = out(reg) dst,
options(pure, readonly, nostack, preserves_flags),
);
}
dst
}

fn round<const R: usize>(state: &mut [u32; 8], block: &[u32; 16], k: &[u32]) {
let n = K32.len() - R;
#[allow(clippy::identity_op)]
let a = (n + 0) % 8;
Expand All @@ -34,100 +63,73 @@ fn round<const R: usize>(state: &mut [u32; 8], block: &[u32; 16]) {
state[h] = state[h]
.wrapping_add(unsafe { sha256sum1(state[e]) })
.wrapping_add(ch(state[e], state[f], state[g]))
// Force reading of constants from the static to prevent bad codegen
.wrapping_add(unsafe { core::ptr::read_volatile(&K32[R]) })
.wrapping_add(block[R % 16]);
.wrapping_add(opaque_load::<R>(k))
.wrapping_add(block[R]);
state[d] = state[d].wrapping_add(state[h]);
state[h] = state[h]
.wrapping_add(unsafe { sha256sum0(state[a]) })
.wrapping_add(maj(state[a], state[b], state[c]))
}

fn round_schedule<const R: usize>(state: &mut [u32; 8], block: &mut [u32; 16]) {
round::<R>(state, block);
fn round_schedule<const R: usize>(state: &mut [u32; 8], block: &mut [u32; 16], k: &[u32]) {
round::<R>(state, block, k);

block[R % 16] = block[R % 16]
block[R] = block[R]
.wrapping_add(unsafe { sha256sig1(block[(R + 14) % 16]) })
.wrapping_add(block[(R + 9) % 16])
.wrapping_add(unsafe { sha256sig0(block[(R + 1) % 16]) });
}

#[inline(always)]
fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) {
let s = &mut state.clone();
let b = &mut block;

round_schedule::<0>(s, b);
round_schedule::<1>(s, b);
round_schedule::<2>(s, b);
round_schedule::<3>(s, b);
round_schedule::<4>(s, b);
round_schedule::<5>(s, b);
round_schedule::<6>(s, b);
round_schedule::<7>(s, b);
round_schedule::<8>(s, b);
round_schedule::<9>(s, b);
round_schedule::<10>(s, b);
round_schedule::<11>(s, b);
round_schedule::<12>(s, b);
round_schedule::<13>(s, b);
round_schedule::<14>(s, b);
round_schedule::<15>(s, b);
round_schedule::<16>(s, b);
round_schedule::<17>(s, b);
round_schedule::<18>(s, b);
round_schedule::<19>(s, b);
round_schedule::<20>(s, b);
round_schedule::<21>(s, b);
round_schedule::<22>(s, b);
round_schedule::<23>(s, b);
round_schedule::<24>(s, b);
round_schedule::<25>(s, b);
round_schedule::<26>(s, b);
round_schedule::<27>(s, b);
round_schedule::<28>(s, b);
round_schedule::<29>(s, b);
round_schedule::<30>(s, b);
round_schedule::<31>(s, b);
round_schedule::<32>(s, b);
round_schedule::<33>(s, b);
round_schedule::<34>(s, b);
round_schedule::<35>(s, b);
round_schedule::<36>(s, b);
round_schedule::<37>(s, b);
round_schedule::<38>(s, b);
round_schedule::<39>(s, b);
round_schedule::<40>(s, b);
round_schedule::<41>(s, b);
round_schedule::<42>(s, b);
round_schedule::<43>(s, b);
round_schedule::<44>(s, b);
round_schedule::<45>(s, b);
round_schedule::<46>(s, b);
round_schedule::<47>(s, b);
round::<48>(s, b);
round::<49>(s, b);
round::<50>(s, b);
round::<51>(s, b);
round::<52>(s, b);
round::<53>(s, b);
round::<54>(s, b);
round::<55>(s, b);
round::<56>(s, b);
round::<57>(s, b);
round::<58>(s, b);
round::<59>(s, b);
round::<60>(s, b);
round::<61>(s, b);
round::<62>(s, b);
round::<63>(s, b);
for i in 0..3 {
let k = &K32[16 * i..];
round_schedule::<0>(s, b, k);
round_schedule::<1>(s, b, k);
round_schedule::<2>(s, b, k);
round_schedule::<3>(s, b, k);
round_schedule::<4>(s, b, k);
round_schedule::<5>(s, b, k);
round_schedule::<6>(s, b, k);
round_schedule::<7>(s, b, k);
round_schedule::<8>(s, b, k);
round_schedule::<9>(s, b, k);
round_schedule::<10>(s, b, k);
round_schedule::<11>(s, b, k);
round_schedule::<12>(s, b, k);
round_schedule::<13>(s, b, k);
round_schedule::<14>(s, b, k);
round_schedule::<15>(s, b, k);
}

let k = &K32[48..];
round::<0>(s, b, k);
round::<1>(s, b, k);
round::<2>(s, b, k);
round::<3>(s, b, k);
round::<4>(s, b, k);
round::<5>(s, b, k);
round::<6>(s, b, k);
round::<7>(s, b, k);
round::<8>(s, b, k);
round::<9>(s, b, k);
round::<10>(s, b, k);
round::<11>(s, b, k);
round::<12>(s, b, k);
round::<13>(s, b, k);
round::<14>(s, b, k);
round::<15>(s, b, k);

for i in 0..8 {
state[i] = state[i].wrapping_add(s[i]);
}
}

pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) {
for block in blocks.iter().map(super::to_u32s) {
for block in blocks.iter().map(super::riscv_zknh_utils::load_block) {
compress_block(state, block);
}
}
26 changes: 13 additions & 13 deletions sha2/src/sha256/riscv_zknh_compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ use core::arch::riscv32::*;
#[cfg(target_arch = "riscv64")]
use core::arch::riscv64::*;

#[cfg(not(target_feature = "zknh"))]
compile_error!("riscv-zknh backend requires enabled zknh target feature");
#[cfg(not(all(
target_feature = "zknh",
any(target_feature = "zbb", target_feature = "zbkb")
)))]
compile_error!("riscv-zknh-compact backend requires zknh and zbkb (or zbb) target features");

#[inline(always)]
fn ch(x: u32, y: u32, z: u32) -> u32 {
Expand Down Expand Up @@ -43,9 +46,7 @@ fn round(state: &mut [u32; 8], block: &[u32; 16], r: usize) {
}

#[inline(always)]
fn round_schedule(state: &mut [u32; 8], block: &mut [u32; 16], r: usize) {
round(state, block, r);

fn schedule(block: &mut [u32; 16], r: usize) {
block[r % 16] = block[r % 16]
.wrapping_add(unsafe { sha256sig1(block[(r + 14) % 16]) })
.wrapping_add(block[(r + 9) % 16])
Expand All @@ -54,14 +55,13 @@ fn round_schedule(state: &mut [u32; 8], block: &mut [u32; 16], r: usize) {

#[inline(always)]
fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) {
let s = &mut state.clone();
let b = &mut block;
let mut s = *state;

for i in 0..48 {
round_schedule(s, b, i);
}
for i in 48..64 {
round(s, b, i);
for r in 0..64 {
round(&mut s, &block, r);
if r < 48 {
schedule(&mut block, r)
}
}

for i in 0..8 {
Expand All @@ -70,7 +70,7 @@ fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) {
}

pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) {
for block in blocks.iter().map(super::to_u32s) {
for block in blocks.iter().map(super::riscv_zknh_utils::load_block) {
compress_block(state, block);
}
}
80 changes: 80 additions & 0 deletions sha2/src/sha256/riscv_zknh_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use core::{arch::asm, ptr};

#[inline(always)]
pub(super) fn load_block(block: &[u8; 64]) -> [u32; 16] {
if block.as_ptr().cast::<u32>().is_aligned() {
load_aligned_block(block)
} else {
load_unaligned_block(block)
}
}

#[inline(always)]
fn load_aligned_block(block: &[u8; 64]) -> [u32; 16] {
let p: *const u32 = block.as_ptr().cast();
debug_assert!(p.is_aligned());
let mut res = [0u32; 16];
for i in 0..16 {
let val = unsafe { ptr::read(p.add(i)) };
res[i] = val.to_be();
}
res
}

#[inline(always)]
fn load_unaligned_block(block: &[u8; 64]) -> [u32; 16] {
let offset = (block.as_ptr() as usize) % align_of::<u32>();
debug_assert_ne!(offset, 0);
let off1 = (8 * offset) % 32;
let off2 = (32 - off1) % 32;
let bp: *const u32 = block.as_ptr().wrapping_sub(offset).cast();

let mut left: u32;
let mut res = [0u32; 16];

/// Use LW instruction on RV32 and LWU on RV64
#[cfg(target_arch = "riscv32")]
macro_rules! lw {
($r:literal) => {
concat!("lw ", $r)
};
}
#[cfg(target_arch = "riscv64")]
macro_rules! lw {
($r:literal) => {
concat!("lwu ", $r)
};
}

unsafe {
asm!(
lw!("{left}, 0({bp})"), // left = unsafe { ptr::read(bp) };
"srl {left}, {left}, {off1}", // left >>= off1;
bp = in(reg) bp,
off1 = in(reg) off1,
left = out(reg) left,
options(pure, nostack, readonly, preserves_flags),
);
}

for i in 0..15 {
let right = unsafe { ptr::read(bp.add(1 + i)) };
res[i] = (left | (right << off2)).to_be();
left = right >> off1;
}

let right: u32;
unsafe {
asm!(
lw!("{right}, 16 * 4({bp})"), // right = ptr::read(bp.add(16));
"sll {right}, {right}, {off2}", // right <<= off2;
bp = in(reg) bp,
off2 = in(reg) off2,
right = out(reg) right,
options(pure, nostack, readonly, preserves_flags),
);
}
res[15] = (left | right).to_be();

res
}
Loading

0 comments on commit b2312fa

Please sign in to comment.