Skip to content

Commit

Permalink
Fix Connector Update, Search, and Delete tenant awareness
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Aug 30, 2024
1 parent 3cace94 commit 10b79de
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ private void checkForModelsUsingConnector(String connectorId, String tenantId, A
SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
.builder()
.indices(ML_MODEL_INDEX)
.tenantId(tenantId)
.searchSourceBuilder(sourceBuilder)
.build();
sdkClient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,25 +104,20 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
return;
}
String connectorId = mlUpdateConnectorAction.getConnectorId();
String tenantId = mlCreateConnectorInput.getTenantId();
FetchSourceContext fetchSourceContext = new FetchSourceContext(true, Strings.EMPTY_ARRAY, Strings.EMPTY_ARRAY);
GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest
.builder()
.index(ML_CONNECTOR_INDEX)
.id(connectorId)
.tenantId(tenantId)
.fetchSourceContext(fetchSourceContext)
.build();

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper
.getConnector(sdkClient, client, context, getDataObjectRequest, connectorId, ActionListener.wrap(connector -> {
// context is already restored here
if (TenantAwareHelper
.validateTenantResource(
mlFeatureEnabledSetting,
mlCreateConnectorInput.getTenantId(),
connector.getTenantId(),
listener
)) {
if (TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, connector.getTenantId(), listener)) {
boolean hasPermission = connectorAccessControlHelper.validateConnectorAccess(client, connector);
if (hasPermission) {
connector.update(mlUpdateConnectorAction.getUpdateContent(), mlEngine::encrypt);
Expand All @@ -131,6 +126,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
.builder()
.index(ML_CONNECTOR_INDEX)
.id(connectorId)
.tenantId(tenantId)
.dataObject(connector)
.build();
try (ThreadContext.StoredContext innerContext = client.threadPool().getThreadContext().stashContext()) {
Expand Down Expand Up @@ -173,6 +169,7 @@ private void updateUndeployedConnector(
SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
.builder()
.indices(ML_MODEL_INDEX)
.tenantId(updateDataObjectRequest.tenantId())
.searchSourceBuilder(sourceBuilder)
.build();
sdkClient
Expand Down Expand Up @@ -211,7 +208,6 @@ private void updateUndeployedConnector(
}
} else {
Exception cause = SdkClientUtils.unwrapAndConvertToException(st);
log.error("Failed to update ML connector: " + connectorId, cause);
if (cause instanceof IndexNotFoundException) {
sdkClient
.updateDataObjectAsync(updateDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL))
Expand All @@ -220,6 +216,7 @@ private void updateUndeployedConnector(
});
return;
} else {
log.error("Failed to update ML connector: " + connectorId, cause);
listener.onFailure(cause);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import java.io.IOException;
import java.util.Map;
import java.util.Optional;

import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Response;
Expand All @@ -21,8 +20,11 @@
public class RestMLConnectorTenantAwareIT extends MLCommonsTenantAwareRestTestCase {

public void testConnectorCRUD() throws IOException, InterruptedException {
boolean multiTenancyEnabled = Optional.ofNullable(System.getProperty("multitenancy")).map("true"::equalsIgnoreCase).orElse(false);
testConnectorCRUDMultitenancyEnabled(true);
testConnectorCRUDMultitenancyEnabled(false);
}

public void testConnectorCRUDMultitenancyEnabled(boolean multiTenancyEnabled) throws IOException, InterruptedException {
enableMultiTenancy(multiTenancyEnabled);

/*
Expand Down Expand Up @@ -175,8 +177,7 @@ public void testConnectorCRUD() throws IOException, InterruptedException {
assertOK(response);
SearchResponse searchResponse = searchResponseFromResponse(response);
if (multiTenancyEnabled) {
// TODO Change to 1 when https://github.com/opensearch-project/ml-commons/pull/2803 is merged
assertEquals(2, searchResponse.getHits().getTotalHits().value);
assertEquals(1, searchResponse.getHits().getTotalHits().value);
assertEquals(tenantId, searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID));
} else {
assertEquals(2, searchResponse.getHits().getTotalHits().value);
Expand All @@ -189,10 +190,8 @@ public void testConnectorCRUD() throws IOException, InterruptedException {
assertOK(response);
searchResponse = searchResponseFromResponse(response);
if (multiTenancyEnabled) {
// TODO Change to 1 when https://github.com/opensearch-project/ml-commons/pull/2803 is merged
assertEquals(2, searchResponse.getHits().getTotalHits().value);
// TODO change [1] to [0]
assertEquals(otherTenantId, searchResponse.getHits().getHits()[1].getSourceAsMap().get(TENANT_ID));
assertEquals(1, searchResponse.getHits().getTotalHits().value);
assertEquals(otherTenantId, searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID));
} else {
assertEquals(2, searchResponse.getHits().getTotalHits().value);
assertNull(searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID));
Expand Down

0 comments on commit 10b79de

Please sign in to comment.