diff --git a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java index 33734fe85def5..248c0e07d456c 100644 --- a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java @@ -23,6 +23,8 @@ import org.opensearch.transport.TransportService; import java.io.IOException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; @@ -34,6 +36,9 @@ public class SegmentReplicationTargetServiceTests extends IndexShardTestCase { private SegmentReplicationSource replicationSource; private SegmentReplicationTargetService sut; + private ReplicationCheckpoint cp; + private ReplicationCheckpoint newCheckpoint; + @Override public void setUp() throws Exception { super.setUp(); @@ -48,6 +53,14 @@ public void setUp() throws Exception { when(replicationSourceFactory.get(indexShard)).thenReturn(replicationSource); sut = new SegmentReplicationTargetService(threadPool, recoverySettings, transportService, replicationSourceFactory); + cp = indexShard.getLatestReplicationCheckpoint(); + newCheckpoint = new ReplicationCheckpoint( + cp.getShardId(), + cp.getPrimaryTerm(), + cp.getSegmentsGen(), + cp.getSeqNo(), + cp.getSegmentInfosVersion() + 1 + ); } @Override @@ -121,8 +134,8 @@ public void testAlreadyOnNewCheckpoint() { verify(spy, times(0)).startReplication(any(), any(), any()); } - public void testShardAlreadyReplicating() { - SegmentReplicationTargetService spy = spy(sut); + public void testShardAlreadyReplicating() throws InterruptedException { + SegmentReplicationTargetService serviceSpy = spy(sut); // Create a separate target and start it so the shard is already replicating. final SegmentReplicationTarget target = new SegmentReplicationTarget( checkpoint, @@ -130,13 +143,24 @@ public void testShardAlreadyReplicating() { replicationSource, mock(SegmentReplicationTargetService.SegmentReplicationListener.class) ); - final SegmentReplicationTarget spyTarget = Mockito.spy(target); - spy.startReplication(spyTarget); + final SegmentReplicationTarget targetSpy = Mockito.spy(target); + CountDownLatch latch = new CountDownLatch(1); + doAnswer(invocation -> { + final ActionListener listener = invocation.getArgument(0); + // a new checkpoint arrives before we've completed. + serviceSpy.onNewCheckpoint(newCheckpoint, indexShard); + listener.onResponse(null); + latch.countDown(); + return null; + }).when(targetSpy).startReplication(any()); + doNothing().when(targetSpy).onDone(); - // a new checkpoint comes in for the same IndexShard. - spy.onNewCheckpoint(checkpoint, indexShard); - verify(spy, times(0)).startReplication(any(), any(), any()); - spyTarget.markAsDone(); + // start replication of this shard the first time. + serviceSpy.startReplication(targetSpy); + + // wait for the new checkpoint to arrive, before the listener completes. + latch.await(30, TimeUnit.SECONDS); + verify(serviceSpy, times(0)).startReplication(eq(newCheckpoint), eq(indexShard), any()); } public void testNewCheckpointBehindCurrentCheckpoint() { @@ -157,14 +181,6 @@ public void testNewCheckpoint_validationPassesAndReplicationFails() throws IOExc allowShardFailures(); SegmentReplicationTargetService spy = spy(sut); IndexShard spyShard = spy(indexShard); - ReplicationCheckpoint cp = indexShard.getLatestReplicationCheckpoint(); - ReplicationCheckpoint newCheckpoint = new ReplicationCheckpoint( - cp.getShardId(), - cp.getPrimaryTerm(), - cp.getSegmentsGen(), - cp.getSeqNo(), - cp.getSegmentInfosVersion() + 1 - ); ArgumentCaptor captor = ArgumentCaptor.forClass( SegmentReplicationTargetService.SegmentReplicationListener.class );