diff --git a/Cargo.lock b/Cargo.lock index 4d76988dd..92c1a698b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1629,9 +1629,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.3" +version = "4.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "949626d00e063efc93b6dca932419ceb5432f99769911c0b995f7e884c778813" +checksum = "5db83dced34638ad474f39f250d7fea9598bdd239eaced1bdf45d597da0f433f" dependencies = [ "clap_builder", "clap_derive", @@ -1649,9 +1649,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.2" +version = "4.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" +checksum = "f7e204572485eb3fbf28f871612191521df159bc3e15a9f5064c66dba3a8c05f" dependencies = [ "anstream", "anstyle", @@ -1662,9 +1662,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.3" +version = "4.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90239a040c80f5e14809ca132ddc4176ab33d5e17e49691793296e3fcb34d72f" +checksum = "c780290ccf4fb26629baa7a1081e68ced113f1d3ec302fa5948f1c381ebf06c6" dependencies = [ "heck 0.5.0", "proc-macro2", diff --git a/crates/worker/Cargo.toml b/crates/worker/Cargo.toml index b792b38f5..52a59f67a 100644 --- a/crates/worker/Cargo.toml +++ b/crates/worker/Cargo.toml @@ -74,6 +74,7 @@ tracing = { workspace = true } tracing-opentelemetry = { workspace = true } [dev-dependencies] +restate-bifrost = { workspace = true, features = ["test-util"] } restate-core = { workspace = true, features = ["test-util"] } restate-rocksdb = { workspace = true, features = ["test-util"] } restate-schema-api = { workspace = true, features = ["test-util"] } diff --git a/crates/worker/src/partition/shuffle.rs b/crates/worker/src/partition/shuffle.rs index 50ad5192e..584a0605f 100644 --- a/crates/worker/src/partition/shuffle.rs +++ b/crates/worker/src/partition/shuffle.rs @@ -8,9 +8,12 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -use crate::partition::shuffle::state_machine::StateMachine; -use crate::partition::types::OutboxMessageExt; +use std::future::Future; + use async_channel::{TryRecvError, TrySendError}; +use tokio::sync::mpsc; +use tracing::debug; + use restate_bifrost::Bifrost; use restate_core::cancellation_watcher; use restate_storage_api::deduplication_table::DedupInformation; @@ -19,9 +22,9 @@ use restate_types::identifiers::{LeaderEpoch, PartitionId, PartitionKey, WithPar use restate_types::message::{AckKind, MessageIndex}; use restate_types::NodeId; use restate_wal_protocol::{append_envelope_to_bifrost, Destination, Envelope, Header, Source}; -use std::future::Future; -use tokio::sync::mpsc; -use tracing::debug; + +use crate::partition::shuffle::state_machine::StateMachine; +use crate::partition::types::OutboxMessageExt; #[derive(Debug)] pub(crate) struct NewOutboxMessage { @@ -152,7 +155,7 @@ impl HintSender { } } -#[derive(Debug)] +#[derive(Debug, Copy, Clone)] pub(crate) struct ShuffleMetadata { partition_id: PartitionId, leader_epoch: LeaderEpoch, @@ -265,20 +268,23 @@ where } mod state_machine { - use crate::partition::shuffle; - use crate::partition::shuffle::{ - wrap_outbox_message_in_envelope, NewOutboxMessage, OutboxReaderError, ShuffleMetadata, - }; - use pin_project::pin_project; - use restate_storage_api::outbox_table::OutboxMessage; - use restate_types::message::MessageIndex; - use restate_wal_protocol::Envelope; use std::cmp::Ordering; use std::future::Future; use std::pin::Pin; + + use pin_project::pin_project; use tokio_util::sync::ReusableBoxFuture; use tracing::trace; + use restate_storage_api::outbox_table::OutboxMessage; + use restate_types::message::MessageIndex; + use restate_wal_protocol::Envelope; + + use crate::partition::shuffle; + use crate::partition::shuffle::{ + wrap_outbox_message_in_envelope, NewOutboxMessage, OutboxReaderError, ShuffleMetadata, + }; + type ReadFuture = ReusableBoxFuture< 'static, ( @@ -306,20 +312,6 @@ mod state_machine { state: State, } - async fn get_message( - mut outbox_reader: OutboxReader, - sequence_number: MessageIndex, - ) -> ( - Result, OutboxReaderError>, - OutboxReader, - ) { - let result = outbox_reader.get_message(sequence_number).await; - ( - result.map(|opt| opt.map(|m| (sequence_number, m))), - outbox_reader, - ) - } - async fn get_next_message( mut outbox_reader: OutboxReader, sequence_number: MessageIndex, @@ -432,7 +424,7 @@ mod state_machine { let successfully_shuffled_sequence_number = *this.current_sequence_number; *this.current_sequence_number += 1; - this.read_future.set(get_message( + this.read_future.set(get_next_message( this.outbox_reader .take() .expect("outbox reader should be available"), @@ -447,3 +439,421 @@ mod state_machine { } } } + +#[cfg(test)] +mod tests { + use anyhow::anyhow; + use assert2::let_assert; + use futures::{Stream, StreamExt}; + use std::iter; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use test_log::test; + use tokio::sync::mpsc; + + use restate_bifrost::{Bifrost, LogRecord, Record}; + use restate_core::{MockNetworkSender, TaskKind, TestCoreEnv, TestCoreEnvBuilder}; + use restate_storage_api::outbox_table::OutboxMessage; + use restate_storage_api::StorageError; + use restate_types::identifiers::{InvocationId, LeaderEpoch, PartitionId}; + use restate_types::invocation::ServiceInvocation; + use restate_types::logs::{LogId, Lsn, SequenceNumber}; + use restate_types::message::MessageIndex; + use restate_types::partition_table::FixedPartitionTable; + use restate_types::storage::StorageCodec; + use restate_types::{NodeId, Version}; + use restate_wal_protocol::{Command, Envelope}; + + use crate::partition::shuffle::{OutboxReader, OutboxReaderError, Shuffle, ShuffleMetadata}; + + struct MockOutboxReader { + base_offset: MessageIndex, + // there can be holes in our records + records: Vec>, + } + + impl MockOutboxReader { + fn new(base_offset: MessageIndex, records: Vec>) -> Self { + Self { + base_offset, + records, + } + } + + fn subslice_from_index( + &self, + starting_index: MessageIndex, + ) -> &[Option] { + if starting_index < self.base_offset { + <&[Option]>::default() + } else { + self.records + .get((starting_index - self.base_offset) as usize..) + .unwrap_or_default() + } + } + } + + impl OutboxReader for MockOutboxReader { + async fn get_next_message( + &mut self, + next_sequence_number: MessageIndex, + ) -> Result, OutboxReaderError> { + let next_sequence_number = next_sequence_number.max(self.base_offset); + let records = self.subslice_from_index(next_sequence_number); + let next_some_index = records.iter().position(|m| m.is_some()); + + Ok(next_some_index.map(|index| { + ( + next_sequence_number + u64::try_from(index).expect("usize fits in u64"), + OutboxMessage::ServiceInvocation( + records + .get(index) + .expect("subslice entry should exist") + .clone() + .expect("message should exist"), + ), + ) + })) + } + + async fn get_message( + &mut self, + next_sequence_number: MessageIndex, + ) -> Result, OutboxReaderError> { + Ok(self + .subslice_from_index(next_sequence_number) + .first() + .and_then(|x| { + x.clone().map(|service_invocation| { + OutboxMessage::ServiceInvocation(service_invocation) + }) + })) + } + } + + /// Outbox reader which is used to let the shuffler fail in a controlled manner so that we + /// can simulate restarts. + struct FailingOutboxReader { + records: Vec>, + fail_index: MessageIndex, + } + + impl FailingOutboxReader { + fn new(records: Vec>, fail_index: MessageIndex) -> Self { + Self { + records, + fail_index, + } + } + + fn check_fail(&self, next_sequence_number: MessageIndex) -> Result<(), OutboxReaderError> { + if next_sequence_number >= self.fail_index { + return Err(OutboxReaderError::Storage(StorageError::Generic(anyhow!( + "test error" + )))); + } + + Ok(()) + } + } + + impl OutboxReader for Arc { + async fn get_next_message( + &mut self, + next_sequence_number: MessageIndex, + ) -> Result, OutboxReaderError> { + let next_sequence_number = next_sequence_number as usize; + let offset_records = self.records.get(next_sequence_number..).unwrap_or_default(); + let next_some_index = offset_records + .iter() + .position(|record| record.is_some()) + .unwrap_or(offset_records.len()) + + next_sequence_number; + + self.check_fail(u64::try_from(next_some_index).expect("usize fits in u64"))?; + + Ok(self.records.get(next_some_index).map(|record| { + ( + u64::try_from(next_some_index).expect("usize fits in u64"), + OutboxMessage::ServiceInvocation(record.clone().expect("record must exist")), + ) + })) + } + + async fn get_message( + &mut self, + next_sequence_number: MessageIndex, + ) -> Result, OutboxReaderError> { + self.check_fail(next_sequence_number)?; + + Ok(self + .records + .get(next_sequence_number as usize) + .and_then(|msg| { + msg.clone().map(|service_invocation| { + OutboxMessage::ServiceInvocation(service_invocation) + }) + })) + } + } + + async fn collect_invoke_commands_until( + stream: impl Stream>, + last_invocation_id: InvocationId, + ) -> anyhow::Result> { + let mut messages = Vec::new(); + let mut stream = std::pin::pin!(stream); + + while let Some(record) = stream.next().await { + let record = record?; + + if let Record::Data(data) = record.record { + let mut body = data.into_body(); + let envelope = StorageCodec::decode::(&mut body)?; + + let_assert!(Command::Invoke(service_invocation) = envelope.command); + let invocation_id = service_invocation.invocation_id; + messages.push(service_invocation); + + if last_invocation_id == invocation_id { + break; + } + } + } + + Ok(messages) + } + + fn assert_received_invoke_commands( + received_invokes: Vec, + expected_invokes: Vec>, + ) { + // remove Nones + let expected_messages = expected_invokes.iter().flatten(); + + // received_messages can theoretically contain duplicate messages + let mut received_messages = received_invokes.iter(); + + for expected_message in expected_messages { + let mut message_found = false; + for received_message in received_messages.by_ref() { + if received_message == expected_message { + message_found = true; + break; + } + } + + assert!( + message_found, + "Expected message {:?} was not found in received messages", + expected_message + ); + } + } + + struct ShuffleEnv { + env: TestCoreEnv, + bifrost: Bifrost, + shuffle: Shuffle, + } + + async fn create_shuffle_env( + outbox_reader: OR, + ) -> ShuffleEnv { + // set numbers of partitions to 1 to easily find all sent messages by the shuffle + let env = TestCoreEnvBuilder::new_with_mock_network() + .with_partition_table(FixedPartitionTable::new(Version::MIN, 1)) + .build() + .await; + let tc = &env.tc; + let metadata = ShuffleMetadata::new( + PartitionId::from(0), + LeaderEpoch::from(0), + NodeId::new(0, Some(0)), + ); + + let (truncation_tx, _truncation_rx) = mpsc::channel(1); + + let bifrost = tc.run_in_scope("init bifrost", None, Bifrost::init()).await; + let shuffle = Shuffle::new(metadata, outbox_reader, truncation_tx, 1, bifrost.clone()); + + ShuffleEnv { + env, + bifrost, + shuffle, + } + } + + #[test(tokio::test)] + async fn shuffle_consecutive_outbox() -> anyhow::Result<()> { + let expected_messages = iter::repeat_with(|| Some(ServiceInvocation::mock())) + .take(10) + .collect::>(); + + let last_invocation_id = expected_messages + .last() + .and_then(|msg| { + msg.as_ref() + .map(|service_invocation| service_invocation.invocation_id) + }) + .expect("service invocation should be present"); + + let outbox_reader = MockOutboxReader::new(42, expected_messages.clone()); + let shuffle_env = create_shuffle_env(outbox_reader).await; + let tc = shuffle_env.env.tc.clone(); + + tc.run_in_scope("test", None, async { + let partition_id = shuffle_env.shuffle.metadata.partition_id; + tc.spawn_child( + TaskKind::Shuffle, + "shuffle", + None, + shuffle_env.shuffle.run(), + )?; + let reader = shuffle_env + .bifrost + .create_reader(LogId::from(partition_id), Lsn::INVALID, Lsn::MAX) + .await?; + + let messages = collect_invoke_commands_until(reader, last_invocation_id).await?; + + assert_received_invoke_commands(messages, expected_messages); + + Ok::<(), anyhow::Error>(()) + }) + .await + } + + #[test(tokio::test)] + async fn shuffle_holey_outbox() -> anyhow::Result<()> { + let expected_messages = vec![ + Some(ServiceInvocation::mock()), + None, + None, + Some(ServiceInvocation::mock()), + Some(ServiceInvocation::mock()), + ]; + + let last_invocation_id = expected_messages + .last() + .and_then(|msg| { + msg.as_ref() + .map(|service_invocation| service_invocation.invocation_id) + }) + .expect("service invocation should be present"); + + let outbox_reader = MockOutboxReader::new(42, expected_messages.clone()); + let shuffle_env = create_shuffle_env(outbox_reader).await; + let tc = shuffle_env.env.tc.clone(); + + tc.run_in_scope("test", None, async { + let partition_id = shuffle_env.shuffle.metadata.partition_id; + tc.spawn_child( + TaskKind::Shuffle, + "shuffle", + None, + shuffle_env.shuffle.run(), + )?; + let reader = shuffle_env + .bifrost + .create_reader(LogId::from(partition_id), Lsn::INVALID, Lsn::MAX) + .await?; + + let messages = collect_invoke_commands_until(reader, last_invocation_id).await?; + + assert_received_invoke_commands(messages, expected_messages); + + Ok::<(), anyhow::Error>(()) + }) + .await + } + + #[test(tokio::test)] + async fn shuffle_with_restarts() -> anyhow::Result<()> { + let expected_messages: Vec<_> = iter::repeat_with(|| Some(ServiceInvocation::mock())) + .take(100) + .collect(); + + let last_invocation_id = expected_messages + .last() + .and_then(|msg| { + msg.as_ref() + .map(|service_invocation| service_invocation.invocation_id) + }) + .expect("service invocation should be present"); + + let mut outbox_reader = Arc::new(FailingOutboxReader::new(expected_messages.clone(), 10)); + let shuffle_env = create_shuffle_env(Arc::clone(&outbox_reader)).await; + let tc = shuffle_env.env.tc.clone(); + let total_restarts = Arc::new(AtomicUsize::new(0)); + + let shuffle_task_id = tc + .run_in_scope("test", None, async { + let partition_id = shuffle_env.shuffle.metadata.partition_id; + let reader = shuffle_env + .bifrost + .create_reader(LogId::from(partition_id), Lsn::INVALID, Lsn::MAX) + .await?; + let total_restarts = Arc::clone(&total_restarts); + + let shuffle_task = + tc.spawn_child(TaskKind::Shuffle, "shuffle", None, async move { + let mut shuffle = shuffle_env.shuffle; + let metadata = shuffle.metadata; + let truncation_tx = shuffle.truncation_tx.clone(); + let mut processed_range = 0; + let mut num_restarts = 0; + + // restart shuffle on failures and update failing outbox reader + while shuffle.run().await.is_err() { + num_restarts += 1; + // update the failing outbox reader to make a bit more progress and delete some of the delivered records + { + let outbox_reader = Arc::get_mut(&mut outbox_reader) + .expect("only one reference should exist"); + + // leave the first entry to generate some holes + for idx in (processed_range + 1)..outbox_reader.fail_index { + outbox_reader.records[usize::try_from(idx) + .expect("index should fit in usize")] = None; + } + + processed_range = outbox_reader.fail_index; + outbox_reader.fail_index += 10; + } + + shuffle = Shuffle::new( + metadata, + Arc::clone(&outbox_reader), + truncation_tx.clone(), + 1, + shuffle_env.bifrost.clone(), + ); + } + + total_restarts.store(num_restarts, Ordering::Relaxed); + + Ok(()) + })?; + + let messages = collect_invoke_commands_until(reader, last_invocation_id).await?; + + assert_received_invoke_commands(messages, expected_messages); + + Ok::<_, anyhow::Error>(shuffle_task) + }) + .await?; + + let shuffle_task = tc.cancel_task(shuffle_task_id).expect("should exist"); + shuffle_task.await?; + + // make sure that we have restarted the shuffle + assert!( + total_restarts.load(Ordering::Relaxed) > 0, + "expecting the shuffle to be restarted a couple of times" + ); + + Ok(()) + } +}