Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WLM] add wlm support for scroll API #16981

Merged
merged 6 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix multi-value sort for unsigned long ([#16732](https://github.com/opensearch-project/OpenSearch/pull/16732))
- The `phone-search` analyzer no longer emits the tel/sip prefix, international calling code, extension numbers and unformatted input as a token ([#16993](https://github.com/opensearch-project/OpenSearch/pull/16993))
- Fix GRPC AUX_TRANSPORT_PORT and SETTING_GRPC_PORT settings and remove lingering HTTP terminology ([#17037](https://github.com/opensearch-project/OpenSearch/pull/17037))
- [WLM] Add WLM support for search scroll API ([#16981](https://github.com/opensearch-project/OpenSearch/pull/16981))
- Fix exists queries on nested flat_object fields throws exception ([#16803](https://github.com/opensearch-project/OpenSearch/pull/16803))

### Security
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.opensearch.wlm.QueryGroupTask;

/**
* Perform the search scroll
Expand All @@ -51,24 +53,32 @@ public class TransportSearchScrollAction extends HandledTransportAction<SearchSc
private final ClusterService clusterService;
private final SearchTransportService searchTransportService;
private final SearchPhaseController searchPhaseController;
private final ThreadPool threadPool;

@Inject
public TransportSearchScrollAction(
TransportService transportService,
ClusterService clusterService,
ActionFilters actionFilters,
SearchTransportService searchTransportService,
SearchPhaseController searchPhaseController
SearchPhaseController searchPhaseController,
ThreadPool threadPool
) {
super(SearchScrollAction.NAME, transportService, actionFilters, (Writeable.Reader<SearchScrollRequest>) SearchScrollRequest::new);
this.clusterService = clusterService;
this.searchTransportService = searchTransportService;
this.searchPhaseController = searchPhaseController;
this.threadPool = threadPool;
}

@Override
protected void doExecute(Task task, SearchScrollRequest request, ActionListener<SearchResponse> listener) {
try {

if (task instanceof QueryGroupTask) {
((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext());
}

ParsedScrollId scrollId = TransportSearchHelper.parseScrollId(request.scrollId());
Runnable action;
switch (scrollId.getType()) {
Expand Down
6 changes: 6 additions & 0 deletions server/src/main/java/org/opensearch/wlm/QueryGroupTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class QueryGroupTask extends CancellableTask {
public static final Supplier<String> DEFAULT_QUERY_GROUP_ID_SUPPLIER = () -> "DEFAULT_QUERY_GROUP";
private final LongSupplier nanoTimeSupplier;
private String queryGroupId;
private boolean isQueryGroupSet = false;

public QueryGroupTask(long id, String type, String action, String description, TaskId parentTaskId, Map<String, String> headers) {
this(id, type, action, description, parentTaskId, headers, NO_TIMEOUT, System::nanoTime);
Expand Down Expand Up @@ -81,6 +82,7 @@ public final String getQueryGroupId() {
* @param threadContext current threadContext
*/
public final void setQueryGroupId(final ThreadContext threadContext) {
isQueryGroupSet = true;
if (threadContext != null && threadContext.getHeader(QUERY_GROUP_ID_HEADER) != null) {
this.queryGroupId = threadContext.getHeader(QUERY_GROUP_ID_HEADER);
} else {
Expand All @@ -92,6 +94,10 @@ public long getElapsedTime() {
return nanoTimeSupplier.getAsLong() - getStartTimeNanos();
}

public boolean isQueryGroupSet() {
return isQueryGroupSet;
}

@Override
public boolean shouldCancelChildrenOnCancellation() {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ private Map<String, List<QueryGroupTask>> getTasksGroupedByQueryGroup() {
.stream()
.filter(QueryGroupTask.class::isInstance)
.map(QueryGroupTask.class::cast)
.filter(QueryGroupTask::isQueryGroupSet)
.collect(Collectors.groupingBy(QueryGroupTask::getQueryGroupId, Collectors.mapping(task -> task, Collectors.toList())));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.wlm.tracker;

import org.opensearch.action.search.SearchTask;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.tasks.TaskId;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.wlm.QueryGroupLevelResourceUsageView;
import org.opensearch.wlm.QueryGroupTask;

import java.util.HashMap;
import java.util.Map;

public class QueryGroupTaskResourceTrackingTests extends OpenSearchTestCase {
ThreadPool threadPool;
QueryGroupResourceUsageTrackerService queryGroupResourceUsageTrackerService;
TaskResourceTrackingService taskResourceTrackingService;

@Override
public void setUp() throws Exception {
super.setUp();
threadPool = new TestThreadPool("workload-management-tracking-thread-pool");
taskResourceTrackingService = new TaskResourceTrackingService(
Settings.EMPTY,
new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS),
threadPool
);
queryGroupResourceUsageTrackerService = new QueryGroupResourceUsageTrackerService(taskResourceTrackingService);
}

public void tearDown() throws Exception {
super.tearDown();
threadPool.shutdownNow();
}

public void testValidQueryGroupTasksCase() {
taskResourceTrackingService.setTaskResourceTrackingEnabled(true);
QueryGroupTask task = new SearchTask(1, "test", "test", () -> "Test", TaskId.EMPTY_TASK_ID, new HashMap<>());
taskResourceTrackingService.startTracking(task);

// since the query group id is not set we should not track this task
Map<String, QueryGroupLevelResourceUsageView> resourceUsageViewMap = queryGroupResourceUsageTrackerService
.constructQueryGroupLevelUsageViews();
assertTrue(resourceUsageViewMap.isEmpty());

// Now since this task has a valid queryGroupId header it should be tracked
try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) {
threadPool.getThreadContext().putHeader(QueryGroupTask.QUERY_GROUP_ID_HEADER, "testHeader");
task.setQueryGroupId(threadPool.getThreadContext());
resourceUsageViewMap = queryGroupResourceUsageTrackerService.constructQueryGroupLevelUsageViews();
assertFalse(resourceUsageViewMap.isEmpty());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ private <T extends QueryGroupTask> T createMockTask(Class<T> type, long cpuUsage
when(task.getTotalResourceUtilization(ResourceStats.MEMORY)).thenReturn(heapUsage);
when(task.getStartTimeNanos()).thenReturn((long) 0);
when(task.getElapsedTime()).thenReturn(clock.getTime());
when(task.isQueryGroupSet()).thenReturn(true);

AtomicBoolean isCancelled = new AtomicBoolean(false);
doAnswer(invocation -> {
Expand Down
Loading