Skip to content

Commit

Permalink
Merge pull request #1251 from private-attribution/andy/malicious-aggr…
Browse files Browse the repository at this point in the history
…egation

Malicious aggregation and DP noising
  • Loading branch information
andyleiserson authored Sep 7, 2024
2 parents 8cc3fb2 + da67550 commit 9676fbe
Show file tree
Hide file tree
Showing 12 changed files with 277 additions and 122 deletions.
9 changes: 9 additions & 0 deletions ipa-core/src/protocol/basics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ impl<'a, B: ShardBinding> BooleanProtocols<DZKPUpgradedSemiHonestContext<'a, B>>
{
}

impl<'a> BooleanProtocols<DZKPUpgradedMaliciousContext<'a>> for AdditiveShare<Boolean> {}

// Used for aggregation tests
impl<'a, B: ShardBinding> BooleanProtocols<UpgradedSemiHonestContext<'a, B, Boolean>, 8>
for AdditiveShare<Boolean, 8>
Expand All @@ -105,6 +107,11 @@ impl<'a, B: ShardBinding> BooleanProtocols<DZKPUpgradedSemiHonestContext<'a, B>,
{
}

impl<'a> BooleanProtocols<DZKPUpgradedMaliciousContext<'a>, PRF_CHUNK>
for AdditiveShare<Boolean, PRF_CHUNK>
{
}

impl<'a, B: ShardBinding> BooleanProtocols<UpgradedSemiHonestContext<'a, B, Boolean>, AGG_CHUNK>
for AdditiveShare<Boolean, AGG_CHUNK>
{
Expand Down Expand Up @@ -147,6 +154,8 @@ impl<'a, B: ShardBinding> BooleanProtocols<DZKPUpgradedSemiHonestContext<'a, B>,
{
}

impl<'a> BooleanProtocols<DZKPUpgradedMaliciousContext<'a>, 32> for AdditiveShare<Boolean, 32> {}

const_assert_eq!(
AGG_CHUNK,
256,
Expand Down
53 changes: 36 additions & 17 deletions ipa-core/src/protocol/context/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use tokio::sync::watch;

use crate::{
error::Error,
helpers::TotalRecords,
protocol::RecordId,
sync::{Arc, Mutex},
};
Expand Down Expand Up @@ -34,7 +35,7 @@ pub(super) struct Batcher<'a, B> {
/// Absolute index of the first element of `batches`.
first_batch: usize,
records_per_batch: usize,
total_records: Option<usize>,
total_records: TotalRecords,

/// Used to initialize new batches.
batch_constructor: Box<dyn Fn(usize) -> B + Send + 'a>,
Expand Down Expand Up @@ -83,20 +84,24 @@ enum Ready<B> {
}

impl<'a, B> Batcher<'a, B> {
pub fn new(
pub fn new<T: Into<TotalRecords>>(
records_per_batch: usize,
total_records: Option<usize>,
total_records: T,
batch_constructor: Box<dyn Fn(usize) -> B + Send + 'a>,
) -> Arc<Mutex<Self>> {
Arc::new(Mutex::new(Self {
batches: VecDeque::new(),
first_batch: 0,
records_per_batch,
total_records,
total_records: total_records.into(),
batch_constructor,
}))
}

pub fn set_total_records<T: Into<TotalRecords>>(&mut self, total_records: T) {
self.total_records = self.total_records.overwrite(total_records.into());
}

fn batch_offset(&self, record_id: RecordId) -> usize {
let batch_index = usize::from(record_id) / self.records_per_batch;
batch_index
Expand Down Expand Up @@ -131,7 +136,7 @@ impl<'a, B> Batcher<'a, B> {
}

fn is_ready_for_validation(&mut self, record_id: RecordId) -> Result<Ready<B>, Error> {
let Some(total_records) = self.total_records else {
let Some(total_records) = self.total_records.count() else {
return Err(Error::MissingTotalRecords(String::from("validate_record")));
};

Expand All @@ -148,10 +153,24 @@ impl<'a, B> Batcher<'a, B> {
let total_count = min(self.records_per_batch, remaining_records);
let record_offset_in_batch = usize::from(record_id) - first_record_in_batch;
let batch = self.get_batch_by_offset(batch_offset);
assert!(
!batch.pending_records[record_offset_in_batch],
"validate_record called twice for record {record_id}",
);
// This assertion is stricter than the bounds check in `BitVec::set` when the
// batch size is not a multiple of 8, or for a partial final batch.
assert!(
record_offset_in_batch < total_count,
"record offset {record_offset_in_batch} exceeds batch size {total_count}",
);
batch.pending_records.set(record_offset_in_batch, true);
batch.pending_count += 1;
if batch.pending_count == total_count {
assert!(batch.pending_records[0..total_count].all());
assert!(
batch.pending_records[0..total_count].all(),
"Expected batch of {total_count} records to be ready for validation, but only have {:?}.",
&batch.pending_records[0..total_count],
);
tracing::info!("batch {batch_index} is ready for validation");
let batch;
if batch_offset == 0 {
Expand Down Expand Up @@ -261,7 +280,7 @@ mod tests {

#[test]
fn makes_batches() {
let batcher = Batcher::new(2, Some(4), Box::new(|_| Vec::new()));
let batcher = Batcher::new(2, 4, Box::new(|_| Vec::new()));
let mut batcher = batcher.lock().unwrap();

for i in 0..4 {
Expand All @@ -280,7 +299,7 @@ mod tests {

#[tokio::test]
async fn validates_batches() {
let batcher = Batcher::new(2, Some(4), Box::new(|_| Vec::new()));
let batcher = Batcher::new(2, 4, Box::new(|_| Vec::new()));
let results = {
let mut batcher = batcher.lock().unwrap();

Expand Down Expand Up @@ -311,7 +330,7 @@ mod tests {
// Test the case where the batches arrive for validation out of order. Under
// normal usage, this is unlikely, but has been observed to happen in e.g.
// test_malicious_convert_to_fp25519.
let batcher = Batcher::new(2, Some(4), Box::new(|_| Vec::new()));
let batcher = Batcher::new(2, 4, Box::new(|_| Vec::new()));

for i in 0..4 {
batcher
Expand Down Expand Up @@ -352,7 +371,7 @@ mod tests {

#[tokio::test]
async fn validates_batches_async() {
let batcher = Batcher::new(2, Some(4), Box::new(|_| Vec::new()));
let batcher = Batcher::new(2, 4, Box::new(|_| Vec::new()));

for i in 0..4 {
batcher
Expand Down Expand Up @@ -401,7 +420,7 @@ mod tests {

#[tokio::test]
async fn validation_failure() {
let batcher = Batcher::new(2, Some(4), Box::new(|_| Vec::new()));
let batcher = Batcher::new(2, 4, Box::new(|_| Vec::new()));

for i in 0..4 {
batcher
Expand Down Expand Up @@ -453,7 +472,7 @@ mod tests {

#[tokio::test]
async fn handles_partial_final_batch() {
let batcher = Batcher::new(2, Some(3), Box::new(|_| Vec::new()));
let batcher = Batcher::new(2, 3, Box::new(|_| Vec::new()));
let results = {
let mut batcher = batcher.lock().unwrap();

Expand Down Expand Up @@ -483,7 +502,7 @@ mod tests {

#[tokio::test]
async fn requires_total_records_in_validate_record() {
let batcher = Batcher::new(2, None, Box::new(|_| Vec::new()));
let batcher = Batcher::new(2, TotalRecords::Unspecified, Box::new(|_| Vec::new()));
let result = {
let mut batcher = batcher.lock().unwrap();
batcher.get_batch(RecordId::FIRST).batch.push(0);
Expand All @@ -496,7 +515,7 @@ mod tests {

#[tokio::test]
async fn record_id_out_of_range() {
let batcher = Batcher::new(2, Some(1), Box::new(|_| Vec::new()));
let batcher = Batcher::new(2, 1, Box::new(|_| Vec::new()));

for i in 0..2 {
batcher
Expand All @@ -520,7 +539,7 @@ mod tests {

#[test]
fn into_single_batch() {
let batcher = Batcher::new(2, None, Box::new(|_| Vec::new()));
let batcher = Batcher::new(2, TotalRecords::Unspecified, Box::new(|_| Vec::new()));

for i in 0..2 {
batcher
Expand All @@ -538,7 +557,7 @@ mod tests {
#[test]
#[should_panic(expected = "assertion failed: self.batches.len() <= 1")]
fn into_single_batch_fails_with_multiple_batches() {
let batcher = Batcher::new(2, None, Box::new(|_| Vec::new()));
let batcher = Batcher::new(2, TotalRecords::Unspecified, Box::new(|_| Vec::new()));

for i in 0..4 {
batcher
Expand All @@ -556,7 +575,7 @@ mod tests {
#[tokio::test]
#[should_panic(expected = "assertion failed: self.first_batch == 0")]
async fn into_single_batch_fails_after_first_batch() {
let batcher = Batcher::new(2, Some(4), Box::new(|_| Vec::new()));
let batcher = Batcher::new(2, 4, Box::new(|_| Vec::new()));

for i in 0..4 {
batcher
Expand Down
21 changes: 20 additions & 1 deletion ipa-core/src/protocol/context/dzkp_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use futures::{stream, Future, FutureExt, Stream, StreamExt};
use crate::{
error::{BoxError, Error},
ff::{Fp61BitPrime, U128Conversions},
helpers::TotalRecords,
protocol::{
context::{
batcher::Batcher,
Expand Down Expand Up @@ -600,6 +601,11 @@ pub trait DZKPValidator: Send + Sync {

fn context(&self) -> Self::Context;

/// Sets the validator's total number of records field. This is required when using
/// the validate_record API, if it wasn't already set on the context used to create
/// the validator.
fn set_total_records<T: Into<TotalRecords>>(&mut self, total_records: T);

/// Validates all of the multiplies associated with this validator.
///
/// Only one of the `DZKPValidator::validate` or the `DZKPContext::validate_record`
Expand Down Expand Up @@ -677,6 +683,10 @@ impl<'a, B: ShardBinding> DZKPValidator for SemiHonestDZKPValidator<'a, B> {
self.context.clone()
}

fn set_total_records<T: Into<TotalRecords>>(&mut self, _total_records: T) {
// Semi-honest validator doesn't do anything, so doesn't care.
}

async fn validate(self) -> Result<(), Error> {
Ok(())
}
Expand Down Expand Up @@ -705,6 +715,15 @@ impl<'a> DZKPValidator for MaliciousDZKPValidator<'a> {
self.protocol_ctx.clone()
}

fn set_total_records<T: Into<TotalRecords>>(&mut self, total_records: T) {
self.batcher_ref
.as_ref()
.unwrap()
.lock()
.unwrap()
.set_total_records(total_records);
}

async fn validate(mut self) -> Result<(), Error> {
let batcher_arc = self
.batcher_ref
Expand Down Expand Up @@ -744,7 +763,7 @@ impl<'a> MaliciousDZKPValidator<'a> {
pub fn new(ctx: MaliciousContext<'a>, max_multiplications_per_gate: usize) -> Self {
let batcher = Batcher::new(
max_multiplications_per_gate,
ctx.total_records().count(),
ctx.total_records(),
Box::new(move |batch_index| {
Batch::new(
RecordId::from(batch_index * max_multiplications_per_gate),
Expand Down
6 changes: 3 additions & 3 deletions ipa-core/src/protocol/context/validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,19 +211,19 @@ impl<'a, F: ExtendableField> BatchValidator<'a, F> {
/// If total records is not set.
#[must_use]
pub fn new(ctx: MaliciousContext<'a>) -> Self {
let Some(total_records) = ctx.total_records().count() else {
let TotalRecords::Specified(total_records) = ctx.total_records() else {
panic!("Total records must be specified before creating the validator");
};

// TODO: Right now we set the batch work to be equal to active_work,
// but it does not need to be. We can make this configurable if needed.
let records_per_batch = ctx.active_work().get().min(total_records);
let records_per_batch = ctx.active_work().get().min(total_records.get());

Self {
protocol_ctx: ctx.narrow(&Step::MaliciousProtocol),
batches_ref: Batcher::new(
records_per_batch,
Some(total_records),
total_records,
Box::new(move |batch_index| Malicious::new(ctx.clone(), batch_index)),
),
}
Expand Down
Loading

0 comments on commit 9676fbe

Please sign in to comment.