Skip to content

Commit

Permalink
Merge branch 'main' into multi-modal-preprocess-func
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Jun 11, 2024
2 parents 3591d25 + 06d1742 commit 38d8ddf
Show file tree
Hide file tree
Showing 75 changed files with 4,954 additions and 1,098 deletions.
10 changes: 10 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@

package org.opensearch.ml.common;

import com.google.common.collect.ImmutableSet;
import org.opensearch.Version;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.connector.AbstractConnector;
import org.opensearch.ml.common.controller.MLController;

import java.util.Set;

import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD;
Expand Down Expand Up @@ -70,6 +74,7 @@ public class CommonValue {
public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta";
public static final Integer ML_MEMORY_META_INDEX_SCHEMA_VERSION = 1;
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
public static final Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
public static final Integer ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION = 1;
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
Expand Down Expand Up @@ -527,4 +532,9 @@ public class CommonValue {
+ "\": {\"type\": \"long\"}\n"
+ " }\n"
+ "}";
// Calculate Versions independently of OpenSearch core version
public static final Version VERSION_2_11_0 = Version.fromString("2.11.0");
public static final Version VERSION_2_12_0 = Version.fromString("2.12.0");
public static final Version VERSION_2_13_0 = Version.fromString("2.13.0");
public static final Version VERSION_2_14_0 = Version.fromString("2.14.0");
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ public enum FunctionName {
SPARSE_TOKENIZE,
TEXT_SIMILARITY,
QUESTION_ANSWERING,
AGENT;
AGENT,
CONNECTOR;

public static FunctionName from(String value) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.MLAgentType;
import org.opensearch.ml.common.MLModel;

Expand Down Expand Up @@ -47,7 +48,7 @@ public class MLAgent implements ToXContentObject, Writeable {
public static final String APP_TYPE_FIELD = "app_type";
public static final String IS_HIDDEN_FIELD = "is_hidden";

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT = Version.V_2_13_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT = CommonValue.VERSION_2_13_0;

private String name;
private String type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public abstract class AbstractConnector implements Connector {
@Setter
protected ConnectorClientConfig connectorClientConfig;

protected Map<String, String> createPredictDecryptedHeaders(Map<String, String> headers) {
protected Map<String, String> createDecryptedHeaders(Map<String, String> headers) {
if (headers == null) {
return null;
}
Expand Down Expand Up @@ -116,9 +116,9 @@ public <T> void parseResponse(T response, List<ModelTensor> modelTensors, boolea
}

@Override
public Optional<ConnectorAction> findPredictAction() {
public Optional<ConnectorAction> findAction(String action) {
if (actions != null) {
return actions.stream().filter(a -> a.getActionType() == ConnectorAction.ActionType.PREDICT).findFirst();
return actions.stream().filter(a -> a.getActionType().name().equalsIgnoreCase(action)).findFirst();
}
return Optional.empty();
}
Expand All @@ -131,12 +131,12 @@ public void removeCredential() {
}

@Override
public String getPredictEndpoint(Map<String, String> parameters) {
Optional<ConnectorAction> predictAction = findPredictAction();
if (!predictAction.isPresent()) {
public String getActionEndpoint(String action, Map<String, String> parameters) {
Optional<ConnectorAction> actionEndpoint = findAction(action);
if (!actionEndpoint.isPresent()) {
return null;
}
String predictEndpoint = predictAction.get().getUrl();
String predictEndpoint = actionEndpoint.get().getUrl();
if (parameters != null && parameters.size() > 0) {
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
predictEndpoint = substitutor.replace(predictEndpoint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,18 @@ public interface Connector extends ToXContentObject, Writeable {

ConnectorClientConfig getConnectorClientConfig();

String getPredictEndpoint(Map<String, String> parameters);
String getActionEndpoint(String action, Map<String, String> parameters);

String getPredictHttpMethod();
String getActionHttpMethod(String action);

<T> T createPredictPayload(Map<String, String> parameters);
<T> T createPayload(String action, Map<String, String> parameters);

void decrypt(Function<String, String> function);
void decrypt(String action, Function<String, String> function);
void encrypt(Function<String, String> function);

Connector cloneConnector();

Optional<ConnectorAction> findPredictAction();
Optional<ConnectorAction> findAction(String action);

void removeCredential();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
}

public enum ActionType {
PREDICT
PREDICT,
EXECUTE
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.regex.Pattern;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP;
import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
Expand Down Expand Up @@ -307,10 +308,10 @@ public void update(MLCreateConnectorInput updateContent, Function<String, String
}

@Override
public <T> T createPredictPayload(Map<String, String> parameters) {
Optional<ConnectorAction> predictAction = findPredictAction();
if (predictAction.isPresent() && predictAction.get().getRequestBody() != null) {
String payload = predictAction.get().getRequestBody();
public <T> T createPayload(String action, Map<String, String> parameters) {
Optional<ConnectorAction> connectorAction = findAction(action);
if (connectorAction.isPresent() && connectorAction.get().getRequestBody() != null) {
String payload = connectorAction.get().getRequestBody();
payload = fillNullParameters(parameters, payload);
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);
Expand Down Expand Up @@ -348,15 +349,15 @@ private List<String> findStringParametersWithNullDefaultValue(String input) {
}

@Override
public void decrypt(Function<String, String> function) {
public void decrypt(String action, Function<String, String> function) {
Map<String, String> decrypted = new HashMap<>();
for (String key : credential.keySet()) {
decrypted.put(key, function.apply(credential.get(key)));
}
this.decryptedCredential = decrypted;
Optional<ConnectorAction> predictAction = findPredictAction();
Map<String, String> headers = predictAction.isPresent() ? predictAction.get().getHeaders() : null;
this.decryptedHeaders = createPredictDecryptedHeaders(headers);
Optional<ConnectorAction> connectorAction = findAction(action);
Map<String, String> headers = connectorAction.isPresent() ? connectorAction.get().getHeaders() : null;
this.decryptedHeaders = createDecryptedHeaders(headers);
}

@Override
Expand All @@ -378,8 +379,9 @@ public void encrypt(Function<String, String> function) {
}
}

public String getPredictHttpMethod() {
return findPredictAction().get().getMethod();
@Override
public String getActionHttpMethod(String action) {
return findAction(action).get().getMethod();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.annotation.InputDataSet;
import org.opensearch.ml.common.output.model.ModelResultFilter;

Expand All @@ -29,7 +30,7 @@ public class TextDocsInputDataSet extends MLInputDataset{

private List<String> docs;

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MULTI_MODAL = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MULTI_MODAL = CommonValue.VERSION_2_11_0;

@Builder(toBuilder = true)
public TextDocsInputDataSet(List<String> docs, ModelResultFilter resultFilter) {
Expand Down
96 changes: 7 additions & 89 deletions common/src/main/java/org/opensearch/ml/common/model/Guardrail.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,101 +5,19 @@

package org.opensearch.ml.common.model;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.client.Client;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
public abstract class Guardrail implements ToXContentObject {

@EqualsAndHashCode
@Getter
public class Guardrail implements ToXContentObject {
public static final String STOP_WORDS_FIELD = "stop_words";
public static final String REGEX_FIELD = "regex";
public abstract void writeTo(StreamOutput out) throws IOException;

private List<StopWords> stopWords;
private String[] regex;
public abstract Boolean validate(String input, Map<String, String> parameters);

@Builder(toBuilder = true)
public Guardrail(List<StopWords> stopWords, String[] regex) {
this.stopWords = stopWords;
this.regex = regex;
}

public Guardrail(StreamInput input) throws IOException {
if (input.readBoolean()) {
stopWords = new ArrayList<>();
int size = input.readInt();
for (int i=0; i<size; i++) {
stopWords.add(new StopWords(input));
}
}
regex = input.readStringArray();
}

public void writeTo(StreamOutput out) throws IOException {
if (stopWords != null && stopWords.size() > 0) {
out.writeBoolean(true);
out.writeInt(stopWords.size());
for (StopWords e : stopWords) {
e.writeTo(out);
}
} else {
out.writeBoolean(false);
}
out.writeStringArray(regex);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (stopWords != null && stopWords.size() > 0) {
builder.field(STOP_WORDS_FIELD, stopWords);
}
if (regex != null) {
builder.field(REGEX_FIELD, regex);
}
builder.endObject();
return builder;
}

public static Guardrail parse(XContentParser parser) throws IOException {
List<StopWords> stopWords = null;
String[] regex = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case STOP_WORDS_FIELD:
stopWords = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
stopWords.add(StopWords.parse(parser));
}
break;
case REGEX_FIELD:
regex = parser.list().toArray(new String[0]);
break;
default:
parser.skipChildren();
break;
}
}
return Guardrail.builder()
.stopWords(stopWords)
.regex(regex)
.build();
}
public abstract void init(NamedXContentRegistry xContentRegistry, Client client);
}
Loading

0 comments on commit 38d8ddf

Please sign in to comment.