diff --git a/server/src/main/java/org/opensearch/action/admin/indices/replication/TransportSegmentReplicationStatsAction.java b/server/src/main/java/org/opensearch/action/admin/indices/replication/TransportSegmentReplicationStatsAction.java index b15d83de4b3b7..38a5b1c05c8cf 100644 --- a/server/src/main/java/org/opensearch/action/admin/indices/replication/TransportSegmentReplicationStatsAction.java +++ b/server/src/main/java/org/opensearch/action/admin/indices/replication/TransportSegmentReplicationStatsAction.java @@ -21,7 +21,6 @@ import org.opensearch.core.action.support.DefaultShardOperationFailedException; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.index.shard.ShardId; -import org.opensearch.index.IndexService; import org.opensearch.index.SegmentReplicationPerGroupStats; import org.opensearch.index.SegmentReplicationPressureService; import org.opensearch.index.SegmentReplicationShardStats; @@ -31,13 +30,13 @@ import org.opensearch.indices.IndicesService; import org.opensearch.indices.replication.SegmentReplicationState; import org.opensearch.indices.replication.SegmentReplicationTargetService; -import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -166,9 +165,10 @@ protected SegmentReplicationStatsRequest readRequestFrom(StreamInput in) throws @Override protected SegmentReplicationShardStatsResponse shardOperation(SegmentReplicationStatsRequest request, ShardRouting shardRouting) { - IndexService indexService = indicesService.indexServiceSafe(shardRouting.shardId().getIndex()); - IndexShard indexShard = indexService.getShard(shardRouting.shardId().id()); ShardId shardId = shardRouting.shardId(); + IndexShard indexShard = indicesService + .indexServiceSafe(shardId.getIndex()) + .getShard(shardId.id()); if (indexShard.indexSettings().isSegRepEnabledOrRemoteNode() == false) { return null; @@ -177,12 +177,7 @@ protected SegmentReplicationShardStatsResponse shardOperation(SegmentReplication if (shardRouting.primary()) { return new SegmentReplicationShardStatsResponse(pressureService.getStatsForShard(indexShard)); } else if (shardRouting.isSearchOnly()) { - SegmentReplicationShardStats segmentReplicationShardStats = calcualteSegmentReplicationShardStats( - shardRouting, - indexShard, - shardId, - request.activeOnly() - ); + SegmentReplicationShardStats segmentReplicationShardStats = calcualteSegmentReplicationShardStats(shardRouting); return new SegmentReplicationShardStatsResponse(segmentReplicationShardStats); } else { return new SegmentReplicationShardStatsResponse(getSegmentReplicationState(shardId, request.activeOnly())); @@ -208,31 +203,22 @@ protected ClusterBlockException checkRequestBlock( return state.blocks().indicesBlockedException(ClusterBlockLevel.METADATA_READ, concreteIndices); } - private SegmentReplicationShardStats calcualteSegmentReplicationShardStats( - ShardRouting shardRouting, - IndexShard indexShard, - ShardId shardId, - boolean isActiveOnly - ) { - ReplicationCheckpoint indexReplicationCheckpoint = indexShard.getLatestReplicationCheckpoint(); - SegmentReplicationState segmentReplicationState = getSegmentReplicationState(shardId, isActiveOnly); - if (segmentReplicationState != null) { - ReplicationCheckpoint latestReplicationCheckpointReceived = segmentReplicationState.getLatestReplicationCheckpoint(); - - SegmentReplicationShardStats segmentReplicationShardStats = new SegmentReplicationShardStats( - shardRouting.allocationId().getId(), - calculateCheckpointsBehind(indexReplicationCheckpoint, latestReplicationCheckpointReceived), - calculateBytesBehind(indexReplicationCheckpoint, latestReplicationCheckpointReceived), - 0, - calculateCurrentReplicationLag(shardId), - getLastCompletedReplicationLag(shardId) - ); + private SegmentReplicationShardStats calcualteSegmentReplicationShardStats(ShardRouting shardRouting) { + ShardId shardId = shardRouting.shardId(); + SegmentReplicationState completedSegmentReplicationState = targetService.getlatestCompletedEventSegmentReplicationState(shardId); + SegmentReplicationState ongoingSegmentReplicationState = targetService.getOngoingEventSegmentReplicationState(shardId); - segmentReplicationShardStats.setCurrentReplicationState(segmentReplicationState); - return segmentReplicationShardStats; - } else { - return new SegmentReplicationShardStats(shardRouting.allocationId().getId(), 0, 0, 0, 0, 0); - } + SegmentReplicationShardStats segmentReplicationShardStats = new SegmentReplicationShardStats( + shardRouting.allocationId().getId(), + calculateCheckpointsBehind(completedSegmentReplicationState, ongoingSegmentReplicationState), + calculateBytesBehind(completedSegmentReplicationState, ongoingSegmentReplicationState), + 0, + getCurrentReplicationLag(ongoingSegmentReplicationState), + getLastCompletedReplicationLag(completedSegmentReplicationState) + ); + + segmentReplicationShardStats.setCurrentReplicationState(targetService.getSegmentReplicationState(shardId)); + return segmentReplicationShardStats; } private SegmentReplicationState getSegmentReplicationState(ShardId shardId, boolean isActiveOnly) { @@ -244,38 +230,54 @@ private SegmentReplicationState getSegmentReplicationState(ShardId shardId, bool } private long calculateCheckpointsBehind( - ReplicationCheckpoint indexReplicationCheckpoint, - ReplicationCheckpoint latestReplicationCheckpointReceived + SegmentReplicationState completedSegmentReplicationState, + SegmentReplicationState ongoingSegmentReplicationState ) { - if (latestReplicationCheckpointReceived != null) { - return latestReplicationCheckpointReceived.getSegmentInfosVersion() - indexReplicationCheckpoint.getSegmentInfosVersion(); + if (ongoingSegmentReplicationState == null || ongoingSegmentReplicationState.getReplicationCheckpoint() == null) { + return 0; } - return 0; + + if(completedSegmentReplicationState == null || + completedSegmentReplicationState.getReplicationCheckpoint() == null) { + return ongoingSegmentReplicationState + .getReplicationCheckpoint() + .getSegmentInfosVersion(); + } + + return ongoingSegmentReplicationState.getReplicationCheckpoint().getSegmentInfosVersion() - + completedSegmentReplicationState.getReplicationCheckpoint().getSegmentInfosVersion(); } private long calculateBytesBehind( - ReplicationCheckpoint indexReplicationCheckpoint, - ReplicationCheckpoint latestReplicationCheckpointReceived + SegmentReplicationState completedSegmentReplicationState, + SegmentReplicationState ongoingSegmentReplicationState ) { - if (latestReplicationCheckpointReceived != null) { + if (ongoingSegmentReplicationState == null || + ongoingSegmentReplicationState.getReplicationCheckpoint() == null) { + return 0; + } + + if (completedSegmentReplicationState == null || + completedSegmentReplicationState.getReplicationCheckpoint() == null) { Store.RecoveryDiff diff = Store.segmentReplicationDiff( - latestReplicationCheckpointReceived.getMetadataMap(), - indexReplicationCheckpoint.getMetadataMap() + ongoingSegmentReplicationState.getReplicationCheckpoint().getMetadataMap(), + Collections.emptyMap() ); return diff.missing.stream().mapToLong(StoreFileMetadata::length).sum(); } - return 0; + + Store.RecoveryDiff diff = Store.segmentReplicationDiff( + ongoingSegmentReplicationState.getReplicationCheckpoint().getMetadataMap(), + completedSegmentReplicationState.getReplicationCheckpoint().getMetadataMap() + ); + return diff.missing.stream().mapToLong(StoreFileMetadata::length).sum(); } - private long calculateCurrentReplicationLag(ShardId shardId) { - SegmentReplicationState ongoingEventSegmentReplicationState = targetService.getOngoingEventSegmentReplicationState(shardId); - return ongoingEventSegmentReplicationState != null ? ongoingEventSegmentReplicationState.getTimer().time() : 0; + private long getCurrentReplicationLag(SegmentReplicationState ongoingSegmentReplicationState) { + return ongoingSegmentReplicationState != null ? ongoingSegmentReplicationState.getTimer().time() : 0; } - private long getLastCompletedReplicationLag(ShardId shardId) { - SegmentReplicationState lastCompletedSegmentReplicationState = targetService.getlatestCompletedEventSegmentReplicationState( - shardId - ); - return lastCompletedSegmentReplicationState != null ? lastCompletedSegmentReplicationState.getTimer().time() : 0; + private long getLastCompletedReplicationLag(SegmentReplicationState completedSegmentReplicationState) { + return completedSegmentReplicationState != null ? completedSegmentReplicationState.getTimer().time() : 0; } } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java index 9e712b981d30a..29130b18ffc7b 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java @@ -89,7 +89,7 @@ public static Stage fromId(byte id) { private String sourceDescription; private DiscoveryNode targetNode; - private ReplicationCheckpoint latestReplicationCheckpoint; + private ReplicationCheckpoint replicationCheckpoint; public ShardRouting getShardRouting() { return shardRouting; @@ -151,8 +151,8 @@ public TimeValue getFinalizeReplicationStageTime() { return new TimeValue(time); } - public ReplicationCheckpoint getLatestReplicationCheckpoint() { - return this.latestReplicationCheckpoint; + public ReplicationCheckpoint getReplicationCheckpoint() { + return this.replicationCheckpoint; } public SegmentReplicationState( @@ -259,8 +259,8 @@ public void setStage(Stage stage) { } } - public void setLatestReplicationCheckpoint(ReplicationCheckpoint latestReplicationCheckpoint) { - this.latestReplicationCheckpoint = latestReplicationCheckpoint; + public void setReplicationCheckpoint(ReplicationCheckpoint replicationCheckpoint) { + this.replicationCheckpoint = replicationCheckpoint; } @Override diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java index bf86e316db8ec..8223f5d504700 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java @@ -177,7 +177,7 @@ public void startReplication(ActionListener listener) { source.getCheckpointMetadata(getId(), checkpoint, checkpointInfoListener); checkpointInfoListener.whenComplete(checkpointInfo -> { - state.setLatestReplicationCheckpoint(checkpointInfo.getCheckpoint()); + state.setReplicationCheckpoint(checkpointInfo.getCheckpoint()); final List filesToFetch = getFiles(checkpointInfo); state.setStage(SegmentReplicationState.Stage.GET_FILES); cancellableThreads.checkForCancel(); diff --git a/server/src/test/java/org/opensearch/action/admin/indices/replication/TransportSegmentReplicationStatsActionTests.java b/server/src/test/java/org/opensearch/action/admin/indices/replication/TransportSegmentReplicationStatsActionTests.java new file mode 100644 index 0000000000000..cbb737d92972f --- /dev/null +++ b/server/src/test/java/org/opensearch/action/admin/indices/replication/TransportSegmentReplicationStatsActionTests.java @@ -0,0 +1,495 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.admin.indices.replication; + +import org.junit.Before; +import org.opensearch.Version; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlock; +import org.opensearch.cluster.block.ClusterBlockLevel; +import org.opensearch.cluster.block.ClusterBlocks; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.routing.AllocationId; +import org.opensearch.cluster.routing.RoutingTable; +import org.opensearch.cluster.routing.ShardIterator; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.routing.ShardRoutingState; +import org.opensearch.cluster.routing.ShardsIterator; +import org.opensearch.cluster.routing.TestShardRouting; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.IndexService; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.SegmentReplicationPerGroupStats; +import org.opensearch.index.SegmentReplicationPressureService; +import org.opensearch.index.SegmentReplicationShardStats; +import org.opensearch.index.shard.IndexShard; +import org.opensearch.index.store.StoreFileMetadata; +import org.opensearch.indices.IndicesService; +import org.opensearch.indices.replication.SegmentReplicationState; +import org.opensearch.indices.replication.SegmentReplicationTargetService; +import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; +import org.opensearch.indices.replication.common.ReplicationTimer; +import org.opensearch.indices.replication.common.ReplicationType; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.TransportService; + +import java.util.EnumSet; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static java.util.Collections.EMPTY_LIST; +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TransportSegmentReplicationStatsActionTests extends OpenSearchTestCase { + + private static final String TEST_INDEX = "test-index"; + + private SegmentReplicationPerGroupStats segmentReplicationPerGroupStats; + + private IndexShard indexShard; + + private SegmentReplicationState completedSegmentReplicationState; + private SegmentReplicationState onGoingSegmentReplicationState; + + private SegmentReplicationTargetService targetService; + + private ShardId shardId; + + + TransportSegmentReplicationStatsAction action; + + private final ClusterBlock writeClusterBlock = new ClusterBlock( + 1, + "uuid", + "", + true, + true, + true, + RestStatus.OK, + EnumSet.of(ClusterBlockLevel.METADATA_WRITE) + ); + + private final ClusterBlock readClusterBlock = new ClusterBlock( + 1, + "uuid", + "", + true, + true, + true, + RestStatus.OK, + EnumSet.of(ClusterBlockLevel.METADATA_READ) + ); + + @Before + public void setUp() throws Exception { + super.setUp(); + Index index = new Index(TEST_INDEX, "_na_"); + shardId = new ShardId(TEST_INDEX, "_na_", 0); + + IndicesService indicesService = mock(IndicesService.class); + IndexService indexService = mock(IndexService.class); + + indexShard = mock(IndexShard.class); + SegmentReplicationPressureService pressureService = mock(SegmentReplicationPressureService.class); + segmentReplicationPerGroupStats = mock(SegmentReplicationPerGroupStats.class); + targetService = mock(SegmentReplicationTargetService.class); + completedSegmentReplicationState = mock(SegmentReplicationState.class); + onGoingSegmentReplicationState = mock(SegmentReplicationState.class); + ReplicationCheckpoint completedCheckpoint = mock(ReplicationCheckpoint.class); + ReplicationCheckpoint onGoingCheckpoint = mock(ReplicationCheckpoint.class); + + ReplicationTimer replicationTimerCompleted = mock(ReplicationTimer.class); + ReplicationTimer replicationTimerOngoing = mock(ReplicationTimer.class); + + Settings settings = Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 2) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 2) + .put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT) + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .build(); + IndexMetadata indexMetadata = new IndexMetadata.Builder(TEST_INDEX).settings(settings).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, Settings.EMPTY); + + when(indicesService.indexServiceSafe(index)).thenReturn(indexService); + when(indexService.getShard(shardId.id())).thenReturn(indexShard); + when(pressureService.getStatsForShard(indexShard)).thenReturn(segmentReplicationPerGroupStats); + when(indexShard.indexSettings()).thenReturn(indexSettings); + + when(completedSegmentReplicationState.getTimer()).thenReturn(replicationTimerCompleted); + when(onGoingSegmentReplicationState.getTimer()).thenReturn(replicationTimerOngoing); + when(onGoingSegmentReplicationState.getReplicationCheckpoint()).thenReturn(onGoingCheckpoint); + when(completedSegmentReplicationState.getReplicationCheckpoint()).thenReturn(completedCheckpoint); + + long segmentInfoCompleted = 5; + long segmentInfoOngoing = 9; + when(onGoingCheckpoint.getSegmentInfosVersion()).thenReturn(segmentInfoOngoing); + when(completedCheckpoint.getSegmentInfosVersion()).thenReturn(segmentInfoCompleted); + + final StoreFileMetadata segment_1 = new StoreFileMetadata("segment_1", 1L, "abcd", org.apache.lucene.util.Version.LATEST); + final StoreFileMetadata segment_2 = new StoreFileMetadata("segment_2", 50L, "abcd", org.apache.lucene.util.Version.LATEST); + + when(onGoingCheckpoint.getMetadataMap()).thenReturn(Map.of("segment_1", segment_1, "segment_2", segment_2)); + when(completedCheckpoint.getMetadataMap()).thenReturn(Map.of("segment_1", segment_1)); + + long time1 = 10; + long time2 = 15; + when(replicationTimerOngoing.time()).thenReturn(time1); + when(replicationTimerCompleted.time()).thenReturn(time2); + + action = new TransportSegmentReplicationStatsAction( + mock(ClusterService.class), + mock(TransportService.class), + indicesService, + targetService, + new ActionFilters(new HashSet<>()), + mock(IndexNameExpressionResolver.class), + pressureService + ); + } + + DiscoveryNode newNode(int nodeId) { + return new DiscoveryNode("node_" + nodeId, buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT); + } + + public void testShardReturnsAllTheShardsForTheIndex1() { + SegmentReplicationStatsRequest segmentReplicationStatsRequest = mock(SegmentReplicationStatsRequest.class); + String[] concreteIndices = new String[] { TEST_INDEX }; + ClusterState clusterState = mock(ClusterState.class); + RoutingTable routingTables = mock(RoutingTable.class); + ShardsIterator shardsIterator = mock(ShardIterator.class); + + when(clusterState.routingTable()).thenReturn(routingTables); + when(routingTables.allShardsIncludingRelocationTargets(any())).thenReturn(shardsIterator); + assertEquals(shardsIterator, action.shards(clusterState, segmentReplicationStatsRequest, concreteIndices)); + } + + public void testGlobalBlockCheck() { + ClusterBlocks.Builder builder = ClusterBlocks.builder(); + builder.addGlobalBlock(writeClusterBlock); + ClusterState metadataWriteBlockedState = ClusterState.builder(ClusterState.EMPTY_STATE).blocks(builder).build(); + assertNull(action.checkGlobalBlock(metadataWriteBlockedState, new SegmentReplicationStatsRequest())); + + builder = ClusterBlocks.builder(); + builder.addGlobalBlock(readClusterBlock); + ClusterState metadataReadBlockedState = ClusterState.builder(ClusterState.EMPTY_STATE).blocks(builder).build(); + assertNotNull(action.checkGlobalBlock(metadataReadBlockedState, new SegmentReplicationStatsRequest())); + } + + public void testIndexBlockCheck() { + String indexName = "test"; + ClusterBlocks.Builder builder = ClusterBlocks.builder(); + builder.addIndexBlock(indexName, writeClusterBlock); + ClusterState metadataWriteBlockedState = ClusterState.builder(ClusterState.EMPTY_STATE).blocks(builder).build(); + assertNull(action.checkRequestBlock(metadataWriteBlockedState, new SegmentReplicationStatsRequest(), new String[] { indexName })); + + builder = ClusterBlocks.builder(); + builder.addIndexBlock(indexName, readClusterBlock); + ClusterState metadataReadBlockedState = ClusterState.builder(ClusterState.EMPTY_STATE).blocks(builder).build(); + assertNotNull(action.checkRequestBlock(metadataReadBlockedState, new SegmentReplicationStatsRequest(), new String[] { indexName })); + } + + public void testShardOperationWhenReplicationIsNotSegRep() { + Settings settings = Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 2) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 2) + .put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.DOCUMENT) + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .build(); + IndexMetadata indexMetadata = new IndexMetadata.Builder(TEST_INDEX).settings(settings).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, Settings.EMPTY); + + when(indexShard.indexSettings()).thenReturn(indexSettings); + + final DiscoveryNode node = newNode(0); + final ShardId shardId = new ShardId(TEST_INDEX, "_na_", 0); + ShardRouting shardRouting = TestShardRouting.newShardRouting( + TEST_INDEX, + shardId.getId(), + node.getId(), + true, + ShardRoutingState.STARTED + ); + SegmentReplicationStatsRequest segmentReplicationStatsRequest = new SegmentReplicationStatsRequest(); + SegmentReplicationShardStatsResponse response = action.shardOperation(segmentReplicationStatsRequest, shardRouting); + assertNull(response); + } + + public void testShardOperationOnPrimaryShard() { + final DiscoveryNode node = newNode(0); + final ShardId shardId = new ShardId(TEST_INDEX, "_na_", 0); + ShardRouting shardRouting = TestShardRouting.newShardRouting( + TEST_INDEX, + shardId.getId(), + node.getId(), + true, + ShardRoutingState.STARTED + ); + SegmentReplicationStatsRequest segmentReplicationStatsRequest = new SegmentReplicationStatsRequest(); + SegmentReplicationShardStatsResponse response = action.shardOperation(segmentReplicationStatsRequest, shardRouting); + + assertEquals(segmentReplicationPerGroupStats, response.getPrimaryStats()); + assertNull(response.getReplicaStats()); + assertNull(response.getSegmentReplicationShardStats()); + } + + public void testShardOperationOnReplicaShard() { + when(targetService.getSegmentReplicationState(shardId)).thenReturn(completedSegmentReplicationState); + + final DiscoveryNode node = newNode(0); + final ShardId shardId = new ShardId(TEST_INDEX, "_na_", 0); + ShardRouting shardRouting = TestShardRouting.newShardRouting( + TEST_INDEX, + shardId.getId(), + node.getId(), + false, + ShardRoutingState.STARTED + ); + SegmentReplicationStatsRequest segmentReplicationStatsRequest = new SegmentReplicationStatsRequest(); + segmentReplicationStatsRequest.activeOnly(false); + SegmentReplicationShardStatsResponse response = action.shardOperation(segmentReplicationStatsRequest, shardRouting); + + assertEquals(completedSegmentReplicationState, response.getReplicaStats()); + assertNull(response.getPrimaryStats()); + assertNull(response.getSegmentReplicationShardStats()); + } + + public void testShardOperationOnReplicaShardWhenActiveOnlyIsSet() { + when(targetService.getOngoingEventSegmentReplicationState(shardId)).thenReturn(onGoingSegmentReplicationState); + + final DiscoveryNode node = newNode(0); + final ShardId shardId = new ShardId(TEST_INDEX, "_na_", 0); + ShardRouting shardRouting = TestShardRouting.newShardRouting( + TEST_INDEX, + shardId.getId(), + node.getId(), + false, + ShardRoutingState.STARTED + ); + SegmentReplicationStatsRequest segmentReplicationStatsRequest = new SegmentReplicationStatsRequest(); + segmentReplicationStatsRequest.activeOnly(true); + SegmentReplicationShardStatsResponse response = action.shardOperation(segmentReplicationStatsRequest, shardRouting); + + assertEquals(onGoingSegmentReplicationState, response.getReplicaStats()); + assertNull(response.getPrimaryStats()); + assertNull(response.getSegmentReplicationShardStats()); + } + + public void testShardOperationOnSearchReplicaWhenCompletedAndOngoingSegRepStateNotNull() { + when(targetService.getlatestCompletedEventSegmentReplicationState(shardId)).thenReturn(completedSegmentReplicationState); + when(targetService.getOngoingEventSegmentReplicationState(shardId)).thenReturn(onGoingSegmentReplicationState); + when(targetService.getSegmentReplicationState(shardId)).thenReturn(onGoingSegmentReplicationState); + + final DiscoveryNode node = newNode(0); + final ShardId shardId = new ShardId(TEST_INDEX, "_na_", 0); + + ShardRouting searchShardRouting = TestShardRouting.newShardRouting( + shardId, + node.getId(), + null, + false, + true, + ShardRoutingState.STARTED, + null + ); + + SegmentReplicationStatsRequest segmentReplicationStatsRequest = new SegmentReplicationStatsRequest(); + SegmentReplicationShardStatsResponse response = action.shardOperation(segmentReplicationStatsRequest, searchShardRouting); + + assertNull(response.getPrimaryStats()); + assertNull(response.getReplicaStats()); + assertNotNull(response.getSegmentReplicationShardStats()); + } + + public void testShardOperationOnSearchReplicaWhenCompletedSegRepStateIsNull() { + when(targetService.getOngoingEventSegmentReplicationState(shardId)).thenReturn(onGoingSegmentReplicationState); + when(targetService.getSegmentReplicationState(shardId)).thenReturn(onGoingSegmentReplicationState); + + final DiscoveryNode node = newNode(0); + final ShardId shardId = new ShardId(TEST_INDEX, "_na_", 0); + + ShardRouting searchShardRouting = TestShardRouting.newShardRouting( + shardId, + node.getId(), + null, + false, + true, + ShardRoutingState.STARTED, + null + ); + + SegmentReplicationStatsRequest segmentReplicationStatsRequest = new SegmentReplicationStatsRequest(); + SegmentReplicationShardStatsResponse response = action.shardOperation(segmentReplicationStatsRequest, searchShardRouting); + + assertNull(response.getPrimaryStats()); + assertNull(response.getReplicaStats()); + assertNotNull(response.getSegmentReplicationShardStats()); + } + + public void testShardOperationOnSearchReplicaWhenOngoingSegRepStateIsNull() { + when(targetService.getlatestCompletedEventSegmentReplicationState(shardId)).thenReturn(completedSegmentReplicationState); + when(targetService.getSegmentReplicationState(shardId)).thenReturn(completedSegmentReplicationState); + + final DiscoveryNode node = newNode(0); + final ShardId shardId = new ShardId(TEST_INDEX, "_na_", 0); + + ShardRouting searchShardRouting = TestShardRouting.newShardRouting( + shardId, + node.getId(), + null, + false, + true, + ShardRoutingState.STARTED, + null + ); + + SegmentReplicationStatsRequest segmentReplicationStatsRequest = new SegmentReplicationStatsRequest(); + SegmentReplicationShardStatsResponse response = action.shardOperation(segmentReplicationStatsRequest, searchShardRouting); + + assertNull(response.getPrimaryStats()); + assertNull(response.getReplicaStats()); + assertNotNull(response.getSegmentReplicationShardStats()); + } + + public void testShardOperationOnSearchReplicaWhenCompletedAndOngoingSegRepStateIsNull() { + final DiscoveryNode node = newNode(0); + final ShardId shardId = new ShardId(TEST_INDEX, "_na_", 0); + + ShardRouting searchShardRouting = TestShardRouting.newShardRouting( + shardId, + node.getId(), + null, + false, + true, + ShardRoutingState.STARTED, + null + ); + + SegmentReplicationStatsRequest segmentReplicationStatsRequest = new SegmentReplicationStatsRequest(); + SegmentReplicationShardStatsResponse response = action.shardOperation(segmentReplicationStatsRequest, searchShardRouting); + + assertNull(response.getPrimaryStats()); + assertNull(response.getReplicaStats()); + assertNotNull(response.getSegmentReplicationShardStats()); + } + + public void testNewResponseWhenAllReplicasReturnResponseCombinesTheResults() { + SegmentReplicationStatsRequest request = new SegmentReplicationStatsRequest(); + String[] shards = {"1", "2", "3"}; + request.shards(shards); + + int totalShards = 3; + int successfulShards = 3; + int failedShard = 0; + String allocIdOne = "allocIdOne"; + String allocIdTwo = "allocIdTwo"; + ShardId shardIdOne = mock(ShardId.class); + ShardId shardIdTwo = mock(ShardId.class); + ShardId shardIdThree = mock(ShardId.class); + ShardRouting shardRoutingOne = mock(ShardRouting.class); + ShardRouting shardRoutingTwo = mock(ShardRouting.class); + ShardRouting shardRoutingThree = mock(ShardRouting.class); + when(shardIdOne.getId()).thenReturn(1); + when(shardIdTwo.getId()).thenReturn(2); + when(shardIdThree.getId()).thenReturn(3); + when(shardRoutingOne.shardId()).thenReturn(shardIdOne); + when(shardRoutingTwo.shardId()).thenReturn(shardIdTwo); + when(shardRoutingThree.shardId()).thenReturn(shardIdThree); + AllocationId allocationId = mock(AllocationId.class); + when(allocationId.getId()).thenReturn(allocIdOne); + when(shardRoutingTwo.allocationId()).thenReturn(allocationId); + when(shardIdOne.getIndexName()).thenReturn(TEST_INDEX); + + Set segmentReplicationShardStats = new HashSet<>(); + SegmentReplicationShardStats segmentReplicationShardStatsOfReplica = new SegmentReplicationShardStats(allocIdOne, 0, 0, 0, 0, 0); + segmentReplicationShardStats.add(segmentReplicationShardStatsOfReplica); + SegmentReplicationPerGroupStats segmentReplicationPerGroupStats = new SegmentReplicationPerGroupStats(shardIdOne, segmentReplicationShardStats, 0); + + SegmentReplicationState segmentReplicationState = mock(SegmentReplicationState.class); + SegmentReplicationShardStats segmentReplicationShardStatsFromSearchReplica = mock(SegmentReplicationShardStats.class); + when(segmentReplicationShardStatsFromSearchReplica.getAllocationId()).thenReturn("alloc2"); + when(segmentReplicationState.getShardRouting()).thenReturn(shardRoutingTwo); + + List responses = List.of( + new SegmentReplicationShardStatsResponse(segmentReplicationPerGroupStats), + new SegmentReplicationShardStatsResponse(segmentReplicationState), + new SegmentReplicationShardStatsResponse(segmentReplicationShardStatsFromSearchReplica) + ); + + SegmentReplicationStatsResponse response = action.newResponse( + request, totalShards, successfulShards, failedShard, responses, EMPTY_LIST, ClusterState.EMPTY_STATE); + + List responseStats = response.getReplicationStats().get(TEST_INDEX); + SegmentReplicationPerGroupStats primStats = responseStats.get(0); + Set segRpShardStatsSet = primStats.getReplicaStats(); + + for (SegmentReplicationShardStats segRpShardStats: segRpShardStatsSet) { + if(segRpShardStats.getAllocationId().equals(allocIdOne)) { + assertEquals(segmentReplicationState, segRpShardStats.getCurrentReplicationState()); + } + + if (segRpShardStats.getAllocationId().equals(allocIdTwo)) { + assertEquals(segmentReplicationShardStatsFromSearchReplica, segRpShardStats); + } + } + } + + public void testNewResponseWhenTwoPrimaryShardsForSameIndex() { + SegmentReplicationStatsRequest request = new SegmentReplicationStatsRequest(); + String[] shards = {"1", "2"}; + request.shards(shards); + int totalShards = 3; + int successfulShards = 3; + int failedShard = 0; + + SegmentReplicationPerGroupStats segmentReplicationPerGroupStatsOne = mock(SegmentReplicationPerGroupStats.class); + SegmentReplicationPerGroupStats segmentReplicationPerGroupStatsTwo = mock(SegmentReplicationPerGroupStats.class); + + ShardId shardIdOne = mock(ShardId.class); + ShardId shardIdTwo = mock(ShardId.class); + when(segmentReplicationPerGroupStatsOne.getShardId()).thenReturn(shardIdOne); + when(segmentReplicationPerGroupStatsTwo.getShardId()).thenReturn(shardIdTwo); + when(shardIdOne.getIndexName()).thenReturn(TEST_INDEX); + when(shardIdTwo.getIndexName()).thenReturn(TEST_INDEX); + when(shardIdOne.getId()).thenReturn(1); + when(shardIdTwo.getId()).thenReturn(2); + + List responses = List.of( + new SegmentReplicationShardStatsResponse(segmentReplicationPerGroupStatsOne), + new SegmentReplicationShardStatsResponse(segmentReplicationPerGroupStatsTwo) + ); + + SegmentReplicationStatsResponse response = action.newResponse( + request, totalShards, successfulShards, failedShard, responses, EMPTY_LIST, ClusterState.EMPTY_STATE); + + List responseStats = response.getReplicationStats().get(TEST_INDEX); + + for (SegmentReplicationPerGroupStats primStat: responseStats) { + if(primStat.getShardId().equals(shardIdOne)) { + assertEquals(segmentReplicationPerGroupStatsOne, primStat); + } + + if(primStat.getShardId().equals(shardIdTwo)) { + assertEquals(segmentReplicationPerGroupStatsTwo, primStat); + } + } + } +}