Skip to content

Commit

Permalink
addressed more comments + refactored few codes
Browse files Browse the repository at this point in the history
Signed-off-by: Dhrubo Saha <[email protected]>
  • Loading branch information
dhrubo-os committed Jan 18, 2025
1 parent 5bf1e00 commit e3eb699
Show file tree
Hide file tree
Showing 14 changed files with 647 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,18 @@ default ActionFuture<DeleteResponse> deleteModel(String modelId) {
* @param modelId id of the model
* @param listener action listener
*/
void deleteModel(String modelId, ActionListener<DeleteResponse> listener);
default void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
deleteModel(modelId, null, listener);
}

/**
* Delete MLModel
* For more info on delete model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#delete-model
* @param modelId id of the model
* @param tenantId the tenant id. This is necessary for multi-tenancy.
* @param listener action listener
*/
void deleteModel(String modelId, String tenantId, ActionListener<DeleteResponse> listener);

/**
* Delete the task with taskId.
Expand Down Expand Up @@ -334,19 +345,10 @@ default ActionFuture<DeleteResponse> deleteConnector(String connectorId) {
return actionFuture;
}

/**
* Delete connector for remote model
* @param connectorId The id of the connector to delete
* @return the result future
*/
default ActionFuture<DeleteResponse> deleteConnector(String connectorId, String tenantId) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteConnector(connectorId, tenantId, actionFuture);
return actionFuture;
default void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
deleteConnector(connectorId, null, listener);
}

void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener);

void deleteConnector(String connectorId, String tenantId, ActionListener<DeleteResponse> listener);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,6 @@ public void run(MLInput mlInput, Map<String, Object> args, ActionListener<MLOutp
}
}

@Override
public void getModel(String modelId, ActionListener<MLModel> listener) {
MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build();

client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, getMlGetModelResponseActionListener(listener));
}

@Override
public void getModel(String modelId, String tenantId, ActionListener<MLModel> listener) {
MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).tenantId(tenantId).build();
Expand All @@ -185,8 +178,8 @@ private ActionListener<MLModelGetResponse> getMlGetModelResponseActionListener(A
}

@Override
public void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build();
public void deleteModel(String modelId, String tenantId, ActionListener<DeleteResponse> listener) {
MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).tenantId(tenantId).build();

client.execute(MLModelDeleteAction.INSTANCE, mlModelDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
}
Expand Down Expand Up @@ -266,17 +259,6 @@ public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, Actio
client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, getMlCreateConnectorResponseActionListener(listener));
}

@Override
public void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId);
client
.execute(
MLConnectorDeleteAction.INSTANCE,
connectorDeleteRequest,
ActionListener.wrap(listener::onResponse, listener::onFailure)
);
}

@Override
public void deleteConnector(String connectorId, String tenantId, ActionListener<DeleteResponse> listener) {
MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId, tenantId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ public class MachineLearningClientTest {
@Mock
ActionListener<MLOutput> dataFrameActionListener;

@Mock
ActionListener<MLModel> mlModelActionListener;

@Mock
DeleteResponse deleteResponse;

Expand Down Expand Up @@ -176,6 +179,11 @@ public void deleteModel(String modelId, ActionListener<DeleteResponse> listener)
listener.onResponse(deleteResponse);
}

@Override
public void deleteModel(String modelId, String tenantId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
listener.onResponse(searchResponse);
Expand Down Expand Up @@ -357,6 +365,22 @@ public void getModel() {
assertEquals(mlModel, machineLearningClient.getModel("modelId").actionGet());
}

@Test
public void getModelActionListener() {
ArgumentCaptor<MLModel> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLModel.class);
machineLearningClient.getModel("modelId", mlModelActionListener);
verify(mlModelActionListener).onResponse(dataFrameArgumentCaptor.capture());
assertEquals(mlModel, dataFrameArgumentCaptor.getValue());
assertEquals(mlModel.getTenantId(), dataFrameArgumentCaptor.getValue().getTenantId());
}

@Test
public void undeploy_WithSpecificNodes() {
String[] modelIds = new String[] { "model1", "model2" };
String[] nodeIds = new String[] { "node1", "node2" };
assertEquals(undeployModelsResponse, machineLearningClient.undeploy(modelIds, nodeIds).actionGet());
}

@Test
public void deleteModel() {
assertEquals(deleteResponse, machineLearningClient.deleteModel("modelId").actionGet());
Expand All @@ -367,6 +391,11 @@ public void searchModel() {
assertEquals(searchResponse, machineLearningClient.searchModel(new SearchRequest()).actionGet());
}

@Test
public void deleteConnector_WithTenantId() {
assertEquals(deleteResponse, machineLearningClient.deleteConnector("connectorId").actionGet());
}

@Test
public void registerModelGroup() {
List<String> backendRoles = Arrays.asList("IT", "HR");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Answers.RETURNS_DEEP_STUBS;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
Expand Down Expand Up @@ -325,6 +326,64 @@ public void train() {
assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus());
}

@Test
public void getModel_withTenantId() {
String modelContent = "test content";
String tenantId = "tenantId";
doAnswer(invocation -> {
ActionListener<MLModelGetResponse> actionListener = invocation.getArgument(2);
MLModel mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("test").content(modelContent).build();
MLModelGetResponse output = MLModelGetResponse.builder().mlModel(mlModel).build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLModelGetAction.INSTANCE), any(), any());

ArgumentCaptor<MLModel> argumentCaptor = ArgumentCaptor.forClass(MLModel.class);
machineLearningNodeClient.getModel("modelId", tenantId, getModelActionListener);

verify(client).execute(eq(MLModelGetAction.INSTANCE), isA(MLModelGetRequest.class), any());
verify(getModelActionListener).onResponse(argumentCaptor.capture());
assertEquals(FunctionName.KMEANS, argumentCaptor.getValue().getAlgorithm());
assertEquals(modelContent, argumentCaptor.getValue().getContent());
}

@Test
public void undeployModels_withNullNodeIds() {
doAnswer(invocation -> {
ActionListener<MLUndeployModelsResponse> actionListener = invocation.getArgument(2);
MLUndeployModelsResponse output = new MLUndeployModelsResponse(
new MLUndeployModelNodesResponse(ClusterName.DEFAULT, Collections.emptyList(), Collections.emptyList())
);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLUndeployModelsAction.INSTANCE), any(), any());

machineLearningNodeClient.undeploy(new String[] { "model1" }, null, undeployModelsActionListener);
verify(client).execute(eq(MLUndeployModelsAction.INSTANCE), isA(MLUndeployModelsRequest.class), any());
}

@Test
public void createConnector_withValidInput() {
doAnswer(invocation -> {
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(2);
MLCreateConnectorResponse output = new MLCreateConnectorResponse("connectorId");
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLCreateConnectorAction.INSTANCE), any(), any());

MLCreateConnectorInput input = MLCreateConnectorInput
.builder()
.name("testConnector")
.protocol("http")
.version("1")
.credential(Map.of("TEST_CREDENTIAL_KEY", "TEST_CREDENTIAL_VALUE"))
.parameters(Map.of("endpoint", "https://example.com"))
.build();

machineLearningNodeClient.createConnector(input, createConnectorActionListener);
verify(client).execute(eq(MLCreateConnectorAction.INSTANCE), isA(MLCreateConnectorRequest.class), any());
}

@Test
public void registerModelGroup_withValidInput() {
doAnswer(invocation -> {
Expand All @@ -346,6 +405,146 @@ public void registerModelGroup_withValidInput() {
verify(client).execute(eq(MLRegisterModelGroupAction.INSTANCE), isA(MLRegisterModelGroupRequest.class), any());
}

@Test
public void listTools_withValidRequest() {
doAnswer(invocation -> {
ActionListener<MLToolsListResponse> actionListener = invocation.getArgument(2);
MLToolsListResponse output = MLToolsListResponse
.builder()
.toolMetadata(
Arrays
.asList(
ToolMetadata.builder().name("tool1").description("description1").build(),
ToolMetadata.builder().name("tool2").description("description2").build()
)
)
.build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLListToolsAction.INSTANCE), any(), any());

machineLearningNodeClient.listTools(listToolsActionListener);
verify(client).execute(eq(MLListToolsAction.INSTANCE), isA(MLToolsListRequest.class), any());
}

@Test
public void listTools_withEmptyResponse() {
doAnswer(invocation -> {
ActionListener<MLToolsListResponse> actionListener = invocation.getArgument(2);
MLToolsListResponse output = MLToolsListResponse.builder().toolMetadata(Collections.emptyList()).build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLListToolsAction.INSTANCE), any(), any());

ArgumentCaptor<List<ToolMetadata>> argumentCaptor = ArgumentCaptor.forClass(List.class);
machineLearningNodeClient.listTools(listToolsActionListener);

verify(client).execute(eq(MLListToolsAction.INSTANCE), isA(MLToolsListRequest.class), any());
verify(listToolsActionListener).onResponse(argumentCaptor.capture());

List<ToolMetadata> capturedTools = argumentCaptor.getValue();
assertTrue(capturedTools.isEmpty());
}

@Test
public void getTool_withValidToolName() {
doAnswer(invocation -> {
ActionListener<MLToolGetResponse> actionListener = invocation.getArgument(2);
MLToolGetResponse output = MLToolGetResponse
.builder()
.toolMetadata(ToolMetadata.builder().name("tool1").description("description1").build())
.build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLGetToolAction.INSTANCE), any(), any());

machineLearningNodeClient.getTool("tool1", getToolActionListener);
verify(client).execute(eq(MLGetToolAction.INSTANCE), isA(MLToolGetRequest.class), any());
}

@Test
public void getTool_withValidRequest() {
ToolMetadata toolMetadata = ToolMetadata
.builder()
.name("MathTool")
.description("Use this tool to calculate any math problem.")
.build();

doAnswer(invocation -> {
ActionListener<MLToolGetResponse> actionListener = invocation.getArgument(2);
MLToolGetResponse output = MLToolGetResponse.builder().toolMetadata(toolMetadata).build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLGetToolAction.INSTANCE), any(), any());

ArgumentCaptor<ToolMetadata> argumentCaptor = ArgumentCaptor.forClass(ToolMetadata.class);
machineLearningNodeClient.getTool("MathTool", getToolActionListener);

verify(client).execute(eq(MLGetToolAction.INSTANCE), isA(MLToolGetRequest.class), any());
verify(getToolActionListener).onResponse(argumentCaptor.capture());

ToolMetadata capturedTool = argumentCaptor.getValue();
assertEquals("MathTool", capturedTool.getName());
assertEquals("Use this tool to calculate any math problem.", capturedTool.getDescription());
}

@Test
public void getTool_withFailureResponse() {
doAnswer(invocation -> {
ActionListener<MLToolGetResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new RuntimeException("Test exception"));
return null;
}).when(client).execute(eq(MLGetToolAction.INSTANCE), any(), any());

machineLearningNodeClient.getTool("MathTool", new ActionListener<>() {
@Override
public void onResponse(ToolMetadata toolMetadata) {
fail("Expected failure but got response");
}

@Override
public void onFailure(Exception e) {
assertEquals("Test exception", e.getMessage());
}
});

verify(client).execute(eq(MLGetToolAction.INSTANCE), isA(MLToolGetRequest.class), any());
}

@Test
public void train_withAsync() {
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
MLTrainingOutput output = MLTrainingOutput.builder().status("InProgress").modelId("modelId").build();
actionListener.onResponse(MLTaskResponse.builder().output(output).build());
return null;
}).when(client).execute(eq(MLTrainingTaskAction.INSTANCE), any(), any());

MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
machineLearningNodeClient.train(mlInput, true, trainingActionListener);
verify(client).execute(eq(MLTrainingTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any());
}

@Test
public void deleteModel_withTenantId() {
String modelId = "testModelId";
String tenantId = "tenantId";
doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
DeleteResponse output = new DeleteResponse(shardId, modelId, 1, 1, 1, true);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLModelDeleteAction.INSTANCE), any(), any());

ArgumentCaptor<DeleteResponse> argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class);
machineLearningNodeClient.deleteModel(modelId, tenantId, deleteModelActionListener);

verify(client).execute(eq(MLModelDeleteAction.INSTANCE), isA(MLModelDeleteRequest.class), any());
verify(deleteModelActionListener).onResponse(argumentCaptor.capture());
assertEquals(modelId, argumentCaptor.getValue().getId());
}

@Test
public void train_Exception_WithNullDataSet() {
exceptionRule.expect(IllegalArgumentException.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ public MLModelDeleteRequest(StreamInput input) throws IOException {
this.modelId = input.readString();
if (streamInputVersion.onOrAfter(VERSION_2_19_0)) {
this.tenantId = input.readOptionalString();
} else {
this.tenantId = null;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,5 +162,4 @@ public static MLRegisterModelGroupInput parse(XContentParser parser) throws IOEx
}
return new MLRegisterModelGroupInput(name, description, backendRoles, modelAccessMode, isAddAllBackendRoles, tenantId);
}

}
Loading

0 comments on commit e3eb699

Please sign in to comment.