-
Notifications
You must be signed in to change notification settings - Fork 143
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Guardrails for remote model input and output (#2209)
* guardrails Signed-off-by: Jing Zhang <[email protected]> * update guardrails Signed-off-by: Jing Zhang <[email protected]> * bug fix Signed-off-by: Jing Zhang <[email protected]> * add some UT Signed-off-by: Jing Zhang <[email protected]> * change stop words search to unblocking way Signed-off-by: Jing Zhang <[email protected]> * add more UT Signed-off-by: Jing Zhang <[email protected]> * address comments Signed-off-by: Jing Zhang <[email protected]> * add latch countdown when catching exception Signed-off-by: Jing Zhang <[email protected]> --------- Signed-off-by: Jing Zhang <[email protected]> (cherry picked from commit 2d401bc)
- Loading branch information
1 parent
f63bc89
commit 02db35a
Showing
20 changed files
with
1,121 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
105 changes: 105 additions & 0 deletions
105
common/src/main/java/org/opensearch/ml/common/model/Guardrail.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.model; | ||
|
||
import lombok.Builder; | ||
import lombok.EqualsAndHashCode; | ||
import lombok.Getter; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.ToXContentObject; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
|
||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; | ||
|
||
@EqualsAndHashCode | ||
@Getter | ||
public class Guardrail implements ToXContentObject { | ||
public static final String STOP_WORDS_FIELD = "stop_words"; | ||
public static final String REGEX_FIELD = "regex"; | ||
|
||
private List<StopWords> stopWords; | ||
private String[] regex; | ||
|
||
@Builder(toBuilder = true) | ||
public Guardrail(List<StopWords> stopWords, String[] regex) { | ||
this.stopWords = stopWords; | ||
this.regex = regex; | ||
} | ||
|
||
public Guardrail(StreamInput input) throws IOException { | ||
if (input.readBoolean()) { | ||
stopWords = new ArrayList<>(); | ||
int size = input.readInt(); | ||
for (int i=0; i<size; i++) { | ||
stopWords.add(new StopWords(input)); | ||
} | ||
} | ||
regex = input.readStringArray(); | ||
} | ||
|
||
public void writeTo(StreamOutput out) throws IOException { | ||
if (stopWords != null && stopWords.size() > 0) { | ||
out.writeBoolean(true); | ||
out.writeInt(stopWords.size()); | ||
for (StopWords e : stopWords) { | ||
e.writeTo(out); | ||
} | ||
} else { | ||
out.writeBoolean(false); | ||
} | ||
out.writeStringArray(regex); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
if (stopWords != null && stopWords.size() > 0) { | ||
builder.field(STOP_WORDS_FIELD, stopWords); | ||
} | ||
if (regex != null) { | ||
builder.field(REGEX_FIELD, regex); | ||
} | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
public static Guardrail parse(XContentParser parser) throws IOException { | ||
List<StopWords> stopWords = null; | ||
String[] regex = null; | ||
|
||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
while (parser.nextToken() != XContentParser.Token.END_OBJECT) { | ||
String fieldName = parser.currentName(); | ||
parser.nextToken(); | ||
|
||
switch (fieldName) { | ||
case STOP_WORDS_FIELD: | ||
stopWords = new ArrayList<>(); | ||
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); | ||
while (parser.nextToken() != XContentParser.Token.END_ARRAY) { | ||
stopWords.add(StopWords.parse(parser)); | ||
} | ||
break; | ||
case REGEX_FIELD: | ||
regex = parser.list().toArray(new String[0]); | ||
break; | ||
default: | ||
parser.skipChildren(); | ||
break; | ||
} | ||
} | ||
return Guardrail.builder() | ||
.stopWords(stopWords) | ||
.regex(regex) | ||
.build(); | ||
} | ||
} |
125 changes: 125 additions & 0 deletions
125
common/src/main/java/org/opensearch/ml/common/model/Guardrails.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.model; | ||
|
||
import lombok.Builder; | ||
import lombok.EqualsAndHashCode; | ||
import lombok.Getter; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.ToXContentObject; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
|
||
import java.io.IOException; | ||
|
||
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; | ||
|
||
@EqualsAndHashCode | ||
@Getter | ||
public class Guardrails implements ToXContentObject { | ||
public static final String TYPE_FIELD = "type"; | ||
public static final String ENGLISH_DETECTION_ENABLED_FIELD = "english_detection_enabled"; | ||
public static final String INPUT_GUARDRAIL_FIELD = "input_guardrail"; | ||
public static final String OUTPUT_GUARDRAIL_FIELD = "output_guardrail"; | ||
|
||
private String type; | ||
private Boolean engDetectionEnabled; | ||
private Guardrail inputGuardrail; | ||
private Guardrail outputGuardrail; | ||
|
||
@Builder(toBuilder = true) | ||
public Guardrails(String type, Boolean engDetectionEnabled, Guardrail inputGuardrail, Guardrail outputGuardrail) { | ||
this.type = type; | ||
this.engDetectionEnabled = engDetectionEnabled; | ||
this.inputGuardrail = inputGuardrail; | ||
this.outputGuardrail = outputGuardrail; | ||
} | ||
|
||
public Guardrails(StreamInput input) throws IOException { | ||
type = input.readString(); | ||
engDetectionEnabled = input.readBoolean(); | ||
if (input.readBoolean()) { | ||
inputGuardrail = new Guardrail(input); | ||
} | ||
if (input.readBoolean()) { | ||
outputGuardrail = new Guardrail(input); | ||
} | ||
} | ||
|
||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeString(type); | ||
out.writeBoolean(engDetectionEnabled); | ||
if (inputGuardrail != null) { | ||
out.writeBoolean(true); | ||
inputGuardrail.writeTo(out); | ||
} else { | ||
out.writeBoolean(false); | ||
} | ||
if (outputGuardrail != null) { | ||
out.writeBoolean(true); | ||
outputGuardrail.writeTo(out); | ||
} else { | ||
out.writeBoolean(false); | ||
} | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
if (type != null) { | ||
builder.field(TYPE_FIELD, type); | ||
} | ||
if (engDetectionEnabled != null) { | ||
builder.field(ENGLISH_DETECTION_ENABLED_FIELD, engDetectionEnabled); | ||
} | ||
if (inputGuardrail != null) { | ||
builder.field(INPUT_GUARDRAIL_FIELD, inputGuardrail); | ||
} | ||
if (outputGuardrail != null) { | ||
builder.field(OUTPUT_GUARDRAIL_FIELD, outputGuardrail); | ||
} | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
public static Guardrails parse(XContentParser parser) throws IOException { | ||
String type = null; | ||
Boolean engDetectionEnabled = null; | ||
Guardrail inputGuardrail = null; | ||
Guardrail outputGuardrail = null; | ||
|
||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
while (parser.nextToken() != XContentParser.Token.END_OBJECT) { | ||
String fieldName = parser.currentName(); | ||
parser.nextToken(); | ||
|
||
switch (fieldName) { | ||
case TYPE_FIELD: | ||
type = parser.text(); | ||
break; | ||
case ENGLISH_DETECTION_ENABLED_FIELD: | ||
engDetectionEnabled = parser.booleanValue(); | ||
break; | ||
case INPUT_GUARDRAIL_FIELD: | ||
inputGuardrail = Guardrail.parse(parser); | ||
break; | ||
case OUTPUT_GUARDRAIL_FIELD: | ||
outputGuardrail = Guardrail.parse(parser); | ||
break; | ||
default: | ||
parser.skipChildren(); | ||
break; | ||
} | ||
} | ||
return Guardrails.builder() | ||
.type(type) | ||
.engDetectionEnabled(engDetectionEnabled) | ||
.inputGuardrail(inputGuardrail) | ||
.outputGuardrail(outputGuardrail) | ||
.build(); | ||
} | ||
} |
Oops, something went wrong.