Skip to content

Commit

Permalink
feat: Add MPC functionality to slice a shared value
Browse files Browse the repository at this point in the history
  • Loading branch information
rw0x0 committed Jan 20, 2025
1 parent db1e449 commit 595e33b
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 0 deletions.
34 changes: 34 additions & 0 deletions mpc-core/src/protocols/rep3/yao.rs
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,19 @@ pub fn decompose_arithmetic<F: PrimeField, N: Rep3Network>(
)
}

/// Slices a shared field element at given indices (msb, lsb), both included in the slice.
/// Only consideres bitsize bits.
/// Result is thus [lo, slice, hi], where slice has all bits from lsb to msb, lo all bits smaller than lsb, and hi all bits greater msb up to bitsize.
pub fn slice_arithmetic<F: PrimeField, N: Rep3Network>(
input: Rep3PrimeFieldShare<F>,
io_context: &mut IoContext<N>,
msb: usize,
lsb: usize,
bitsize: usize,
) -> IoResult<Vec<Rep3PrimeFieldShare<F>>> {
slice_arithmetic_many(&[input], io_context, msb, lsb, bitsize)
}

/// Divides a vector of field elements by a power of 2, rounding down.
pub fn field_int_div_power_2_many<F: PrimeField, N: Rep3Network>(
inputs: &[Rep3PrimeFieldShare<F>],
Expand Down Expand Up @@ -934,3 +947,24 @@ pub fn decompose_arithmetic_many<F: PrimeField, N: Rep3Network>(
(decompose_bit_size, total_bit_size_per_field)
)
}

/// Slices a vector of shared field elements at given indices (msb, lsb), both included in the slice.
/// Only consideres bitsize bits.
/// Result (per input) is thus [lo, slice, hi], where slice has all bits from lsb to msb, lo all bits smaller than lsb, and hi all bits greater msb up to bitsize.
pub fn slice_arithmetic_many<F: PrimeField, N: Rep3Network>(
inputs: &[Rep3PrimeFieldShare<F>],
io_context: &mut IoContext<N>,
msb: usize,
lsb: usize,
bitsize: usize,
) -> IoResult<Vec<Rep3PrimeFieldShare<F>>> {
let num_inputs = inputs.len();
let total_output_elements = 3 * num_inputs;
decompose_circuit_compose_blueprint!(
inputs,
io_context,
total_output_elements,
GarbledCircuits::slice_field_element_many::<_, F>,
(msb, lsb, bitsize)
)
}
86 changes: 86 additions & 0 deletions mpc-core/src/protocols/rep3/yao/circuits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,10 +580,15 @@ impl GarbledCircuits {
// compose chunk_bits again
// For the bin addition, our input is not of size F::ModulusBitSize, thus we can optimize a little bit

if field_wires.is_empty() {
return Ok(rand_wires.to_owned());
}

let mut added = Vec::with_capacity(input_bitlen);

let xs = field_wires;
let ys = rand_wires;

let (mut s, mut c) = Self::half_adder(g, &xs[0], &ys[0])?;
added.push(s);

Expand Down Expand Up @@ -793,6 +798,87 @@ impl GarbledCircuits {
Ok(BinaryBundle::new(results))
}

/// Slices a field element (represented as two bitdecompositions wires_a, wires_b which need to be added first) at given indices (msb, lsb), both included in the slice. For the bitcomposition, wires_c are used.
fn slice_field_element<G: FancyBinary, F: PrimeField>(
g: &mut G,
wires_a: &[G::Item],
wires_b: &[G::Item],
wires_c: &[G::Item],
msb: usize,
lsb: usize,
bitsize: usize,
) -> Result<Vec<G::Item>, G::Error> {
debug_assert_eq!(wires_a.len(), wires_b.len());
let input_bitlen = wires_a.len();
debug_assert_eq!(input_bitlen, F::MODULUS_BIT_SIZE as usize);
debug_assert!(input_bitlen >= bitsize);
debug_assert!(msb >= lsb);
debug_assert!(msb < bitsize);
debug_assert_eq!(wires_c.len(), input_bitlen * 3);

let input_bits = Self::adder_mod_p_with_output_size::<_, F>(g, wires_a, wires_b, bitsize)?;

let mut rands = wires_c.chunks(input_bitlen);

let lo = Self::compose_field_element::<_, F>(g, &input_bits[..lsb], rands.next().unwrap())?;
let slice =
Self::compose_field_element::<_, F>(g, &input_bits[lsb..=msb], rands.next().unwrap())?;

let hi = if msb == bitsize {
Self::compose_field_element::<_, F>(g, &[], rands.next().unwrap())?
} else {
Self::compose_field_element::<_, F>(
g,
&input_bits[msb + 1..bitsize],
rands.next().unwrap(),
)?
};

let mut results = lo;
results.extend(slice);
results.extend(hi);

Ok(results)
}

/// Slices a vector of field elements (represented as two bitdecompositions wires_a, wires_b which need to be added first) at given indices (msb, lsb), both included in the slice. For the bitcomposition, wires_c are used.
pub(crate) fn slice_field_element_many<G: FancyBinary, F: PrimeField>(
g: &mut G,
wires_a: &BinaryBundle<G::Item>,
wires_b: &BinaryBundle<G::Item>,
wires_c: &BinaryBundle<G::Item>,
msb: usize,
lsb: usize,
bitsize: usize,
) -> Result<BinaryBundle<G::Item>, G::Error> {
debug_assert_eq!(wires_a.size(), wires_b.size());
let input_size = wires_a.size();
let input_bitlen = F::MODULUS_BIT_SIZE as usize;
let num_inputs = input_size / input_bitlen;

let total_output_elements = 3 * num_inputs;

debug_assert_eq!(input_size % input_bitlen, 0);
debug_assert!(input_bitlen >= bitsize);
debug_assert!(msb >= lsb);
debug_assert!(msb < bitsize);
debug_assert_eq!(wires_c.size(), input_bitlen * total_output_elements);

let mut results = Vec::with_capacity(wires_c.size());

for (chunk_a, chunk_b, chunk_c) in izip!(
wires_a.wires().chunks(input_bitlen),
wires_b.wires().chunks(input_bitlen),
wires_c.wires().chunks(input_bitlen * 3)
) {
let sliced =
Self::slice_field_element::<_, F>(g, chunk_a, chunk_b, chunk_c, msb, lsb, bitsize)?;
results.extend(sliced);
}

Ok(BinaryBundle::new(results))
}

fn unsigned_ge<G: FancyBinary>(
g: &mut G,
a: &[G::Item],
Expand Down
61 changes: 61 additions & 0 deletions tests/tests/mpc/rep3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,67 @@ mod field_share {
assert_eq!(is_result, should_result);
}

fn rep3_slice_shared_field_many_via_yao_inner(msb: usize, lsb: usize, bitsize: usize) {
const VEC_SIZE: usize = 10;

let test_network = Rep3TestNetwork::default();
let mut rng = thread_rng();
let x = (0..VEC_SIZE)
.map(|_| ark_bn254::Fr::rand(&mut rng))
.collect_vec();
let x_shares = rep3::share_field_elements(&x, &mut rng);

let mut should_result = Vec::with_capacity(VEC_SIZE * 3);
let big_mask = (BigUint::from(1u64) << bitsize) - BigUint::one();
let hi_mask = (BigUint::one() << (bitsize - msb)) - BigUint::one();
let lo_mask = (BigUint::one() << lsb) - BigUint::one();
let slice_mask = (BigUint::one() << ((msb - lsb) as u32 + 1)) - BigUint::one();
let msb_plus_one = msb as u32 + 1;

for x in x.into_iter() {
let mut x: BigUint = x.into();
x &= &big_mask;
let hi = (&x >> msb_plus_one) & &hi_mask;
let lo = &x & &lo_mask;
let slice = (&x >> lsb) & &slice_mask;
assert_eq!(x, &lo + (&slice << lsb) + (&hi << msb_plus_one));
should_result.push(ark_bn254::Fr::from(lo));
should_result.push(ark_bn254::Fr::from(slice));
should_result.push(ark_bn254::Fr::from(hi));
}

let (tx1, rx1) = mpsc::channel();
let (tx2, rx2) = mpsc::channel();
let (tx3, rx3) = mpsc::channel();

for (net, tx, x) in izip!(
test_network.get_party_networks().into_iter(),
[tx1, tx2, tx3],
x_shares.into_iter()
) {
thread::spawn(move || {
let mut rep3 = IoContext::init(net).unwrap();

let decomposed =
yao::slice_arithmetic_many(&x, &mut rep3, msb, lsb, bitsize).unwrap();
tx.send(decomposed)
});
}

let result1 = rx1.recv().unwrap();
let result2 = rx2.recv().unwrap();
let result3 = rx3.recv().unwrap();
let is_result = rep3::combine_field_elements(&result1, &result2, &result3);
assert_eq!(is_result, should_result);
}

#[test]
fn rep3_slice_shared_field_many_via_yao() {
rep3_slice_shared_field_many_via_yao_inner(253, 0, 254);
rep3_slice_shared_field_many_via_yao_inner(100, 10, 254);
rep3_slice_shared_field_many_via_yao_inner(100, 10, 110);
}

#[test]
fn rep3_batcher_odd_even_merge_sort_via_yao() {
const VEC_SIZE: usize = 10;
Expand Down

0 comments on commit 595e33b

Please sign in to comment.