Skip to content

Commit

Permalink
Add Search model group implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Jun 30, 2024
1 parent 3e600a7 commit 09660b8
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
package org.opensearch.ml.action.model_group;

import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener;
import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL;
import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
Expand All @@ -18,10 +20,14 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.sdk.SdkClient;
import org.opensearch.sdk.SdkClientUtils;
import org.opensearch.sdk.SearchDataObjectRequest;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand All @@ -30,6 +36,7 @@
@Log4j2
public class SearchModelGroupTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
Client client;
SdkClient sdkClient;
ClusterService clusterService;

ModelAccessControlHelper modelAccessControlHelper;
Expand All @@ -39,11 +46,13 @@ public SearchModelGroupTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
SdkClient sdkClient,
ClusterService clusterService,
ModelAccessControlHelper modelAccessControlHelper
) {
super(MLModelGroupSearchAction.NAME, transportService, actionFilters, SearchRequest::new);
this.client = client;
this.sdkClient = sdkClient;
this.clusterService = clusterService;
this.modelAccessControlHelper = modelAccessControlHelper;
}
Expand All @@ -63,14 +72,36 @@ private void preProcessRoleAndPerformSearch(SearchRequest request, User user, Ac
final ActionListener<SearchResponse> doubleWrappedListener = ActionListener
.wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener));

if (modelAccessControlHelper.skipModelAccessControl(user)) {
client.search(request, doubleWrappedListener);
} else {
if (!modelAccessControlHelper.skipModelAccessControl(user)) {
// Security is enabled, filter is enabled and user isn't admin
modelAccessControlHelper.addUserBackendRolesFilter(user, request.source());
log.debug("Filtering result by " + user.getBackendRoles());
client.search(request, doubleWrappedListener);
}
SearchDataObjectRequest searchDataObjecRequest = new SearchDataObjectRequest.Builder()
.indices(request.indices())
.searchSourceBuilder(request.source())
.build();
sdkClient
.searchDataObjectAsync(searchDataObjecRequest, client.threadPool().executor(GENERAL_THREAD_POOL))
.whenComplete((r, throwable) -> {
if (throwable == null) {
try {
SearchResponse searchResponse = SearchResponse.fromXContent(r.parser());
log.info("Model group search complete: {}", searchResponse.getHits().getTotalHits());
doubleWrappedListener.onResponse(searchResponse);
} catch (Exception e) {
log.error("Failed to parse search response", e);
doubleWrappedListener
.onFailure(
new OpenSearchStatusException("Failed to parse search response", RestStatus.INTERNAL_SERVER_ERROR)
);
}
} else {
Exception e = SdkClientUtils.unwrapAndConvertToException(throwable);
log.error(e.getMessage(), e);
doubleWrappedListener.onFailure(e);
}
});
} catch (Exception e) {
log.error("Failed to search", e);
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,73 @@

package org.opensearch.ml.action.model_group;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.apache.lucene.search.TotalHits;
import org.junit.AfterClass;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponse.Clusters;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.ConfigConstants;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.sdkclient.LocalClusterIndicesClient;
import org.opensearch.sdk.SdkClient;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ScalingExecutorBuilder;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

public class SearchModelGroupTransportActionTests extends OpenSearchTestCase {

private static TestThreadPool testThreadPool = new TestThreadPool(
SearchModelGroupTransportActionTests.class.getName(),
new ScalingExecutorBuilder(
GENERAL_THREAD_POOL,
1,
Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1),
TimeValue.timeValueMinutes(1),
ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL
)
);

@Mock
Client client;

SdkClient sdkClient;

@Mock
NamedXContentRegistry namedXContentRegistry;

Expand All @@ -41,9 +81,15 @@ public class SearchModelGroupTransportActionTests extends OpenSearchTestCase {
@Mock
ActionFilters actionFilters;

@Mock
SearchRequest searchRequest;

SearchResponse searchResponse;

SearchSourceBuilder searchSourceBuilder;

@Mock
FetchSourceContext fetchSourceContext;

@Mock
ActionListener<SearchResponse> actionListener;

Expand All @@ -61,10 +107,13 @@ public class SearchModelGroupTransportActionTests extends OpenSearchTestCase {
@Before
public void setup() {
MockitoAnnotations.openMocks(this);

sdkClient = new LocalClusterIndicesClient(client, namedXContentRegistry);
searchModelGroupTransportAction = new SearchModelGroupTransportAction(
transportService,
actionFilters,
client,
sdkClient,
clusterService,
modelAccessControlHelper
);
Expand All @@ -74,21 +123,81 @@ public void setup() {
threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations");
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL));

searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.fetchSource(fetchSourceContext);
searchRequest = new SearchRequest(new String[0], searchSourceBuilder);
when(fetchSourceContext.includes()).thenReturn(new String[] {});
when(fetchSourceContext.excludes()).thenReturn(new String[] {});

SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), Float.NaN);
InternalSearchResponse internalSearchResponse = new InternalSearchResponse(
searchHits,
InternalAggregations.EMPTY,
null,
null,
false,
null,
0
);
searchResponse = new SearchResponse(
internalSearchResponse,
null,
0,
0,
0,
1,
ShardSearchFailure.EMPTY_ARRAY,
mock(Clusters.class),
null
);

PlainActionFuture<SearchResponse> future = PlainActionFuture.newFuture();
future.onResponse(searchResponse);
when(client.search(any(SearchRequest.class))).thenReturn(future);
}

@AfterClass
public static void cleanup() {
ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS);
}

public void test_DoExecute() {
public void test_DoExecute() throws InterruptedException {
when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(false);
CountDownLatch latch = new CountDownLatch(1);
LatchedActionListener<SearchResponse> latchedActionListener = new LatchedActionListener<>(actionListener, latch);
searchModelGroupTransportAction.doExecute(null, searchRequest, latchedActionListener);
latch.await(500, TimeUnit.MILLISECONDS);

verify(modelAccessControlHelper).addUserBackendRolesFilter(any(), any());
verify(client).search(any());
}

public void test_DoExecute_Exception() throws InterruptedException {
PlainActionFuture<SearchResponse> future = PlainActionFuture.newFuture();
future.onFailure(new RuntimeException("search failed"));
when(client.search(any(SearchRequest.class))).thenReturn(future);

when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(false);
searchModelGroupTransportAction.doExecute(null, searchRequest, actionListener);
CountDownLatch latch = new CountDownLatch(1);
LatchedActionListener<SearchResponse> latchedActionListener = new LatchedActionListener<>(actionListener, latch);
searchModelGroupTransportAction.doExecute(null, searchRequest, latchedActionListener);
latch.await(500, TimeUnit.MILLISECONDS);

verify(modelAccessControlHelper).addUserBackendRolesFilter(any(), any());
verify(client).search(any(), any());
verify(client).search(any());
verify(actionListener).onFailure(any(RuntimeException.class));
}

public void test_skipModelAccessControlTrue() {
public void test_skipModelAccessControlTrue() throws InterruptedException {
when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true);
searchModelGroupTransportAction.doExecute(null, searchRequest, actionListener);
CountDownLatch latch = new CountDownLatch(1);
LatchedActionListener<SearchResponse> latchedActionListener = new LatchedActionListener<>(actionListener, latch);
searchModelGroupTransportAction.doExecute(null, searchRequest, latchedActionListener);
latch.await(500, TimeUnit.MILLISECONDS);

verify(client).search(any(), any());
verify(client).search(any());
}

public void test_ThreadContextError() {
Expand Down

0 comments on commit 09660b8

Please sign in to comment.