Skip to content

Commit

Permalink
feat: Add tests for the REP3 ring implementation and fix minor bugs"
Browse files Browse the repository at this point in the history
  • Loading branch information
rw0x0 committed Nov 25, 2024
1 parent e1c03f3 commit ae3408a
Show file tree
Hide file tree
Showing 16 changed files with 1,708 additions and 100 deletions.
6 changes: 5 additions & 1 deletion mpc-core/src/protocols/rep3/yao/circuits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,16 @@ impl GarbledCircuits {
ys: &[G::Item],
) -> Result<Vec<G::Item>, G::Error> {
debug_assert_eq!(xs.len(), ys.len());
if xs.len() == 1 {
return Ok(vec![g.xor(&xs[0], &ys[0])?]);
}

let mut result = Vec::with_capacity(xs.len());

let (mut s, mut c) = Self::half_adder(g, &xs[0], &ys[0])?;
result.push(s);

for (x, y) in xs.iter().zip(ys.iter()).skip(1).take(xs.len() - 2) {
for (x, y) in xs.iter().zip(ys.iter()).take(xs.len() - 1).skip(1) {
let res = Self::full_adder(g, x, y, &c)?;
s = res.0;
c = res.1;
Expand Down
4 changes: 2 additions & 2 deletions mpc-core/src/protocols/rep3/yao/garbler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ use sha3::{Digest, Sha3_256};
/// This struct implements the garbler for replicated 3-party garbled circuits as described in [ABY3](https://eprint.iacr.org/2018/403.pdf).
pub struct Rep3Garbler<'a, N: Rep3Network> {
io_context: &'a mut IoContext<N>,
delta: WireMod2,
pub(crate) delta: WireMod2,
current_output: usize,
current_gate: usize,
rng: RngType,
pub(crate) rng: RngType,
hash: Sha3_256, // For the ID2 to match everything sent with one hash
circuit: Vec<[u8; 16]>,
}
Expand Down
4 changes: 2 additions & 2 deletions mpc-core/src/protocols/rep3/yao/streaming_garbler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ use sha3::{Digest, Sha3_256};
/// This struct implements the garbler for replicated 3-party garbled circuits as described in [ABY3](https://eprint.iacr.org/2018/403.pdf).
pub struct StreamingRep3Garbler<'a, N: Rep3Network> {
io_context: &'a mut IoContext<N>,
delta: WireMod2,
pub(crate) delta: WireMod2,
current_output: usize,
current_gate: usize,
rng: RngType,
pub(crate) rng: RngType,
hash: Sha3_256, // For the ID2 to match everything sent with one hash
}

Expand Down
102 changes: 96 additions & 6 deletions mpc-core/src/protocols/rep3_ring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,106 @@
//!
//! This module implements the rep3 share and combine operations for rings
use rand::{distributions::Standard, prelude::Distribution, CryptoRng, Rng};
use ring::{int_ring::IntRing2k, ring_impl::RingElement};

pub mod arithmetic;
pub mod binary;
pub mod conversion;
mod detail;
pub(crate) mod ring;

use ring::bit::Bit;
pub mod ring;
pub mod yao;

/// Shorthand type for a secret shared bit.
pub type Rep3BitShare = Rep3RingShare<Bit>;

pub type Rep3BitShare = Rep3RingShare<ring::bit::Bit>;
pub use arithmetic::types::Rep3RingShare;
pub mod yao;

/// Secret shares a ring element using replicated secret sharing and the provided random number generator. The ring element is split into three additive shares, where each party holds two. The outputs are of type [Rep3RingShare].
pub fn share_ring_element<T: IntRing2k, R: Rng + CryptoRng>(
val: RingElement<T>,
rng: &mut R,
) -> [Rep3RingShare<T>; 3]
where
Standard: Distribution<T>,
{
let a = rng.gen::<RingElement<T>>();
let b = rng.gen::<RingElement<T>>();
let c = val - a - b;
let share1 = Rep3RingShare::new_ring(a, c);
let share2 = Rep3RingShare::new_ring(b, a);
let share3 = Rep3RingShare::new_ring(c, b);
[share1, share2, share3]
}

/// Secret shares a vector of ring element using replicated secret sharing and the provided random number generator. The ring elements are split into three additive shares each, where each party holds two. The outputs are of type [Rep3RingShare].
pub fn share_ring_elements<T: IntRing2k, R: Rng + CryptoRng>(
vals: &[RingElement<T>],
rng: &mut R,
) -> [Vec<Rep3RingShare<T>>; 3]
where
Standard: Distribution<T>,
{
let mut shares1 = Vec::with_capacity(vals.len());
let mut shares2 = Vec::with_capacity(vals.len());
let mut shares3 = Vec::with_capacity(vals.len());
for val in vals {
let [share1, share2, share3] = share_ring_element(val.to_owned(), rng);
shares1.push(share1);
shares2.push(share2);
shares3.push(share3);
}
[shares1, shares2, shares3]
}

/// Secret shares a ring element using replicated secret sharing and the provided random number generator. The ring element is split into three binary shares, where each party holds two. The outputs are of type [Rep3RingShare].
pub fn share_ring_element_binary<T: IntRing2k, R: Rng + CryptoRng>(
val: RingElement<T>,
rng: &mut R,
) -> [Rep3RingShare<T>; 3]
where
Standard: Distribution<T>,
{
let a = rng.gen::<RingElement<T>>();
let b = rng.gen::<RingElement<T>>();
let c = val ^ a ^ b;
let share1 = Rep3RingShare::new_ring(a, c);
let share2 = Rep3RingShare::new_ring(b, a);
let share3 = Rep3RingShare::new_ring(c, b);
[share1, share2, share3]
}

//TODO RENAME ME TO COMBINE_ARITHMETIC_SHARE
/// Reconstructs a ring element from its arithmetic replicated shares.
pub fn combine_ring_element<T: IntRing2k>(
share1: Rep3RingShare<T>,
share2: Rep3RingShare<T>,
share3: Rep3RingShare<T>,
) -> RingElement<T> {
share1.a + share2.a + share3.a
}

/// Reconstructs a vector of ring elements from its arithmetic replicated shares.
/// # Panics
/// Panics if the provided `Vec` sizes do not match.
pub fn combine_ring_elements<T: IntRing2k>(
share1: Vec<Rep3RingShare<T>>,
share2: Vec<Rep3RingShare<T>>,
share3: Vec<Rep3RingShare<T>>,
) -> Vec<RingElement<T>> {
assert_eq!(share1.len(), share2.len());
assert_eq!(share2.len(), share3.len());

itertools::multizip((share1.into_iter(), share2.into_iter(), share3.into_iter()))
.map(|(x1, x2, x3)| x1.a + x2.a + x3.a)
.collect::<Vec<_>>()
}

//TODO RENAME ME TO COMBINE_BINARY_SHARE
/// Reconstructs a ring element from its binary replicated shares.
pub fn combine_ring_element_binary<T: IntRing2k>(
share1: Rep3RingShare<T>,
share2: Rep3RingShare<T>,
share3: Rep3RingShare<T>,
) -> RingElement<T> {
share1.a ^ share2.a ^ share3.a
}
72 changes: 24 additions & 48 deletions mpc-core/src/protocols/rep3_ring/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use types::Rep3RingShare;

use super::{
binary, conversion, detail,
ring::{int_ring::IntRing2k, ring_impl::RingElement},
ring::{bit::Bit, int_ring::IntRing2k, ring_impl::RingElement},
};

pub(super) mod ops;
Expand Down Expand Up @@ -353,7 +353,7 @@ pub fn lt<T: IntRing2k, N: Rep3Network>(
lhs: RingShare<T>,
rhs: RingShare<T>,
io_context: &mut IoContext<N>,
) -> IoResult<RingShare<T>>
) -> IoResult<RingShare<Bit>>
where
Standard: Distribution<T>,
{
Expand All @@ -367,21 +367,21 @@ pub fn lt_public<T: IntRing2k, N: Rep3Network>(
lhs: RingShare<T>,
rhs: RingElement<T>,
io_context: &mut IoContext<N>,
) -> IoResult<RingShare<T>>
) -> IoResult<RingShare<Bit>>
where
Standard: Distribution<T>,
{
// a < b is equivalent to !(a >= b)
let tmp = ge_public(lhs, rhs, io_context)?;
Ok(sub_public_by_shared(RingElement::one(), tmp, io_context.id))
Ok(!tmp)
}

/// Returns 1 if lhs <= rhs and 0 otherwise. Checks if one shared value is less than or equal to another shared value. The result is a shared value that has value 1 if the first shared value is less than or equal to the second shared value and 0 otherwise.
pub fn le<T: IntRing2k, N: Rep3Network>(
lhs: RingShare<T>,
rhs: RingShare<T>,
io_context: &mut IoContext<N>,
) -> IoResult<RingShare<T>>
) -> IoResult<RingShare<Bit>>
where
Standard: Distribution<T>,
{
Expand All @@ -394,108 +394,84 @@ pub fn le_public<T: IntRing2k, N: Rep3Network>(
lhs: RingShare<T>,
rhs: RingElement<T>,
io_context: &mut IoContext<N>,
) -> IoResult<RingShare<T>>
) -> IoResult<RingShare<Bit>>
where
Standard: Distribution<T>,
{
let res = detail::unsigned_ge_const_lhs(rhs, lhs, io_context)?;
conversion::bit_inject_from_bit(&res, io_context)
detail::unsigned_ge_const_lhs(rhs, lhs, io_context)
}

/// Returns 1 if lhs > rhs and 0 otherwise. Checks if one shared value is greater than another shared value. The result is a shared value that has value 1 if the first shared value is greater than the second shared value and 0 otherwise.
pub fn gt<T: IntRing2k, N: Rep3Network>(
lhs: RingShare<T>,
rhs: RingShare<T>,
io_context: &mut IoContext<N>,
) -> IoResult<RingShare<T>>
) -> IoResult<RingShare<Bit>>
where
Standard: Distribution<T>,
{
// a > b is equivalent to !(a <= b)
let tmp = le(lhs, rhs, io_context)?;
Ok(sub_public_by_shared(RingElement::one(), tmp, io_context.id))
Ok(!tmp)
}

/// Returns 1 if lhs > rhs and 0 otherwise. Checks if a shared value is greater than the public value. The result is a shared value that has value 1 if the shared value is greater than the public value and 0 otherwise.
pub fn gt_public<T: IntRing2k, N: Rep3Network>(
lhs: RingShare<T>,
rhs: RingElement<T>,
io_context: &mut IoContext<N>,
) -> IoResult<RingShare<T>>
) -> IoResult<RingShare<Bit>>
where
Standard: Distribution<T>,
{
// a > b is equivalent to !(a <= b)
let tmp = le_public(lhs, rhs, io_context)?;
Ok(sub_public_by_shared(RingElement::one(), tmp, io_context.id))
Ok(!tmp)
}

/// Returns 1 if lhs >= rhs and 0 otherwise. Checks if one shared value is greater than or equal to another shared value. The result is a shared value that has value 1 if the first shared value is greater than or equal to the second shared value and 0 otherwise.
pub fn ge<T: IntRing2k, N: Rep3Network>(
lhs: RingShare<T>,
rhs: RingShare<T>,
io_context: &mut IoContext<N>,
) -> IoResult<RingShare<T>>
) -> IoResult<RingShare<Bit>>
where
Standard: Distribution<T>,
{
let res = detail::unsigned_ge(lhs, rhs, io_context)?;
conversion::bit_inject_from_bit(&res, io_context)
detail::unsigned_ge(lhs, rhs, io_context)
}

/// Returns 1 if lhs >= rhs and 0 otherwise. Checks if a shared value is greater than or equal to a public value. The result is a shared value that has value 1 if the shared value is greater than or equal to the public value and 0 otherwise.
pub fn ge_public<T: IntRing2k, N: Rep3Network>(
lhs: RingShare<T>,
rhs: RingElement<T>,
io_context: &mut IoContext<N>,
) -> IoResult<RingShare<T>>
where
Standard: Distribution<T>,
{
let res = detail::unsigned_ge_const_rhs(lhs, rhs, io_context)?;
conversion::bit_inject_from_bit(&res, io_context)
}

//TODO FN REMARK - I think we can skip the bit_inject.
//Circom has dedicated op codes for bool ops so we would know
//for bool_and/bool_or etc that we are a boolean value (and therefore
//bit len 1).
//
//We leave it like that and come back to that later. Maybe it doesn't matter...

/// Checks if two shared values are equal. The result is a shared value that has value 1 if the two shared values are equal and 0 otherwise.
pub fn eq<T: IntRing2k, N: Rep3Network>(
a: RingShare<T>,
b: RingShare<T>,
io_context: &mut IoContext<N>,
) -> IoResult<RingShare<T>>
) -> IoResult<RingShare<Bit>>
where
Standard: Distribution<T>,
{
let is_zero = eq_bit(a, b, io_context)?;
let res = conversion::bit_inject(&is_zero, io_context)?;
Ok(res)
detail::unsigned_ge_const_rhs(lhs, rhs, io_context)
}

/// Checks if a shared value is equal to a public value. The result is a shared value that has value 1 if the two values are equal and 0 otherwise.
pub fn eq_public<T: IntRing2k, N: Rep3Network>(
shared: RingShare<T>,
public: RingElement<T>,
io_context: &mut IoContext<N>,
) -> IoResult<RingShare<T>>
) -> IoResult<RingShare<Bit>>
where
Standard: Distribution<T>,
{
let public = promote_to_trivial_share(io_context.id, public);
eq(shared, public, io_context)
}

/// Same as eq but without using bit_inject on the result. Checks whether two shares are equal and return a binary share of 0 or 1. 1 means they are equal.
pub fn eq_bit<T: IntRing2k, N: Rep3Network>(
/// Checks if two shared values are equal. The result is a shared value that has value 1 if the two shared values are equal and 0 otherwise.
pub fn eq<T: IntRing2k, N: Rep3Network>(
a: RingShare<T>,
b: RingShare<T>,
io_context: &mut IoContext<N>,
) -> IoResult<RingShare<T>>
) -> IoResult<RingShare<Bit>>
where
Standard: Distribution<T>,
{
Expand All @@ -510,20 +486,20 @@ pub fn neq<T: IntRing2k, N: Rep3Network>(
a: RingShare<T>,
b: RingShare<T>,
io_context: &mut IoContext<N>,
) -> IoResult<RingShare<T>>
) -> IoResult<RingShare<Bit>>
where
Standard: Distribution<T>,
{
let eq = eq(a, b, io_context)?;
Ok(sub_public_by_shared(RingElement::one(), eq, io_context.id))
Ok(!eq)
}

/// Checks if a shared value is not equal to a public value. The result is a shared value that has value 1 if the two values are not equal and 0 otherwise.
pub fn neq_public<T: IntRing2k, N: Rep3Network>(
shared: RingShare<T>,
public: RingElement<T>,
io_context: &mut IoContext<N>,
) -> IoResult<RingShare<T>>
) -> IoResult<RingShare<Bit>>
where
Standard: Distribution<T>,
{
Expand All @@ -540,9 +516,9 @@ where
Standard: Distribution<T>,
{
let zero_share = RingShare::default();
let res = eq_bit(zero_share, a, io_context)?;
let res = eq(zero_share, a, io_context)?;
let x = open_bit(res, io_context)?;
Ok(x.is_one())
Ok(x.0.convert())
}

/// Computes `shared*2^public`. This is the same as `shared << public`.
Expand Down
9 changes: 6 additions & 3 deletions mpc-core/src/protocols/rep3_ring/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use super::{
arithmetic::RingShare,
ring::{int_ring::IntRing2k, ring_impl::RingElement},
ring::{bit::Bit, int_ring::IntRing2k, ring_impl::RingElement},
};
use crate::protocols::rep3::{
id::PartyID,
Expand Down Expand Up @@ -201,7 +201,7 @@ where
pub fn is_zero<T: IntRing2k, N: Rep3Network>(
x: &RingShare<T>,
io_context: &mut IoContext<N>,
) -> IoResult<RingShare<T>>
) -> IoResult<RingShare<Bit>>
where
Standard: Distribution<T>,
{
Expand All @@ -220,5 +220,8 @@ where
x = and(&(x & mask), &(y & mask), io_context)?;
}
// extract LSB
Ok(x & RingElement::one())
Ok(RingShare {
a: RingElement(Bit::new((x.a & RingElement::one()) == RingElement::one())),
b: RingElement(Bit::new((x.b & RingElement::one()) == RingElement::one())),
})
}
10 changes: 7 additions & 3 deletions mpc-core/src/protocols/rep3_ring/ring.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
pub(super) mod bit;
pub(super) mod int_ring;
pub(super) mod ring_impl;
//! Ring
//!
//! Contains traits and implementations for different rings Z_{2^k}
pub mod bit;
pub mod int_ring;
pub mod ring_impl;
Loading

0 comments on commit ae3408a

Please sign in to comment.