Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jun 6, 2024
1 parent 126b3f2 commit d02f38a
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ public void testConnectorTool_NullConnectorId() {
}).when(client).execute(eq(MLExecuteConnectorAction.INSTANCE), any(), any());

Exception exception = assertThrows(
IllegalArgumentException.class,
() -> ConnectorTool.Factory.getInstance().create(Map.of("connector_action", "execute"))
IllegalArgumentException.class,
() -> ConnectorTool.Factory.getInstance().create(Map.of("connector_action", "execute"))
);
MatcherAssert.assertThat(exception.getMessage(), containsString("connector_id can't be null"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ public class ExecuteConnectorTransportAction extends HandledTransportAction<Acti

@Inject
public ExecuteConnectorTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
ClusterService clusterService,
ScriptService scriptService,
NamedXContentRegistry xContentRegistry,
ConnectorAccessControlHelper connectorAccessControlHelper,
EncryptorImpl encryptor
TransportService transportService,
ActionFilters actionFilters,
Client client,
ClusterService clusterService,
ScriptService scriptService,
NamedXContentRegistry xContentRegistry,
ConnectorAccessControlHelper connectorAccessControlHelper,
EncryptorImpl encryptor
) {
super(MLExecuteConnectorAction.NAME, transportService, actionFilters, MLConnectorDeleteRequest::new);
this.client = client;
Expand All @@ -74,15 +74,15 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) {
connector.decrypt(connectorAction, (credential) -> encryptor.decrypt(credential));
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader
.initInstance(connector.getProtocol(), connector, Connector.class);
.initInstance(connector.getProtocol(), connector, Connector.class);
connectorExecutor.setScriptService(scriptService);
connectorExecutor.setClusterService(clusterService);
connectorExecutor.setClient(client);
connectorExecutor.setXContentRegistry(xContentRegistry);
connectorExecutor
.executeAction(connectorAction, executeConnectorRequest.getMlInput(), ActionListener.wrap(taskResponse -> {
actionListener.onResponse(taskResponse);
}, e -> { actionListener.onFailure(e); }));
.executeAction(connectorAction, executeConnectorRequest.getMlInput(), ActionListener.wrap(taskResponse -> {
actionListener.onResponse(taskResponse);
}, e -> { actionListener.onFailure(e); }));
}
}, e -> {
log.error("Failed to get connector " + connectorId, e);
Expand All @@ -96,4 +96,4 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,16 @@ public void setup() {
MockitoAnnotations.openMocks(this);

ClusterState testState = new ClusterState(
new ClusterName("clusterName"),
123l,
"111111",
metaData,
null,
null,
null,
Map.of(),
0,
false
new ClusterName("clusterName"),
123l,
"111111",
metaData,
null,
null,
null,
Map.of(),
0,
false
);
when(clusterService.state()).thenReturn(testState);

Expand All @@ -103,14 +103,14 @@ public void setup() {
when(threadPool.getThreadContext()).thenReturn(threadContext);

action = new ExecuteConnectorTransportAction(
transportService,
actionFilters,
client,
clusterService,
scriptService,
xContentRegistry,
connectorAccessControlHelper,
encryptor
transportService,
actionFilters,
client,
clusterService,
scriptService,
xContentRegistry,
connectorAccessControlHelper,
encryptor
);
}

Expand Down
144 changes: 72 additions & 72 deletions plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,44 +27,44 @@ public class RestConnectorToolIT extends RestBaseAgentToolsIT {
public void setUp() throws Exception {
super.setUp();
String bedrockClaudeConnectorEntity = "{\n"
+ " \"name\": \"BedRock Claude instant-v1 Connector \",\n"
+ " \"description\": \"The connector to BedRock service for claude model\",\n"
+ " \"version\": 1,\n"
+ " \"protocol\": \"aws_sigv4\",\n"
+ " \"parameters\": {\n"
+ " \"region\": \""
+ GITHUB_CI_AWS_REGION
+ "\",\n"
+ " \"service_name\": \"bedrock\",\n"
+ " \"anthropic_version\": \"bedrock-2023-05-31\",\n"
+ " \"max_tokens_to_sample\": 8000,\n"
+ " \"temperature\": 0.0001,\n"
+ " \"response_filter\": \"$.completion\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"access_key\": \""
+ AWS_ACCESS_KEY_ID
+ "\",\n"
+ " \"secret_key\": \""
+ AWS_SECRET_ACCESS_KEY
+ "\",\n"
+ " \"session_token\": \""
+ AWS_SESSION_TOKEN
+ "\"\n"
+ " },\n"
+ " \"actions\": [\n"
+ " {\n"
+ " \"action_type\": \"execute\",\n"
+ " \"method\": \"POST\",\n"
+ " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/anthropic.claude-instant-v1/invoke\",\n"
+ " \"headers\": {\n"
+ " \"content-type\": \"application/json\",\n"
+ " \"x-amz-content-sha256\": \"required\"\n"
+ " },\n"
+ " \"request_body\": \"{\\\"prompt\\\":\\\"\\\\n\\\\nHuman:${parameters.question}\\\\n\\\\nAssistant:\\\", \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n"
+ " }\n"
+ " ]\n"
+ "}";
+ " \"name\": \"BedRock Claude instant-v1 Connector \",\n"
+ " \"description\": \"The connector to BedRock service for claude model\",\n"
+ " \"version\": 1,\n"
+ " \"protocol\": \"aws_sigv4\",\n"
+ " \"parameters\": {\n"
+ " \"region\": \""
+ GITHUB_CI_AWS_REGION
+ "\",\n"
+ " \"service_name\": \"bedrock\",\n"
+ " \"anthropic_version\": \"bedrock-2023-05-31\",\n"
+ " \"max_tokens_to_sample\": 8000,\n"
+ " \"temperature\": 0.0001,\n"
+ " \"response_filter\": \"$.completion\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"access_key\": \""
+ AWS_ACCESS_KEY_ID
+ "\",\n"
+ " \"secret_key\": \""
+ AWS_SECRET_ACCESS_KEY
+ "\",\n"
+ " \"session_token\": \""
+ AWS_SESSION_TOKEN
+ "\"\n"
+ " },\n"
+ " \"actions\": [\n"
+ " {\n"
+ " \"action_type\": \"execute\",\n"
+ " \"method\": \"POST\",\n"
+ " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/anthropic.claude-instant-v1/invoke\",\n"
+ " \"headers\": {\n"
+ " \"content-type\": \"application/json\",\n"
+ " \"x-amz-content-sha256\": \"required\"\n"
+ " },\n"
+ " \"request_body\": \"{\\\"prompt\\\":\\\"\\\\n\\\\nHuman:${parameters.question}\\\\n\\\\nAssistant:\\\", \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n"
+ " }\n"
+ " ]\n"
+ "}";
this.bedrockClaudeConnectorId = registerConnector(bedrockClaudeConnectorEntity);
}

Expand All @@ -76,23 +76,23 @@ public void tearDown() throws Exception {

public void testConnectorToolInFlowAgent_WrongAction() throws IOException, ParseException {
String registerAgentRequestBody = "{\n"
+ " \"name\": \"Test agent with connector tool\",\n"
+ " \"type\": \"flow\",\n"
+ " \"description\": \"This is a demo agent for connector tool\",\n"
+ " \"app_type\": \"test1\",\n"
+ " \"tools\": [\n"
+ " {\n"
+ " \"type\": \"ConnectorTool\",\n"
+ " \"name\": \"bedrock_model\",\n"
+ " \"parameters\": {\n"
+ " \"connector_id\": \""
+ bedrockClaudeConnectorId
+ "\",\n"
+ " \"connector_action\": \"predict\"\n"
+ " }\n"
+ " }\n"
+ " ]\n"
+ "}";
+ " \"name\": \"Test agent with connector tool\",\n"
+ " \"type\": \"flow\",\n"
+ " \"description\": \"This is a demo agent for connector tool\",\n"
+ " \"app_type\": \"test1\",\n"
+ " \"tools\": [\n"
+ " {\n"
+ " \"type\": \"ConnectorTool\",\n"
+ " \"name\": \"bedrock_model\",\n"
+ " \"parameters\": {\n"
+ " \"connector_id\": \""
+ bedrockClaudeConnectorId
+ "\",\n"
+ " \"connector_action\": \"predict\"\n"
+ " }\n"
+ " }\n"
+ " ]\n"
+ "}";
String agentId = createAgent(registerAgentRequestBody);
String agentInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}";
Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, agentInput));
Expand All @@ -101,23 +101,23 @@ public void testConnectorToolInFlowAgent_WrongAction() throws IOException, Parse

public void testConnectorToolInFlowAgent() throws IOException, ParseException {
String registerAgentRequestBody = "{\n"
+ " \"name\": \"Test agent with connector tool\",\n"
+ " \"type\": \"flow\",\n"
+ " \"description\": \"This is a demo agent for connector tool\",\n"
+ " \"app_type\": \"test1\",\n"
+ " \"tools\": [\n"
+ " {\n"
+ " \"type\": \"ConnectorTool\",\n"
+ " \"name\": \"bedrock_model\",\n"
+ " \"parameters\": {\n"
+ " \"connector_id\": \""
+ bedrockClaudeConnectorId
+ "\",\n"
+ " \"connector_action\": \"execute\"\n"
+ " }\n"
+ " }\n"
+ " ]\n"
+ "}";
+ " \"name\": \"Test agent with connector tool\",\n"
+ " \"type\": \"flow\",\n"
+ " \"description\": \"This is a demo agent for connector tool\",\n"
+ " \"app_type\": \"test1\",\n"
+ " \"tools\": [\n"
+ " {\n"
+ " \"type\": \"ConnectorTool\",\n"
+ " \"name\": \"bedrock_model\",\n"
+ " \"parameters\": {\n"
+ " \"connector_id\": \""
+ bedrockClaudeConnectorId
+ "\",\n"
+ " \"connector_action\": \"execute\"\n"
+ " }\n"
+ " }\n"
+ " ]\n"
+ "}";
String agentId = createAgent(registerAgentRequestBody);
String agentInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}";
String result = executeAgent(agentId, agentInput);
Expand Down

0 comments on commit d02f38a

Please sign in to comment.