Skip to content

Commit

Permalink
Guardrails for remote model input and output (#2209)
Browse files Browse the repository at this point in the history
* guardrails

Signed-off-by: Jing Zhang <[email protected]>

* update guardrails

Signed-off-by: Jing Zhang <[email protected]>

* bug fix

Signed-off-by: Jing Zhang <[email protected]>

* add some UT

Signed-off-by: Jing Zhang <[email protected]>

* change stop words search to unblocking way

Signed-off-by: Jing Zhang <[email protected]>

* add more UT

Signed-off-by: Jing Zhang <[email protected]>

* address comments

Signed-off-by: Jing Zhang <[email protected]>

* add latch countdown when catching exception

Signed-off-by: Jing Zhang <[email protected]>

---------

Signed-off-by: Jing Zhang <[email protected]>
(cherry picked from commit 2d401bc)
  • Loading branch information
jngz-es authored and github-actions[bot] committed Mar 19, 2024
1 parent f63bc89 commit 02db35a
Show file tree
Hide file tree
Showing 20 changed files with 1,121 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public class CommonValue {
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
public static final String ML_TASK_INDEX = ".plugins-ml-task";
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 9;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 10;
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2;
Expand Down Expand Up @@ -265,7 +265,10 @@ public class CommonValue {
+ MLModel.CONNECTOR_FIELD
+ "\": {" + ML_CONNECTOR_INDEX_FIELDS + " }\n},"
+ USER_FIELD_MAPPING
+ " }\n"
+ " },\n"
+ " \""
+ MLModel.GUARDRAILS_FIELD
+ "\" : {\"type\": \"flat_object\"},\n"
+ "}";

public static final String ML_TASK_INDEX_MAPPING = "{\n"
Expand Down
24 changes: 23 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.model.Guardrails;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.controller.MLRateLimiter;
import org.opensearch.ml.common.model.MLModelFormat;
Expand Down Expand Up @@ -84,6 +85,7 @@ public class MLModel implements ToXContentObject {
public static final String IS_HIDDEN_FIELD = "is_hidden";
public static final String CONNECTOR_FIELD = "connector";
public static final String CONNECTOR_ID_FIELD = "connector_id";
public static final String GUARDRAILS_FIELD = "guardrails";

private String name;
private String modelGroupId;
Expand Down Expand Up @@ -127,6 +129,7 @@ public class MLModel implements ToXContentObject {
@Setter
private Connector connector;
private String connectorId;
private Guardrails guardrails;

@Builder(toBuilder = true)
public MLModel(String name,
Expand Down Expand Up @@ -158,7 +161,8 @@ public MLModel(String name,
boolean deployToAllNodes,
Boolean isHidden,
Connector connector,
String connectorId) {
String connectorId,
Guardrails guardrails) {
this.name = name;
this.modelGroupId = modelGroupId;
this.algorithm = algorithm;
Expand Down Expand Up @@ -190,6 +194,7 @@ public MLModel(String name,
this.isHidden = isHidden;
this.connector = connector;
this.connectorId = connectorId;
this.guardrails = guardrails;
}

public MLModel(StreamInput input) throws IOException {
Expand Down Expand Up @@ -243,6 +248,9 @@ public MLModel(StreamInput input) throws IOException {
connector = Connector.fromStream(input);
}
connectorId = input.readOptionalString();
if (input.readBoolean()) {
this.guardrails = new Guardrails(input);
}
}
}

Expand Down Expand Up @@ -308,6 +316,12 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalString(connectorId);
if (guardrails != null) {
out.writeBoolean(true);
guardrails.writeTo(out);
} else {
out.writeBoolean(false);
}
}

@Override
Expand Down Expand Up @@ -406,6 +420,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (connectorId != null) {
builder.field(CONNECTOR_ID_FIELD, connectorId);
}
if (guardrails != null) {
builder.field(GUARDRAILS_FIELD, guardrails);
}
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -448,6 +465,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
boolean isHidden = false;
Connector connector = null;
String connectorId = null;
Guardrails guardrails = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -571,6 +589,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case LAST_UNDEPLOYED_TIME_FIELD:
lastUndeployedTime = Instant.ofEpochMilli(parser.longValue());
break;
case GUARDRAILS_FIELD:
guardrails = Guardrails.parse(parser);
break;
default:
parser.skipChildren();
break;
Expand Down Expand Up @@ -608,6 +629,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
.isHidden(isHidden)
.connector(connector)
.connectorId(connectorId)
.guardrails(guardrails)
.build();
}

Expand Down
105 changes: 105 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/model/Guardrail.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.model;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

@EqualsAndHashCode
@Getter
public class Guardrail implements ToXContentObject {
public static final String STOP_WORDS_FIELD = "stop_words";
public static final String REGEX_FIELD = "regex";

private List<StopWords> stopWords;
private String[] regex;

@Builder(toBuilder = true)
public Guardrail(List<StopWords> stopWords, String[] regex) {
this.stopWords = stopWords;
this.regex = regex;
}

public Guardrail(StreamInput input) throws IOException {
if (input.readBoolean()) {
stopWords = new ArrayList<>();
int size = input.readInt();
for (int i=0; i<size; i++) {
stopWords.add(new StopWords(input));
}
}
regex = input.readStringArray();
}

public void writeTo(StreamOutput out) throws IOException {
if (stopWords != null && stopWords.size() > 0) {
out.writeBoolean(true);
out.writeInt(stopWords.size());
for (StopWords e : stopWords) {
e.writeTo(out);
}
} else {
out.writeBoolean(false);
}
out.writeStringArray(regex);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (stopWords != null && stopWords.size() > 0) {
builder.field(STOP_WORDS_FIELD, stopWords);
}
if (regex != null) {
builder.field(REGEX_FIELD, regex);
}
builder.endObject();
return builder;
}

public static Guardrail parse(XContentParser parser) throws IOException {
List<StopWords> stopWords = null;
String[] regex = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case STOP_WORDS_FIELD:
stopWords = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
stopWords.add(StopWords.parse(parser));
}
break;
case REGEX_FIELD:
regex = parser.list().toArray(new String[0]);
break;
default:
parser.skipChildren();
break;
}
}
return Guardrail.builder()
.stopWords(stopWords)
.regex(regex)
.build();
}
}
125 changes: 125 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/model/Guardrails.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.model;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

@EqualsAndHashCode
@Getter
public class Guardrails implements ToXContentObject {
public static final String TYPE_FIELD = "type";
public static final String ENGLISH_DETECTION_ENABLED_FIELD = "english_detection_enabled";
public static final String INPUT_GUARDRAIL_FIELD = "input_guardrail";
public static final String OUTPUT_GUARDRAIL_FIELD = "output_guardrail";

private String type;
private Boolean engDetectionEnabled;
private Guardrail inputGuardrail;
private Guardrail outputGuardrail;

@Builder(toBuilder = true)
public Guardrails(String type, Boolean engDetectionEnabled, Guardrail inputGuardrail, Guardrail outputGuardrail) {
this.type = type;
this.engDetectionEnabled = engDetectionEnabled;
this.inputGuardrail = inputGuardrail;
this.outputGuardrail = outputGuardrail;
}

public Guardrails(StreamInput input) throws IOException {
type = input.readString();
engDetectionEnabled = input.readBoolean();
if (input.readBoolean()) {
inputGuardrail = new Guardrail(input);
}
if (input.readBoolean()) {
outputGuardrail = new Guardrail(input);
}
}

public void writeTo(StreamOutput out) throws IOException {
out.writeString(type);
out.writeBoolean(engDetectionEnabled);
if (inputGuardrail != null) {
out.writeBoolean(true);
inputGuardrail.writeTo(out);
} else {
out.writeBoolean(false);
}
if (outputGuardrail != null) {
out.writeBoolean(true);
outputGuardrail.writeTo(out);
} else {
out.writeBoolean(false);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (type != null) {
builder.field(TYPE_FIELD, type);
}
if (engDetectionEnabled != null) {
builder.field(ENGLISH_DETECTION_ENABLED_FIELD, engDetectionEnabled);
}
if (inputGuardrail != null) {
builder.field(INPUT_GUARDRAIL_FIELD, inputGuardrail);
}
if (outputGuardrail != null) {
builder.field(OUTPUT_GUARDRAIL_FIELD, outputGuardrail);
}
builder.endObject();
return builder;
}

public static Guardrails parse(XContentParser parser) throws IOException {
String type = null;
Boolean engDetectionEnabled = null;
Guardrail inputGuardrail = null;
Guardrail outputGuardrail = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case TYPE_FIELD:
type = parser.text();
break;
case ENGLISH_DETECTION_ENABLED_FIELD:
engDetectionEnabled = parser.booleanValue();
break;
case INPUT_GUARDRAIL_FIELD:
inputGuardrail = Guardrail.parse(parser);
break;
case OUTPUT_GUARDRAIL_FIELD:
outputGuardrail = Guardrail.parse(parser);
break;
default:
parser.skipChildren();
break;
}
}
return Guardrails.builder()
.type(type)
.engDetectionEnabled(engDetectionEnabled)
.inputGuardrail(inputGuardrail)
.outputGuardrail(outputGuardrail)
.build();
}
}
Loading

0 comments on commit 02db35a

Please sign in to comment.