Skip to content

Commit

Permalink
Proof batching for breakdown reveal aggregation (#1323)
Browse files Browse the repository at this point in the history
* Unlimited batch size for reveal aggregation

But avoid excessive memory allocations

* Batching for breakdown reveal aggregation

* Reduce TARGET_PROOF_SIZE for tests

* Add a test for growing `pending_records`

* More large batch tests

* Compact gate fixes

* Fixing step count blow-up

* More test fixes

* Fix a bug and adjust the test to catch it.

* Keep semi-honest for shuttle, don't do shuttle for malicious

* Optimize vec_chunks

---------

Co-authored-by: Alex Koshelev <koshelev@meta.com>
  • Loading branch information
andyleiserson and akoshelev authored Oct 11, 2024
1 parent f7adc86 commit 41b057c
Show file tree
Hide file tree
Showing 15 changed files with 434 additions and 119 deletions.
48 changes: 30 additions & 18 deletions ipa-core/src/protocol/basics/reveal.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,27 @@
// Several of the reveal impls use distinct type parameters for the value being revealed
// and the context-assiciated field.
//
// For MAC, this takes the form of distinct `V` and `CtxF` type parameters. For DZKP,
// this takes the form of a `V` type parameter different from the implicit `Boolean`
// used by the context.
//
// This decoupling is needed to support:
//
// 1. The PRF evaluation protocol, which uses `Fp25519` for the malicious context, but
// needs to reveal `RP25519` values.
// 2. The breakdown reveal aggregation protocol, which uses `Boolean` for the malicious
// context, but needs to reveal `BK` values.
//
// The malicious reveal protocol must check the shares being revealed for consistency,
// but doesn't care that they are in the same field as is used for the malicious
// context. Contrast with multiplication, which can only be supported in the malicious
// context's field.
//
// It also doesn't matter that `V` and `CtxF` support the same vectorization dimension
// `N`, but the compiler would not be able to infer the value of a decoupled
// vectorization dimension for `CtxF` from context, so it's easier to make them the same
// absent a need for them to be different.

use std::{
future::Future,
iter::{repeat, zip},
Expand All @@ -8,7 +32,6 @@ use futures::{FutureExt, TryFutureExt};

use crate::{
error::Error,
ff::boolean::Boolean,
helpers::{Direction, MaybeFuture, Role},
protocol::{
boolean::step::TwoHundredFiftySixBitOpStep,
Expand Down Expand Up @@ -170,8 +193,6 @@ where
}
}

// Like the impl for `UpgradedMaliciousContext`, this impl uses distinct `V` and `CtxF` type
// parameters. See the comment on that impl for more details.
impl<'a, B, V, CtxF, const N: usize> Reveal<UpgradedSemiHonestContext<'a, B, CtxF>>
for Replicated<V, N>
where
Expand All @@ -194,12 +215,12 @@ where
}
}

impl<'a, B, const N: usize> Reveal<DZKPUpgradedSemiHonestContext<'a, B>> for Replicated<Boolean, N>
impl<'a, V, B, const N: usize> Reveal<DZKPUpgradedSemiHonestContext<'a, B>> for Replicated<V, N>
where
B: ShardBinding,
Boolean: Vectorizable<N>,
V: SharedValue + Vectorizable<N>,
{
type Output = <Boolean as Vectorizable<N>>::Array;
type Output = <V as Vectorizable<N>>::Array;

async fn generic_reveal<'fut>(
&'fut self,
Expand Down Expand Up @@ -270,15 +291,6 @@ where
}
}

// This impl uses distinct `V` and `CtxF` type parameters to support the PRF evaluation protocol,
// which uses `Fp25519` for the malicious context, but needs to reveal `RP25519` values. The
// malicious reveal protocol must check the shares being revealed for consistency, but doesn't care
// that they are in the same field as is used for the malicious context. Contrast with
// multiplication, which can only be supported in the malicious context's field.
//
// It also doesn't matter that `V` and `CtxF` support the same vectorization dimension `N`, but the
// compiler would not be able to infer the value of a decoupled vectorization dimension for `CtxF`
// from context, so it's easier to make them the same absent a need for them to be different.
impl<'a, V, const N: usize, CtxF> Reveal<UpgradedMaliciousContext<'a, CtxF>> for Replicated<V, N>
where
CtxF: ExtendableField,
Expand Down Expand Up @@ -321,12 +333,12 @@ where
}
}

impl<'a, B, const N: usize> Reveal<DZKPUpgradedMaliciousContext<'a, B>> for Replicated<Boolean, N>
impl<'a, V, B, const N: usize> Reveal<DZKPUpgradedMaliciousContext<'a, B>> for Replicated<V, N>
where
B: ShardBinding,
Boolean: Vectorizable<N>,
V: SharedValue + Vectorizable<N>,
{
type Output = <Boolean as Vectorizable<N>>::Array;
type Output = <V as Vectorizable<N>>::Array;

async fn generic_reveal<'fut>(
&'fut self,
Expand Down
78 changes: 71 additions & 7 deletions ipa-core/src/protocol/context/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@ use std::{cmp::min, collections::VecDeque, future::Future};
use bitvec::{bitvec, prelude::BitVec};
use tokio::sync::watch;

use crate::{error::Error, helpers::TotalRecords, protocol::RecordId, sync::Mutex};
use crate::{
error::Error,
helpers::TotalRecords,
protocol::{context::dzkp_validator::TARGET_PROOF_SIZE, RecordId},
sync::Mutex,
};

/// Manages validation of batches of records for malicious protocols.
///
Expand Down Expand Up @@ -111,13 +116,14 @@ impl<'a, B> Batcher<'a, B> {
fn get_batch_by_offset(&mut self, batch_offset: usize) -> &mut BatchState<B> {
if self.batches.len() <= batch_offset {
self.batches.reserve(batch_offset - self.batches.len() + 1);
let pending_records_capacity = self.records_per_batch.min(TARGET_PROOF_SIZE);
while self.batches.len() <= batch_offset {
let (validation_result, _) = watch::channel::<bool>(false);
let state = BatchState {
batch: (self.batch_constructor)(self.first_batch + self.batches.len()),
validation_result,
pending_count: 0,
pending_records: bitvec![0; self.records_per_batch],
pending_records: bitvec![0; pending_records_capacity],
};
self.batches.push_back(Some(state));
}
Expand Down Expand Up @@ -152,10 +158,16 @@ impl<'a, B> Batcher<'a, B> {
let total_count = min(self.records_per_batch, remaining_records);
let record_offset_in_batch = usize::from(record_id) - first_record_in_batch;
let batch = self.get_batch_by_offset(batch_offset);
assert!(
!batch.pending_records[record_offset_in_batch],
"validate_record called twice for record {record_id}",
);
if batch.pending_records.len() <= record_offset_in_batch {
batch
.pending_records
.resize(record_offset_in_batch + 1, false);
} else {
assert!(
!batch.pending_records[record_offset_in_batch],
"validate_record called twice for record {record_id}",
);
}
// This assertion is stricter than the bounds check in `BitVec::set` when the
// batch size is not a multiple of 8, or for a partial final batch.
assert!(
Expand Down Expand Up @@ -273,7 +285,10 @@ impl<'a, B> Batcher<'a, B> {
mod tests {
use std::{future::ready, pin::pin};

use futures::future::{poll_immediate, try_join, try_join3, try_join4};
use futures::{
future::{join_all, poll_immediate, try_join, try_join3, try_join4},
FutureExt,
};

use super::*;

Expand Down Expand Up @@ -553,6 +568,55 @@ mod tests {
));
}

#[tokio::test]
async fn large_batch() {
// This test exercises the case where the preallocated size of `pending_records`
// was limited to `TARGET_PROOF_SIZE`, and we need to grow it alter.
let batcher = Batcher::new(
TARGET_PROOF_SIZE + 1,
TotalRecords::specified(TARGET_PROOF_SIZE + 1).unwrap(),
Box::new(|_| Vec::new()),
);

let mut futs = (0..TARGET_PROOF_SIZE)
.map(|i| {
batcher
.lock()
.unwrap()
.get_batch(RecordId::from(i))
.batch
.push(i);
batcher
.lock()
.unwrap()
.validate_record(RecordId::from(i), |_i, _b| async { unreachable!() })
.map(Result::unwrap)
.boxed()
})
.collect::<Vec<_>>();

batcher
.lock()
.unwrap()
.get_batch(RecordId::from(TARGET_PROOF_SIZE))
.batch
.push(TARGET_PROOF_SIZE);
futs.push(
batcher
.lock()
.unwrap()
.validate_record(RecordId::from(TARGET_PROOF_SIZE), |i, b| {
assert!(i == 0 && b.as_slice() == (0..=TARGET_PROOF_SIZE).collect::<Vec<_>>());
ready(Ok(()))
})
.map(Result::unwrap)
.boxed(),
);
join_all(futs).await;

assert!(batcher.lock().unwrap().is_empty());
}

#[test]
fn into_single_batch() {
let batcher = Batcher::new(2, TotalRecords::Unspecified, Box::new(|_| Vec::new()));
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/protocol/context/dzkp_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl<'a, B: ShardBinding> DZKPUpgraded<'a, B> {
base_ctx: MaliciousContext<'a, B>,
) -> Self {
let records_per_batch = validator_inner.batcher.lock().unwrap().records_per_batch();
let active_work = if records_per_batch == 1 {
let active_work = if records_per_batch == 1 || records_per_batch == usize::MAX {
// If records_per_batch is 1, let active_work be anything. This only happens
// in tests; there shouldn't be a risk of deadlocks with one record per
// batch; and UnorderedReceiver capacity (which is set from active_work)
Expand Down
85 changes: 76 additions & 9 deletions ipa-core/src/protocol/context/dzkp_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ const BIT_ARRAY_LEN: usize = 256;
const BIT_ARRAY_MASK: usize = BIT_ARRAY_LEN - 1;
const BIT_ARRAY_SHIFT: usize = BIT_ARRAY_LEN.ilog2() as usize;

// The target size of a zero-knowledge proof, in GF(2) multiplies. Seven intermediate
// values are stored for each multiply, so the amount memory required is 7 times this
// value.
//
// To enable computing a read size for `OrdereringSender` that achieves good network
// utilization, the number of records in a proof must be a power of two. Protocols
// typically compute the size of a proof batch by dividing TARGET_PROOF_SIZE by
// an approximate number of multiplies per record, and then rounding up to a power
// of two. Thus, it is not necessary for TARGET_PROOF_SIZE to be a power of two.
//
// A smaller value is used for tests, to enable covering some corner cases with a
// reasonable runtime. Some of these tests use TARGET_PROOF_SIZE directly, so for tests
// it does need to be a power of two.
#[cfg(test)]
pub const TARGET_PROOF_SIZE: usize = 8192;
#[cfg(not(test))]
pub const TARGET_PROOF_SIZE: usize = 50_000_000;

/// `MultiplicationInputsBlock` is a block of fixed size of intermediate values
Expand Down Expand Up @@ -257,7 +273,7 @@ impl MultiplicationInputsBatch {
// records.
let capacity_bits = usize::min(
TARGET_PROOF_SIZE,
max_multiplications * multiplication_bit_size,
max_multiplications.saturating_mul(multiplication_bit_size),
);
Self {
first_record,
Expand Down Expand Up @@ -295,7 +311,7 @@ impl MultiplicationInputsBatch {
// panics when record_id is out of bounds
assert!(record_id >= self.first_record);
assert!(
record_id < RecordId::from(self.max_multiplications + usize::from(self.first_record)),
usize::from(record_id) < self.max_multiplications + usize::from(self.first_record),
"record_id out of range in insert_segment. record {record_id} is beyond \
segment of length {} starting at {}",
self.max_multiplications,
Expand Down Expand Up @@ -326,9 +342,7 @@ impl MultiplicationInputsBatch {

// panics when record_id is out of bounds
assert!(record_id >= self.first_record);
assert!(
record_id < RecordId::from(self.max_multiplications + usize::from(self.first_record))
);
assert!(usize::from(record_id) < self.max_multiplications + usize::from(self.first_record));

// panics when record_id is less than first_record
let id_within_batch = usize::from(record_id) - usize::from(self.first_record);
Expand Down Expand Up @@ -377,9 +391,7 @@ impl MultiplicationInputsBatch {

// panics when record_id is out of bounds
assert!(record_id >= self.first_record);
assert!(
record_id < RecordId::from(self.max_multiplications + usize::from(self.first_record))
);
assert!(usize::from(record_id) < self.max_multiplications + usize::from(self.first_record));

let id_within_batch = usize::from(record_id) - usize::from(self.first_record);
let block_id = (segment.len() * id_within_batch) >> BIT_ARRAY_SHIFT;
Expand Down Expand Up @@ -866,7 +878,7 @@ mod tests {
replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue,
Vectorizable,
},
seq_join::seq_join,
seq_join::{seq_join, SeqJoin},
sharding::NotSharded,
test_fixture::{join3v, Reconstruct, Runner, TestWorld},
};
Expand Down Expand Up @@ -1254,6 +1266,61 @@ mod tests {
}
}

#[tokio::test]
async fn large_batch() {
multi_select_malicious::<BA8>(2 * TARGET_PROOF_SIZE, 2 * TARGET_PROOF_SIZE).await;
}

// Similar to multi_select_malicious, but instead of using `validated_seq_join`, passes
// `usize::MAX` as the batch size and does a single `v.validate()`.
#[tokio::test]
async fn large_single_batch() {
let count: usize = TARGET_PROOF_SIZE + 1;
let mut rng = thread_rng();

let bit: Vec<Boolean> = repeat_with(|| rng.gen::<Boolean>()).take(count).collect();
let a: Vec<BA8> = repeat_with(|| rng.gen()).take(count).collect();
let b: Vec<BA8> = repeat_with(|| rng.gen()).take(count).collect();

let [ab0, ab1, ab2]: [Vec<Replicated<BA8>>; 3] = TestWorld::default()
.malicious(
zip(bit.clone(), zip(a.clone(), b.clone())),
|ctx, inputs| async move {
let v = ctx
.set_total_records(count)
.dzkp_validator(TEST_DZKP_STEPS, usize::MAX);
let m_ctx = v.context();

let result = seq_join(
m_ctx.active_work(),
stream::iter(inputs).enumerate().map(
|(i, (bit_share, (a_share, b_share)))| {
let m_ctx = m_ctx.clone();
async move {
select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share)
.await
}
},
),
)
.try_collect()
.await
.unwrap();

v.validate().await.unwrap();

result
},
)
.await;

let ab: Vec<BA8> = [ab0, ab1, ab2].reconstruct();

for i in 0..count {
assert_eq!(ab[i], if bit[i].into() { a[i] } else { b[i] });
}
}

#[tokio::test]
#[should_panic(expected = "ContextUnsafe(\"DZKPMaliciousContext\")")]
async fn missing_validate() {
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/protocol/dp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ where
let aggregation_input = Box::pin(stream::iter(vector_input_to_agg.into_iter()).map(Ok));
// Step 3: Call `aggregate_values` to sum up Bernoulli noise.
let noise_vector: Result<BitDecomposed<AdditiveShare<Boolean, { B }>>, Error> =
aggregate_values::<_, OV, B>(ctx, aggregation_input, num_bernoulli).await;
aggregate_values::<_, OV, B>(ctx, aggregation_input, num_bernoulli, None).await;
noise_vector
}
/// `apply_dp_noise` takes the noise distribution parameters (`num_bernoulli` and in the future `quantization_scale`)
Expand Down
Loading

0 comments on commit 41b057c

Please sign in to comment.