Skip to content

Commit

Permalink
feat: Add a garbled circuit based soritng algorithm (#252)
Browse files Browse the repository at this point in the history
  • Loading branch information
rw0x0 authored Nov 5, 2024
1 parent 3907a25 commit 7d38334
Show file tree
Hide file tree
Showing 6 changed files with 455 additions and 139 deletions.
1 change: 1 addition & 0 deletions mpc-core/src/protocols/rep3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod arithmetic;
pub mod binary;
pub mod conversion;
mod detail;
pub mod gadgets;
pub mod id;
pub mod lut;
pub mod network;
Expand Down
5 changes: 5 additions & 0 deletions mpc-core/src/protocols/rep3/gadgets/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
//! Gadgets
//!
//! This module contains some commonly used gadgets for the Rep3 protocol.
pub mod sort;
34 changes: 34 additions & 0 deletions mpc-core/src/protocols/rep3/gadgets/sort.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//! Sort
//!
//! This module contains some oblivious sorting algorithms for the Rep3 protocol.
use crate::protocols::rep3::{
arithmetic::FieldShare,
network::{IoContext, Rep3Network},
yao::{self, circuits::GarbledCircuits},
IoResult,
};
use ark_ff::PrimeField;

/// Sorts the inputs using the Batcher's odd-even merge sort algorithm. Thereby, only the lowest `bitsize` bits are considered. The final results also only have bitsize bits each.
pub fn batcher_odd_even_merge_sort_yao<F: PrimeField, N: Rep3Network>(
inputs: &[FieldShare<F>],
io_context: &mut IoContext<N>,
bitsize: usize,
) -> IoResult<Vec<FieldShare<F>>> {
if bitsize > F::MODULUS_BIT_SIZE as usize {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Bit size is larger than field size",
))?;
}
let num_inputs = inputs.len();

yao::decompose_circuit_compose_blueprint!(
inputs,
io_context,
num_inputs,
GarbledCircuits::batcher_odd_even_merge_sort::<_, F>,
(bitsize)
)
}
221 changes: 109 additions & 112 deletions mpc-core/src/protocols/rep3/yao.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ use super::{
use crate::protocols::rep3::id::PartyID;
use ark_ff::{PrimeField, Zero};
use circuits::GarbledCircuits;
use evaluator::Rep3Evaluator;
use fancy_garbling::{hash_wires, util::tweak2, BinaryBundle, WireLabel, WireMod2};
use garbler::Rep3Garbler;
use itertools::{izip, Itertools};
use num_bigint::BigUint;
use rand::{CryptoRng, Rng};
Expand Down Expand Up @@ -442,7 +440,7 @@ pub fn joint_input_arithmetic_added<F: PrimeField, N: Rep3Network>(
}

/// Transforms a vector of arithmetically shared inputs x = (x_1, x_2, x_3) into two yao shares x_1^Y, (x_2 + x_3)^Y. The used delta is an input to the function to allow for the same delta to be used for multiple conversions.
fn joint_input_arithmetic_added_many<F: PrimeField, N: Rep3Network>(
pub fn joint_input_arithmetic_added_many<F: PrimeField, N: Rep3Network>(
x: &[Rep3PrimeFieldShare<F>],
delta: Option<WireMod2>,
io_context: &mut IoContext<N>,
Expand Down Expand Up @@ -592,7 +590,7 @@ pub fn joint_input_binary_xored<F: PrimeField, N: Rep3Network>(
}

/// Lets the party with id2 input a vector field elements, which gets shared as Yao wires to the other parties.
fn input_field_id2_many<F: PrimeField, N: Rep3Network>(
pub fn input_field_id2_many<F: PrimeField, N: Rep3Network>(
x: Option<Vec<F>>,
delta: Option<WireMod2>,
n_inputs: usize,
Expand Down Expand Up @@ -677,126 +675,125 @@ pub fn decompose_arithmetic<F: PrimeField, N: Rep3Network>(
)
}

// TODO implement with streaming Garbler/Evaluator as well
// TODO implement with a2b/b2a as well
macro_rules! decompose_circuit_compose_blueprint {
($inputs:expr, $io_context:expr, $output_size:expr, $circuit:expr, ($( $args:expr ),*)) => {{
use $crate::protocols::rep3::id::PartyID;
use itertools::izip;
use $crate::protocols::rep3::yao;
use $crate::protocols::rep3::Rep3PrimeFieldShare;

/// Decomposes a vector of shared field element into chunks, which are also represented as shared field elements. Per field element, the total bit size of the shared chunks is given by total_bit_size_per_field, whereas each chunk has at most (i.e, the last chunk can be smaller) decompose_bit_size bits.
pub fn decompose_arithmetic_many<F: PrimeField, N: Rep3Network>(
inputs: &[Rep3PrimeFieldShare<F>],
io_context: &mut IoContext<N>,
total_bit_size_per_field: usize,
decompose_bit_size: usize,
) -> IoResult<Vec<Rep3PrimeFieldShare<F>>> {
let num_inputs = inputs.len();
let num_decomps_per_field = total_bit_size_per_field.div_ceil(decompose_bit_size);
let total_output_elements = num_decomps_per_field * num_inputs;
let delta = $io_context
.rngs
.generate_random_garbler_delta($io_context.id);

let delta = io_context.rngs.generate_random_garbler_delta(io_context.id);
let [x01, x2] = yao::joint_input_arithmetic_added_many($inputs, delta, $io_context)?;

let [x01, x2] = joint_input_arithmetic_added_many(inputs, delta, io_context)?;
let mut res = vec![Rep3PrimeFieldShare::zero_share(); $output_size];

let mut res = vec![Rep3PrimeFieldShare::zero_share(); total_output_elements];
match $io_context.id {
PartyID::ID0 => {
for res in res.iter_mut() {
let k3 = $io_context.rngs.bitcomp2.random_fes_3keys::<F>();
res.b = (k3.0 + k3.1 + k3.2).neg();
}

match io_context.id {
PartyID::ID0 => {
for res in res.iter_mut() {
let k3 = io_context.rngs.bitcomp2.random_fes_3keys::<F>();
res.b = (k3.0 + k3.1 + k3.2).neg();
}
// TODO this can be parallelized with joint_input_arithmetic_added_many
let x23 = yao::input_field_id2_many::<F, _>(None, None, $output_size, $io_context)?;

// TODO this can be parallelized with joint_input_arithmetic_added_many
let x23 = input_field_id2_many::<F, _>(None, None, total_output_elements, io_context)?;

let mut evaluator = Rep3Evaluator::new(io_context);
evaluator.receive_circuit()?;

let x1 = GarbledCircuits::decompose_field_element_many::<_, F>(
&mut evaluator,
&x01,
&x2,
&x23,
decompose_bit_size,
total_bit_size_per_field,
);
let x1 = GCUtils::garbled_circuits_error(x1)?;
let x1 = evaluator.output_to_id0_and_id1(x1.wires())?;

// Compose the bits
for (res, x1) in izip!(res.iter_mut(), x1.chunks(F::MODULUS_BIT_SIZE as usize)) {
res.a = GCUtils::bits_to_field(x1)?;
}
}
PartyID::ID1 => {
for res in res.iter_mut() {
let k2 = io_context.rngs.bitcomp1.random_fes_3keys::<F>();
res.a = (k2.0 + k2.1 + k2.2).neg();
}
let mut evaluator = yao::evaluator::Rep3Evaluator::new($io_context);
evaluator.receive_circuit()?;

// TODO this can be parallelized with joint_input_arithmetic_added_many
let x23 = input_field_id2_many::<F, _>(None, None, total_output_elements, io_context)?;

let mut garbler =
Rep3Garbler::new_with_delta(io_context, delta.expect("Delta not provided"));

let x1 = GarbledCircuits::decompose_field_element_many::<_, F>(
&mut garbler,
&x01,
&x2,
&x23,
decompose_bit_size,
total_bit_size_per_field,
);
let x1 = GCUtils::garbled_circuits_error(x1)?;
let x1 = garbler.output_to_id0_and_id1(x1.wires())?;
let x1 = match x1 {
Some(x1) => x1,
None => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"No output received",
))?,
};
let x1 = $circuit(&mut evaluator, &x01, &x2, &x23, $($args),*);
let x1 = yao::GCUtils::garbled_circuits_error(x1)?;
let x1 = evaluator.output_to_id0_and_id1(x1.wires())?;

// Compose the bits
for (res, x1) in izip!(res.iter_mut(), x1.chunks(F::MODULUS_BIT_SIZE as usize)) {
res.b = GCUtils::bits_to_field(x1)?;
// Compose the bits
for (res, x1) in izip!(res.iter_mut(), x1.chunks(F::MODULUS_BIT_SIZE as usize)) {
res.a = yao::GCUtils::bits_to_field(x1)?;
}
}
}
PartyID::ID2 => {
let mut x23 = Vec::with_capacity(total_output_elements);
for res in res.iter_mut() {
let k2 = io_context.rngs.bitcomp1.random_fes_3keys::<F>();
let k3 = io_context.rngs.bitcomp2.random_fes_3keys::<F>();
let k2_comp = k2.0 + k2.1 + k2.2;
let k3_comp = k3.0 + k3.1 + k3.2;
x23.push(k2_comp + k3_comp);
res.a = k3_comp.neg();
res.b = k2_comp.neg();
PartyID::ID1 => {
for res in res.iter_mut() {
let k2 = $io_context.rngs.bitcomp1.random_fes_3keys::<F>();
res.a = (k2.0 + k2.1 + k2.2).neg();
}

// TODO this can be parallelized with joint_input_arithmetic_added_many
let x23 = yao::input_field_id2_many::<F, _>(None, None, $output_size, $io_context)?;

let mut garbler =
yao::garbler::Rep3Garbler::new_with_delta($io_context, delta.expect("Delta not provided"));

let x1 = $circuit(&mut garbler, &x01, &x2, &x23, $($args),*);
let x1 = yao::GCUtils::garbled_circuits_error(x1)?;
let x1 = garbler.output_to_id0_and_id1(x1.wires())?;
let x1 = match x1 {
Some(x1) => x1,
None => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"No output received",
))?,
};

// Compose the bits
for (res, x1) in izip!(res.iter_mut(), x1.chunks(F::MODULUS_BIT_SIZE as usize)) {
res.b = yao::GCUtils::bits_to_field(x1)?;
}
}
PartyID::ID2 => {
let mut x23 = Vec::with_capacity($output_size);
for res in res.iter_mut() {
let k2 = $io_context.rngs.bitcomp1.random_fes_3keys::<F>();
let k3 = $io_context.rngs.bitcomp2.random_fes_3keys::<F>();
let k2_comp = k2.0 + k2.1 + k2.2;
let k3_comp = k3.0 + k3.1 + k3.2;
x23.push(k2_comp + k3_comp);
res.a = k3_comp.neg();
res.b = k2_comp.neg();
}

// TODO this can be parallelized with joint_input_arithmetic_added_many
let x23 = yao::input_field_id2_many(Some(x23), delta, $output_size, $io_context)?;

// TODO this can be parallelized with joint_input_arithmetic_added_many
let x23 = input_field_id2_many(Some(x23), delta, total_output_elements, io_context)?;

let mut garbler =
Rep3Garbler::new_with_delta(io_context, delta.expect("Delta not provided"));

let x1 = GarbledCircuits::decompose_field_element_many::<_, F>(
&mut garbler,
&x01,
&x2,
&x23,
decompose_bit_size,
total_bit_size_per_field,
);
let x1 = GCUtils::garbled_circuits_error(x1)?;
let x1 = garbler.output_to_id0_and_id1(x1.wires())?;
if x1.is_some() {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Unexpected output received",
))?;
let mut garbler =
yao::garbler::Rep3Garbler::new_with_delta($io_context, delta.expect("Delta not provided"));

let x1 = $circuit(&mut garbler, &x01, &x2, &x23, $($args),*);
let x1 = yao::GCUtils::garbled_circuits_error(x1)?;
let x1 = garbler.output_to_id0_and_id1(x1.wires())?;
if x1.is_some() {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Unexpected output received",
))?;
}
}
}
}

Ok(res)
Ok(res)
}};
}
pub(crate) use decompose_circuit_compose_blueprint;

// TODO implement with streaming Garbler/Evaluator as well
// TODO implement with a2b/b2a as well

/// Decomposes a vector of shared field element into chunks, which are also represented as shared field elements. Per field element, the total bit size of the shared chunks is given by total_bit_size_per_field, whereas each chunk has at most (i.e, the last chunk can be smaller) decompose_bit_size bits.
pub fn decompose_arithmetic_many<F: PrimeField, N: Rep3Network>(
inputs: &[Rep3PrimeFieldShare<F>],
io_context: &mut IoContext<N>,
total_bit_size_per_field: usize,
decompose_bit_size: usize,
) -> IoResult<Vec<Rep3PrimeFieldShare<F>>> {
let num_inputs = inputs.len();
let num_decomps_per_field = total_bit_size_per_field.div_ceil(decompose_bit_size);
let total_output_elements = num_decomps_per_field * num_inputs;

decompose_circuit_compose_blueprint!(
inputs,
io_context,
total_output_elements,
GarbledCircuits::decompose_field_element_many::<_, F>,
(decompose_bit_size, total_bit_size_per_field)
)
}
Loading

0 comments on commit 7d38334

Please sign in to comment.