Skip to content

Commit

Permalink
Merge pull request #363 from akoshelev/infra-network
Browse files Browse the repository at this point in the history
`Transport` trait definition and in-memory implementation
  • Loading branch information
akoshelev authored Dec 23, 2022
2 parents dac25f7 + e5bdcd1 commit 82f7a48
Show file tree
Hide file tree
Showing 58 changed files with 966 additions and 453 deletions.
2 changes: 1 addition & 1 deletion benches/oneshot/ipa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ async fn main() -> Result<(), Error> {
let mut config = TestWorldConfig::default();
config.gateway_config.send_buffer_config.items_in_batch = 1;
config.gateway_config.send_buffer_config.batch_count = 1000;
let world = TestWorld::new_with(config);
let world = TestWorld::new_with(config).await;
let mut rng = rand::thread_rng();

const BATCHSIZE: u64 = 100;
Expand Down
2 changes: 1 addition & 1 deletion benches/oneshot/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async fn main() -> Result<(), Error> {
let mut config = TestWorldConfig::default();
config.gateway_config.send_buffer_config.items_in_batch = 1;
config.gateway_config.send_buffer_config.batch_count = 1000;
let world = TestWorld::new_with(config);
let world = TestWorld::new_with(config).await;
let [ctx0, ctx1, ctx2] = world.contexts::<Fp32BitPrime>();
let num_bits = 64;
let mut rng = thread_rng();
Expand Down
2 changes: 1 addition & 1 deletion src/bin/test_mpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ fn print_output<O: Debug>(values: &[Vec<O>; 3]) {
]);
}

println!("{}", shares_table);
println!("{shares_table}");
}

#[tokio::main]
Expand Down
1 change: 1 addition & 0 deletions src/helpers/buffers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod receive;
mod send;

pub use receive::ReceiveBuffer;
pub(super) use send::PushError;
pub use {send::Config as SendBufferConfig, send::SendBuffer};

#[cfg(debug_assertions)]
Expand Down
11 changes: 3 additions & 8 deletions src/helpers/buffers/send.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::{
helpers::{
buffers::fsv::FixedSizeByteVec,
network::{ChannelId, MessageEnvelope},
buffers::fsv::FixedSizeByteVec, network::ChannelId, network::MessageEnvelope,
MESSAGE_PAYLOAD_SIZE_BYTES,
},
protocol::RecordId,
Expand Down Expand Up @@ -169,12 +168,8 @@ impl From<&ByteBuf> for Range<RecordId> {

#[cfg(all(test, not(feature = "shuttle")))]
mod tests {
use crate::helpers::buffers::send::{ByteBuf, Config, PushError};
use crate::helpers::buffers::SendBuffer;
use crate::helpers::network::{ChannelId, MessageEnvelope};
use crate::helpers::Role;
use crate::protocol::{RecordId, Step};

use super::*;
use crate::{helpers::Role, protocol::Step};
use tinyvec::array_vec;

impl Clone for MessageEnvelope {
Expand Down
28 changes: 12 additions & 16 deletions src/helpers/error.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::helpers::messaging::SendRequest;
use crate::helpers::TransportError;
use crate::{
error::BoxError,
helpers::{
messaging::ReceiveRequest,
messaging::{ReceiveRequest, SendRequest},
network::{ChannelId, MessageChunks},
Role,
HelperIdentity, Role,
},
net::MpcHelperServerError,
protocol::{RecordId, Step},
Expand Down Expand Up @@ -40,11 +40,12 @@ pub enum Error {
#[source]
inner: BoxError,
},
#[error("Failed to send data to the network")]
NetworkError {
#[from]
inner: BoxError,
},
#[error("Encountered unknown identity {0:?}")]
UnknownIdentity(HelperIdentity),
#[error("identity had invalid format: {0}")]
InvalidIdentity(#[from] hyper::http::uri::InvalidUri),
#[error("Failed to send command on the transport: {0}")]
TransportError(#[from] TransportError),
#[error("server encountered an error: {0}")]
ServerError(#[from] MpcHelperServerError),
}
Expand Down Expand Up @@ -104,15 +105,10 @@ impl From<SendError<SendRequest>> for Error {

impl From<PollSendError<MessageChunks>> for Error {
fn from(source: PollSendError<MessageChunks>) -> Self {
let err_msg = source.to_string();
let inner = source.to_string().into();
match source.into_inner() {
Some(inner) => Self::SendError {
channel: inner.0,
inner: err_msg.into(),
},
None => Self::PollSendError {
inner: err_msg.into(),
},
Some((channel, _)) => Self::SendError { channel, inner },
None => Self::PollSendError { inner },
}
}
}
Expand Down
26 changes: 15 additions & 11 deletions src/helpers/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::net::SocketAddr;
pub struct HttpHelper<'p> {
role: Role,
peers: &'p [peer::Config; 3],
gateway_config: GatewayConfig,
_gateway_config: GatewayConfig,
server: MpcHelperServer,
}

Expand All @@ -36,7 +36,7 @@ impl<'p> HttpHelper<'p> {
Self {
role,
peers,
gateway_config,
_gateway_config: gateway_config,
server: MpcHelperServer::new(MessageSendMap::default()),
}
}
Expand All @@ -57,15 +57,17 @@ impl<'p> HttpHelper<'p> {
/// adds a query to the running server so that it knows where to send arriving data
/// # Errors
/// if a query has been previously added
pub fn query(&self, query_id: QueryId) -> Result<Gateway, Error> {
tracing::debug!("starting query {}", query_id.as_ref());
let network = HttpNetwork::new(self.role, self.peers, query_id);

let gateway = Gateway::new(self.role, &network, self.gateway_config);
// allow for server to forward requests to this network
// TODO: how to remove from map?
self.server.add_query(query_id, network)?;
Ok(gateway)
pub fn query(&self, _query_id: QueryId) -> Result<Gateway, Error> {
// TODO: This requires `HttpNetwork` to implement Transport
unimplemented!();
// tracing::debug!("starting query {}", query_id.as_ref());
// let network = HttpNetwork::new(self.role, self.peers, query_id);
//
// let gateway = Gateway::new(self.role, network, self.gateway_config);
// // allow for server to forward requests to this network
// // TODO: how to remove from map?
// self.server.add_query(query_id, network)?;
// Ok(gateway)
}

/// establish the prss endpoint by exchanging public keys with the other helpers
Expand Down Expand Up @@ -208,6 +210,7 @@ mod e2e_tests {
}

#[tokio::test]
#[ignore] // TODO (thurstonsand): enable after `HttpNetwork` implements `Transport`
async fn prss_key_exchange() {
logging::setup();

Expand Down Expand Up @@ -275,6 +278,7 @@ mod e2e_tests {
}

#[tokio::test]
#[ignore] // TODO (thurstonsand): enable after `HttpNetwork` implements `Transport`
async fn basic_mul() {
logging::setup();

Expand Down
12 changes: 7 additions & 5 deletions src/helpers/http/network.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#[allow(deprecated)]
use crate::helpers::old_network::{Network, NetworkSink};
use crate::{
helpers::{
network::{MessageChunks, Network, NetworkSink},
Role,
},
helpers::{network::MessageChunks, Role},
net::{discovery::peer, HttpSendMessagesArgs, MpcHelperClient},
protocol::QueryId,
sync::{Arc, Mutex},
Expand Down Expand Up @@ -99,6 +98,7 @@ impl HttpNetwork {
}
}

#[allow(deprecated)]
impl Network for HttpNetwork {
type Sink = NetworkSink<MessageChunks>;

Expand All @@ -122,15 +122,16 @@ impl Network for HttpNetwork {
#[cfg(test)]
mod tests {
use super::*;
use crate::test_fixture::net::localhost_config;
use crate::{
helpers::{network::ChannelId, Direction, MESSAGE_PAYLOAD_SIZE_BYTES},
net::{discovery::PeerDiscovery, BindTarget, MessageSendMap, MpcHelperServer},
protocol::Step,
test_fixture::net::localhost_config,
};
use futures::{Stream, StreamExt};
use futures_util::SinkExt;

#[allow(deprecated)]
async fn setup() -> (Role, [peer::Config; 3], impl Stream<Item = MessageChunks>) {
// setup server
let network = HttpNetwork::new_without_clients(QueryId, None);
Expand All @@ -149,6 +150,7 @@ mod tests {
}

#[tokio::test]
#[allow(deprecated)]
async fn send_multiple_messages() {
const DATA_LEN: usize = 3;
let (target_role, peers_conf, mut rx_stream) = setup().await;
Expand Down
68 changes: 38 additions & 30 deletions src/helpers/messaging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,30 @@
//! corresponding helper without needing to know the exact location - this is what this module
//! enables MPC protocols to do.
//!
use crate::ff::{Field, Int};
use crate::helpers::buffers::{SendBuffer, SendBufferConfig};
use crate::helpers::{MessagePayload, MESSAGE_PAYLOAD_SIZE_BYTES};
use crate::task::JoinHandle;
use crate::telemetry::labels::STEP;
use crate::{
helpers::buffers::ReceiveBuffer,
helpers::error::Error,
helpers::network::{ChannelId, MessageEnvelope, Network},
helpers::Role,
ff::{Field, Int},
helpers::{
buffers::{ReceiveBuffer, SendBuffer, SendBufferConfig},
network::ChannelId,
Error, MessagePayload, Role, MESSAGE_PAYLOAD_SIZE_BYTES,
},
protocol::{RecordId, Step},
task::JoinHandle,
telemetry::{labels::STEP, metrics::RECORDS_SENT},
};
use ::tokio::sync::{mpsc, oneshot};
use ::tokio::time::Instant;
use futures::SinkExt;
use futures::StreamExt;
use std::fmt::{Debug, Formatter};
use std::time::Duration;
use std::{io, panic};
use tinyvec::array_vec;
use tracing::Instrument;

use crate::telemetry::metrics::RECORDS_SENT;
use crate::helpers::buffers::PushError;
use crate::helpers::network::{MessageEnvelope, Network};
use crate::helpers::time::Timer;
use crate::helpers::transport::Transport;
use ::tokio::sync::{mpsc, oneshot};
use futures_util::stream::FuturesUnordered;
#[cfg(all(feature = "shuttle", test))]
use shuttle::future as tokio;

Expand Down Expand Up @@ -160,19 +161,18 @@ pub struct GatewayConfig {
}

impl Gateway {
pub fn new<N: Network>(role: Role, network: &N, config: GatewayConfig) -> Self {
pub async fn new<T: Transport>(role: Role, network: Network<T>, config: GatewayConfig) -> Self {
let (recv_tx, mut recv_rx) = mpsc::channel::<ReceiveRequest>(config.recv_outstanding);
let (send_tx, mut send_rx) = mpsc::channel::<SendRequest>(config.send_outstanding);
let mut message_stream = network.recv_stream();
let mut network_sink = network.sink();
let mut message_stream = network.recv_stream().await;

let control_handle = tokio::spawn(async move {
const INTERVAL: Duration = Duration::from_secs(3);

let mut receive_buf = ReceiveBuffer::default();
let mut send_buf = SendBuffer::new(config.send_buffer_config);

let sleep = ::tokio::time::sleep(INTERVAL);
let mut pending_sends = FuturesUnordered::new();
let sleep = Timer::new(INTERVAL);
::tokio::pin!(sleep);

loop {
Expand All @@ -189,9 +189,14 @@ impl Gateway {
tracing::trace!("received {} bytes from {:?}", messages.len(), channel_id);
receive_buf.receive_messages(&channel_id, &messages);
}
Some(send_req) = send_rx.recv() => {
tracing::trace!("new SendRequest({:?})", send_req);
send_message::<N>(&mut network_sink, &mut send_buf, send_req).await;
Some((channel_id, envelope)) = send_rx.recv(), if pending_sends.is_empty() => {
tracing::trace!("new SendRequest({:?})", (&channel_id, &envelope));
metrics::increment_counter!(RECORDS_SENT, STEP => channel_id.step.as_ref().to_string());
let data = send_buf.push(&channel_id, &envelope);
pending_sends.push(send_message(&network, channel_id, data));
}
Some(_) = &mut pending_sends.next() => {
pending_sends.clear();
}
_ = &mut sleep => {
#[cfg(debug_assertions)]
Expand All @@ -204,7 +209,7 @@ impl Gateway {
}

// reset the timer on every action
sleep.as_mut().reset(Instant::now() + INTERVAL);
sleep.as_mut().reset();
}
}.instrument(tracing::info_span!("gateway_loop", role=role.as_static_str()).or_current()));

Expand Down Expand Up @@ -287,19 +292,22 @@ impl Debug for ReceiveRequest {
}
}

async fn send_message<N: Network>(sink: &mut N::Sink, buf: &mut SendBuffer, req: SendRequest) {
let (channel_id, msg) = req;
metrics::increment_counter!(RECORDS_SENT, STEP => channel_id.step.as_ref().to_string());
match buf.push(&channel_id, &msg) {
async fn send_message<T: Transport>(
network: &Network<T>,
channel_id: ChannelId,
data: Result<Option<Vec<u8>>, PushError>,
) {
match data {
Ok(Some(buf_to_send)) => {
tracing::trace!("sending {} bytes to {:?}", buf_to_send.len(), &channel_id);
sink.send((channel_id, buf_to_send))
network
.send((channel_id, buf_to_send))
.await
.expect("Failed to send data to the network");
}
Ok(None) => {}
Err(err) => panic!("failed to send to the {channel_id:?}: {err}"),
};
}
}

#[cfg(debug_assertions)]
Expand Down Expand Up @@ -329,7 +337,7 @@ mod tests {
config.gateway_config.send_buffer_config.items_in_batch = 1; // Send every record
config.gateway_config.send_buffer_config.batch_count = 3; // keep 3 at a time

let world = Box::leak(Box::new(TestWorld::new_with(config)));
let world = Box::leak(Box::new(TestWorld::new_with(config).await));
let contexts = world.contexts::<Fp31>();
let sender_ctx = contexts[0].narrow("reordering-test");
let recv_ctx = contexts[1].narrow("reordering-test");
Expand Down Expand Up @@ -361,7 +369,7 @@ mod tests {
#[tokio::test]
#[should_panic(expected = "Record RecordId(1) has been received twice")]
async fn duplicate_message() {
let world = TestWorld::new();
let world = TestWorld::new().await;
let (v1, v2) = (Fp31::from(1u128), Fp31::from(2u128));
let peer = Role::H2;
let record_id = 1.into();
Expand Down
Loading

0 comments on commit 82f7a48

Please sign in to comment.