Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Jing Zhang <[email protected]>
  • Loading branch information
jngz-es committed Mar 18, 2024
1 parent 8cae1a8 commit 6c67b1c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public class CommonValue {
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
public static final String ML_TASK_INDEX = ".plugins-ml-task";
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 9;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 10;
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2;
Expand Down
20 changes: 12 additions & 8 deletions common/src/main/java/org/opensearch/ml/common/model/MLGuard.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import static java.util.concurrent.TimeUnit.SECONDS;
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
Expand All @@ -47,6 +48,8 @@ public class MLGuard {
private Map<String, List<String>> stopWordsIndicesOutput = new HashMap<>();
private List<String> inputRegex;
private List<String> outputRegex;
private List<Pattern> inputRegexPattern;
private List<Pattern> outputRegexPattern;
private NamedXContentRegistry xContentRegistry;
private Client client;

Expand All @@ -61,10 +64,12 @@ public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Cl
if (inputGuardrail != null) {
fillStopWordsToMap(inputGuardrail, stopWordsIndicesInput);
inputRegex = inputGuardrail.getRegex() == null ? new ArrayList<>() : Arrays.asList(inputGuardrail.getRegex());
inputRegexPattern = inputRegex.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList());
}
if (outputGuardrail != null) {
fillStopWordsToMap(outputGuardrail, stopWordsIndicesOutput);
outputRegex = outputGuardrail.getRegex() == null ? new ArrayList<>() : Arrays.asList(outputGuardrail.getRegex());
outputRegexPattern = outputRegex.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList());
}
}

Expand All @@ -81,25 +86,24 @@ private void fillStopWordsToMap(@NonNull Guardrail guardrail, Map<String, List<S
public Boolean validate(String input, int type) {
switch (type) {
case 0: // validate input
return validateRegexList(input, inputRegex) && validateStopWords(input, stopWordsIndicesInput);
return validateRegexList(input, inputRegexPattern) && validateStopWords(input, stopWordsIndicesInput);
case 1: // validate output
return validateRegexList(input, outputRegex) && validateStopWords(input, stopWordsIndicesOutput);
return validateRegexList(input, outputRegexPattern) && validateStopWords(input, stopWordsIndicesOutput);
default:
return true;
throw new IllegalArgumentException("Unsupported type to validate for guardrails.");
}
}

public Boolean validateRegexList(String input, List<String> regexList) {
for (String regex : regexList) {
if (!validateRegex(input, regex)) {
public Boolean validateRegexList(String input, List<Pattern> regexPatterns) {
for (Pattern pattern : regexPatterns) {
if (!validateRegex(input, pattern)) {
return false;
}
}
return true;
}

public Boolean validateRegex(String input, String regex) {
Pattern pattern = Pattern.compile(regex);
public Boolean validateRegex(String input, Pattern pattern) {
Matcher matcher = pattern.matcher(input);
return !matcher.matches();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.Data;
import lombok.Builder;
import lombok.Getter;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand Down Expand Up @@ -46,6 +47,8 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable {
// request
public static final String GUARDRAILS_FIELD = "guardrails";

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS = Version.V_2_13_0;

@Getter
private String modelId;
private String description;
Expand Down Expand Up @@ -81,6 +84,7 @@ public MLUpdateModelInput(String modelId, String description, String version, St
}

public MLUpdateModelInput(StreamInput in) throws IOException {
Version streamInputVersion = in.getVersion();
modelId = in.readString();
description = in.readOptionalString();
version = in.readOptionalString();
Expand All @@ -101,8 +105,10 @@ public MLUpdateModelInput(StreamInput in) throws IOException {
connector = new MLCreateConnectorInput(in);
}
lastUpdateTime = in.readOptionalInstant();
if (in.readBoolean()) {
this.guardrails = new Guardrails(in);
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS)) {
if (in.readBoolean()) {
this.guardrails = new Guardrails(in);
}
}
}

Expand Down Expand Up @@ -193,6 +199,7 @@ public XContentBuilder toXContentForUpdateRequestDoc(XContentBuilder builder, Pa

@Override
public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeString(modelId);
out.writeOptionalString(description);
out.writeOptionalString(version);
Expand Down Expand Up @@ -225,11 +232,13 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalInstant(lastUpdateTime);
if (guardrails != null) {
out.writeBoolean(true);
guardrails.writeTo(out);
} else {
out.writeBoolean(false);
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS)) {
if (guardrails != null) {
out.writeBoolean(true);
guardrails.writeTo(out);
} else {
out.writeBoolean(false);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.regex.Pattern;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
Expand All @@ -54,6 +55,7 @@ public class MLGuardTests {

StopWords stopWords;
String[] regex;
List<Pattern> regexPatterns;
Guardrail inputGuardrail;
Guardrail outputGuardrail;
Guardrails guardrails;
Expand All @@ -70,6 +72,7 @@ public void setUp() {

stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0]));
regex = List.of("(.|\n)*stop words(.|\n)*").toArray(new String[0]);
regexPatterns = List.of(Pattern.compile("(.|\n)*stop words(.|\n)*"));
inputGuardrail = new Guardrail(List.of(stopWords), regex);
outputGuardrail = new Guardrail(List.of(stopWords), regex);
guardrails = new Guardrails("test_type", false, inputGuardrail, outputGuardrail);
Expand All @@ -95,33 +98,31 @@ public void validateOutput() {
@Test
public void validateRegexListSuccess() {
String input = "\n\nHuman:hello good words.\n\nAssistant:";
List<String> regexList = List.of(regex);
Boolean res = mlGuard.validateRegexList(input, regexList);
Boolean res = mlGuard.validateRegexList(input, regexPatterns);

Assert.assertTrue(res);
}

@Test
public void validateRegexListFailed() {
String input = "\n\nHuman:hello stop words.\n\nAssistant:";
List<String> regexList = List.of(regex);
Boolean res = mlGuard.validateRegexList(input, regexList);
Boolean res = mlGuard.validateRegexList(input, regexPatterns);

Assert.assertFalse(res);
}

@Test
public void validateRegexSuccess() {
String input = "\n\nHuman:hello good words.\n\nAssistant:";
Boolean res = mlGuard.validateRegex(input, regex[0]);
Boolean res = mlGuard.validateRegex(input, regexPatterns.get(0));

Assert.assertTrue(res);
}

@Test
public void validateRegexFailed() {
String input = "\n\nHuman:hello stop words.\n\nAssistant:";
Boolean res = mlGuard.validateRegex(input, regex[0]);
Boolean res = mlGuard.validateRegex(input, regexPatterns.get(0));

Assert.assertFalse(res);
}
Expand Down

0 comments on commit 6c67b1c

Please sign in to comment.