From 3eb07e7934d25fa94c4fe05a055ebad6709dd95a Mon Sep 17 00:00:00 2001 From: Kaushal Kumar Date: Wed, 24 Jul 2024 12:17:37 -0700 Subject: [PATCH] add QueryGroupTask tests Signed-off-by: Kaushal Kumar --- .../action/search/SearchShardTask.java | 1 - .../opensearch/action/search/SearchTask.java | 1 - .../action/search/TransportSearchAction.java | 6 +- .../main/java/org/opensearch/node/Node.java | 11 +++- .../org/opensearch/wlm/QueryGroupTask.java | 19 ++++-- .../wlm/SearchWorkloadTransportHandler.java | 46 ------------- .../SearchWorkloadTransportInterceptor.java | 37 ----------- ...orkloadManagementTransportInterceptor.java | 64 +++++++++++++++++++ .../opensearch/wlm/QueryGroupTaskTests.java | 45 +++++++++++++ ...kloadManagementTransportHandlerTests.java} | 31 ++++----- ...dManagementTransportInterceptorTests.java} | 9 +-- 11 files changed, 152 insertions(+), 118 deletions(-) delete mode 100644 server/src/main/java/org/opensearch/wlm/SearchWorkloadTransportHandler.java delete mode 100644 server/src/main/java/org/opensearch/wlm/SearchWorkloadTransportInterceptor.java create mode 100644 server/src/main/java/org/opensearch/wlm/WorkloadManagementTransportInterceptor.java create mode 100644 server/src/test/java/org/opensearch/wlm/QueryGroupTaskTests.java rename server/src/test/java/org/opensearch/wlm/{SearchWorkloadTransportHandlerTests.java => WorkloadManagementTransportHandlerTests.java} (69%) rename server/src/test/java/org/opensearch/wlm/{SearchWorkloadTransportInterceptorTests.java => WorkloadManagementTransportInterceptorTests.java} (70%) diff --git a/server/src/main/java/org/opensearch/action/search/SearchShardTask.java b/server/src/main/java/org/opensearch/action/search/SearchShardTask.java index 183d7155069d3..ed2943db94420 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchShardTask.java +++ b/server/src/main/java/org/opensearch/action/search/SearchShardTask.java @@ -37,7 +37,6 @@ import org.opensearch.core.tasks.TaskId; import org.opensearch.search.fetch.ShardFetchSearchRequest; import org.opensearch.search.internal.ShardSearchRequest; -import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.SearchBackpressureTask; import org.opensearch.wlm.QueryGroupTask; diff --git a/server/src/main/java/org/opensearch/action/search/SearchTask.java b/server/src/main/java/org/opensearch/action/search/SearchTask.java index 9b0723a7391c7..2a1a961e7607b 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchTask.java +++ b/server/src/main/java/org/opensearch/action/search/SearchTask.java @@ -35,7 +35,6 @@ import org.opensearch.common.annotation.PublicApi; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.tasks.TaskId; -import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.SearchBackpressureTask; import org.opensearch.wlm.QueryGroupTask; diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 6a241beaf9041..88bf7ebea8e52 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -101,7 +101,6 @@ import org.opensearch.transport.RemoteTransportException; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportService; -import org.opensearch.wlm.QueryGroupConstants; import org.opensearch.wlm.QueryGroupTask; import java.util.ArrayList; @@ -446,8 +445,9 @@ private void executeRequest( // At this point either the QUERY_GROUP_ID header will be present in ThreadContext either via ActionFilter // or HTTP header (HTTP header will be deprecated once ActionFilter is implemented) - - ((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext()); + if (task instanceof QueryGroupTask) { + ((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext()); + } PipelinedRequest searchRequest; ActionListener listener; diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 99d59766d787d..8684b1b383cab 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -263,7 +263,7 @@ import org.opensearch.transport.TransportService; import org.opensearch.usage.UsageService; import org.opensearch.watcher.ResourceWatcherService; -import org.opensearch.wlm.SearchWorkloadTransportInterceptor; +import org.opensearch.wlm.WorkloadManagementTransportInterceptor; import javax.net.ssl.SNIHostName; @@ -1048,7 +1048,9 @@ protected Node( admissionControlService ); - SearchWorkloadTransportInterceptor searchWorkloadTransportInterceptor = new SearchWorkloadTransportInterceptor(threadPool); + WorkloadManagementTransportInterceptor workloadManagementTransportInterceptor = new WorkloadManagementTransportInterceptor( + threadPool + ); final Collection secureSettingsFactories = pluginsService.filterPlugins(Plugin.class) .stream() @@ -1057,7 +1059,10 @@ protected Node( .map(Optional::get) .collect(Collectors.toList()); - List transportInterceptors = List.of(admissionControlTransportInterceptor, searchWorkloadTransportInterceptor); + List transportInterceptors = List.of( + admissionControlTransportInterceptor, + workloadManagementTransportInterceptor + ); final NetworkModule networkModule = new NetworkModule( settings, pluginsService.filterPlugins(NetworkPlugin.class), diff --git a/server/src/main/java/org/opensearch/wlm/QueryGroupTask.java b/server/src/main/java/org/opensearch/wlm/QueryGroupTask.java index b4b4e9057bf5e..ae0a0a61f4388 100644 --- a/server/src/main/java/org/opensearch/wlm/QueryGroupTask.java +++ b/server/src/main/java/org/opensearch/wlm/QueryGroupTask.java @@ -13,7 +13,6 @@ import org.opensearch.core.tasks.TaskId; import org.opensearch.tasks.CancellableTask; - import java.util.Map; import static org.opensearch.search.SearchService.NO_TIMEOUT; @@ -29,15 +28,26 @@ public QueryGroupTask(long id, String type, String action, String description, T this(id, type, action, description, parentTaskId, headers, NO_TIMEOUT); } - public QueryGroupTask(long id, String type, String action, String description, TaskId parentTaskId, Map headers, TimeValue cancelAfterTimeInterval) { + public QueryGroupTask( + long id, + String type, + String action, + String description, + TaskId parentTaskId, + Map headers, + TimeValue cancelAfterTimeInterval + ) { super(id, type, action, description, parentTaskId, headers, cancelAfterTimeInterval); } /** - * + * This method should always be called after calling setQueryGroupId at least once on this object * @return task queryGroupId */ public String getQueryGroupId() { + if (queryGroupId == null) { + throw new IllegalStateException("queryGroupId is not set, queryGroup has to be set for the object"); + } return queryGroupId; } @@ -49,8 +59,7 @@ public String getQueryGroupId() { public void setQueryGroupId(final ThreadContext threadContext) { this.queryGroupId = QueryGroupConstants.DEFAULT_QUERY_GROUP_ID_SUPPLIER.get(); - if (threadContext != null - && threadContext.getHeader(QueryGroupConstants.QUERY_GROUP_ID_HEADER) != null) { + if (threadContext != null && threadContext.getHeader(QueryGroupConstants.QUERY_GROUP_ID_HEADER) != null) { this.queryGroupId = threadContext.getHeader(QueryGroupConstants.QUERY_GROUP_ID_HEADER); } } diff --git a/server/src/main/java/org/opensearch/wlm/SearchWorkloadTransportHandler.java b/server/src/main/java/org/opensearch/wlm/SearchWorkloadTransportHandler.java deleted file mode 100644 index b603519398567..0000000000000 --- a/server/src/main/java/org/opensearch/wlm/SearchWorkloadTransportHandler.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.wlm; - -import org.opensearch.search.fetch.ShardFetchRequest; -import org.opensearch.search.internal.InternalScrollSearchRequest; -import org.opensearch.search.internal.ShardSearchRequest; -import org.opensearch.search.query.QuerySearchRequest; -import org.opensearch.tasks.Task; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportChannel; -import org.opensearch.transport.TransportRequest; -import org.opensearch.transport.TransportRequestHandler; - -/** - * This class is mainly used to populate the queryGroupId header - * @param T is Search related request - */ -public class SearchWorkloadTransportHandler implements TransportRequestHandler { - - private final ThreadPool threadPool; - TransportRequestHandler actualHandler; - - public SearchWorkloadTransportHandler(ThreadPool threadPool, TransportRequestHandler actualHandler) { - this.threadPool = threadPool; - this.actualHandler = actualHandler; - } - - @Override - public void messageReceived(T request, TransportChannel channel, Task task) throws Exception { - if (isSearchWorkloadRequest(task)) { - ((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext()); - } - actualHandler.messageReceived(request, channel, task); - } - - private boolean isSearchWorkloadRequest(Task task) { - return task instanceof QueryGroupTask; - } -} diff --git a/server/src/main/java/org/opensearch/wlm/SearchWorkloadTransportInterceptor.java b/server/src/main/java/org/opensearch/wlm/SearchWorkloadTransportInterceptor.java deleted file mode 100644 index 2583158a98113..0000000000000 --- a/server/src/main/java/org/opensearch/wlm/SearchWorkloadTransportInterceptor.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.wlm; - -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportInterceptor; -import org.opensearch.transport.TransportRequest; -import org.opensearch.transport.TransportRequestHandler; - -/** - * This class is used to intercept search traffic requests and populate the queryGroupId header in task headers - * TODO: We still need to add this interceptor in {@link org.opensearch.node.Node} class to enable, - * leaving it until the feature is tested and done. - */ -public class SearchWorkloadTransportInterceptor implements TransportInterceptor { - private final ThreadPool threadPool; - - public SearchWorkloadTransportInterceptor(ThreadPool threadPool) { - this.threadPool = threadPool; - } - - @Override - public TransportRequestHandler interceptHandler( - String action, - String executor, - boolean forceExecution, - TransportRequestHandler actualHandler - ) { - return new SearchWorkloadTransportHandler(threadPool, actualHandler); - } -} diff --git a/server/src/main/java/org/opensearch/wlm/WorkloadManagementTransportInterceptor.java b/server/src/main/java/org/opensearch/wlm/WorkloadManagementTransportInterceptor.java new file mode 100644 index 0000000000000..ef97f81cd7c7b --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/WorkloadManagementTransportInterceptor.java @@ -0,0 +1,64 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportChannel; +import org.opensearch.transport.TransportInterceptor; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestHandler; + +/** + * This class is used to intercept search traffic requests and populate the queryGroupId header in task headers + */ +public class WorkloadManagementTransportInterceptor implements TransportInterceptor { + private final ThreadPool threadPool; + + public WorkloadManagementTransportInterceptor(ThreadPool threadPool) { + this.threadPool = threadPool; + } + + @Override + public TransportRequestHandler interceptHandler( + String action, + String executor, + boolean forceExecution, + TransportRequestHandler actualHandler + ) { + return new WorkloadManagementTransportHandler(threadPool, actualHandler); + } + + /** + * This class is mainly used to populate the queryGroupId header + * @param T is Search related request + */ + public static class WorkloadManagementTransportHandler implements TransportRequestHandler { + + private final ThreadPool threadPool; + TransportRequestHandler actualHandler; + + public WorkloadManagementTransportHandler(ThreadPool threadPool, TransportRequestHandler actualHandler) { + this.threadPool = threadPool; + this.actualHandler = actualHandler; + } + + @Override + public void messageReceived(T request, TransportChannel channel, Task task) throws Exception { + if (isSearchWorkloadRequest(task)) { + ((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext()); + } + actualHandler.messageReceived(request, channel, task); + } + + boolean isSearchWorkloadRequest(Task task) { + return task instanceof QueryGroupTask; + } + } +} diff --git a/server/src/test/java/org/opensearch/wlm/QueryGroupTaskTests.java b/server/src/test/java/org/opensearch/wlm/QueryGroupTaskTests.java new file mode 100644 index 0000000000000..9d4907153cb80 --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/QueryGroupTaskTests.java @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import java.util.Collections; + +public class QueryGroupTaskTests extends OpenSearchTestCase { + private ThreadPool threadPool; + private QueryGroupTask sut; + + public void setUp() throws Exception { + super.setUp(); + threadPool = new TestThreadPool(getTestName()); + sut = new QueryGroupTask(123, "transport", "Search", "test task", null, Collections.emptyMap()); + } + + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + } + + public void testSuccessfulSetQueryGroupId() { + sut.setQueryGroupId(threadPool.getThreadContext()); + assertEquals(QueryGroupConstants.DEFAULT_QUERY_GROUP_ID_SUPPLIER.get(), sut.getQueryGroupId()); + + threadPool.getThreadContext().putHeader(QueryGroupConstants.QUERY_GROUP_ID_HEADER, "akfanglkaglknag2332"); + + sut.setQueryGroupId(threadPool.getThreadContext()); + assertEquals("akfanglkaglknag2332", sut.getQueryGroupId()); + } + + public void testUnsuccessfulSetGroupId() { + assertThrows(IllegalStateException.class, () -> sut.getQueryGroupId()); + } +} diff --git a/server/src/test/java/org/opensearch/wlm/SearchWorkloadTransportHandlerTests.java b/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportHandlerTests.java similarity index 69% rename from server/src/test/java/org/opensearch/wlm/SearchWorkloadTransportHandlerTests.java rename to server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportHandlerTests.java index 9e3cf020ccd61..c8edc8e199b65 100644 --- a/server/src/test/java/org/opensearch/wlm/SearchWorkloadTransportHandlerTests.java +++ b/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportHandlerTests.java @@ -17,27 +17,27 @@ import org.opensearch.transport.TransportChannel; import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportRequestHandler; +import org.opensearch.wlm.WorkloadManagementTransportInterceptor.WorkloadManagementTransportHandler; import java.util.Collections; -import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -public class SearchWorkloadTransportHandlerTests extends OpenSearchTestCase { - private SearchWorkloadTransportHandler sut; +public class WorkloadManagementTransportHandlerTests extends OpenSearchTestCase { + private WorkloadManagementTransportHandler sut; private ThreadPool threadPool; - private TransportRequestHandler actualHandler; + private TestTransportRequestHandler actualHandler; public void setUp() throws Exception { super.setUp(); threadPool = new TestThreadPool(getTestName()); actualHandler = new TestTransportRequestHandler<>(); - sut = new SearchWorkloadTransportHandler<>(threadPool, actualHandler); + sut = new WorkloadManagementTransportHandler<>(threadPool, actualHandler); } public void tearDown() throws Exception { @@ -47,27 +47,23 @@ public void tearDown() throws Exception { public void testMessageReceivedForSearchWorkload() throws Exception { ShardSearchRequest request = mock(ShardSearchRequest.class); - Task spyTask = getSpyTask(); + QueryGroupTask spyTask = getSpyTask(); sut.messageReceived(request, mock(TransportChannel.class), spyTask); - verify(spyTask, times(1)).addHeader( - QueryGroupConstants.QUERY_GROUP_ID_HEADER, - threadPool.getThreadContext(), - QueryGroupConstants.DEFAULT_QUERY_GROUP_ID_SUPPLIER - ); + verify(spyTask, times(1)).setQueryGroupId(threadPool.getThreadContext()); } public void testMessageReceivedForNonSearchWorkload() throws Exception { IndexRequest indexRequest = mock(IndexRequest.class); - Task spyTask = getSpyTask(); - sut.messageReceived(indexRequest, mock(TransportChannel.class), spyTask); - - verify(spyTask, times(0)).addHeader(any(), any(), any()); + Task task = mock(Task.class); + sut.messageReceived(indexRequest, mock(TransportChannel.class), task); + assertFalse(sut.isSearchWorkloadRequest(task)); + assertEquals(1, actualHandler.invokeCount); } - private static Task getSpyTask() { - final Task task = new Task(123, "transport", "Search", "test task", null, Collections.emptyMap()); + private static QueryGroupTask getSpyTask() { + final QueryGroupTask task = new QueryGroupTask(123, "transport", "Search", "test task", null, Collections.emptyMap()); return spy(task); } @@ -79,6 +75,5 @@ private static class TestTransportRequestHandler imp public void messageReceived(TransportRequest request, TransportChannel channel, Task task) throws Exception { invokeCount += 1; } - }; } diff --git a/server/src/test/java/org/opensearch/wlm/SearchWorkloadTransportInterceptorTests.java b/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportInterceptorTests.java similarity index 70% rename from server/src/test/java/org/opensearch/wlm/SearchWorkloadTransportInterceptorTests.java rename to server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportInterceptorTests.java index 0dbb3e9f88b4b..22d6d839ce8f6 100644 --- a/server/src/test/java/org/opensearch/wlm/SearchWorkloadTransportInterceptorTests.java +++ b/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportInterceptorTests.java @@ -13,17 +13,18 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportRequestHandler; +import org.opensearch.wlm.WorkloadManagementTransportInterceptor.WorkloadManagementTransportHandler; import static org.opensearch.threadpool.ThreadPool.Names.SAME; -public class SearchWorkloadTransportInterceptorTests extends OpenSearchTestCase { +public class WorkloadManagementTransportInterceptorTests extends OpenSearchTestCase { private ThreadPool threadPool; - private SearchWorkloadTransportInterceptor sut; + private WorkloadManagementTransportInterceptor sut; public void setUp() throws Exception { threadPool = new TestThreadPool(getTestName()); - sut = new SearchWorkloadTransportInterceptor(threadPool); + sut = new WorkloadManagementTransportInterceptor(threadPool); } public void tearDown() throws Exception { @@ -32,6 +33,6 @@ public void tearDown() throws Exception { public void testInterceptHandler() { TransportRequestHandler requestHandler = sut.interceptHandler("Search", SAME, false, null); - assertTrue(requestHandler instanceof SearchWorkloadTransportHandler); + assertTrue(requestHandler instanceof WorkloadManagementTransportHandler); } }