Skip to content

Commit

Permalink
[WLM] Add wlm support for scroll API (opensearch-project#16981)
Browse files Browse the repository at this point in the history
* add wlm support for scroll API

Signed-off-by: Kaushal Kumar <[email protected]>

* add CHANGELOG entry

Signed-off-by: Kaushal Kumar <[email protected]>

* remove untagged tasks from WLM tracking

Signed-off-by: Kaushal Kumar <[email protected]>

* add UTs for invalid tasks

Signed-off-by: Kaushal Kumar <[email protected]>

* fix UT failures

Signed-off-by: Kaushal Kumar <[email protected]>

* rename a field in QueryGroupTask

Signed-off-by: Kaushal Kumar <[email protected]>

---------

Signed-off-by: Kaushal Kumar <[email protected]>
  • Loading branch information
kaushalmahi12 committed Jan 28, 2025
1 parent 1c7f719 commit 8251a83
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- The `phone-search` analyzer no longer emits the tel/sip prefix, international calling code, extension numbers and unformatted input as a token ([#16993](https://github.com/opensearch-project/OpenSearch/pull/16993))
- Stop processing search requests when _msearch request is cancelled ([#17005](https://github.com/opensearch-project/OpenSearch/pull/17005))
- Fix GRPC AUX_TRANSPORT_PORT and SETTING_GRPC_PORT settings and remove lingering HTTP terminology ([#17037](https://github.com/opensearch-project/OpenSearch/pull/17037))
- [WLM] Add WLM support for search scroll API ([#16981](https://github.com/opensearch-project/OpenSearch/pull/16981))
- Fix exists queries on nested flat_object fields throws exception ([#16803](https://github.com/opensearch-project/OpenSearch/pull/16803))
- Use OpenSearch version to deserialize remote custom metadata([#16494](https://github.com/opensearch-project/OpenSearch/pull/16494))

### Security
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.opensearch.wlm.QueryGroupTask;

/**
* Perform the search scroll
Expand All @@ -51,24 +53,32 @@ public class TransportSearchScrollAction extends HandledTransportAction<SearchSc
private final ClusterService clusterService;
private final SearchTransportService searchTransportService;
private final SearchPhaseController searchPhaseController;
private final ThreadPool threadPool;

@Inject
public TransportSearchScrollAction(
TransportService transportService,
ClusterService clusterService,
ActionFilters actionFilters,
SearchTransportService searchTransportService,
SearchPhaseController searchPhaseController
SearchPhaseController searchPhaseController,
ThreadPool threadPool
) {
super(SearchScrollAction.NAME, transportService, actionFilters, (Writeable.Reader<SearchScrollRequest>) SearchScrollRequest::new);
this.clusterService = clusterService;
this.searchTransportService = searchTransportService;
this.searchPhaseController = searchPhaseController;
this.threadPool = threadPool;
}

@Override
protected void doExecute(Task task, SearchScrollRequest request, ActionListener<SearchResponse> listener) {
try {

if (task instanceof QueryGroupTask) {
((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext());
}

ParsedScrollId scrollId = TransportSearchHelper.parseScrollId(request.scrollId());
Runnable action;
switch (scrollId.getType()) {
Expand Down
6 changes: 6 additions & 0 deletions server/src/main/java/org/opensearch/wlm/QueryGroupTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class QueryGroupTask extends CancellableTask {
public static final Supplier<String> DEFAULT_QUERY_GROUP_ID_SUPPLIER = () -> "DEFAULT_QUERY_GROUP";
private final LongSupplier nanoTimeSupplier;
private String queryGroupId;
private boolean isQueryGroupSet = false;

public QueryGroupTask(long id, String type, String action, String description, TaskId parentTaskId, Map<String, String> headers) {
this(id, type, action, description, parentTaskId, headers, NO_TIMEOUT, System::nanoTime);
Expand Down Expand Up @@ -81,6 +82,7 @@ public final String getQueryGroupId() {
* @param threadContext current threadContext
*/
public final void setQueryGroupId(final ThreadContext threadContext) {
isQueryGroupSet = true;
if (threadContext != null && threadContext.getHeader(QUERY_GROUP_ID_HEADER) != null) {
this.queryGroupId = threadContext.getHeader(QUERY_GROUP_ID_HEADER);
} else {
Expand All @@ -92,6 +94,10 @@ public long getElapsedTime() {
return nanoTimeSupplier.getAsLong() - getStartTimeNanos();
}

public boolean isQueryGroupSet() {
return isQueryGroupSet;
}

@Override
public boolean shouldCancelChildrenOnCancellation() {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ private Map<String, List<QueryGroupTask>> getTasksGroupedByQueryGroup() {
.stream()
.filter(QueryGroupTask.class::isInstance)
.map(QueryGroupTask.class::cast)
.filter(QueryGroupTask::isQueryGroupSet)
.collect(Collectors.groupingBy(QueryGroupTask::getQueryGroupId, Collectors.mapping(task -> task, Collectors.toList())));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.wlm.tracker;

import org.opensearch.action.search.SearchTask;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.tasks.TaskId;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.wlm.QueryGroupLevelResourceUsageView;
import org.opensearch.wlm.QueryGroupTask;

import java.util.HashMap;
import java.util.Map;

public class QueryGroupTaskResourceTrackingTests extends OpenSearchTestCase {
ThreadPool threadPool;
QueryGroupResourceUsageTrackerService queryGroupResourceUsageTrackerService;
TaskResourceTrackingService taskResourceTrackingService;

@Override
public void setUp() throws Exception {
super.setUp();
threadPool = new TestThreadPool("workload-management-tracking-thread-pool");
taskResourceTrackingService = new TaskResourceTrackingService(
Settings.EMPTY,
new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS),
threadPool
);
queryGroupResourceUsageTrackerService = new QueryGroupResourceUsageTrackerService(taskResourceTrackingService);
}

public void tearDown() throws Exception {
super.tearDown();
threadPool.shutdownNow();
}

public void testValidQueryGroupTasksCase() {
taskResourceTrackingService.setTaskResourceTrackingEnabled(true);
QueryGroupTask task = new SearchTask(1, "test", "test", () -> "Test", TaskId.EMPTY_TASK_ID, new HashMap<>());
taskResourceTrackingService.startTracking(task);

// since the query group id is not set we should not track this task
Map<String, QueryGroupLevelResourceUsageView> resourceUsageViewMap = queryGroupResourceUsageTrackerService
.constructQueryGroupLevelUsageViews();
assertTrue(resourceUsageViewMap.isEmpty());

// Now since this task has a valid queryGroupId header it should be tracked
try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) {
threadPool.getThreadContext().putHeader(QueryGroupTask.QUERY_GROUP_ID_HEADER, "testHeader");
task.setQueryGroupId(threadPool.getThreadContext());
resourceUsageViewMap = queryGroupResourceUsageTrackerService.constructQueryGroupLevelUsageViews();
assertFalse(resourceUsageViewMap.isEmpty());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ private <T extends QueryGroupTask> T createMockTask(Class<T> type, long cpuUsage
when(task.getTotalResourceUtilization(ResourceStats.MEMORY)).thenReturn(heapUsage);
when(task.getStartTimeNanos()).thenReturn((long) 0);
when(task.getElapsedTime()).thenReturn(clock.getTime());
when(task.isQueryGroupSet()).thenReturn(true);

AtomicBoolean isCancelled = new AtomicBoolean(false);
doAnswer(invocation -> {
Expand Down

0 comments on commit 8251a83

Please sign in to comment.