diff --git a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java index 883b415250eb9..1faaa16ce5628 100644 --- a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java @@ -47,6 +47,7 @@ import org.opensearch.indices.replication.common.ReplicationLuceneIndex; import org.opensearch.indices.replication.common.ReplicationType; import org.opensearch.telemetry.tracing.noop.NoopTracer; +import org.opensearch.test.junit.annotations.TestLogging; import org.opensearch.test.transport.CapturingTransport; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -123,11 +124,7 @@ public void setUp() throws Exception { ); testThreadPool = new TestThreadPool("test", Settings.EMPTY); - localNode = new DiscoveryNode( - primaryShard.getReplicationGroup().getRoutingTable().primaryShard().currentNodeId(), - buildNewFakeTransportAddress(), - Version.CURRENT - ); + localNode = new DiscoveryNode("local", buildNewFakeTransportAddress(), Version.CURRENT); CapturingTransport transport = new CapturingTransport(); transportService = transport.createTransportService( Settings.EMPTY, @@ -264,9 +261,13 @@ public void testAlreadyOnNewCheckpoint() { verify(spy, times(1)).updateVisibleCheckpoint(NO_OPS_PERFORMED, replicaShard); } - @AwaitsFix(bugUrl = "https://github.com/opensearch-project/OpenSearch/issues/8928") - public void testShardAlreadyReplicating() { + @TestLogging(reason = "Getting trace logs from replication package", value = "org.opensearch.indices.replication:TRACE") + public void testShardAlreadyReplicating() throws InterruptedException { + // in this case shard is already replicating and we receive an ahead checkpoint with same pterm. + // ongoing replication is not cancelled and new one does not start. CountDownLatch blockGetCheckpointMetadata = new CountDownLatch(1); + CountDownLatch continueGetCheckpointMetadata = new CountDownLatch(1); + CountDownLatch replicationCompleteLatch = new CountDownLatch(1); SegmentReplicationSource source = new TestReplicationSource() { @Override public void getCheckpointMetadata( @@ -275,11 +276,13 @@ public void getCheckpointMetadata( ActionListener listener ) { try { - blockGetCheckpointMetadata.await(); - final CopyState copyState = new CopyState(primaryShard); - listener.onResponse( - new CheckpointInfoResponse(copyState.getCheckpoint(), copyState.getMetadataMap(), copyState.getInfosBytes()) - ); + blockGetCheckpointMetadata.countDown(); + continueGetCheckpointMetadata.await(); + try (final CopyState copyState = new CopyState(primaryShard)) { + listener.onResponse( + new CheckpointInfoResponse(copyState.getCheckpoint(), copyState.getMetadataMap(), copyState.getInfosBytes()) + ); + } } catch (InterruptedException | IOException e) { throw new RuntimeException(e); } @@ -300,24 +303,73 @@ public void getSegmentFiles( final SegmentReplicationTarget target = spy( new SegmentReplicationTarget( replicaShard, - primaryShard.getLatestReplicationCheckpoint(), + initialCheckpoint, source, - mock(SegmentReplicationTargetService.SegmentReplicationListener.class) + new SegmentReplicationTargetService.SegmentReplicationListener() { + @Override + public void onReplicationDone(SegmentReplicationState state) { + replicationCompleteLatch.countDown(); + } + + @Override + public void onReplicationFailure( + SegmentReplicationState state, + ReplicationFailedException e, + boolean sendShardFailure + ) { + Assert.fail("Replication should not fail"); + } + } ) ); final SegmentReplicationTargetService spy = spy(sut); - doReturn(false).when(spy).processLatestReceivedCheckpoint(eq(replicaShard), any()); // Start first round of segment replication. spy.startReplication(target); + // wait until we are at getCheckpointMetadata stage + blockGetCheckpointMetadata.await(5, TimeUnit.MINUTES); - // Start second round of segment replication, this should fail to start as first round is still in-progress - spy.onNewCheckpoint(newPrimaryCheckpoint, replicaShard); - verify(spy, times(1)).processLatestReceivedCheckpoint(eq(replicaShard), any()); - blockGetCheckpointMetadata.countDown(); + // try and insert a new target directly - it should fail immediately and alert listener + spy.startReplication( + new SegmentReplicationTarget( + replicaShard, + aheadCheckpoint, + source, + new SegmentReplicationTargetService.SegmentReplicationListener() { + @Override + public void onReplicationDone(SegmentReplicationState state) { + Assert.fail("Should not succeed"); + } + + @Override + public void onReplicationFailure( + SegmentReplicationState state, + ReplicationFailedException e, + boolean sendShardFailure + ) { + assertFalse(sendShardFailure); + assertEquals("Shard " + replicaShard.shardId() + " is already replicating", e.getMessage()); + } + } + ) + ); + + // Start second round of segment replication through onNewCheckpoint, this should fail to start as first round is still in-progress + // aheadCheckpoint is of same pterm but higher version + assertTrue(replicaShard.shouldProcessCheckpoint(aheadCheckpoint)); + spy.onNewCheckpoint(aheadCheckpoint, replicaShard); + verify(spy, times(0)).processLatestReceivedCheckpoint(eq(replicaShard), any()); + // start replication is not invoked with aheadCheckpoint + verify(spy, times(0)).startReplication( + eq(replicaShard), + eq(aheadCheckpoint), + any(SegmentReplicationTargetService.SegmentReplicationListener.class) + ); + continueGetCheckpointMetadata.countDown(); + replicationCompleteLatch.await(5, TimeUnit.MINUTES); } - public void testOnNewCheckpointFromNewPrimaryCancelOngoingReplication() throws InterruptedException { + public void testShardAlreadyReplicating_HigherPrimaryTermReceived() throws InterruptedException { // Create a spy of Target Service so that we can verify invocation of startReplication call with specific checkpoint on it. SegmentReplicationTargetService serviceSpy = spy(sut); doNothing().when(serviceSpy).updateVisibleCheckpoint(anyLong(), any());