Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions src/simd_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,74 @@ pub fn pack_column(col: [F; PackedF::WIDTH]) -> PackedF {
PackedF::from_fn(|i| col[i])
}

/// Pack contiguous FieldArrays directly into a destination slice at the given offset.
///
/// Packs `data[0..WIDTH]` into `dest[offset..offset+N]`.
/// This avoids creating an intermediate `[PackedF; N]` array.
///
/// # Arguments
/// * `dest` - Destination slice to pack into
/// * `offset` - Starting index in `dest`
/// * `data` - Source slice of FieldArrays (must have length >= WIDTH)
#[inline(always)]
pub fn pack_into<const N: usize>(dest: &mut [PackedF], offset: usize, data: &[FieldArray<N>]) {
for i in 0..N {
dest[offset + i] = PackedF::from_fn(|lane| data[lane][i]);
}
}

/// Pack even-indexed FieldArrays (stride 2) directly into destination.
///
/// Packs `data[0], data[2], data[4], ...` into `dest[offset..offset+N]`.
/// Useful for packing left children from interleaved `[L0, R0, L1, R1, ...]` pairs.
///
/// # Arguments
/// * `dest` - Destination slice to pack into
/// * `offset` - Starting index in `dest`
/// * `data` - Source slice of interleaved pairs (must have length >= 2 * WIDTH)
#[inline(always)]
pub fn pack_even_into<const N: usize>(dest: &mut [PackedF], offset: usize, data: &[FieldArray<N>]) {
for i in 0..N {
dest[offset + i] = PackedF::from_fn(|lane| data[2 * lane][i]);
}
}

/// Pack odd-indexed FieldArrays (stride 2) directly into destination.
///
/// Packs `data[1], data[3], data[5], ...` into `dest[offset..offset+N]`.
/// Useful for packing right children from interleaved `[L0, R0, L1, R1, ...]` pairs.
///
/// # Arguments
/// * `dest` - Destination slice to pack into
/// * `offset` - Starting index in `dest`
/// * `data` - Source slice of interleaved pairs (must have length >= 2 * WIDTH)
#[inline(always)]
pub fn pack_odd_into<const N: usize>(dest: &mut [PackedF], offset: usize, data: &[FieldArray<N>]) {
for i in 0..N {
dest[offset + i] = PackedF::from_fn(|lane| data[2 * lane + 1][i]);
}
}

/// Pack values generated by a function directly into destination.
///
/// For each element index `i` in `0..N`, generates a PackedF by calling
/// `f(i, lane)` for each SIMD lane.
///
/// # Arguments
/// * `dest` - Destination slice to pack into
/// * `offset` - Starting index in `dest`
/// * `f` - Function that takes (element_index, lane_index) and returns a field element
#[inline(always)]
pub fn pack_fn_into<const N: usize>(
dest: &mut [PackedF],
offset: usize,
f: impl Fn(usize, usize) -> F,
) {
for i in 0..N {
dest[offset + i] = PackedF::from_fn(|lane| f(i, lane));
}
}

#[cfg(test)]
mod tests {
use crate::F;
Expand Down Expand Up @@ -210,5 +278,91 @@ mod tests {
// Verify they match
prop_assert_eq!(output1, output2);
}

#[test]
fn proptest_pack_into_matches_pack_array(
_seed in any::<u64>()
) {
let mut rng = rand::rng();

// Generate random data
let data: [FieldArray<7>; PackedF::WIDTH] = array::from_fn(|_| {
FieldArray(array::from_fn(|_| rng.random()))
});

// Pack using pack_array
let expected = pack_array(&data);

// Pack using pack_into
let mut dest = [PackedF::ZERO; 10];
pack_into(&mut dest, 2, &data);

// Verify they match at the offset
for i in 0..7 {
prop_assert_eq!(dest[2 + i], expected[i]);
}
}

#[test]
fn proptest_pack_even_odd_into(
_seed in any::<u64>()
) {
let mut rng = rand::rng();

// Generate interleaved pairs: [L0, R0, L1, R1, ...]
let pairs: [FieldArray<5>; 2 * PackedF::WIDTH] = array::from_fn(|_| {
FieldArray(array::from_fn(|_| rng.random()))
});

// Pack even (left children) and odd (right children)
let mut dest = [PackedF::ZERO; 12];
pack_even_into(&mut dest, 1, &pairs);
pack_odd_into(&mut dest, 6, &pairs);

// Verify even indices were packed correctly
for i in 0..5 {
for lane in 0..PackedF::WIDTH {
prop_assert_eq!(
dest[1 + i].as_slice()[lane],
pairs[2 * lane][i],
"Even packing mismatch at element {}, lane {}", i, lane
);
}
}

// Verify odd indices were packed correctly
for i in 0..5 {
for lane in 0..PackedF::WIDTH {
prop_assert_eq!(
dest[6 + i].as_slice()[lane],
pairs[2 * lane + 1][i],
"Odd packing mismatch at element {}, lane {}", i, lane
);
}
}
}

#[test]
fn proptest_pack_fn_into(
_seed in any::<u64>()
) {
// Pack using a function that generates predictable values
let mut dest = [PackedF::ZERO; 8];
pack_fn_into::<4>(&mut dest, 3, |elem_idx, lane_idx| {
F::from_u64((elem_idx * 100 + lane_idx) as u64)
});

// Verify
for i in 0..4 {
for lane in 0..PackedF::WIDTH {
let expected = F::from_u64((i * 100 + lane) as u64);
prop_assert_eq!(
dest[3 + i].as_slice()[lane],
expected,
"pack_fn_into mismatch at element {}, lane {}", i, lane
);
}
}
}
}
}
124 changes: 63 additions & 61 deletions src/symmetric/tweak_hash/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::TWEAK_SEPARATOR_FOR_TREE_HASH;
use crate::array::FieldArray;
use crate::poseidon2_16;
use crate::poseidon2_24;
use crate::simd_utils::{pack_array, unpack_array};
use crate::simd_utils::{pack_array, pack_even_into, pack_fn_into, pack_odd_into, unpack_array};
use crate::symmetric::prf::Pseudorandom;
use crate::symmetric::tweak_hash::chain;
use crate::{F, PackedF};
Expand Down Expand Up @@ -390,6 +390,11 @@ impl<
// Permutation for merging two inputs (width-24)
let perm = poseidon2_24();

// Offsets for assembling packed_input: [parameter | tweak | left | right]
let tweak_offset = PARAMETER_LEN;
let left_offset = PARAMETER_LEN + TWEAK_LEN;
let right_offset = PARAMETER_LEN + TWEAK_LEN + HASH_LEN;

// Process SIMD batches with in-place mutation
parents
.par_chunks_exact_mut(WIDTH)
Expand All @@ -398,34 +403,23 @@ impl<
.for_each(|(chunk_idx, (parents_chunk, children_chunk))| {
let parent_pos = (parent_start + chunk_idx * WIDTH) as u32;

// DIRECT PACKING
//
// Pack left children: children[0], children[2], children[4], ...
let packed_left: [PackedF; HASH_LEN] =
array::from_fn(|h| PackedF::from_fn(|lane| children_chunk[2 * lane].0[h]));

// Pack right children: children[1], children[3], children[5], ...
let packed_right: [PackedF; HASH_LEN] =
array::from_fn(|h| PackedF::from_fn(|lane| children_chunk[2 * lane + 1].0[h]));

// Pack tweaks directly (no intermediate scalar arrays)
let packed_tweak: [PackedF; TWEAK_LEN] = array::from_fn(|t_idx| {
PackedF::from_fn(|lane| {
Self::tree_tweak(level, parent_pos + lane as u32)
.to_field_elements::<TWEAK_LEN>()[t_idx]
})
});

// Assemble packed input: [parameter | tweak | left | right]
// Assemble packed input directly: [parameter | tweak | left | right]
let mut packed_input = [PackedF::ZERO; MERGE_COMPRESSION_WIDTH];

// Copy pre-packed parameter
packed_input[..PARAMETER_LEN].copy_from_slice(&packed_parameter);
packed_input[PARAMETER_LEN..PARAMETER_LEN + TWEAK_LEN]
.copy_from_slice(&packed_tweak);
packed_input[PARAMETER_LEN + TWEAK_LEN..PARAMETER_LEN + TWEAK_LEN + HASH_LEN]
.copy_from_slice(&packed_left);
packed_input[PARAMETER_LEN + TWEAK_LEN + HASH_LEN
..PARAMETER_LEN + TWEAK_LEN + 2 * HASH_LEN]
.copy_from_slice(&packed_right);

// Pack tweaks directly into destination
pack_fn_into::<TWEAK_LEN>(&mut packed_input, tweak_offset, |t_idx, lane| {
Self::tree_tweak(level, parent_pos + lane as u32)
.to_field_elements::<TWEAK_LEN>()[t_idx]
});

// Pack left children (even indices) directly into destination
pack_even_into(&mut packed_input, left_offset, children_chunk);

// Pack right children (odd indices) directly into destination
pack_odd_into(&mut packed_input, right_offset, children_chunk);

// Compress all WIDTH parent pairs simultaneously
let packed_parents =
Expand Down Expand Up @@ -547,6 +541,10 @@ impl<
// Cache strategy: process one chain at a time to maximize locality.
// All epochs for that chain stay in registers across iterations.

// Offsets for chain compression: [parameter | tweak | current_value]
let chain_tweak_offset = PARAMETER_LEN;
let chain_value_offset = PARAMETER_LEN + TWEAK_LEN;

for (chain_index, packed_chain) in
packed_chains.iter_mut().enumerate().take(num_chains)
{
Expand All @@ -556,32 +554,25 @@ impl<
// Current position in the chain.
let pos = (step + 1) as u8;

// Generate tweaks for all epochs in this SIMD batch.
// Each lane gets a tweak specific to its epoch.
let packed_tweak = array::from_fn::<_, TWEAK_LEN, _>(|t_idx| {
PackedF::from_fn(|lane| {
Self::chain_tweak(epoch_chunk[lane], chain_index as u8, pos)
.to_field_elements::<TWEAK_LEN>()[t_idx]
})
});

// Assemble the packed input for the hash function.
// Layout: [parameter | tweak | current_value]
let mut packed_input = [PackedF::ZERO; CHAIN_COMPRESSION_WIDTH];
let mut current_pos = 0;

// Copy parameter into the input buffer.
packed_input[current_pos..current_pos + PARAMETER_LEN]
.copy_from_slice(&packed_parameter);
current_pos += PARAMETER_LEN;
// Copy pre-packed parameter
packed_input[..PARAMETER_LEN].copy_from_slice(&packed_parameter);

// Copy tweak into the input buffer.
packed_input[current_pos..current_pos + TWEAK_LEN]
.copy_from_slice(&packed_tweak);
current_pos += TWEAK_LEN;
// Pack tweaks directly into destination
pack_fn_into::<TWEAK_LEN>(
&mut packed_input,
chain_tweak_offset,
|t_idx, lane| {
Self::chain_tweak(epoch_chunk[lane], chain_index as u8, pos)
.to_field_elements::<TWEAK_LEN>()[t_idx]
},
);

// Copy current chain value into the input buffer.
packed_input[current_pos..current_pos + HASH_LEN]
// Copy current chain value (already packed)
packed_input[chain_value_offset..chain_value_offset + HASH_LEN]
.copy_from_slice(packed_chain);

// Apply the hash function to advance the chain.
Expand All @@ -601,23 +592,34 @@ impl<
//
// This uses the sponge construction for variable-length input.

// Generate tree tweaks for all epochs.
// Level 0 indicates this is a bottom-layer leaf in the tree.
let packed_tree_tweak = array::from_fn::<_, TWEAK_LEN, _>(|t_idx| {
PackedF::from_fn(|lane| {
// Assemble the sponge input.
// Layout: [parameter | tree_tweak | all_chain_ends]
let sponge_tweak_offset = PARAMETER_LEN;
let sponge_chains_offset = PARAMETER_LEN + TWEAK_LEN;
let sponge_input_len = PARAMETER_LEN + TWEAK_LEN + NUM_CHUNKS * HASH_LEN;

let mut packed_leaf_input = vec![PackedF::ZERO; sponge_input_len];

// Copy pre-packed parameter
packed_leaf_input[..PARAMETER_LEN].copy_from_slice(&packed_parameter);

// Pack tree tweaks directly (level 0 for bottom-layer leaves)
pack_fn_into::<TWEAK_LEN>(
&mut packed_leaf_input,
sponge_tweak_offset,
|t_idx, lane| {
Self::tree_tweak(0, epoch_chunk[lane]).to_field_elements::<TWEAK_LEN>()
[t_idx]
})
});
},
);

// Assemble the sponge input.
// Layout: [parameter | tree_tweak | all_chain_ends]
let packed_leaf_input: Vec<_> = packed_parameter
.iter()
.chain(packed_tree_tweak.iter())
.chain(packed_chains.iter().flatten())
.copied()
.collect();
// Copy all chain ends (already packed)
for (c_idx, chain) in packed_chains.iter().enumerate() {
packed_leaf_input
[sponge_chains_offset + c_idx * HASH_LEN
..sponge_chains_offset + (c_idx + 1) * HASH_LEN]
.copy_from_slice(chain);
}

// Apply the sponge hash to produce the leaf.
// This absorbs all chain ends and squeezes out the final hash.
Expand Down