Skip to content

Commit

Permalink
Add GetTool API and ListTools API (#1818)
Browse files Browse the repository at this point in the history
* Add GetTool API and ListTools API

Signed-off-by: Jackie Han <[email protected]>

* rename externalTools parameter as toolMetadataList

Signed-off-by: Jackie Han <[email protected]>

* spotless apply

Signed-off-by: Jackie Han <[email protected]>

* add more unit tests

Signed-off-by: Jackie Han <[email protected]>

* tweak unit test cases

Signed-off-by: Jackie Han <[email protected]>

* increase test coverage

Signed-off-by: Jackie Han <[email protected]>

* increase test coverage

Signed-off-by: Jackie Han <[email protected]>

* add more tests

Signed-off-by: Jackie Han <[email protected]>

* Include Type and Version in GetTool and ListTools API responses

Signed-off-by: Jackie Han <[email protected]>

* tweak ListTools result format

Signed-off-by: Jackie Han <[email protected]>

* change term no version found to undefined

Signed-off-by: Jackie Han <[email protected]>

---------

Signed-off-by: Jackie Han <[email protected]>
  • Loading branch information
jackiehanyang authored Jan 9, 2024
1 parent 032f4f6 commit deb51f6
Show file tree
Hide file tree
Showing 30 changed files with 1,843 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.client;

import java.util.List;
import java.util.Map;

import org.opensearch.action.delete.DeleteResponse;
Expand All @@ -17,6 +18,7 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
Expand Down Expand Up @@ -390,4 +392,40 @@ default ActionFuture<DeleteResponse> deleteAgent(String agentId) {

void deleteAgent(String agentId, ActionListener<DeleteResponse> listener);

/**
* Get a list of ToolMetadata and return ActionFuture.
* For more info on list tools, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#list-tools
* @return ActionFuture of a list of tool metadata
*/
default ActionFuture<List<ToolMetadata>> listTools() {
PlainActionFuture<List<ToolMetadata>> actionFuture = PlainActionFuture.newFuture();
listTools(actionFuture);
return actionFuture;
}

/**
* List ToolMetadata and return a list of ToolMetadata in listener
* For more info on get tools, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#list-tools
* @param listener action listener
*/
void listTools(ActionListener<List<ToolMetadata>> listener);

/**
* Get ToolMetadata and return ActionFuture.
* For more info on get tool, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#get-tool
* @return ActionFuture of tool metadata
*/
default ActionFuture<ToolMetadata> getTool(String toolName) {
PlainActionFuture<ToolMetadata> actionFuture = PlainActionFuture.newFuture();
getTool(toolName, actionFuture);
return actionFuture;
}

/**
* Get ToolMetadata and return ToolMetadata in listener
* For more info on get tool, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#get-tool
* @param listener action listener
*/
void getTool(String toolName, ActionListener<ToolMetadata> listener);

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.opensearch.ml.common.input.InputHelper.getAction;
import static org.opensearch.ml.common.input.InputHelper.getFunctionName;

import java.util.List;
import java.util.Map;
import java.util.function.Function;

Expand All @@ -26,6 +27,7 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
Expand Down Expand Up @@ -71,6 +73,12 @@
import org.opensearch.ml.common.transport.task.MLTaskGetRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
import org.opensearch.ml.common.transport.tools.MLGetToolAction;
import org.opensearch.ml.common.transport.tools.MLListToolsAction;
import org.opensearch.ml.common.transport.tools.MLToolGetRequest;
import org.opensearch.ml.common.transport.tools.MLToolGetResponse;
import org.opensearch.ml.common.transport.tools.MLToolsListRequest;
import org.opensearch.ml.common.transport.tools.MLToolsListResponse;
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
Expand Down Expand Up @@ -287,6 +295,42 @@ public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener)
}, listener::onFailure));
}

@Override
public void listTools(ActionListener<List<ToolMetadata>> listener) {
MLToolsListRequest mlToolsListRequest = MLToolsListRequest.builder().build();

client.execute(MLListToolsAction.INSTANCE, mlToolsListRequest, getMlListToolsResponseActionListener(listener));
}

@Override
public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
MLToolGetRequest mlToolGetRequest = MLToolGetRequest.builder().toolName(toolName).build();

client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, getMlGetToolResponseActionListener(listener));
}

private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener(ActionListener<List<ToolMetadata>> listener) {
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
listener.onResponse(mlModelListResponse.getToolMetadataList());
}, listener::onFailure);
ActionListener<MLToolsListResponse> actionListener = wrapActionListener(internalListener, res -> {
MLToolsListResponse getResponse = MLToolsListResponse.fromActionResponse(res);
return getResponse;
});
return actionListener;
}

private ActionListener<MLToolGetResponse> getMlGetToolResponseActionListener(ActionListener<ToolMetadata> listener) {
ActionListener<MLToolGetResponse> internalListener = ActionListener.wrap(mlModelGetResponse -> {
listener.onResponse(mlModelGetResponse.getToolMetadata());
}, listener::onFailure);
ActionListener<MLToolGetResponse> actionListener = wrapActionListener(internalListener, res -> {
MLToolGetResponse getResponse = MLToolGetResponse.fromActionResponse(res);
return getResponse;
});
return actionListener;
}

private ActionListener<MLRegisterAgentResponse> getMLRegisterAgentResponseActionListener(
ActionListener<MLRegisterAgentResponse> listener
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
Expand Down Expand Up @@ -100,6 +101,8 @@ public class MachineLearningClientTest {
private String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
private ToolMetadata toolMetadata;
private List<ToolMetadata> toolsList = new ArrayList<>();

@Before
public void setUp() {
Expand All @@ -111,6 +114,15 @@ public void setUp() {
String modelContent = "test content";
mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("test").content(modelContent).build();

toolMetadata = ToolMetadata
.builder()
.name("MathTool")
.description("Use this tool to calculate any math problem.")
.type("MathTool")
.version(null)
.build();
toolsList.add(toolMetadata);

machineLearningClient = new MachineLearningClient() {
@Override
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
Expand Down Expand Up @@ -192,6 +204,16 @@ public void deleteConnector(String connectorId, ActionListener<DeleteResponse> l
listener.onResponse(deleteResponse);
}

@Override
public void listTools(ActionListener<List<ToolMetadata>> listener) {
listener.onResponse(toolsList);
}

@Override
public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
listener.onResponse(toolMetadata);
}

public void registerModelGroup(
MLRegisterModelGroupInput mlRegisterModelGroupInput,
ActionListener<MLRegisterModelGroupResponse> listener
Expand Down Expand Up @@ -470,4 +492,14 @@ public void testRegisterAgent() {
public void deleteAgent() {
assertEquals(deleteResponse, machineLearningClient.deleteAgent("agentId").actionGet());
}

@Test
public void getTool() {
assertEquals(toolMetadata, machineLearningClient.getTool("MathTool").actionGet());
}

@Test
public void listTools() {
assertEquals(toolMetadata, machineLearningClient.listTools().actionGet().get(0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.ToolMetadata;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.MLInputDataset;
Expand Down Expand Up @@ -116,6 +117,12 @@
import org.opensearch.ml.common.transport.task.MLTaskGetRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
import org.opensearch.ml.common.transport.tools.MLGetToolAction;
import org.opensearch.ml.common.transport.tools.MLListToolsAction;
import org.opensearch.ml.common.transport.tools.MLToolGetRequest;
import org.opensearch.ml.common.transport.tools.MLToolGetResponse;
import org.opensearch.ml.common.transport.tools.MLToolsListRequest;
import org.opensearch.ml.common.transport.tools.MLToolsListResponse;
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
Expand Down Expand Up @@ -192,6 +199,12 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<DeleteResponse> deleteAgentActionListener;

@Mock
ActionListener<List<ToolMetadata>> listToolsActionListener;

@Mock
ActionListener<ToolMetadata> getToolActionListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -887,6 +900,56 @@ public void deleteAgent() {
assertEquals(agentId, (argumentCaptor.getValue()).getId());
}

@Test
public void getTool() {
ToolMetadata toolMetadata = ToolMetadata
.builder()
.name("MathTool")
.description("Use this tool to calculate any math problem.")
.build();

doAnswer(invocation -> {
ActionListener<MLToolGetResponse> actionListener = invocation.getArgument(2);
MLToolGetResponse output = MLToolGetResponse.builder().toolMetadata(toolMetadata).build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLGetToolAction.INSTANCE), any(), any());

ArgumentCaptor<ToolMetadata> argumentCaptor = ArgumentCaptor.forClass(ToolMetadata.class);
machineLearningNodeClient.getTool("MathTool", getToolActionListener);

verify(client).execute(eq(MLGetToolAction.INSTANCE), isA(MLToolGetRequest.class), any());
verify(getToolActionListener).onResponse(argumentCaptor.capture());
assertEquals("MathTool", argumentCaptor.getValue().getName());
assertEquals("Use this tool to calculate any math problem.", argumentCaptor.getValue().getDescription());
}

@Test
public void listTools() {
List<ToolMetadata> toolMetadataList = new ArrayList<>();
ToolMetadata wikipediaTool = ToolMetadata
.builder()
.name("WikipediaTool")
.description("Use this tool to search general knowledge on wikipedia.")
.build();
toolMetadataList.add(wikipediaTool);

doAnswer(invocation -> {
ActionListener<MLToolsListResponse> actionListener = invocation.getArgument(2);
MLToolsListResponse output = MLToolsListResponse.builder().toolMetadata(toolMetadataList).build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLListToolsAction.INSTANCE), any(), any());

ArgumentCaptor<List<ToolMetadata>> argumentCaptor = ArgumentCaptor.forClass(List.class);
machineLearningNodeClient.listTools(listToolsActionListener);

verify(client).execute(eq(MLListToolsAction.INSTANCE), isA(MLToolsListRequest.class), any());
verify(listToolsActionListener).onResponse(argumentCaptor.capture());
assertEquals("WikipediaTool", argumentCaptor.getValue().get(0).getName());
assertEquals("Use this tool to search general knowledge on wikipedia.", argumentCaptor.getValue().get(0).getDescription());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down
Loading

0 comments on commit deb51f6

Please sign in to comment.