diff --git a/benches/oneshot/ipa.rs b/benches/oneshot/ipa.rs index 3056e7be8..739eae364 100644 --- a/benches/oneshot/ipa.rs +++ b/benches/oneshot/ipa.rs @@ -9,7 +9,7 @@ async fn main() -> Result<(), Error> { let mut config = TestWorldConfig::default(); config.gateway_config.send_buffer_config.items_in_batch = 1; config.gateway_config.send_buffer_config.batch_count = 1000; - let world = TestWorld::new_with(config); + let world = TestWorld::new_with(config).await; let mut rng = rand::thread_rng(); const BATCHSIZE: u64 = 100; diff --git a/benches/oneshot/sort.rs b/benches/oneshot/sort.rs index dd93f7503..64386aed0 100644 --- a/benches/oneshot/sort.rs +++ b/benches/oneshot/sort.rs @@ -13,7 +13,7 @@ async fn main() -> Result<(), Error> { let mut config = TestWorldConfig::default(); config.gateway_config.send_buffer_config.items_in_batch = 1; config.gateway_config.send_buffer_config.batch_count = 1000; - let world = TestWorld::new_with(config); + let world = TestWorld::new_with(config).await; let [ctx0, ctx1, ctx2] = world.contexts::(); let num_bits = 64; let mut rng = thread_rng(); diff --git a/src/bin/test_mpc.rs b/src/bin/test_mpc.rs index 07022f616..a2574b611 100644 --- a/src/bin/test_mpc.rs +++ b/src/bin/test_mpc.rs @@ -70,7 +70,7 @@ fn print_output(values: &[Vec; 3]) { ]); } - println!("{}", shares_table); + println!("{shares_table}"); } #[tokio::main] diff --git a/src/helpers/buffers/mod.rs b/src/helpers/buffers/mod.rs index 450efca82..663f52a2c 100644 --- a/src/helpers/buffers/mod.rs +++ b/src/helpers/buffers/mod.rs @@ -3,6 +3,7 @@ mod receive; mod send; pub use receive::ReceiveBuffer; +pub(super) use send::PushError; pub use {send::Config as SendBufferConfig, send::SendBuffer}; #[cfg(debug_assertions)] diff --git a/src/helpers/buffers/send.rs b/src/helpers/buffers/send.rs index a5336dc4f..1c0ace7ef 100644 --- a/src/helpers/buffers/send.rs +++ b/src/helpers/buffers/send.rs @@ -1,7 +1,6 @@ use crate::{ helpers::{ - buffers::fsv::FixedSizeByteVec, - network::{ChannelId, MessageEnvelope}, + buffers::fsv::FixedSizeByteVec, network::ChannelId, network::MessageEnvelope, MESSAGE_PAYLOAD_SIZE_BYTES, }, protocol::RecordId, @@ -169,12 +168,8 @@ impl From<&ByteBuf> for Range { #[cfg(all(test, not(feature = "shuttle")))] mod tests { - use crate::helpers::buffers::send::{ByteBuf, Config, PushError}; - use crate::helpers::buffers::SendBuffer; - use crate::helpers::network::{ChannelId, MessageEnvelope}; - use crate::helpers::Role; - use crate::protocol::{RecordId, Step}; - + use super::*; + use crate::{helpers::Role, protocol::Step}; use tinyvec::array_vec; impl Clone for MessageEnvelope { diff --git a/src/helpers/error.rs b/src/helpers/error.rs index b522b327c..372ea3466 100644 --- a/src/helpers/error.rs +++ b/src/helpers/error.rs @@ -1,10 +1,10 @@ -use crate::helpers::messaging::SendRequest; +use crate::helpers::TransportError; use crate::{ error::BoxError, helpers::{ - messaging::ReceiveRequest, + messaging::{ReceiveRequest, SendRequest}, network::{ChannelId, MessageChunks}, - Role, + HelperIdentity, Role, }, net::MpcHelperServerError, protocol::{RecordId, Step}, @@ -40,11 +40,12 @@ pub enum Error { #[source] inner: BoxError, }, - #[error("Failed to send data to the network")] - NetworkError { - #[from] - inner: BoxError, - }, + #[error("Encountered unknown identity {0:?}")] + UnknownIdentity(HelperIdentity), + #[error("identity had invalid format: {0}")] + InvalidIdentity(#[from] hyper::http::uri::InvalidUri), + #[error("Failed to send command on the transport: {0}")] + TransportError(#[from] TransportError), #[error("server encountered an error: {0}")] ServerError(#[from] MpcHelperServerError), } @@ -104,15 +105,10 @@ impl From> for Error { impl From> for Error { fn from(source: PollSendError) -> Self { - let err_msg = source.to_string(); + let inner = source.to_string().into(); match source.into_inner() { - Some(inner) => Self::SendError { - channel: inner.0, - inner: err_msg.into(), - }, - None => Self::PollSendError { - inner: err_msg.into(), - }, + Some((channel, _)) => Self::SendError { channel, inner }, + None => Self::PollSendError { inner }, } } } diff --git a/src/helpers/http/mod.rs b/src/helpers/http/mod.rs index 1cb489d39..23eb57050 100644 --- a/src/helpers/http/mod.rs +++ b/src/helpers/http/mod.rs @@ -21,7 +21,7 @@ use std::net::SocketAddr; pub struct HttpHelper<'p> { role: Role, peers: &'p [peer::Config; 3], - gateway_config: GatewayConfig, + _gateway_config: GatewayConfig, server: MpcHelperServer, } @@ -36,7 +36,7 @@ impl<'p> HttpHelper<'p> { Self { role, peers, - gateway_config, + _gateway_config: gateway_config, server: MpcHelperServer::new(MessageSendMap::default()), } } @@ -57,15 +57,17 @@ impl<'p> HttpHelper<'p> { /// adds a query to the running server so that it knows where to send arriving data /// # Errors /// if a query has been previously added - pub fn query(&self, query_id: QueryId) -> Result { - tracing::debug!("starting query {}", query_id.as_ref()); - let network = HttpNetwork::new(self.role, self.peers, query_id); - - let gateway = Gateway::new(self.role, &network, self.gateway_config); - // allow for server to forward requests to this network - // TODO: how to remove from map? - self.server.add_query(query_id, network)?; - Ok(gateway) + pub fn query(&self, _query_id: QueryId) -> Result { + // TODO: This requires `HttpNetwork` to implement Transport + unimplemented!(); + // tracing::debug!("starting query {}", query_id.as_ref()); + // let network = HttpNetwork::new(self.role, self.peers, query_id); + // + // let gateway = Gateway::new(self.role, network, self.gateway_config); + // // allow for server to forward requests to this network + // // TODO: how to remove from map? + // self.server.add_query(query_id, network)?; + // Ok(gateway) } /// establish the prss endpoint by exchanging public keys with the other helpers @@ -208,6 +210,7 @@ mod e2e_tests { } #[tokio::test] + #[ignore] // TODO (thurstonsand): enable after `HttpNetwork` implements `Transport` async fn prss_key_exchange() { logging::setup(); @@ -275,6 +278,7 @@ mod e2e_tests { } #[tokio::test] + #[ignore] // TODO (thurstonsand): enable after `HttpNetwork` implements `Transport` async fn basic_mul() { logging::setup(); diff --git a/src/helpers/http/network.rs b/src/helpers/http/network.rs index afabdfe3f..d25de2a64 100644 --- a/src/helpers/http/network.rs +++ b/src/helpers/http/network.rs @@ -1,8 +1,7 @@ +#[allow(deprecated)] +use crate::helpers::old_network::{Network, NetworkSink}; use crate::{ - helpers::{ - network::{MessageChunks, Network, NetworkSink}, - Role, - }, + helpers::{network::MessageChunks, Role}, net::{discovery::peer, HttpSendMessagesArgs, MpcHelperClient}, protocol::QueryId, sync::{Arc, Mutex}, @@ -99,6 +98,7 @@ impl HttpNetwork { } } +#[allow(deprecated)] impl Network for HttpNetwork { type Sink = NetworkSink; @@ -122,15 +122,16 @@ impl Network for HttpNetwork { #[cfg(test)] mod tests { use super::*; - use crate::test_fixture::net::localhost_config; use crate::{ helpers::{network::ChannelId, Direction, MESSAGE_PAYLOAD_SIZE_BYTES}, net::{discovery::PeerDiscovery, BindTarget, MessageSendMap, MpcHelperServer}, protocol::Step, + test_fixture::net::localhost_config, }; use futures::{Stream, StreamExt}; use futures_util::SinkExt; + #[allow(deprecated)] async fn setup() -> (Role, [peer::Config; 3], impl Stream) { // setup server let network = HttpNetwork::new_without_clients(QueryId, None); @@ -149,6 +150,7 @@ mod tests { } #[tokio::test] + #[allow(deprecated)] async fn send_multiple_messages() { const DATA_LEN: usize = 3; let (target_role, peers_conf, mut rx_stream) = setup().await; diff --git a/src/helpers/messaging.rs b/src/helpers/messaging.rs index 72d77b85d..6f2c90020 100644 --- a/src/helpers/messaging.rs +++ b/src/helpers/messaging.rs @@ -6,21 +6,17 @@ //! corresponding helper without needing to know the exact location - this is what this module //! enables MPC protocols to do. //! -use crate::ff::{Field, Int}; -use crate::helpers::buffers::{SendBuffer, SendBufferConfig}; -use crate::helpers::{MessagePayload, MESSAGE_PAYLOAD_SIZE_BYTES}; -use crate::task::JoinHandle; -use crate::telemetry::labels::STEP; use crate::{ - helpers::buffers::ReceiveBuffer, - helpers::error::Error, - helpers::network::{ChannelId, MessageEnvelope, Network}, - helpers::Role, + ff::{Field, Int}, + helpers::{ + buffers::{ReceiveBuffer, SendBuffer, SendBufferConfig}, + network::ChannelId, + Error, MessagePayload, Role, MESSAGE_PAYLOAD_SIZE_BYTES, + }, protocol::{RecordId, Step}, + task::JoinHandle, + telemetry::{labels::STEP, metrics::RECORDS_SENT}, }; -use ::tokio::sync::{mpsc, oneshot}; -use ::tokio::time::Instant; -use futures::SinkExt; use futures::StreamExt; use std::fmt::{Debug, Formatter}; use std::time::Duration; @@ -28,7 +24,12 @@ use std::{io, panic}; use tinyvec::array_vec; use tracing::Instrument; -use crate::telemetry::metrics::RECORDS_SENT; +use crate::helpers::buffers::PushError; +use crate::helpers::network::{MessageEnvelope, Network}; +use crate::helpers::time::Timer; +use crate::helpers::transport::Transport; +use ::tokio::sync::{mpsc, oneshot}; +use futures_util::stream::FuturesUnordered; #[cfg(all(feature = "shuttle", test))] use shuttle::future as tokio; @@ -160,19 +161,18 @@ pub struct GatewayConfig { } impl Gateway { - pub fn new(role: Role, network: &N, config: GatewayConfig) -> Self { + pub async fn new(role: Role, network: Network, config: GatewayConfig) -> Self { let (recv_tx, mut recv_rx) = mpsc::channel::(config.recv_outstanding); let (send_tx, mut send_rx) = mpsc::channel::(config.send_outstanding); - let mut message_stream = network.recv_stream(); - let mut network_sink = network.sink(); + let mut message_stream = network.recv_stream().await; let control_handle = tokio::spawn(async move { const INTERVAL: Duration = Duration::from_secs(3); let mut receive_buf = ReceiveBuffer::default(); let mut send_buf = SendBuffer::new(config.send_buffer_config); - - let sleep = ::tokio::time::sleep(INTERVAL); + let mut pending_sends = FuturesUnordered::new(); + let sleep = Timer::new(INTERVAL); ::tokio::pin!(sleep); loop { @@ -189,9 +189,14 @@ impl Gateway { tracing::trace!("received {} bytes from {:?}", messages.len(), channel_id); receive_buf.receive_messages(&channel_id, &messages); } - Some(send_req) = send_rx.recv() => { - tracing::trace!("new SendRequest({:?})", send_req); - send_message::(&mut network_sink, &mut send_buf, send_req).await; + Some((channel_id, envelope)) = send_rx.recv(), if pending_sends.is_empty() => { + tracing::trace!("new SendRequest({:?})", (&channel_id, &envelope)); + metrics::increment_counter!(RECORDS_SENT, STEP => channel_id.step.as_ref().to_string()); + let data = send_buf.push(&channel_id, &envelope); + pending_sends.push(send_message(&network, channel_id, data)); + } + Some(_) = &mut pending_sends.next() => { + pending_sends.clear(); } _ = &mut sleep => { #[cfg(debug_assertions)] @@ -204,7 +209,7 @@ impl Gateway { } // reset the timer on every action - sleep.as_mut().reset(Instant::now() + INTERVAL); + sleep.as_mut().reset(); } }.instrument(tracing::info_span!("gateway_loop", role=role.as_static_str()).or_current())); @@ -287,19 +292,22 @@ impl Debug for ReceiveRequest { } } -async fn send_message(sink: &mut N::Sink, buf: &mut SendBuffer, req: SendRequest) { - let (channel_id, msg) = req; - metrics::increment_counter!(RECORDS_SENT, STEP => channel_id.step.as_ref().to_string()); - match buf.push(&channel_id, &msg) { +async fn send_message( + network: &Network, + channel_id: ChannelId, + data: Result>, PushError>, +) { + match data { Ok(Some(buf_to_send)) => { tracing::trace!("sending {} bytes to {:?}", buf_to_send.len(), &channel_id); - sink.send((channel_id, buf_to_send)) + network + .send((channel_id, buf_to_send)) .await .expect("Failed to send data to the network"); } Ok(None) => {} Err(err) => panic!("failed to send to the {channel_id:?}: {err}"), - }; + } } #[cfg(debug_assertions)] @@ -329,7 +337,7 @@ mod tests { config.gateway_config.send_buffer_config.items_in_batch = 1; // Send every record config.gateway_config.send_buffer_config.batch_count = 3; // keep 3 at a time - let world = Box::leak(Box::new(TestWorld::new_with(config))); + let world = Box::leak(Box::new(TestWorld::new_with(config).await)); let contexts = world.contexts::(); let sender_ctx = contexts[0].narrow("reordering-test"); let recv_ctx = contexts[1].narrow("reordering-test"); @@ -361,7 +369,7 @@ mod tests { #[tokio::test] #[should_panic(expected = "Record RecordId(1) has been received twice")] async fn duplicate_message() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let (v1, v2) = (Fp31::from(1u128), Fp31::from(2u128)); let peer = Role::H2; let record_id = 1.into(); diff --git a/src/helpers/mod.rs b/src/helpers/mod.rs index 7e54f12a4..1e4ef5dd9 100644 --- a/src/helpers/mod.rs +++ b/src/helpers/mod.rs @@ -1,13 +1,20 @@ pub mod http; pub mod messaging; pub mod network; +#[deprecated(note = "Use `Transport` instead")] +pub mod old_network; mod buffers; mod error; +mod time; +mod transport; pub use buffers::SendBufferConfig; pub use error::{Error, Result}; pub use messaging::GatewayConfig; +pub use transport::{ + CommandEnvelope, CommandOrigin, SubscriptionType, Transport, TransportCommand, TransportError, +}; use crate::helpers::{ Direction::{Left, Right}, @@ -19,6 +26,31 @@ use tinyvec::ArrayVec; pub const MESSAGE_PAYLOAD_SIZE_BYTES: usize = 8; type MessagePayload = ArrayVec<[u8; MESSAGE_PAYLOAD_SIZE_BYTES]>; +/// Represents an opaque identifier of the helper instance. Compare with a [`Role`], which +/// represents a helper's role within an MPC protocol, which may be different per protocol. +/// `HelperIdentity` will be established at startup and then never change. Components that want to +/// resolve this identifier into something (Uri, encryption keys, etc) must consult configuration +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct HelperIdentity { + id: u8, +} + +impl TryFrom for HelperIdentity { + type Error = String; + + fn try_from(value: usize) -> std::result::Result { + if value == 0 || value > 3 { + Err(format!( + "{value} must be within [1, 3] range to be a valid helper identity" + )) + } else { + Ok(Self { + id: u8::try_from(value).unwrap(), + }) + } + } +} + /// Represents a unique role of the helper inside the MPC circuit. Each helper may have different /// roles in queries it processes in parallel. For some queries it can be `H1` and for others it /// may be `H2` or `H3`. @@ -36,6 +68,11 @@ pub enum Role { H3 = 2, } +#[derive(Clone, Debug)] +pub struct RoleAssignment { + helper_roles: [HelperIdentity; 3], +} + #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Direction { Left, @@ -137,10 +174,38 @@ impl IndexMut for Vec { } } +impl RoleAssignment { + #[must_use] + pub fn new(helper_roles: [HelperIdentity; 3]) -> Self { + Self { helper_roles } + } + + /// Returns the assigned role for the given helper identity. + /// + /// ## Panics + /// Panics if there is no role assigned to it. + #[must_use] + pub fn role(&self, id: &HelperIdentity) -> Role { + for (idx, item) in self.helper_roles.iter().enumerate() { + if item == id { + return Role::all()[idx]; + } + } + + panic!("No role assignment for {id:?} found in {self:?}") + } + + #[must_use] + pub fn identity(&self, role: Role) -> &HelperIdentity { + &self.helper_roles[role] + } +} + #[cfg(all(test, not(feature = "shuttle")))] mod tests { + use super::*; mod role_tests { - use crate::helpers::{Direction, Role}; + use super::*; #[test] pub fn peer_works() { @@ -160,4 +225,81 @@ mod tests { assert_eq!(5, data[Role::H3]); } } + + mod role_assignment_tests { + use super::*; + + #[test] + fn basic() { + let identities = (1..=3) + .map(|v| HelperIdentity::try_from(v).unwrap()) + .collect::>() + .try_into() + .unwrap(); + let assignment = RoleAssignment::new(identities); + + assert_eq!( + Role::H1, + assignment.role(&HelperIdentity::try_from(1).unwrap()) + ); + assert_eq!( + Role::H2, + assignment.role(&HelperIdentity::try_from(2).unwrap()) + ); + assert_eq!( + Role::H3, + assignment.role(&HelperIdentity::try_from(3).unwrap()) + ); + + assert_eq!( + &HelperIdentity::try_from(1).unwrap(), + assignment.identity(Role::H1) + ); + assert_eq!( + &HelperIdentity::try_from(2).unwrap(), + assignment.identity(Role::H2) + ); + assert_eq!( + &HelperIdentity::try_from(3).unwrap(), + assignment.identity(Role::H3) + ); + } + + #[test] + fn reverse() { + let identities = (1..=3) + .rev() + .map(|v| HelperIdentity::try_from(v).unwrap()) + .collect::>() + .try_into() + .unwrap(); + let assignment = RoleAssignment::new(identities); + + assert_eq!( + Role::H3, + assignment.role(&HelperIdentity::try_from(1).unwrap()) + ); + assert_eq!( + Role::H2, + assignment.role(&HelperIdentity::try_from(2).unwrap()) + ); + assert_eq!( + Role::H1, + assignment.role(&HelperIdentity::try_from(3).unwrap()) + ); + + assert_eq!( + &HelperIdentity::try_from(3).unwrap(), + assignment.identity(Role::H1) + ); + assert_eq!( + &HelperIdentity::try_from(2).unwrap(), + assignment.identity(Role::H2) + ); + assert_eq!( + &HelperIdentity::try_from(1).unwrap(), + assignment.identity(Role::H3) + ); + } + } } diff --git a/src/helpers/network.rs b/src/helpers/network.rs index cdf19a0c0..7e3e37588 100644 --- a/src/helpers/network.rs +++ b/src/helpers/network.rs @@ -1,23 +1,17 @@ +#![allow(dead_code)] // will use these soon + +use crate::helpers::transport::CommandOrigin; +use crate::helpers::{MessagePayload, RoleAssignment}; +use crate::protocol::RecordId; use crate::{ - helpers::{error::Error, MessagePayload, Role}, - protocol::{RecordId, Step}, + helpers::{ + transport::{SubscriptionType, Transport, TransportCommand}, + Error, Role, + }, + protocol::{QueryId, Step}, }; -use async_trait::async_trait; -use futures::{ready, Stream}; -use pin_project::pin_project; +use futures::{Stream, StreamExt}; use std::fmt::{Debug, Formatter}; -use std::pin::Pin; -use std::task::{Context, Poll}; -use tokio::sync::mpsc; -use tokio_util::sync::{PollSendError, PollSender}; - -/// Combination of helper role and step that uniquely identifies a single channel of communication -/// between two helpers. -#[derive(Clone, Eq, PartialEq, Hash)] -pub struct ChannelId { - pub role: Role, - pub step: Step, -} #[derive(Debug, PartialEq, Eq)] pub struct MessageEnvelope { @@ -25,21 +19,12 @@ pub struct MessageEnvelope { pub payload: MessagePayload, } -pub type MessageChunks = (ChannelId, Vec); - -/// Network interface for components that require communication. -#[async_trait] -pub trait Network: Sync { - /// Type of the channel that is used to send messages to other helpers - type Sink: futures::Sink + Send + Unpin + 'static; - type MessageStream: Stream + Send + Unpin + 'static; - - /// Returns a sink that accepts data to be sent to other helper parties. - fn sink(&self) -> Self::Sink; - - /// Returns a stream to receive messages that have arrived from other helpers. Note that - /// some implementations may panic if this method is called more than once. - fn recv_stream(&self) -> Self::MessageStream; +/// Combination of helper role and step that uniquely identifies a single channel of communication +/// between two helpers. +#[derive(Clone, Eq, PartialEq, Hash)] +pub struct ChannelId { + pub role: Role, + pub step: Step, } impl ChannelId { @@ -55,47 +40,72 @@ impl Debug for ChannelId { } } -/// Wrapper around a [`PollSender`] to modify the error message to match what the [`NetworkSink`] -/// requires. The only error that [`PollSender`] will generate is "channel closed", and thus is the -/// only error message forwarded from this [`NetworkSink`]. -#[pin_project] -pub struct NetworkSink { - #[pin] - inner: PollSender, +pub type MessageChunks = (ChannelId, Vec); + +/// Given any implementation of [`Transport`], a `Network` is able to send and receive +/// [`MessageChunks`] for a specific query id. The [`Transport`] will receive [`StepData`] +/// containing the `MessageChunks` +pub struct Network { + transport: T, + query_id: QueryId, + roles: RoleAssignment, } -impl NetworkSink { - #[must_use] - pub fn new(sender: mpsc::Sender) -> Self { +impl Network { + pub fn new(transport: T, query_id: QueryId, roles: RoleAssignment) -> Self { Self { - inner: PollSender::new(sender), + transport, + query_id, + roles, } } -} -impl futures::Sink for NetworkSink -where - Error: From>, -{ - type Error = Error; + /// sends a [`StepData`] containing [`MessageChunks`] on the underlying [`Transport`] + /// # Errors + /// if `message_chunks` fail to be delivered + /// # Panics + /// if `roles_to_helpers` does not have all 3 roles + pub async fn send(&self, message_chunks: MessageChunks) -> Result<(), Error> { + let (channel, payload) = message_chunks; + let destination = self.roles.identity(channel.role); - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.project().inner.poll_ready(cx)?); - Poll::Ready(Ok(())) + self.transport + .send( + destination, + TransportCommand::StepData(self.query_id, channel.step, payload), + ) + .await + .map_err(Error::from) } - fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { - self.project().inner.start_send(item)?; - Ok(()) - } + /// returns a [`Stream`] of [`MessageChunks`]s from the underlying [`Transport`] + /// # Panics + /// if called more than once during the execution of a query. + pub async fn recv_stream(&self) -> impl Stream { + let self_query_id = self.query_id; + let query_command_stream = self + .transport + .subscribe(SubscriptionType::Query(self_query_id)) + .await; + let assignment = self.roles.clone(); // need to move it inside the closure - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.project().inner.poll_flush(cx)?); - Poll::Ready(Ok(())) - } + query_command_stream.map(move |envelope| match envelope.payload { + TransportCommand::StepData(query_id, step, payload) => { + debug_assert!(query_id == self_query_id); + + let CommandOrigin::Helper(identity) = &envelope.origin else { + panic!("Message origin is incorrect: expected it to be from a helper, got {:?}", &envelope.origin); + }; + let origin_role = assignment.role(identity); + let channel_id = ChannelId::new(origin_role, step); - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.project().inner.poll_close(cx))?; - Poll::Ready(Ok(())) + (channel_id, payload) + } + #[allow(unreachable_patterns)] // there will be more commands in the future + other_command => panic!( + "received unexpected command {other_command:?} for query id {}", + self_query_id.as_ref() + ), + }) } } diff --git a/src/helpers/old_network.rs b/src/helpers/old_network.rs new file mode 100644 index 000000000..a6ab2962c --- /dev/null +++ b/src/helpers/old_network.rs @@ -0,0 +1,71 @@ +use crate::helpers::error::Error; +/// The only usage of this module is in `net` module that is awaiting to be migrated to `Transport` +/// interface. +use crate::helpers::network::MessageChunks; +use async_trait::async_trait; +use futures::{ready, Stream}; +use pin_project::pin_project; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::sync::mpsc; +use tokio_util::sync::{PollSendError, PollSender}; + +/// Network interface for components that require communication. +#[async_trait] +pub trait Network: Sync { + /// Type of the channel that is used to send messages to other helpers + type Sink: futures::Sink + Send + Unpin + 'static; + type MessageStream: Stream + Send + Unpin + 'static; + + /// Returns a sink that accepts data to be sent to other helper parties. + fn sink(&self) -> Self::Sink; + + /// Returns a stream to receive messages that have arrived from other helpers. Note that + /// some implementations may panic if this method is called more than once. + fn recv_stream(&self) -> Self::MessageStream; +} + +/// Wrapper around a [`PollSender`] to modify the error message to match what the [`NetworkSink`] +/// requires. The only error that [`PollSender`] will generate is "channel closed", and thus is the +/// only error message forwarded from this [`NetworkSink`]. +#[pin_project] +pub struct NetworkSink { + #[pin] + inner: PollSender, +} + +impl NetworkSink { + #[must_use] + pub fn new(sender: mpsc::Sender) -> Self { + Self { + inner: PollSender::new(sender), + } + } +} + +impl futures::Sink for NetworkSink +where + Error: From>, +{ + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.project().inner.poll_ready(cx)?); + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + self.project().inner.start_send(item)?; + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.project().inner.poll_flush(cx)?); + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.project().inner.poll_close(cx))?; + Poll::Ready(Ok(())) + } +} diff --git a/src/helpers/time.rs b/src/helpers/time.rs new file mode 100644 index 000000000..6a1cf452e --- /dev/null +++ b/src/helpers/time.rs @@ -0,0 +1,58 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +/// Simple timer that only works in the presence of tokio runtime. Any other runtime will +/// make it a no-op. +#[cfg(not(all(test, feature = "shuttle")))] +#[pin_project::pin_project] +pub struct Timer { + interval: Duration, + #[pin] + timer: tokio::time::Sleep, +} + +#[cfg(all(test, feature = "shuttle"))] +pub struct Timer {} + +#[cfg(not(all(test, feature = "shuttle")))] +impl Timer { + pub fn new(interval: Duration) -> Self { + Self { + interval, + timer: tokio::time::sleep(interval), + } + } + + pub fn reset(self: Pin<&mut Self>) { + let this = self.project(); + this.timer + .reset(tokio::time::Instant::now() + *this.interval); + } +} + +#[cfg(all(test, feature = "shuttle"))] +impl Timer { + pub fn new(_: Duration) -> Self { + Self {} + } + + #[allow(clippy::unused_self)] + pub fn reset(&mut self) {} +} + +impl Future for Timer { + type Output = (); + + #[cfg(all(test, feature = "shuttle"))] + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + Poll::Pending + } + + #[cfg(not(all(test, feature = "shuttle")))] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + this.timer.poll(cx) + } +} diff --git a/src/helpers/transport/error.rs b/src/helpers/transport/error.rs new file mode 100644 index 000000000..f0f0b11ad --- /dev/null +++ b/src/helpers/transport/error.rs @@ -0,0 +1,22 @@ +use crate::error::BoxError; +use crate::helpers::TransportCommand; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Failed to send {command:?}: {inner:?}")] + SendFailed { + command: TransportCommand, + #[source] + inner: BoxError, + }, +} + +#[cfg(any(test, feature = "test-fixture"))] +impl From> for Error { + fn from(value: tokio::sync::mpsc::error::SendError) -> Self { + Self::SendFailed { + command: value.0, + inner: "channel closed".into(), + } + } +} diff --git a/src/helpers/transport/mod.rs b/src/helpers/transport/mod.rs new file mode 100644 index 000000000..17c76376c --- /dev/null +++ b/src/helpers/transport/mod.rs @@ -0,0 +1,66 @@ +mod error; + +pub use error::Error as TransportError; + +use crate::protocol::Step; +use crate::{helpers::HelperIdentity, protocol::QueryId}; +use async_trait::async_trait; +use futures::Stream; + +#[derive(Debug)] +pub enum TransportCommand { + // `Administration` Commands + // TODO: none for now + + // `Query` Commands + /// Query/step data received from a helper peer. + /// TODO: this is really bad for performance, once we have channel per step all the way + /// from gateway to network, this definition should be (QueryId, Step, Stream>) instead + StepData(QueryId, Step, Vec), +} + +/// Users of a [`Transport`] must subscribe to a specific type of command, and so must pass this +/// type as argument to the `subscribe` function +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub enum SubscriptionType { + /// Commands for managing queries + QueryManagement, + /// Commands intended for a running query + Query(QueryId), +} + +/// The source of the command, i.e. where it came from. Some may arrive from helper peers, others +/// may come directly from the clients +#[derive(Debug)] +pub enum CommandOrigin { + Helper(HelperIdentity), + Other, +} + +/// Wrapper around `TransportCommand` that indicates the origin of it. +#[derive(Debug)] +pub struct CommandEnvelope { + pub origin: CommandOrigin, + pub payload: TransportCommand, +} + +/// Represents the transport layer of the IPA network. Allows layers above to subscribe for events +/// arriving from helper peers or other parties (clients) and also reliably deliver messages using +/// `send` method. +#[async_trait] +pub trait Transport: Send + Sync + 'static { + type CommandStream: Stream + Send + Unpin; + + /// To be called by an entity which will handle the events as indicated by the + /// [`SubscriptionType`]. There should be only 1 subscriber per type. + /// # Panics + /// May panic if attempt to subscribe to the same [`SubscriptionType`] twice + async fn subscribe(&self, subscription: SubscriptionType) -> Self::CommandStream; + + /// To be called when an entity wants to send commands to the `Transport`. + async fn send( + &self, + destination: &HelperIdentity, + command: TransportCommand, + ) -> Result<(), TransportError>; +} diff --git a/src/lib.rs b/src/lib.rs index e8f36237d..f4d8232dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,8 @@ pub mod net; pub mod protocol; pub mod secret_sharing; pub mod telemetry; +#[cfg(feature = "enable-serde")] +pub mod uri; #[cfg(any(test, feature = "test-fixture"))] pub mod test_fixture; diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index caac49c6f..802b59ad5 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -117,10 +117,12 @@ impl MpcHelperClient { #[cfg(all(test, not(feature = "shuttle")))] mod tests { use super::*; + #[allow(deprecated)] use crate::{ helpers::{ http::HttpNetwork, - network::{ChannelId, MessageChunks, Network}, + network::{ChannelId, MessageChunks}, + old_network::Network, Role, MESSAGE_PAYLOAD_SIZE_BYTES, }, net::{server::MessageSendMap, BindTarget, MpcHelperServer}, @@ -128,6 +130,7 @@ mod tests { use futures::{Stream, StreamExt}; use hyper_tls::native_tls::TlsConnector; + #[allow(deprecated)] async fn setup_server(bind_target: BindTarget) -> (u16, impl Stream) { let network = HttpNetwork::new_without_clients(QueryId, None); let rx_stream = network.recv_stream(); diff --git a/src/net/discovery/mod.rs b/src/net/discovery/mod.rs index c48bb3449..4d7ed89cc 100644 --- a/src/net/discovery/mod.rs +++ b/src/net/discovery/mod.rs @@ -29,20 +29,11 @@ pub mod peer { #[derive(Clone, Debug)] #[cfg_attr(feature = "enable-serde", derive(serde::Deserialize))] pub struct Config { - #[cfg_attr(feature = "enable-serde", serde(deserialize_with = "uri_from_str"))] + #[cfg_attr(feature = "enable-serde", serde(with = "crate::uri"))] pub origin: Uri, pub tls: HttpConfig, } - #[cfg(feature = "enable-serde")] - fn uri_from_str<'de, D>(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let s: String = Deserialize::deserialize(deserializer)?; - s.parse().map_err(D::Error::custom) - } - #[cfg(feature = "enable-serde")] fn pk_from_str<'de, D>(deserializer: D) -> Result where diff --git a/src/net/server/handlers/query.rs b/src/net/server/handlers/query.rs index 71810f3d6..ed6aa0163 100644 --- a/src/net/server/handlers/query.rs +++ b/src/net/server/handlers/query.rs @@ -123,8 +123,9 @@ pub async fn handler(mut req: Request) -> Result<(), MpcHelperServerError> #[cfg(all(test, not(feature = "shuttle")))] mod tests { use super::*; + #[allow(deprecated)] use crate::{ - helpers::{http::HttpNetwork, network::Network, MESSAGE_PAYLOAD_SIZE_BYTES}, + helpers::{http::HttpNetwork, old_network::Network, MESSAGE_PAYLOAD_SIZE_BYTES}, net::{ server::MessageSendMap, BindTarget, MpcHelperServer, CONTENT_LENGTH_HEADER_NAME, OFFSET_HEADER_NAME, @@ -143,6 +144,7 @@ mod tests { const DATA_LEN: usize = 3; + #[allow(deprecated)] async fn init_server() -> (u16, impl Stream) { let network = HttpNetwork::new_without_clients(QueryId, None); let rx_stream = network.recv_stream(); @@ -375,6 +377,7 @@ mod tests { } #[tokio::test] + #[allow(deprecated)] async fn backpressure_applied() { const QUEUE_DEPTH: usize = 8; let network = HttpNetwork::new_without_clients(QueryId, Some(QUEUE_DEPTH)); diff --git a/src/protocol/attribution/accumulate_credit.rs b/src/protocol/attribution/accumulate_credit.rs index 1f4b493e9..c40b5ebb5 100644 --- a/src/protocol/attribution/accumulate_credit.rs +++ b/src/protocol/attribution/accumulate_credit.rs @@ -353,7 +353,7 @@ pub(crate) mod tests { ]) }); - let world = TestWorld::new(); + let world = TestWorld::new().await; let result = world .semi_honest(input, |ctx, input| async move { accumulate_credit(ctx, &input).await.unwrap() @@ -380,7 +380,7 @@ pub(crate) mod tests { let mut rng = thread_rng(); let secret: [Fp31; 4] = [(); 4].map(|_| rng.gen::()); - let world = TestWorld::new(); + let world = TestWorld::new().await; for &role in Role::all() { let new_shares = world diff --git a/src/protocol/attribution/aggregate_credit.rs b/src/protocol/attribution/aggregate_credit.rs index 893f837f7..dbb9fb3c8 100644 --- a/src/protocol/attribution/aggregate_credit.rs +++ b/src/protocol/attribution/aggregate_credit.rs @@ -497,7 +497,7 @@ pub(crate) mod tests { ]) }); - let world = TestWorld::new(); + let world = TestWorld::new().await; let result = world .semi_honest(input, |ctx, share| async move { aggregate_credit(ctx, &share, 8).await.unwrap() @@ -603,7 +603,7 @@ pub(crate) mod tests { ]) }); - let world = TestWorld::new(); + let world = TestWorld::new().await; let result = world .semi_honest(input, |ctx, share| async move { sort_by_breakdown_key(ctx, &share, 8).await.unwrap() diff --git a/src/protocol/attribution/credit_capping.rs b/src/protocol/attribution/credit_capping.rs index 68654d586..fe3cb8b6b 100644 --- a/src/protocol/attribution/credit_capping.rs +++ b/src/protocol/attribution/credit_capping.rs @@ -350,7 +350,7 @@ mod tests { let expected = TEST_CASE.iter().map(|t| t[4]).collect::>(); //TODO: move to the new test framework - let world = TestWorld::new(); + let world = TestWorld::new().await; let context = world.contexts::(); let mut rng = StepRng::new(100, 1); diff --git a/src/protocol/basics/check_zero.rs b/src/protocol/basics/check_zero.rs index 8b1cbc3a9..0d7bf4ef2 100644 --- a/src/protocol/basics/check_zero.rs +++ b/src/protocol/basics/check_zero.rs @@ -92,7 +92,7 @@ mod tests { #[tokio::test] async fn basic() -> Result<(), Error> { - let world = TestWorld::new(); + let world = TestWorld::new().await; let context = world.contexts::(); let mut rng = thread_rng(); let mut counter = 0_u32; diff --git a/src/protocol/basics/mul/malicious.rs b/src/protocol/basics/mul/malicious.rs index ea20717c0..e5075fd8b 100644 --- a/src/protocol/basics/mul/malicious.rs +++ b/src/protocol/basics/mul/malicious.rs @@ -107,7 +107,7 @@ mod test { #[tokio::test] pub async fn simple() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng = thread_rng(); let a = rng.gen::(); diff --git a/src/protocol/basics/mul/semi_honest.rs b/src/protocol/basics/mul/semi_honest.rs index 6028629d8..101477b2f 100644 --- a/src/protocol/basics/mul/semi_honest.rs +++ b/src/protocol/basics/mul/semi_honest.rs @@ -88,7 +88,7 @@ mod test { #[tokio::test] async fn basic() { - let world = TestWorld::new(); + let world = TestWorld::new().await; assert_eq!(30, multiply_sync::(&world, 6, 5).await); assert_eq!(25, multiply_sync::(&world, 5, 5).await); @@ -101,7 +101,7 @@ mod test { #[tokio::test] pub async fn simple() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng = thread_rng(); let a = rng.gen::(); @@ -123,7 +123,7 @@ mod test { #[tokio::test] pub async fn concurrent_mul() { const COUNT: usize = 10; - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng = thread_rng(); let a: Vec<_> = (0..COUNT).map(|_| rng.gen::()).collect(); diff --git a/src/protocol/basics/mul/sparse.rs b/src/protocol/basics/mul/sparse.rs index b6d33048f..570147f41 100644 --- a/src/protocol/basics/mul/sparse.rs +++ b/src/protocol/basics/mul/sparse.rs @@ -364,7 +364,7 @@ pub(in crate::protocol) mod test { #[tokio::test] async fn check_output() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng = thread_rng(); for &a in ZeroPositions::all() { @@ -390,7 +390,7 @@ pub(in crate::protocol) mod test { #[tokio::test] async fn check_output_malicious() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng = thread_rng(); for &a in ZeroPositions::all() { diff --git a/src/protocol/basics/reshare.rs b/src/protocol/basics/reshare.rs index bddbb5f25..aa5a35f82 100644 --- a/src/protocol/basics/reshare.rs +++ b/src/protocol/basics/reshare.rs @@ -141,7 +141,7 @@ mod tests { /// Validates that reshare protocol actually generates new shares using PRSS. #[tokio::test] async fn generates_unique_shares() { - let world = TestWorld::new(); + let world = TestWorld::new().await; for &target in Role::all() { let secret = thread_rng().gen::(); @@ -174,7 +174,7 @@ mod tests { /// the input will pass this test. However `generates_unique_shares` will fail this implementation. #[tokio::test] async fn correct() { - let world = TestWorld::new(); + let world = TestWorld::new().await; for &role in Role::all() { let secret = thread_rng().gen::(); @@ -213,7 +213,7 @@ mod tests { /// it. #[tokio::test] async fn correct() { - let world = TestWorld::new(); + let world = TestWorld::new().await; for &role in Role::all() { let secret = thread_rng().gen::(); @@ -306,7 +306,7 @@ mod tests { #[tokio::test] async fn malicious_validation_fail() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng = thread_rng(); let a = rng.gen::(); diff --git a/src/protocol/basics/reveal.rs b/src/protocol/basics/reveal.rs index 57cf545ef..847a3749c 100644 --- a/src/protocol/basics/reveal.rs +++ b/src/protocol/basics/reveal.rs @@ -138,7 +138,7 @@ mod tests { #[tokio::test] pub async fn simple() -> Result<(), Error> { let mut rng = thread_rng(); - let world = TestWorld::new(); + let world = TestWorld::new().await; let ctx = world.contexts::(); for i in 0..10_u32 { @@ -163,7 +163,7 @@ mod tests { #[tokio::test] pub async fn malicious() -> Result<(), Error> { let mut rng = thread_rng(); - let world = TestWorld::new(); + let world = TestWorld::new().await; let sh_ctx = world.contexts::(); let v = sh_ctx.map(MaliciousValidator::new); @@ -193,7 +193,7 @@ mod tests { #[tokio::test] pub async fn malicious_validation_fail() -> Result<(), Error> { let mut rng = thread_rng(); - let world = TestWorld::new(); + let world = TestWorld::new().await; let sh_ctx = world.contexts::(); let v = sh_ctx.map(MaliciousValidator::new); diff --git a/src/protocol/basics/sum_of_product/malicious.rs b/src/protocol/basics/sum_of_product/malicious.rs index e4d896b23..950f79afa 100644 --- a/src/protocol/basics/sum_of_product/malicious.rs +++ b/src/protocol/basics/sum_of_product/malicious.rs @@ -111,7 +111,7 @@ mod test { #[tokio::test] pub async fn simple() { const MULTI_BIT_LEN: usize = 10; - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng = thread_rng(); diff --git a/src/protocol/basics/sum_of_product/semi_honest.rs b/src/protocol/basics/sum_of_product/semi_honest.rs index e3365db47..0b7d0d15c 100644 --- a/src/protocol/basics/sum_of_product/semi_honest.rs +++ b/src/protocol/basics/sum_of_product/semi_honest.rs @@ -79,7 +79,7 @@ mod test { #[tokio::test] async fn basic() { - let world = TestWorld::new(); + let world = TestWorld::new().await; assert_eq!(11, sop_sync::(&world, &[7], &[6]).await); assert_eq!(3, sop_sync::(&world, &[6, 2], &[5, 2]).await); assert_eq!(28, sop_sync::(&world, &[5, 3], &[5, 1]).await); @@ -97,7 +97,7 @@ mod test { #[tokio::test] pub async fn simple() { const MULTI_BIT_LEN: usize = 10; - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng = thread_rng(); diff --git a/src/protocol/boolean/bit_decomposition.rs b/src/protocol/boolean/bit_decomposition.rs index 30a9ed988..4dd937068 100644 --- a/src/protocol/boolean/bit_decomposition.rs +++ b/src/protocol/boolean/bit_decomposition.rs @@ -154,7 +154,7 @@ mod tests { // New BitwiseLessThan -> 0.56 secs * 5 cases = 2.8 #[tokio::test] pub async fn fp31() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let c = Fp31::from; assert_eq!(0, bits_to_value(&bit_decomposition(&world, c(0_u32)).await)); assert_eq!(1, bits_to_value(&bit_decomposition(&world, c(1)).await)); @@ -168,7 +168,7 @@ mod tests { #[ignore] #[tokio::test] pub async fn fp32_bit_prime() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let c = Fp32BitPrime::from; let u16_max: u32 = u16::MAX.into(); assert_eq!(0, bits_to_value(&bit_decomposition(&world, c(0_u32)).await)); diff --git a/src/protocol/boolean/bitwise_equal.rs b/src/protocol/boolean/bitwise_equal.rs index ef1e5272d..e4ac5d761 100644 --- a/src/protocol/boolean/bitwise_equal.rs +++ b/src/protocol/boolean/bitwise_equal.rs @@ -95,7 +95,7 @@ mod tests { } async fn run_bitwise_equal(a: u32, b: u32, num_bits: u32) -> u128 { - let world = TestWorld::new(); + let world = TestWorld::new().await; let a_fp31 = get_bits::(a, num_bits); let b_fp31 = get_bits::(b, num_bits); diff --git a/src/protocol/boolean/bitwise_less_than_prime.rs b/src/protocol/boolean/bitwise_less_than_prime.rs index c7178a87d..a698d1299 100644 --- a/src/protocol/boolean/bitwise_less_than_prime.rs +++ b/src/protocol/boolean/bitwise_less_than_prime.rs @@ -290,7 +290,7 @@ mod tests { F: Field + Sized, Standard: Distribution, { - let world = TestWorld::new(); + let world = TestWorld::new().await; let bits = get_bits::(a, num_bits); let result = world .semi_honest(bits.clone(), |ctx, x_share| async move { diff --git a/src/protocol/boolean/dumb_bitwise_lt.rs b/src/protocol/boolean/dumb_bitwise_lt.rs index 38d5d72e9..f96d8cd82 100644 --- a/src/protocol/boolean/dumb_bitwise_lt.rs +++ b/src/protocol/boolean/dumb_bitwise_lt.rs @@ -214,7 +214,7 @@ mod tests { let c = Fp31::from; let zero = Fp31::ZERO; let one = Fp31::ONE; - let world = TestWorld::new(); + let world = TestWorld::new().await; assert_eq!(one, bitwise_lt(&world, zero, one).await); assert_eq!(zero, bitwise_lt(&world, one, zero).await); @@ -234,7 +234,7 @@ mod tests { let zero = Fp32BitPrime::ZERO; let one = Fp32BitPrime::ONE; let u16_max: u32 = u16::MAX.into(); - let world = TestWorld::new(); + let world = TestWorld::new().await; assert_eq!(one, bitwise_lt(&world, zero, one).await); assert_eq!(zero, bitwise_lt(&world, one, zero).await); @@ -257,7 +257,7 @@ mod tests { #[tokio::test] pub async fn cmp_different_bit_lengths() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let input = ( get_bits::(3, 8), // 8-bit @@ -290,7 +290,7 @@ mod tests { #[ignore] #[tokio::test] pub async fn cmp_random_32_bit_prime_field_elements() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rand = thread_rng(); for _ in 0..1000 { let a = rand.gen::(); @@ -306,7 +306,7 @@ mod tests { #[ignore] #[tokio::test] pub async fn cmp_all_fp31() { - let world = TestWorld::new(); + let world = TestWorld::new().await; for a in 0..Fp31::PRIME { for b in 0..Fp31::PRIME { assert_eq!( diff --git a/src/protocol/boolean/or.rs b/src/protocol/boolean/or.rs index 474fa0c8f..956b7637e 100644 --- a/src/protocol/boolean/or.rs +++ b/src/protocol/boolean/or.rs @@ -55,7 +55,7 @@ mod tests { #[tokio::test] pub async fn all() { type F = Fp31; - let world = TestWorld::new(); + let world = TestWorld::new().await; assert_eq!(F::ZERO, run(&world, F::ZERO, F::ZERO).await); assert_eq!(F::ONE, run(&world, F::ONE, F::ZERO).await); diff --git a/src/protocol/boolean/random_bits_generator.rs b/src/protocol/boolean/random_bits_generator.rs index 132d817fb..b0bc402c9 100644 --- a/src/protocol/boolean/random_bits_generator.rs +++ b/src/protocol/boolean/random_bits_generator.rs @@ -78,7 +78,7 @@ mod tests { #[tokio::test] pub async fn semi_honest() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let [c0, c1, c2] = world.contexts::(); let rbg0 = RandomBitsGenerator::new(c0); @@ -93,7 +93,7 @@ mod tests { #[tokio::test] pub async fn malicious() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let contexts = world.contexts::(); let validators = contexts.map(MaliciousValidator::new); diff --git a/src/protocol/boolean/solved_bits.rs b/src/protocol/boolean/solved_bits.rs index ae9d68d8d..03d42e6ae 100644 --- a/src/protocol/boolean/solved_bits.rs +++ b/src/protocol/boolean/solved_bits.rs @@ -203,7 +203,7 @@ mod tests { #[tokio::test] pub async fn fp31() -> Result<(), Error> { - let world = TestWorld::new(); + let world = TestWorld::new().await; let ctx = world.contexts::(); let [c0, c1, c2] = ctx; @@ -227,7 +227,7 @@ mod tests { #[tokio::test] pub async fn fp_32bit_prime() -> Result<(), Error> { - let world = TestWorld::new(); + let world = TestWorld::new().await; let ctx = world.contexts::(); let [c0, c1, c2] = ctx; @@ -251,7 +251,7 @@ mod tests { #[tokio::test] pub async fn malicious() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut success = 0; for _ in 0..4 { diff --git a/src/protocol/boolean/xor.rs b/src/protocol/boolean/xor.rs index c70ddd139..722eb344c 100644 --- a/src/protocol/boolean/xor.rs +++ b/src/protocol/boolean/xor.rs @@ -85,7 +85,7 @@ mod tests { pub async fn all_combinations() { type F = Fp32BitPrime; - let world = TestWorld::new(); + let world = TestWorld::new().await; assert_eq!(F::ZERO, run(&world, F::ZERO, F::ZERO).await); assert_eq!(F::ONE, run(&world, F::ONE, F::ZERO).await); @@ -124,7 +124,7 @@ mod tests { /// Run all XOR operations with all combinations of sparse inputs. #[tokio::test] pub async fn all_sparse() { - let world = TestWorld::new(); + let world = TestWorld::new().await; for &a in ZeroPositions::all() { for &b in ZeroPositions::all() { diff --git a/src/protocol/context/mod.rs b/src/protocol/context/mod.rs index 2b6f69943..4efda07bc 100644 --- a/src/protocol/context/mod.rs +++ b/src/protocol/context/mod.rs @@ -158,7 +158,7 @@ mod tests { #[tokio::test] async fn semi_honest_metrics() { - let world = TestWorld::new_with(*TestWorldConfig::default().enable_metrics()); + let world = TestWorld::new_with(*TestWorldConfig::default().enable_metrics()).await; let input = (0..10u128).map(Fp31::from).collect::>(); let result = world @@ -211,7 +211,7 @@ mod tests { #[tokio::test] async fn malicious_metrics() { - let world = TestWorld::new_with(*TestWorldConfig::default().enable_metrics()); + let world = TestWorld::new_with(*TestWorldConfig::default().enable_metrics()).await; let input = vec![Fp31::from(0u128), Fp31::from(1u128)]; let _result = world diff --git a/src/protocol/ipa/mod.rs b/src/protocol/ipa/mod.rs index 4b52f78fa..eba4feeb9 100644 --- a/src/protocol/ipa/mod.rs +++ b/src/protocol/ipa/mod.rs @@ -239,7 +239,7 @@ pub mod tests { const EXPECTED: &[[u128; 2]] = &[[0, 0], [1, 2], [2, 3]]; const MAX_BREAKDOWN_KEY: u128 = 3; - let world = TestWorld::new(); + let world = TestWorld::new().await; // match key, is_trigger, breakdown_key, trigger_value let records = [ @@ -306,7 +306,7 @@ pub mod tests { const MAX_TRIGGER_VALUE: u128 = 5; let max_match_key: u64 = BATCHSIZE / 10; - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng = thread_rng(); let mut records: Vec = Vec::new(); diff --git a/src/protocol/malicious.rs b/src/protocol/malicious.rs index 6c81f713f..e04cd23ce 100644 --- a/src/protocol/malicious.rs +++ b/src/protocol/malicious.rs @@ -276,7 +276,7 @@ mod tests { /// There is a small chance of failure which is `2 / |F|`, where `|F|` is the cardinality of the prime field. #[tokio::test] async fn simplest_circuit() -> Result<(), Error> { - let world = TestWorld::new(); + let world = TestWorld::new().await; let context = world.contexts::(); let mut rng = thread_rng(); @@ -324,7 +324,7 @@ mod tests { #[tokio::test] async fn upgrade_only() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng = thread_rng(); let a = rng.gen::(); @@ -341,7 +341,7 @@ mod tests { #[tokio::test] async fn upgrade_only_tweaked() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng = thread_rng(); let a = rng.gen::(); @@ -388,7 +388,7 @@ mod tests { /// There is a small chance of failure which is `2 / |F|`, where `|F|` is the cardinality of the prime field. #[tokio::test] async fn complex_circuit() -> Result<(), Error> { - let world = TestWorld::new(); + let world = TestWorld::new().await; let context = world.contexts::(); let mut rng = thread_rng(); diff --git a/src/protocol/modulus_conversion/convert_shares.rs b/src/protocol/modulus_conversion/convert_shares.rs index 2ab5335d5..8fd91af98 100644 --- a/src/protocol/modulus_conversion/convert_shares.rs +++ b/src/protocol/modulus_conversion/convert_shares.rs @@ -201,7 +201,7 @@ mod tests { const BITNUM: u32 = 4; let mut rng = thread_rng(); - let world = TestWorld::new(); + let world = TestWorld::new().await; let match_key = MaskedMatchKey::mask(rng.gen()); let result: [Replicated; 3] = world .semi_honest(match_key, |ctx, mk_share| async move { @@ -217,7 +217,7 @@ mod tests { const BITNUM: u32 = 4; let mut rng = thread_rng(); - let world = TestWorld::new(); + let world = TestWorld::new().await; let match_key = MaskedMatchKey::mask(rng.gen()); let result: [Replicated; 3] = world .semi_honest(match_key, |ctx, mk_share| async move { @@ -277,7 +277,7 @@ mod tests { const BITNUM: u32 = 4; let mut rng = thread_rng(); - let world = TestWorld::new(); + let world = TestWorld::new().await; for tweak in TWEAKS { let match_key = MaskedMatchKey::mask(rng.gen()); world diff --git a/src/protocol/sort/apply_sort/mod.rs b/src/protocol/sort/apply_sort/mod.rs index f9464aa83..2f9b9f8a1 100644 --- a/src/protocol/sort/apply_sort/mod.rs +++ b/src/protocol/sort/apply_sort/mod.rs @@ -58,7 +58,7 @@ mod tests { pub async fn semi_honest() { const COUNT: usize = 5; - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng = thread_rng(); let mut match_keys = Vec::with_capacity(COUNT); diff --git a/src/protocol/sort/apply_sort/shuffle.rs b/src/protocol/sort/apply_sort/shuffle.rs index 7e25a1057..b5615e8f2 100644 --- a/src/protocol/sort/apply_sort/shuffle.rs +++ b/src/protocol/sort/apply_sort/shuffle.rs @@ -178,7 +178,7 @@ mod tests { #[tokio::test] async fn shuffle_attribution_input_row() { const BATCHSIZE: u8 = 25; - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng = thread_rng(); let mut input: Vec> = Vec::with_capacity(BATCHSIZE.into()); @@ -233,7 +233,7 @@ mod tests { ]; let some_numbers_as_bits = some_numbers.map(|x| get_bits::(x, BIT_LENGTH)); - let world = TestWorld::new(); + let world = TestWorld::new().await; let result = world .semi_honest( diff --git a/src/protocol/sort/bit_permutation.rs b/src/protocol/sort/bit_permutation.rs index 8f8e23ee9..3e1968187 100644 --- a/src/protocol/sort/bit_permutation.rs +++ b/src/protocol/sort/bit_permutation.rs @@ -83,7 +83,7 @@ mod tests { #[tokio::test] pub async fn semi_honest() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let input: Vec<_> = INPUT.iter().map(|x| Fp31::from(*x)).collect(); let result = world @@ -97,7 +97,7 @@ mod tests { #[tokio::test] pub async fn malicious() { - let world = TestWorld::new(); + let world = TestWorld::new().await; let input: Vec<_> = INPUT.iter().map(|x| Fp31::from(*x)).collect(); let result = world diff --git a/src/protocol/sort/compose.rs b/src/protocol/sort/compose.rs index 185808e6d..be3f22881 100644 --- a/src/protocol/sort/compose.rs +++ b/src/protocol/sort/compose.rs @@ -61,7 +61,7 @@ mod tests { #[tokio::test] pub async fn semi_honest() { const BATCHSIZE: u32 = 25; - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng_sigma = thread_rng(); let mut rng_rho = thread_rng(); diff --git a/src/protocol/sort/generate_permutation.rs b/src/protocol/sort/generate_permutation.rs index 7820b8e84..36cfc0669 100644 --- a/src/protocol/sort/generate_permutation.rs +++ b/src/protocol/sort/generate_permutation.rs @@ -365,7 +365,7 @@ mod tests { pub async fn semi_honest() { const COUNT: usize = 5; - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng = thread_rng(); let mut match_keys = Vec::with_capacity(COUNT); @@ -412,7 +412,7 @@ mod tests { let mut permutation: Vec = (0..BATCHSIZE).collect(); permutation.shuffle(&mut rng); - let world = TestWorld::new(); + let world = TestWorld::new().await; let [ctx0, ctx1, ctx2] = world.contexts(); let permutation: Vec = permutation.iter().map(|x| u128::from(*x)).collect(); @@ -448,7 +448,7 @@ mod tests { pub async fn malicious_sort_in_semi_honest() { const COUNT: usize = 5; - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng = thread_rng(); let mut match_keys = Vec::with_capacity(COUNT); diff --git a/src/protocol/sort/secureapplyinv.rs b/src/protocol/sort/secureapplyinv.rs index beadb98be..b6a13df7c 100644 --- a/src/protocol/sort/secureapplyinv.rs +++ b/src/protocol/sort/secureapplyinv.rs @@ -69,7 +69,7 @@ mod tests { #[tokio::test] pub async fn semi_honest() { const BATCHSIZE: u32 = 25; - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut rng = rand::thread_rng(); let mut input = Vec::with_capacity(BATCHSIZE as usize); diff --git a/src/protocol/sort/shuffle.rs b/src/protocol/sort/shuffle.rs index 420d1a5b4..6c069fc53 100644 --- a/src/protocol/sort/shuffle.rs +++ b/src/protocol/sort/shuffle.rs @@ -221,7 +221,7 @@ mod tests { #[tokio::test] async fn semi_honest() { const BATCHSIZE: u8 = 25; - let world = TestWorld::new(); + let world = TestWorld::new().await; let input: Vec = (0..BATCHSIZE).collect(); let hashed_input: HashSet = input.clone().into_iter().collect(); @@ -261,7 +261,7 @@ mod tests { async fn shuffle_unshuffle() { const BATCHSIZE: usize = 5; - let world = TestWorld::new(); + let world = TestWorld::new().await; let input: Vec = (0..u128::try_from(BATCHSIZE).unwrap()).collect(); let result = world @@ -307,7 +307,7 @@ mod tests { #[tokio::test] async fn malicious() { const BATCHSIZE: u8 = 25; - let world = TestWorld::new(); + let world = TestWorld::new().await; let input: Vec = (0..BATCHSIZE).collect(); let hashed_input: HashSet = input.clone().into_iter().collect(); @@ -349,7 +349,7 @@ mod tests { async fn shuffle_unshuffle() { const BATCHSIZE: usize = 5; - let world = TestWorld::new(); + let world = TestWorld::new().await; let input: Vec = (0..u128::try_from(BATCHSIZE).unwrap()).collect(); let result = world diff --git a/src/test_fixture/circuit.rs b/src/test_fixture/circuit.rs index 7a050de55..46f2a8b3e 100644 --- a/src/test_fixture/circuit.rs +++ b/src/test_fixture/circuit.rs @@ -12,7 +12,7 @@ use futures_util::future::join_all; /// # Panics /// panics when circuits did not produce the expected value. pub async fn arithmetic(width: u32, depth: u8) { - let world = TestWorld::new(); + let world = TestWorld::new().await; let mut multiplications = Vec::new(); for record in 0..width { diff --git a/src/test_fixture/mod.rs b/src/test_fixture/mod.rs index ca36672bb..0d39419cc 100644 --- a/src/test_fixture/mod.rs +++ b/src/test_fixture/mod.rs @@ -6,7 +6,7 @@ pub mod ipa_input_row; pub mod logging; pub mod metrics; pub mod net; -pub mod network; +mod transport; use crate::ff::{Field, Fp31}; use crate::protocol::context::Context; diff --git a/src/test_fixture/network.rs b/src/test_fixture/network.rs deleted file mode 100644 index 05d8e4fc8..000000000 --- a/src/test_fixture/network.rs +++ /dev/null @@ -1,213 +0,0 @@ -use crate::sync::{Arc, Mutex, Weak}; -use crate::{ - helpers::{ - self, - network::{ChannelId, MessageChunks, Network, NetworkSink}, - Error, Role, - }, - protocol::Step, -}; -use ::tokio::sync::mpsc::{self, Receiver, Sender}; -use async_trait::async_trait; -use futures::StreamExt; -use futures_util::stream::{FuturesUnordered, SelectAll}; -use std::collections::{hash_map::Entry, HashMap}; -use std::fmt::{Debug, Formatter}; -use tokio_stream::wrappers::ReceiverStream; -use tracing::Instrument; - -#[cfg(all(feature = "shuttle", test))] -use shuttle::future as tokio; - -/// Represents control messages sent between helpers to handle infrastructure requests. -pub(super) enum ControlMessage { - /// Connection for a step is requested by the peer. - ConnectionRequest(ChannelId, Receiver>), -} - -/// Container for all active helper endpoints -#[derive(Debug)] -pub struct InMemoryNetwork { - pub endpoints: [Arc; 3], -} - -/// Helper endpoint in memory. Capable of opening connections to other helpers and buffering -/// messages it receives from them until someone requests them. -#[derive(Debug)] -pub struct InMemoryEndpoint { - pub role: Role, - /// Channels that this endpoint is listening to. There are two helper peers for 3 party setting. - /// For each peer there are multiple channels open, one per query + step. - channels: Arc>>>, - tx: Sender, - rx: Arc>>>, - network: Weak, - chunks_sender: Sender, -} - -/// In memory channel is just a standard mpsc channel. -#[derive(Debug, Clone)] -pub struct InMemoryChannel { - tx: Sender>, -} - -impl InMemoryNetwork { - #[must_use] - pub fn new() -> Arc { - Arc::new_cyclic(|weak_ptr| { - let endpoints = Role::all().map(|i| InMemoryEndpoint::new(i, Weak::clone(weak_ptr))); - - Self { endpoints } - }) - } -} - -impl InMemoryEndpoint { - /// Creates new instance for a given helper role. - #[must_use] - #[allow(clippy::missing_panics_doc)] - pub fn new(id: Role, world: Weak) -> Arc { - let (tx, mut open_channel_rx) = mpsc::channel(1); - let (message_stream_tx, message_stream_rx) = mpsc::channel(1); - let (chunks_sender, mut chunks_receiver) = mpsc::channel(1); - - let this = Arc::new(Self { - role: id, - channels: Arc::new(Mutex::new(vec![ - HashMap::default(), - HashMap::default(), - HashMap::default(), - ])), - tx, - rx: Arc::new(Mutex::new(Some(message_stream_rx))), - network: world, - chunks_sender, - }); - - tokio::spawn({ - let this = Arc::clone(&this); - async move { - let mut peer_channels = SelectAll::new(); - let mut pending_sends = FuturesUnordered::new(); - let mut buf = HashMap::>::new(); - - loop { - ::tokio::select! { - // handle request to establish connection with a peer - Some(control_message) = open_channel_rx.recv() => { - match control_message { - ControlMessage::ConnectionRequest(channel_id, new_channel_rx) => { - peer_channels.push(ReceiverStream::new(new_channel_rx).map(move |msg| (channel_id.clone(), msg))); - } - } - } - // receive a batch of messages from the peer - Some((channel_id, msgs)) = peer_channels.next() => { - buf.entry(channel_id).or_default().extend(msgs); - } - // Handle request to send messages to a peer - Some(chunk) = chunks_receiver.recv() => { - pending_sends.push(this.send_chunk(chunk)); - } - // Drive pending sends to completion - Some(_) = pending_sends.next() => { } - // If there is nothing else to do, try to obtain a permit to move messages - // from the buffer to messaging layer. Potentially we might be thrashing - // on permits here. - Ok(permit) = message_stream_tx.reserve(), if !buf.is_empty() => { - let key = buf.keys().next().unwrap().clone(); - let msgs = buf.remove(&key).unwrap(); - - permit.send((key, msgs)); - } - else => { - break; - } - } - } - } - }.instrument(tracing::info_span!("in_memory_helper_event_loop", role=id.as_static_str()).or_current())); - - this - } -} - -impl InMemoryEndpoint { - async fn send_chunk(&self, chunk: MessageChunks) { - let conn = self.get_connection(&chunk.0).await; - conn.send(chunk.0, chunk.1).await.unwrap(); - } - - async fn get_connection(&self, addr: &ChannelId) -> InMemoryChannel { - let mut new_rx = None; - - let channel = { - let mut channels = self.channels.lock().unwrap(); - let peer_channel = &mut channels[addr.role]; - - match peer_channel.entry(addr.step.clone()) { - Entry::Occupied(entry) => entry.get().clone(), - Entry::Vacant(entry) => { - let (tx, rx) = mpsc::channel(1); - let tx = InMemoryChannel { tx }; - entry.insert(tx.clone()); - new_rx = Some(rx); - - tx - } - } - }; - - if let Some(rx) = new_rx { - self.network.upgrade().unwrap().endpoints[addr.role] - .tx - .send(ControlMessage::ConnectionRequest( - ChannelId::new(self.role, addr.step.clone()), - rx, - )) - .await - .unwrap(); - } - - channel - } -} - -#[async_trait] -impl Network for Arc { - type Sink = NetworkSink; - type MessageStream = ReceiverStream; - - fn sink(&self) -> Self::Sink { - let x = self.chunks_sender.clone(); - Self::Sink::new(x) - } - - fn recv_stream(&self) -> Self::MessageStream { - let mut rx = self.rx.lock().unwrap(); - if let Some(rx) = rx.take() { - ReceiverStream::new(rx) - } else { - panic!("Message stream has been consumed already"); - } - } -} - -impl InMemoryChannel { - async fn send(&self, id: ChannelId, msg: Vec) -> helpers::Result<()> { - self.tx - .send(msg) - .await - .map_err(|e| Error::send_error(id, e)) - } -} - -impl Debug for ControlMessage { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - ControlMessage::ConnectionRequest(channel, step) => { - write!(f, "ConnectionRequest(from={channel:?}, step={step:?})") - } - } - } -} diff --git a/src/test_fixture/transport/mod.rs b/src/test_fixture/transport/mod.rs new file mode 100644 index 000000000..0e57084ef --- /dev/null +++ b/src/test_fixture/transport/mod.rs @@ -0,0 +1,105 @@ +pub mod network; +mod routing; + +use crate::helpers::{ + CommandEnvelope, HelperIdentity, SubscriptionType, Transport, TransportCommand, TransportError, +}; +use crate::sync::Weak; +use async_trait::async_trait; +use routing::Switch; +use std::collections::HashMap; +use tokio::sync::mpsc::{channel, Sender}; +use tokio_stream::wrappers::ReceiverStream; + +/// In memory transport setup includes creating resources +/// to create a connection to every other peer in the network. +/// To finalize the setup and obtain [`InMemoryTransport`] instance +/// call [`listen`] method. +pub struct Setup { + switch_setup: routing::Setup, + peer_connections: HashMap>, +} + +impl From for Setup { + fn from(id: HelperIdentity) -> Self { + Self { + switch_setup: routing::Setup::from(id), + peer_connections: HashMap::default(), + } + } +} + +impl Setup { + pub fn connect(&mut self, dest: &mut Self) { + let (tx, rx) = channel(1); + self.peer_connections + .insert(dest.switch_setup.identity.clone(), tx); + dest.switch_setup + .add_peer(self.switch_setup.identity.clone(), rx); + } + + pub fn listen(self) -> InMemoryTransport { + let switch = self.switch_setup.listen(); + + InMemoryTransport { + switch, + peer_connections: self.peer_connections, + } + } +} + +/// Implementation of `Transport` for in-memory testing. Uses tokio channels to exchange messages +/// with peers. +pub struct InMemoryTransport { + switch: Switch, + peer_connections: HashMap>, +} + +impl InMemoryTransport { + pub fn setup(id: HelperIdentity) -> Setup { + Setup::from(id) + } + + /// Establish bidirectional connection between two transports + pub fn link(a: &mut Setup, b: &mut Setup) { + a.connect(b); + b.connect(a); + } + + pub fn identity(&self) -> &HelperIdentity { + self.switch.identity() + } +} + +#[async_trait] +impl Transport for Weak { + type CommandStream = ReceiverStream; + + async fn subscribe(&self, subscription_type: SubscriptionType) -> Self::CommandStream { + let this = self + .upgrade() + .unwrap_or_else(|| panic!("In memory transport is destroyed")); + match subscription_type { + SubscriptionType::QueryManagement => { + unimplemented!() + } + SubscriptionType::Query(query_id) => this.switch.query_stream(query_id).await, + } + } + + async fn send( + &self, + destination: &HelperIdentity, + command: TransportCommand, + ) -> Result<(), TransportError> { + let this = self + .upgrade() + .unwrap_or_else(|| panic!("In memory transport is destroyed")); + Ok(this + .peer_connections + .get(destination) + .unwrap() + .send(command) + .await?) + } +} diff --git a/src/test_fixture/transport/network.rs b/src/test_fixture/transport/network.rs new file mode 100644 index 000000000..345ca4be6 --- /dev/null +++ b/src/test_fixture/transport/network.rs @@ -0,0 +1,39 @@ +use crate::helpers::HelperIdentity; +use crate::sync::Arc; +use crate::test_fixture::transport::InMemoryTransport; + +/// Container for all active transports +pub struct InMemoryNetwork { + pub transports: [Arc; 3], +} + +impl Default for InMemoryNetwork { + fn default() -> Self { + let [mut first, mut second, mut third] = [ + InMemoryTransport::setup(1.try_into().unwrap()), + InMemoryTransport::setup(2.try_into().unwrap()), + InMemoryTransport::setup(3.try_into().unwrap()), + ]; + + InMemoryTransport::link(&mut first, &mut second); + InMemoryTransport::link(&mut second, &mut third); + InMemoryTransport::link(&mut third, &mut first); + + Self { + transports: [first.listen(), second.listen(), third.listen()].map(Arc::new), + } + } +} + +impl InMemoryNetwork { + #[must_use] + #[allow(clippy::missing_panics_doc)] + pub fn helper_identities(&self) -> [HelperIdentity; 3] { + self.transports + .iter() + .map(|t| t.identity().clone()) + .collect::>() + .try_into() + .unwrap() + } +} diff --git a/src/test_fixture/transport/routing.rs b/src/test_fixture/transport/routing.rs new file mode 100644 index 000000000..735b6c841 --- /dev/null +++ b/src/test_fixture/transport/routing.rs @@ -0,0 +1,191 @@ +use crate::helpers::{ + CommandEnvelope, CommandOrigin, HelperIdentity, SubscriptionType, TransportCommand, +}; +use crate::protocol::{QueryId, Step}; +use crate::task::JoinHandle; +use ::tokio::sync::{mpsc, oneshot}; +use futures::StreamExt; +use futures_util::stream::SelectAll; +#[cfg(all(feature = "shuttle", test))] +use shuttle::future as tokio; +use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; +use tokio_stream::wrappers::ReceiverStream; +use tracing::Instrument; + +#[derive(Debug)] +enum SwitchCommand { + Subscribe(SubscribeRequest), +} + +struct SubscribeRequest { + subscription: SubscriptionType, + link: mpsc::Sender, + ack_tx: oneshot::Sender<()>, +} + +impl Debug for SubscribeRequest { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Subscribe[{:?}]", self.subscription) + } +} + +impl SubscribeRequest { + pub fn new( + subscription: SubscriptionType, + link: mpsc::Sender, + ) -> (Self, oneshot::Receiver<()>) { + let (ack_tx, ack_rx) = oneshot::channel(); + ( + Self { + subscription, + link, + ack_tx, + }, + ack_rx, + ) + } + + pub fn acknowledge(self) { + self.ack_tx.send(()).unwrap(); + } + + pub fn subscription(&self) -> SubscriptionType { + self.subscription + } + + pub fn sender(&self) -> mpsc::Sender { + self.link.clone() + } +} + +#[derive(Debug)] +pub struct Setup { + pub identity: HelperIdentity, + peers: HashMap>, +} + +impl From for Setup { + fn from(identity: HelperIdentity) -> Self { + Self { + identity, + peers: HashMap::default(), + } + } +} + +impl Setup { + pub fn add_peer(&mut self, peer_id: HelperIdentity, peer_rx: mpsc::Receiver) { + assert!(self.peers.insert(peer_id, peer_rx).is_none()); + } + + pub(super) fn listen(self) -> Switch { + Switch::new(self) + } +} + +/// Takes care of forwarding commands received from multiple links (one link per peer) +/// to the subscribers +pub(super) struct Switch { + identity: HelperIdentity, + tx: mpsc::Sender, + handle: JoinHandle<()>, +} + +impl Switch { + fn new(setup: Setup) -> Self { + let (tx, mut rx) = mpsc::channel(1); + + let mut peer_links = SelectAll::new(); + for (addr, link) in setup.peers { + peer_links.push(ReceiverStream::new(link).map(move |command| (addr.clone(), command))); + } + + let handle = tokio::spawn(async move { + let mut query_router = QueryCommandRouter::default(); + loop { + ::tokio::select! { + Some(command) = rx.recv() => { + match command { + SwitchCommand::Subscribe(subscribe_command) => { + match subscribe_command.subscription() { + SubscriptionType::Query(query_id) => { + tracing::trace!("Subscribed to receive commands for query {query_id:?}"); + query_router.subscribe(query_id, subscribe_command.sender()); + subscribe_command.acknowledge(); + }, + SubscriptionType::QueryManagement => { + unimplemented!() + } + } + } + } + } + Some((origin, command)) = peer_links.next() => { + match command { + TransportCommand::StepData(query, step, payload) => query_router.route(origin, query, step, payload).await + } + } + else => { + tracing::debug!("All channels are closed and switch is terminated"); + break; + } + } + } + }.instrument(tracing::info_span!("transport_loop", id=?setup.identity).or_current())); + + Self { + identity: setup.identity, + handle, + tx, + } + } + + pub async fn query_stream(&self, query_id: QueryId) -> ReceiverStream { + let (tx, rx) = mpsc::channel(1); + let (command, ack_rx) = SubscribeRequest::new(SubscriptionType::Query(query_id), tx); + self.tx + .send(SwitchCommand::Subscribe(command)) + .await + .unwrap(); + ack_rx.await.unwrap(); + + ReceiverStream::new(rx) + } + + pub fn identity(&self) -> &HelperIdentity { + &self.identity + } +} + +impl Drop for Switch { + fn drop(&mut self) { + self.handle.abort(); + } +} + +#[derive(Default)] +struct QueryCommandRouter { + routes: HashMap>, +} + +impl QueryCommandRouter { + async fn route(&self, origin: HelperIdentity, query_id: QueryId, step: Step, payload: Vec) { + let sender = self + .routes + .get(&query_id) + .unwrap_or_else(|| panic!("No subscribers for {query_id:?}")); + + sender + .send(CommandEnvelope { + origin: CommandOrigin::Helper(origin), + payload: TransportCommand::StepData(query_id, step, payload), + }) + .await + .unwrap(); + } + + fn subscribe(&mut self, query_id: QueryId, sender: mpsc::Sender) { + assert!(self.routes.insert(query_id, sender).is_none()); + } +} diff --git a/src/test_fixture/world.rs b/src/test_fixture/world.rs index 53192dde4..a2da261cb 100644 --- a/src/test_fixture/world.rs +++ b/src/test_fixture/world.rs @@ -17,7 +17,7 @@ use crate::{ prss::Endpoint as PrssEndpoint, }, secret_sharing::DowngradeMalicious, - test_fixture::{logging, make_participants, network::InMemoryNetwork}, + test_fixture::{logging, make_participants}, }; use std::io::stdout; @@ -26,10 +26,13 @@ use std::mem::ManuallyDrop; use std::sync::atomic::AtomicBool; use std::{fmt::Debug, iter::zip, sync::Arc}; -use crate::protocol::Substep; +use crate::helpers::network::Network; +use crate::helpers::RoleAssignment; +use crate::protocol::{QueryId, Substep}; use crate::secret_sharing::IntoShares; use crate::telemetry::stats::Metrics; use crate::telemetry::StepStatsCsvExporter; +use crate::test_fixture::transport::network::InMemoryNetwork; use tracing::Level; use super::{ @@ -47,7 +50,7 @@ pub struct TestWorld { executions: AtomicUsize, metrics_handle: MetricsHandle, joined: AtomicBool, - _network: Arc, + _network: InMemoryNetwork, } #[derive(Copy, Clone)] @@ -98,21 +101,26 @@ impl TestWorld { /// Creates a new `TestWorld` instance using the provided `config`. /// # Panics /// Never. - #[must_use] - pub fn new_with(config: TestWorldConfig) -> TestWorld { + pub async fn new_with(config: TestWorldConfig) -> TestWorld { logging::setup(); let metrics_handle = MetricsHandle::new(config.metrics_level); let participants = make_participants(); - let network = InMemoryNetwork::new(); - - let gateways = network - .endpoints - .iter() - .map(|endpoint| Gateway::new(endpoint.role, endpoint, config.gateway_config)) - .collect::>() - .try_into() - .unwrap(); + let network = InMemoryNetwork::default(); + let role_assignment = RoleAssignment::new(network.helper_identities()); + + let gateways = join_all(network.transports.iter().enumerate().map(|(i, transport)| { + let role_assignment = role_assignment.clone(); + async move { + // simple role assignment, based on transport index + let role = Role::all()[i]; + let network = Network::new(Arc::downgrade(transport), QueryId, role_assignment); + Gateway::new(role, network, config.gateway_config).await + } + })) + .await + .try_into() + .unwrap(); TestWorld { gateways: ManuallyDrop::new(gateways), @@ -126,10 +134,9 @@ impl TestWorld { /// # Panics /// Never. - #[must_use] - pub fn new() -> TestWorld { + pub async fn new() -> TestWorld { let config = TestWorldConfig::default(); - Self::new_with(config) + Self::new_with(config).await } /// Creates protocol contexts for 3 helpers @@ -190,12 +197,6 @@ impl Drop for TestWorld { } } -impl Default for TestWorld { - fn default() -> Self { - Self::new() - } -} - #[async_trait] pub trait Runner { async fn semi_honest<'a, O, H, R>(&'a self, input: I, helper_fn: H) -> [O; 3] diff --git a/src/tests/infra.rs b/src/tests/infra.rs index d0c1f5ebb..927e296bc 100644 --- a/src/tests/infra.rs +++ b/src/tests/infra.rs @@ -14,7 +14,7 @@ mod randomized { shuttle::check_random( || { shuttle::future::block_on(async { - let world = TestWorld::new(); + let world = TestWorld::new().await; let input = (0u32..100).map(Fp32BitPrime::from).collect::>(); let output = world @@ -72,7 +72,7 @@ mod randomized { shuttle::check_random( || { shuttle::future::block_on(async { - let world = TestWorld::new(); + let world = TestWorld::new().await; let input = (0u32..10).map(Fp32BitPrime::from).collect::>(); let output = world diff --git a/src/uri.rs b/src/uri.rs new file mode 100644 index 000000000..bf02fb613 --- /dev/null +++ b/src/uri.rs @@ -0,0 +1,16 @@ +use hyper::Uri; +use serde::de::Error; +use serde::{Deserialize, Deserializer, Serializer}; + +/// # Errors +/// if serializing to string fails +pub fn serialize(uri: &Uri, serializer: S) -> Result { + serializer.serialize_str(&uri.to_string()) +} + +/// # Errors +/// if deserializing from string fails, or if string is not a [`Uri`] +pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result { + let s: String = Deserialize::deserialize(deserializer)?; + s.parse().map_err(D::Error::custom) +}