Skip to content

Commit

Permalink
Implement SdkClient in TransportRegisterModelAction
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Jul 2, 2024
1 parent d0c4a16 commit 2b3db56
Show file tree
Hide file tree
Showing 17 changed files with 428 additions and 163 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
package org.opensearch.sdk;

import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Map;

public class UpdateDataObjectRequest {

Expand Down Expand Up @@ -118,6 +122,20 @@ public Builder dataObject(ToXContentObject dataObject) {
this.dataObject = dataObject;
return this;
}

/**
* Add a data object as a map to this builder
* @param dataObjectMap the data object as a map of fields
* @return the updated builder
*/
public Builder dataObject(Map<String, Object> dataObjectMap) {
this.dataObject = new ToXContentObject() {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder.map(dataObjectMap);
}};
return this;
}

/**
* Builds the request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@
public class DeleteDataObjectRequestTests {
private String testIndex;
private String testId;
private String testTenantId;

@Before
public void setUp() {
testIndex = "test-index";
testId = "test-id";
testTenantId = "test-tenant-id";
}

@Test
public void testDeleteDataObjectRequest() {
DeleteDataObjectRequest request = new DeleteDataObjectRequest.Builder().index(testIndex).id(testId).build();
DeleteDataObjectRequest request = new DeleteDataObjectRequest.Builder().index(testIndex).id(testId).tenantId(testTenantId).build();

assertEquals(testIndex, request.index());
assertEquals(testId, request.id());
assertEquals(testTenantId, request.tenantId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ public class GetDataObjectRequestTests {

private String testIndex;
private String testId;
private String testTenantId;
private FetchSourceContext testFetchSourceContext;

@Before
public void setUp() {
testIndex = "test-index";
testId = "test-id";
testTenantId = "test-tenant-id";
testFetchSourceContext = mock(FetchSourceContext.class);
}

Expand All @@ -33,11 +35,13 @@ public void testGetDataObjectRequest() {
GetDataObjectRequest request = new GetDataObjectRequest.Builder()
.index(testIndex)
.id(testId)
.tenantId(testTenantId)
.fetchSourceContext(testFetchSourceContext)
.build();

assertEquals(testIndex, request.index());
assertEquals(testId, request.id());
assertEquals(testTenantId, request.tenantId());
assertEquals(testFetchSourceContext, request.fetchSourceContext());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,22 @@
public class PutDataObjectRequestTests {

private String testIndex;
private String testTenantId;
private ToXContentObject testDataObject;

@Before
public void setUp() {
testIndex = "test-index";
testTenantId = "test-tenant-id";
testDataObject = mock(ToXContentObject.class);
}

@Test
public void testPutDataObjectRequest() {
PutDataObjectRequest request = new PutDataObjectRequest.Builder().index(testIndex).dataObject(testDataObject).build();
PutDataObjectRequest request = new PutDataObjectRequest.Builder().index(testIndex).tenantId(testTenantId).dataObject(testDataObject).build();

assertEquals(testIndex, request.index());
assertEquals(testTenantId, request.tenantId());
assertEquals(testDataObject, request.dataObject());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,26 @@
public class SearchDataObjectRequestTests {

private String[] testIndices;
private String testTenantId;
private SearchSourceBuilder testSearchSourceBuilder;

@Before
public void setUp() {
testIndices = new String[] {"test-index"};
testTenantId = "test-tenant-id";
testSearchSourceBuilder = new SearchSourceBuilder();
}

@Test
public void testGetDataObjectRequest() {
SearchDataObjectRequest request = new SearchDataObjectRequest.Builder()
.indices(testIndices)
.tenantId(testTenantId)
.searchSourceBuilder(testSearchSourceBuilder)
.build();

assertArrayEquals(testIndices, request.indices());
assertEquals(testTenantId, request.tenantId());
assertEquals(testSearchSourceBuilder, request.searchSourceBuilder());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.sdk;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.ToXContentObject;

import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;

public class UpdateDataObjectRequestTests {

private String testIndex;
private String testId;
private String testTenantId;
private ToXContentObject testDataObject;
private Map<String, Object> testDataObjectMap;

@Before
public void setUp() {
testIndex = "test-index";
testId = "test-id";
testTenantId = "test-tenant-id";
testDataObject = mock(ToXContentObject.class);
testDataObjectMap = Map.of("foo", "bar");
}

@Test
public void testUpdateDataObjectRequest() {
UpdateDataObjectRequest request = new UpdateDataObjectRequest.Builder().index(testIndex).id(testId).tenantId(testTenantId).dataObject(testDataObject).build();

assertEquals(testIndex, request.index());
assertEquals(testId, request.id());
assertEquals(testTenantId, request.tenantId());
assertEquals(testDataObject, request.dataObject());
}

@Test
public void testUpdateDataObjectMapRequest() {
UpdateDataObjectRequest request = new UpdateDataObjectRequest.Builder().index(testIndex).id(testId).tenantId(testTenantId).dataObject(testDataObjectMap).build();

assertEquals(testIndex, request.index());
assertEquals(testId, request.id());
assertEquals(testTenantId, request.tenantId());
assertEquals(testDataObjectMap, XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(XContentType.JSON, request.dataObject()), false));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ private void updateModelWithRegisteringToAnotherModelGroup(
modelAccessControlHelper
.validateModelGroupAccess(user, newModelGroupId, client, ActionListener.wrap(hasNewModelGroupPermission -> {
if (hasNewModelGroupPermission) {
mlModelGroupManager.getModelGroupResponse(newModelGroupId, ActionListener.wrap(newModelGroupResponse -> {
mlModelGroupManager.getModelGroupResponse(sdkClient, newModelGroupId, ActionListener.wrap(newModelGroupResponse -> {
buildUpdateRequest(
modelId,
newModelGroupId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ private void checkUserAccess(
) {
User user = RestActionUtils.getUserContext(client);
modelAccessControlHelper
.validateModelGroupAccess(user, registerModelInput.getModelGroupId(), client, ActionListener.wrap(access -> {
.validateModelGroupAccess(user, registerModelInput.getModelGroupId(), client, sdkClient, ActionListener.wrap(access -> {
if (access) {
doRegister(registerModelInput, listener);
return;
Expand Down Expand Up @@ -351,7 +351,7 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
mlTask.setTaskId(taskId);
mlModelManager.registerMLRemoteModel(registerModelInput, mlTask, listener);
mlModelManager.registerMLRemoteModel(sdkClient, registerModelInput, mlTask, listener);
}, e -> {
logException("Failed to register model", e, log);
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import java.util.HashSet;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.search.SearchRequest;
Expand Down Expand Up @@ -257,25 +256,6 @@ public void validateUniqueModelGroupName(String name, ActionListener<SearchRespo
}
}

// TODO Remove when all calls migrated to SDKClient version
/**
* Get model group from model group index.
*
* @param modelGroupId model group id
* @param listener action listener
*/
public void getModelGroupResponse(String modelGroupId, ActionListener<GetResponse> listener) {
GetRequest getRequest = new GetRequest();
getRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupId);
client.get(getRequest, ActionListener.wrap(r -> {
if (r != null && r.isExists()) {
listener.onResponse(r);
} else {
listener.onFailure(new MLResourceNotFoundException("Failed to find model group with ID: " + modelGroupId));
}
}, listener::onFailure));
}

/**
* Get model group from model group index.
*
Expand Down
Loading

0 comments on commit 2b3db56

Please sign in to comment.