Skip to content

Commit

Permalink
add connector tool
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jun 6, 2024
1 parent 12beac2 commit 2d25764
Show file tree
Hide file tree
Showing 14 changed files with 972 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ public enum FunctionName {
SPARSE_TOKENIZE,
TEXT_SIMILARITY,
QUESTION_ANSWERING,
AGENT;
AGENT,
CONNECTOR;

public static FunctionName from(String value) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
}

public enum ActionType {
PREDICT
PREDICT,
EXECUTE
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import org.opensearch.action.ActionType;
import org.opensearch.ml.common.transport.MLTaskResponse;

public class MLExecuteConnectorAction extends ActionType<MLTaskResponse> {
public static final MLExecuteConnectorAction INSTANCE = new MLExecuteConnectorAction();
public static final String NAME = "cluster:admin/opensearch/ml/connectors/execute";

private MLExecuteConnectorAction() {
super(NAME, MLTaskResponse::new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.transport.MLTaskRequest;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
@FieldDefaults(level = AccessLevel.PRIVATE)
@ToString
public class MLExecuteConnectorRequest extends MLTaskRequest {

String connectorId;
MLInput mlInput;

@Builder
public MLExecuteConnectorRequest(String connectorId, MLInput mlInput, boolean dispatchTask) {
super(dispatchTask);
this.mlInput = mlInput;
this.connectorId = connectorId;
}

public MLExecuteConnectorRequest(String connectorId, MLInput mlInput) {
this(connectorId, mlInput, true);
}

public MLExecuteConnectorRequest(StreamInput in) throws IOException {
super(in);
this.connectorId = in.readString();
this.mlInput = new MLInput(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.connectorId);
this.mlInput.writeTo(out);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (this.mlInput == null) {
exception = addValidationError("ML input can't be null", exception);
} else if (this.mlInput.getInputDataset() == null) {
exception = addValidationError("input data can't be null", exception);
}

return exception;
}


public static MLExecuteConnectorRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLExecuteConnectorRequest) {
return (MLExecuteConnectorRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLExecuteConnectorRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLPredictionTaskRequest", e);
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;

public class MLExecuteConnectorRequestTests {
private MLExecuteConnectorRequest mlExecuteConnectorRequest;
private MLInput mlInput;
private String connectorId;

@Before
public void setUp(){
MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("input", "hello")).build();
connectorId = "test_connector";
mlInput = RemoteInferenceMLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.CONNECTOR).build();
mlExecuteConnectorRequest = MLExecuteConnectorRequest.builder().connectorId(connectorId).mlInput(mlInput).build();
}

@Test
public void writeToSuccess() throws IOException {
BytesStreamOutput output = new BytesStreamOutput();
mlExecuteConnectorRequest.writeTo(output);
MLExecuteConnectorRequest parsedRequest = new MLExecuteConnectorRequest(output.bytes().streamInput());
assertEquals(mlExecuteConnectorRequest.getConnectorId(), parsedRequest.getConnectorId());
assertEquals(mlExecuteConnectorRequest.getMlInput().getAlgorithm(), parsedRequest.getMlInput().getAlgorithm());
assertEquals(mlExecuteConnectorRequest.getMlInput().getInputDataset().getInputDataType(), parsedRequest.getMlInput().getInputDataset().getInputDataType());
assertEquals("hello", ((RemoteInferenceInputDataSet)parsedRequest.getMlInput().getInputDataset()).getParameters().get("input"));
}

@Test
public void validateSuccess() {
assertNull(mlExecuteConnectorRequest.validate());
}

@Test
public void testConstructor() {
MLExecuteConnectorRequest executeConnectorRequest = new MLExecuteConnectorRequest(connectorId, mlInput);
assertTrue(executeConnectorRequest.isDispatchTask());
}

@Test
public void validateWithNullMLInputException() {
MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.builder()
.build();
ActionRequestValidationException exception = executeConnectorRequest.validate();
assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage());
}

@Test
public void validateWithNullMLInputDataSetException() {
MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.builder().mlInput(new MLInput())
.build();
ActionRequestValidationException exception = executeConnectorRequest.validate();
assertEquals("Validation Failed: 1: input data can't be null;", exception.getMessage());
}

@Test
public void fromActionRequestWithMLExecuteConnectorRequestSuccess() {
assertSame(MLExecuteConnectorRequest.fromActionRequest(mlExecuteConnectorRequest), mlExecuteConnectorRequest);
}

@Test
public void fromActionRequestWithNonMLExecuteConnectorRequestSuccess() {
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
mlExecuteConnectorRequest.writeTo(out);
}
};
MLExecuteConnectorRequest result = MLExecuteConnectorRequest.fromActionRequest(actionRequest);
assertNotSame(result, mlExecuteConnectorRequest);
assertEquals(mlExecuteConnectorRequest.getConnectorId(), result.getConnectorId());
assertEquals(mlExecuteConnectorRequest.getMlInput().getFunctionName(), result.getMlInput().getFunctionName());
}

@Test(expected = UncheckedIOException.class)
public void fromActionRequestIOException() {
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException();
}
};
MLExecuteConnectorRequest.fromActionRequest(actionRequest);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.tools;

import java.util.List;
import java.util.Map;

import org.opensearch.action.ActionRequest;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction;
import org.opensearch.ml.common.transport.connector.MLExecuteConnectorRequest;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* This tool supports running connector.
*/
@Log4j2
@ToolAnnotation(ConnectorTool.TYPE)
public class ConnectorTool implements Tool {
public static final String TYPE = "ConnectorTool";
public static final String CONNECTOR_ID = "connector_id";
public static final String CONNECTOR_ACTION = "connector_action";

@Setter
@Getter
private String name = ConnectorTool.TYPE;
@Getter
@Setter
private String description = Factory.DEFAULT_DESCRIPTION;
@Getter
private String version;
@Setter
private Parser inputParser;
@Setter
private Parser outputParser;

private Client client;
private String connectorId;

public ConnectorTool(Client client, String connectorId) {
this.client = client;
if (connectorId == null) {
throw new IllegalArgumentException("connector_id can't be null");
}
this.connectorId = connectorId;

outputParser = new Parser() {
@Override
public Object parse(Object o) {
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
}
};
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build();
MLInput mlInput = RemoteInferenceMLInput.builder().algorithm(FunctionName.CONNECTOR).inputDataset(inputDataSet).build();
ActionRequest request = new MLExecuteConnectorRequest(connectorId, mlInput);

client.execute(MLExecuteConnectorAction.INSTANCE, request, ActionListener.wrap(r -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput();
modelTensorOutput.getMlModelOutputs();
if (outputParser == null) {
listener.onResponse((T) modelTensorOutput.getMlModelOutputs());
} else {
listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs()));
}
}, e -> {
log.error("Failed to run model " + connectorId, e);
listener.onFailure(e);
}));
}

@Override
public String getType() {
return TYPE;
}

@Override
public boolean validate(Map<String, String> parameters) {
if (parameters == null || parameters.size() == 0) {
return false;
}
return true;
}

public static class Factory implements Tool.Factory<ConnectorTool> {
public static final String TYPE = "ConnectorTool";
public static final String DEFAULT_DESCRIPTION = "This tool will invoke external service.";
private Client client;
private static Factory INSTANCE;

public static Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (ConnectorTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new Factory();
return INSTANCE;
}
}

public void init(Client client) {
this.client = client;
}

@Override
public ConnectorTool create(Map<String, Object> map) {
return new ConnectorTool(client, (String) map.get(CONNECTOR_ID));
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}

@Override
public String getDefaultType() {
return TYPE;
}

@Override
public String getDefaultVersion() {
return null;
}
}
}
Loading

0 comments on commit 2d25764

Please sign in to comment.