Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[saffron] Add Diff type , methods, and tests #3006

Merged
merged 3 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions saffron/src/diff.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
use crate::utils::encode_for_domain;
use ark_ff::PrimeField;
use ark_poly::{EvaluationDomain, Evaluations, Radix2EvaluationDomain};
use rayon::prelude::*;
use thiserror::Error;
use tracing::instrument;

// sparse representation, keeping only the non-zero differences
#[derive(Clone, Debug, PartialEq)]
pub struct Diff<F: PrimeField> {
pub evaluation_diffs: Vec<Vec<(usize, F)>>,
}

#[derive(Debug, Error, Clone, PartialEq)]
pub enum DiffError {
#[error("Capacity Mismatch: maximum number of chunks is {max_number_chunks}, attempted to create {attempted}")]
CapacityMismatch {
max_number_chunks: usize,
attempted: usize,
},
}

impl<F: PrimeField> Diff<F> {
#[instrument(skip_all, level = "debug")]
pub fn create<D: EvaluationDomain<F>>(
domain: &D,
old: &[u8],
new: &[u8],
) -> Result<Diff<F>, DiffError> {
let old_elems: Vec<Vec<F>> = encode_for_domain(domain, old);
let mut new_elems: Vec<Vec<F>> = encode_for_domain(domain, new);
if old_elems.len() < new_elems.len() {
return Err(DiffError::CapacityMismatch {
max_number_chunks: old_elems.len(),
attempted: new_elems.len(),
});
}
if old_elems.len() > new_elems.len() {
let padding = vec![F::zero(); domain.size()];
new_elems.resize(old_elems.len(), padding);
}
Ok(Diff {
evaluation_diffs: new_elems
.par_iter()
.zip(old_elems)
.map(|(n, o)| {
n.iter()
.zip(o)
.enumerate()
.map(|(index, (a, b))| (index, *a - b))
.filter(|(_, x)| !x.is_zero())
.collect()
})
.collect(),
})
}

#[instrument(skip_all, level = "debug")]
pub fn as_evaluations(
&self,
domain: &Radix2EvaluationDomain<F>,
) -> Vec<Evaluations<F, Radix2EvaluationDomain<F>>> {
self.evaluation_diffs
.par_iter()
.map(|diff| {
let mut evals = vec![F::zero(); domain.size()];
diff.iter().for_each(|(j, val)| {
evals[*j] = *val;
});
Evaluations::from_vec_and_domain(evals, *domain)
})
.collect()
}
}

#[cfg(test)]
pub mod tests {
use super::*;
use crate::utils::{chunk_size_in_bytes, min_encoding_chunks, test_utils::UserData};
use ark_ff::Zero;
use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
use mina_curves::pasta::Fp;
use once_cell::sync::Lazy;
use proptest::prelude::*;
use rand::Rng;

static DOMAIN: Lazy<Radix2EvaluationDomain<Fp>> =
Lazy::new(|| Radix2EvaluationDomain::new(1 << 16).unwrap());

pub fn randomize_data(threshold: f64, data: &[u8]) -> Vec<u8> {
let mut rng = rand::thread_rng();
data.iter()
.map(|b| {
let n = rng.gen::<f64>();
if n < threshold {
rng.gen::<u8>()
} else {
*b
}
})
.collect()
}

pub fn random_diff(UserData(xs): UserData) -> BoxedStrategy<(UserData, UserData)> {
let n_chunks = min_encoding_chunks(&*DOMAIN, &xs);
let max_byte_len = n_chunks * chunk_size_in_bytes(&*DOMAIN);
(0.0..=1.0, 0..=max_byte_len)
.prop_flat_map(move |(threshold, n)| {
let mut ys = randomize_data(threshold, &xs);
// NOTE: n could be less than xs.len(), in which case this is just truncation
ys.resize_with(n, rand::random);
Just((UserData(xs.clone()), UserData(ys)))
})
.boxed()
}

fn add(mut evals: Vec<Vec<Fp>>, diff: &Diff<Fp>) -> Vec<Vec<Fp>> {
evals
.par_iter_mut()
.zip(diff.evaluation_diffs.par_iter())
.for_each(|(eval_chunk, diff_chunk)| {
diff_chunk.iter().for_each(|(j, val)| {
eval_chunk[*j] += val;
});
});
evals.to_vec()
}

proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]

fn test_allow_legal_updates((UserData(xs), UserData(ys)) in
(UserData::arbitrary().prop_flat_map(random_diff))
) {
let diff = Diff::<Fp>::create(&*DOMAIN, &xs, &ys);
prop_assert!(diff.is_ok());
let xs_elems = encode_for_domain(&*DOMAIN, &xs);
let ys_elems = {
let pad = vec![Fp::zero(); DOMAIN.size()];
let mut elems = encode_for_domain(&*DOMAIN, &ys);
elems.resize(xs_elems.len(), pad);
elems
};
let result = add(xs_elems.clone(), &diff.unwrap());
prop_assert_eq!(result, ys_elems);
}
}

// Check that we CAN'T construct a diff that requires more polynomial chunks than the original data
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
fn test_cannot_construct_bad_diff(
(threshold, (UserData(data), UserData(mut extra))) in (
0.0..1.0,
UserData::arbitrary().prop_flat_map(|UserData(d1)| {
UserData::arbitrary()
.prop_filter_map(
"length constraint", {
move |UserData(d2)| {
let combined = &[d1.as_slice(), d2.as_slice()].concat();
if min_encoding_chunks(&*DOMAIN, &d1) <
min_encoding_chunks(&*DOMAIN, combined) {
Some((UserData(d1.clone()), UserData(d2)))
} else {
None
}
}
}
)
})
)
) {
let mut ys = randomize_data(threshold, &data);
ys.append(&mut extra);
let diff = Diff::<Fp>::create(&*DOMAIN, &data, &ys);
prop_assert!(diff.is_err());
}
}
}
1 change: 1 addition & 0 deletions saffron/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod blob;
pub mod cli;
pub mod commitment;
pub mod diff;
pub mod env;
pub mod proof;
pub mod utils;
21 changes: 16 additions & 5 deletions saffron/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,19 @@ pub mod test_utils {
}
}

// returns the minimum number of polynomials required to encode the data
pub fn min_encoding_chunks<F: PrimeField, D: EvaluationDomain<F>>(domain: &D, xs: &[u8]) -> usize {
let m = F::MODULUS_BIT_SIZE as usize / 8;
let n = xs.len();
let num_field_elems = (n + m - 1) / m;
(num_field_elems + domain.size() - 1) / domain.size()
}

pub fn chunk_size_in_bytes<F: PrimeField, D: EvaluationDomain<F>>(domain: &D) -> usize {
let m = F::MODULUS_BIT_SIZE as usize / 8;
domain.size() * m
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -337,12 +350,10 @@ mod tests {
}
}

// The number of field elements required to encode the data, including the padding
fn padded_field_length(xs: &[u8]) -> usize {
let m = Fp::MODULUS_BIT_SIZE as usize / 8;
let n = xs.len();
let num_field_elems = (n + m - 1) / m;
let num_polys = (num_field_elems + DOMAIN.size() - 1) / DOMAIN.size();
DOMAIN.size() * num_polys
let n = min_encoding_chunks(&*DOMAIN, xs);
n * DOMAIN.size()
}

proptest! {
Expand Down
Loading