Skip to content

Commit

Permalink
Use batcher for DZKPs. Enable tests of DZKP batching. (#1250)
Browse files Browse the repository at this point in the history
  • Loading branch information
andyleiserson authored Sep 7, 2024
1 parent bc05dbd commit 8cc3fb2
Show file tree
Hide file tree
Showing 12 changed files with 409 additions and 535 deletions.
11 changes: 2 additions & 9 deletions ipa-core/src/protocol/basics/mul/dzkp_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
protocol::{
basics::{mul::semi_honest::multiplication_protocol, SecureMul},
context::{
dzkp_field::DZKPCompatibleField, dzkp_validator::Segment, Context, DZKPContext,
dzkp_field::DZKPCompatibleField, dzkp_validator::Segment, Context,
DZKPUpgradedMaliciousContext,
},
prss::SharedRandomness,
Expand Down Expand Up @@ -81,11 +81,10 @@ impl<'a, F: Field + DZKPCompatibleField<N>, const N: usize>
#[cfg(all(test, unit_test))]
mod test {
use crate::{
error::Error,
ff::boolean::Boolean,
protocol::{
basics::SecureMul,
context::{dzkp_validator::DZKPValidator, Context, DZKPContext, UpgradableContext},
context::{dzkp_validator::DZKPValidator, Context, UpgradableContext},
RecordId,
},
rand::{thread_rng, Rng},
Expand All @@ -109,15 +108,9 @@ mod test {
.await
.unwrap();

// batch contains elements
assert!(matches!(mctx.is_verified(), Err(Error::ContextUnsafe(_))));

// validate all elements in the batch
validator.validate().await.unwrap();

// batch is empty now
assert!(mctx.is_verified().is_ok());

result
})
.await;
Expand Down
23 changes: 9 additions & 14 deletions ipa-core/src/protocol/basics/reveal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::{

use embed_doc_image::embed_doc_image;
use futures::{FutureExt, TryFutureExt};
use ipa_step::{Step, StepNarrow};

use crate::{
error::Error,
Expand All @@ -14,10 +13,10 @@ use crate::{
protocol::{
boolean::step::TwoHundredFiftySixBitOpStep,
context::{
dzkp_validator::DZKPValidator, Context, DZKPUpgradedMaliciousContext,
DZKPUpgradedSemiHonestContext, UpgradedMaliciousContext, UpgradedSemiHonestContext,
Context, DZKPContext, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext,
UpgradedMaliciousContext, UpgradedSemiHonestContext,
},
Gate, RecordId,
RecordId,
},
secret_sharing::{
replicated::{
Expand Down Expand Up @@ -382,21 +381,17 @@ where
S::generic_reveal(v, ctx, record_id, excluded)
}

pub async fn validated_partial_reveal<'fut, V, S, STEP>(
validator: V,
step: &'fut STEP,
pub async fn validated_partial_reveal<'fut, C, S>(
ctx: C,
record_id: RecordId,
excluded: Role,
v: &'fut S,
) -> Result<Option<<S as Reveal<V::Context>>::Output>, Error>
) -> Result<Option<<S as Reveal<C>>::Output>, Error>
where
V: DZKPValidator + 'fut,
S: Reveal<V::Context> + Send + Sync + ?Sized,
STEP: Step + Send + Sync + 'static,
Gate: StepNarrow<STEP>,
C: DZKPContext + 'fut,
S: Reveal<C> + Send + Sync + ?Sized,
{
let ctx = validator.context().narrow(step);
validator.validate_record(record_id).await?;
ctx.validate_record(record_id).await?;
partial_reveal(ctx, record_id, excluded, v).await
}

Expand Down
145 changes: 117 additions & 28 deletions ipa-core/src/protocol/context/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,25 @@ use crate::{
/// 3. Either:
/// a. Call `Batcher::validate_record` for each record.
/// b. Call `Batcher::into_single_batch` once.
///
/// The `Batcher` may panic if an attempt is made to continue using
/// a batch after it has already been validated.
pub(super) struct Batcher<'a, B> {
batches: VecDeque<BatchState<B>>,
/// Outstanding batches.
///
/// Normally, `batches` are validated in order off the front, but the `Option` is
/// necessary to support validation requests arriving out of order. There is
/// no memory overhead for the `Option` as long as there is a
/// [niche](https://rustc-dev-guide.rust-lang.org/appendix/glossary.html#niche)
/// somewhere in `BatchState<B>`.
batches: VecDeque<Option<BatchState<B>>>,

/// Absolute index of the first element of `batches`.
first_batch: usize,
records_per_batch: usize,
total_records: Option<usize>,

/// Used to initialize new batches.
batch_constructor: Box<dyn Fn(usize) -> B + Send + 'a>,
}

Expand All @@ -39,6 +53,26 @@ pub(super) struct BatchState<B> {
pending_records: BitVec,
}

trait ExpectBatch {
type Ok;

/// Specialized `Option::expect` for batch-related values.
///
/// Constructs an error message based on the supplied context.
fn expect_not_yet_validated(self, batch_index: usize) -> Self::Ok;
}

impl<T> ExpectBatch for Option<T> {
type Ok = T;

fn expect_not_yet_validated(self, batch_index: usize) -> T {
let Some(value) = self else {
panic!("Attempting to access batch {batch_index}, which has already been validated.");
};
value
}
}

// Helper for `Batcher::validate_record` and `Batcher::is_ready_for_validation`.
enum Ready<B> {
No(watch::Receiver<bool>),
Expand All @@ -65,14 +99,9 @@ impl<'a, B> Batcher<'a, B> {

fn batch_offset(&self, record_id: RecordId) -> usize {
let batch_index = usize::from(record_id) / self.records_per_batch;
let Some(batch_offset) = batch_index.checked_sub(self.first_batch) else {
panic!(
"Batches should be processed in order. Attempting to retrieve batch {batch_index}. \
The oldest active batch is batch {}.",
self.first_batch,
)
};
batch_offset
batch_index
.checked_sub(self.first_batch)
.expect_not_yet_validated(batch_index)
}

fn get_batch_by_offset(&mut self, batch_offset: usize) -> &mut BatchState<B> {
Expand All @@ -86,13 +115,17 @@ impl<'a, B> Batcher<'a, B> {
pending_count: 0,
pending_records: bitvec![0; self.records_per_batch],
};
self.batches.push_back(state);
self.batches.push_back(Some(state));
}
}

&mut self.batches[batch_offset]
self.batches[batch_offset]
.as_mut()
.expect_not_yet_validated(self.first_batch + batch_offset)
}

/// # Panics
/// If the requested batch has already been validated.
pub fn get_batch(&mut self, record_id: RecordId) -> &mut BatchState<B> {
self.get_batch_by_offset(self.batch_offset(record_id))
}
Expand All @@ -118,26 +151,29 @@ impl<'a, B> Batcher<'a, B> {
batch.pending_records.set(record_offset_in_batch, true);
batch.pending_count += 1;
if batch.pending_count == total_count {
// I am not sure if this is okay, or if we need to tolerate batch validation requests
// arriving out of order. (If we do, I think we would still want to actually fulfill
// the validations in order.)
assert_eq!(
batch_offset,
0,
"Batches should be processed in order. \
Batch {batch_index} is ready for validation, but the first batch is {first}.",
first = self.first_batch,
);
assert!(batch.pending_records[0..total_count].all());
tracing::info!("batch {batch_index} is ready for validation");
let batch = self.batches.pop_front().unwrap();
self.first_batch += 1;
let batch;
if batch_offset == 0 {
batch = self.batches.pop_front().unwrap();
self.first_batch += 1;
// Also remove any batches that completed out of order
while let Some(None) = self.batches.front() {
self.batches.pop_front();
self.first_batch += 1;
}
} else {
batch = self.batches[batch_offset].take();
}
let batch = batch.expect_not_yet_validated(self.first_batch + batch_offset);
Ok(Ready::Yes { batch_index, batch })
} else {
Ok(Ready::No(batch.validation_result.subscribe()))
}
}

/// # Panics
/// If the requested batch has already been validated.
pub fn validate_record<VF, Fut>(
&mut self,
record_id: RecordId,
Expand Down Expand Up @@ -196,19 +232,22 @@ impl<'a, B> Batcher<'a, B> {
///
/// # Panics
/// If the batcher contains more than one batch.
#[allow(dead_code)]
pub fn into_single_batch(mut self) -> B {
assert!(self.first_batch == 0);
assert!(self.batches.len() <= 1);
let batch_index = 0;
match self.batches.pop_back() {
Some(state) => state.batch,
Some(state) => {
state
.expect_not_yet_validated(self.first_batch + batch_index)
.batch
}
None => (self.batch_constructor)(0),
}
}

#[allow(dead_code)]
pub fn iter(&self) -> impl Iterator<Item = &BatchState<B>> {
self.batches.iter()
pub fn is_empty(&self) -> bool {
self.batches.len() == 0
}
}

Expand Down Expand Up @@ -264,6 +303,51 @@ mod tests {
};

assert_eq!(results.await.unwrap(), ((), (), (), ()));
assert!(batcher.lock().unwrap().is_empty());
}

#[tokio::test]
async fn validates_batches_out_of_order() {
// Test the case where the batches arrive for validation out of order. Under
// normal usage, this is unlikely, but has been observed to happen in e.g.
// test_malicious_convert_to_fp25519.
let batcher = Batcher::new(2, Some(4), Box::new(|_| Vec::new()));

for i in 0..4 {
batcher
.lock()
.unwrap()
.get_batch(RecordId::from(i))
.batch
.push(i);
}

let fut0 = batcher
.lock()
.unwrap()
.validate_record(RecordId::from(2), |_i, _b| async { unreachable!() });
let fut1 = batcher
.lock()
.unwrap()
.validate_record(RecordId::from(3), |i, b| {
assert!(i == 1 && b.as_slice() == [2, 3]);
ready(Ok(()))
});
try_join(fut0, fut1).await.unwrap();
let fut2 = batcher
.lock()
.unwrap()
.validate_record(RecordId::from(0), |_i, _b| async { unreachable!() });
let fut3 = batcher
.lock()
.unwrap()
.validate_record(RecordId::from(1), |i, b| {
assert!(i == 0 && b.as_slice() == [0, 1]);
ready(Ok(()))
});
try_join(fut2, fut3).await.unwrap();

assert!(batcher.lock().unwrap().is_empty());
}

#[tokio::test]
Expand Down Expand Up @@ -311,6 +395,8 @@ mod tests {

assert!(matches!(fut3.await, Ok(())));
assert!(matches!(poll_immediate(&mut fut2).await, Some(Ok(()))));

assert!(batcher.lock().unwrap().is_empty());
}

#[tokio::test]
Expand Down Expand Up @@ -361,6 +447,8 @@ mod tests {

assert!(matches!(fut3.await, Ok(())));
assert!(matches!(poll_immediate(&mut fut2).await, Some(Ok(()))));

assert!(batcher.lock().unwrap().is_empty());
}

#[tokio::test]
Expand Down Expand Up @@ -390,6 +478,7 @@ mod tests {
};

assert_eq!(results.await.unwrap(), ((), (), ()));
assert!(batcher.lock().unwrap().is_empty());
}

#[tokio::test]
Expand Down
Loading

0 comments on commit 8cc3fb2

Please sign in to comment.