From 02db35a4e6d6841934c83823a07af9717385a79e Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 18 Mar 2024 21:02:56 -0700 Subject: [PATCH] Guardrails for remote model input and output (#2209) * guardrails Signed-off-by: Jing Zhang * update guardrails Signed-off-by: Jing Zhang * bug fix Signed-off-by: Jing Zhang * add some UT Signed-off-by: Jing Zhang * change stop words search to unblocking way Signed-off-by: Jing Zhang * add more UT Signed-off-by: Jing Zhang * address comments Signed-off-by: Jing Zhang * add latch countdown when catching exception Signed-off-by: Jing Zhang --------- Signed-off-by: Jing Zhang (cherry picked from commit 2d401bc0504ca09bd9d729c3c4d06dad08e86e14) --- .../org/opensearch/ml/common/CommonValue.java | 7 +- .../org/opensearch/ml/common/MLModel.java | 24 +- .../opensearch/ml/common/model/Guardrail.java | 105 ++++++++ .../ml/common/model/Guardrails.java | 125 ++++++++++ .../opensearch/ml/common/model/MLGuard.java | 162 ++++++++++++ .../opensearch/ml/common/model/StopWords.java | 85 +++++++ .../transport/model/MLUpdateModelInput.java | 36 ++- .../register/MLRegisterModelInput.java | 36 ++- .../ml/common/model/GuardrailTests.java | 69 ++++++ .../ml/common/model/GuardrailsTests.java | 84 +++++++ .../ml/common/model/MLGuardTests.java | 230 ++++++++++++++++++ .../ml/common/model/StopWordsTests.java | 58 +++++ .../remote/AwsConnectorExecutor.java | 7 + .../remote/HttpJsonConnectorExecutor.java | 7 + .../remote/RemoteConnectorExecutor.java | 8 + .../engine/algorithms/remote/RemoteModel.java | 3 + .../models/UpdateModelTransportAction.java | 6 +- .../org/opensearch/ml/model/MLModelCache.java | 3 + .../ml/model/MLModelCacheHelper.java | 35 +++ .../opensearch/ml/model/MLModelManager.java | 45 +++- 20 files changed, 1121 insertions(+), 14 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/model/Guardrail.java create mode 100644 common/src/main/java/org/opensearch/ml/common/model/Guardrails.java create mode 100644 common/src/main/java/org/opensearch/ml/common/model/MLGuard.java create mode 100644 common/src/main/java/org/opensearch/ml/common/model/StopWords.java create mode 100644 common/src/test/java/org/opensearch/ml/common/model/GuardrailTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/model/StopWordsTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 139ecf82be..196b1eb478 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -56,7 +56,7 @@ public class CommonValue { public static final String ML_MODEL_INDEX = ".plugins-ml-model"; public static final String ML_TASK_INDEX = ".plugins-ml-task"; public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2; - public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 9; + public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 10; public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2; public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2; @@ -265,7 +265,10 @@ public class CommonValue { + MLModel.CONNECTOR_FIELD + "\": {" + ML_CONNECTOR_INDEX_FIELDS + " }\n}," + USER_FIELD_MAPPING - + " }\n" + + " },\n" + + " \"" + + MLModel.GUARDRAILS_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + "}"; public static final String ML_TASK_INDEX_MAPPING = "{\n" diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index cec2805891..87b1862c57 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -16,6 +16,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.model.Guardrails; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.MLModelFormat; @@ -84,6 +85,7 @@ public class MLModel implements ToXContentObject { public static final String IS_HIDDEN_FIELD = "is_hidden"; public static final String CONNECTOR_FIELD = "connector"; public static final String CONNECTOR_ID_FIELD = "connector_id"; + public static final String GUARDRAILS_FIELD = "guardrails"; private String name; private String modelGroupId; @@ -127,6 +129,7 @@ public class MLModel implements ToXContentObject { @Setter private Connector connector; private String connectorId; + private Guardrails guardrails; @Builder(toBuilder = true) public MLModel(String name, @@ -158,7 +161,8 @@ public MLModel(String name, boolean deployToAllNodes, Boolean isHidden, Connector connector, - String connectorId) { + String connectorId, + Guardrails guardrails) { this.name = name; this.modelGroupId = modelGroupId; this.algorithm = algorithm; @@ -190,6 +194,7 @@ public MLModel(String name, this.isHidden = isHidden; this.connector = connector; this.connectorId = connectorId; + this.guardrails = guardrails; } public MLModel(StreamInput input) throws IOException { @@ -243,6 +248,9 @@ public MLModel(StreamInput input) throws IOException { connector = Connector.fromStream(input); } connectorId = input.readOptionalString(); + if (input.readBoolean()) { + this.guardrails = new Guardrails(input); + } } } @@ -308,6 +316,12 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeOptionalString(connectorId); + if (guardrails != null) { + out.writeBoolean(true); + guardrails.writeTo(out); + } else { + out.writeBoolean(false); + } } @Override @@ -406,6 +420,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (connectorId != null) { builder.field(CONNECTOR_ID_FIELD, connectorId); } + if (guardrails != null) { + builder.field(GUARDRAILS_FIELD, guardrails); + } builder.endObject(); return builder; } @@ -448,6 +465,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws boolean isHidden = false; Connector connector = null; String connectorId = null; + Guardrails guardrails = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -571,6 +589,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws case LAST_UNDEPLOYED_TIME_FIELD: lastUndeployedTime = Instant.ofEpochMilli(parser.longValue()); break; + case GUARDRAILS_FIELD: + guardrails = Guardrails.parse(parser); + break; default: parser.skipChildren(); break; @@ -608,6 +629,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws .isHidden(isHidden) .connector(connector) .connectorId(connectorId) + .guardrails(guardrails) .build(); } diff --git a/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java b/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java new file mode 100644 index 0000000000..d690fdce7f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +@EqualsAndHashCode +@Getter +public class Guardrail implements ToXContentObject { + public static final String STOP_WORDS_FIELD = "stop_words"; + public static final String REGEX_FIELD = "regex"; + + private List stopWords; + private String[] regex; + + @Builder(toBuilder = true) + public Guardrail(List stopWords, String[] regex) { + this.stopWords = stopWords; + this.regex = regex; + } + + public Guardrail(StreamInput input) throws IOException { + if (input.readBoolean()) { + stopWords = new ArrayList<>(); + int size = input.readInt(); + for (int i=0; i 0) { + out.writeBoolean(true); + out.writeInt(stopWords.size()); + for (StopWords e : stopWords) { + e.writeTo(out); + } + } else { + out.writeBoolean(false); + } + out.writeStringArray(regex); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (stopWords != null && stopWords.size() > 0) { + builder.field(STOP_WORDS_FIELD, stopWords); + } + if (regex != null) { + builder.field(REGEX_FIELD, regex); + } + builder.endObject(); + return builder; + } + + public static Guardrail parse(XContentParser parser) throws IOException { + List stopWords = null; + String[] regex = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case STOP_WORDS_FIELD: + stopWords = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + stopWords.add(StopWords.parse(parser)); + } + break; + case REGEX_FIELD: + regex = parser.list().toArray(new String[0]); + break; + default: + parser.skipChildren(); + break; + } + } + return Guardrail.builder() + .stopWords(stopWords) + .regex(regex) + .build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java b/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java new file mode 100644 index 0000000000..7dc27d75c8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java @@ -0,0 +1,125 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +@EqualsAndHashCode +@Getter +public class Guardrails implements ToXContentObject { + public static final String TYPE_FIELD = "type"; + public static final String ENGLISH_DETECTION_ENABLED_FIELD = "english_detection_enabled"; + public static final String INPUT_GUARDRAIL_FIELD = "input_guardrail"; + public static final String OUTPUT_GUARDRAIL_FIELD = "output_guardrail"; + + private String type; + private Boolean engDetectionEnabled; + private Guardrail inputGuardrail; + private Guardrail outputGuardrail; + + @Builder(toBuilder = true) + public Guardrails(String type, Boolean engDetectionEnabled, Guardrail inputGuardrail, Guardrail outputGuardrail) { + this.type = type; + this.engDetectionEnabled = engDetectionEnabled; + this.inputGuardrail = inputGuardrail; + this.outputGuardrail = outputGuardrail; + } + + public Guardrails(StreamInput input) throws IOException { + type = input.readString(); + engDetectionEnabled = input.readBoolean(); + if (input.readBoolean()) { + inputGuardrail = new Guardrail(input); + } + if (input.readBoolean()) { + outputGuardrail = new Guardrail(input); + } + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + out.writeBoolean(engDetectionEnabled); + if (inputGuardrail != null) { + out.writeBoolean(true); + inputGuardrail.writeTo(out); + } else { + out.writeBoolean(false); + } + if (outputGuardrail != null) { + out.writeBoolean(true); + outputGuardrail.writeTo(out); + } else { + out.writeBoolean(false); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (type != null) { + builder.field(TYPE_FIELD, type); + } + if (engDetectionEnabled != null) { + builder.field(ENGLISH_DETECTION_ENABLED_FIELD, engDetectionEnabled); + } + if (inputGuardrail != null) { + builder.field(INPUT_GUARDRAIL_FIELD, inputGuardrail); + } + if (outputGuardrail != null) { + builder.field(OUTPUT_GUARDRAIL_FIELD, outputGuardrail); + } + builder.endObject(); + return builder; + } + + public static Guardrails parse(XContentParser parser) throws IOException { + String type = null; + Boolean engDetectionEnabled = null; + Guardrail inputGuardrail = null; + Guardrail outputGuardrail = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case TYPE_FIELD: + type = parser.text(); + break; + case ENGLISH_DETECTION_ENABLED_FIELD: + engDetectionEnabled = parser.booleanValue(); + break; + case INPUT_GUARDRAIL_FIELD: + inputGuardrail = Guardrail.parse(parser); + break; + case OUTPUT_GUARDRAIL_FIELD: + outputGuardrail = Guardrail.parse(parser); + break; + default: + parser.skipChildren(); + break; + } + } + return Guardrails.builder() + .type(type) + .engDetectionEnabled(engDetectionEnabled) + .inputGuardrail(inputGuardrail) + .outputGuardrail(outputGuardrail) + .build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java b/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java new file mode 100644 index 0000000000..f42df5bc62 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java @@ -0,0 +1,162 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import lombok.Getter; +import lombok.NonNull; +import lombok.extern.log4j.Log4j2; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.opensearch.ml.common.CommonValue.MASTER_KEY; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +@Log4j2 +@Getter +public class MLGuard { + private Map> stopWordsIndicesInput = new HashMap<>(); + private Map> stopWordsIndicesOutput = new HashMap<>(); + private List inputRegex; + private List outputRegex; + private List inputRegexPattern; + private List outputRegexPattern; + private NamedXContentRegistry xContentRegistry; + private Client client; + + public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client) { + this.xContentRegistry = xContentRegistry; + this.client = client; + if (guardrails == null) { + return; + } + Guardrail inputGuardrail = guardrails.getInputGuardrail(); + Guardrail outputGuardrail = guardrails.getOutputGuardrail(); + if (inputGuardrail != null) { + fillStopWordsToMap(inputGuardrail, stopWordsIndicesInput); + inputRegex = inputGuardrail.getRegex() == null ? new ArrayList<>() : Arrays.asList(inputGuardrail.getRegex()); + inputRegexPattern = inputRegex.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList()); + } + if (outputGuardrail != null) { + fillStopWordsToMap(outputGuardrail, stopWordsIndicesOutput); + outputRegex = outputGuardrail.getRegex() == null ? new ArrayList<>() : Arrays.asList(outputGuardrail.getRegex()); + outputRegexPattern = outputRegex.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList()); + } + } + + private void fillStopWordsToMap(@NonNull Guardrail guardrail, Map> map) { + List stopWords = guardrail.getStopWords(); + if (stopWords == null || stopWords.isEmpty()) { + return; + } + for (StopWords e : stopWords) { + map.put(e.getIndex(), Arrays.asList(e.getSourceFields())); + } + } + + public Boolean validate(String input, int type) { + switch (type) { + case 0: // validate input + return validateRegexList(input, inputRegexPattern) && validateStopWords(input, stopWordsIndicesInput); + case 1: // validate output + return validateRegexList(input, outputRegexPattern) && validateStopWords(input, stopWordsIndicesOutput); + default: + throw new IllegalArgumentException("Unsupported type to validate for guardrails."); + } + } + + public Boolean validateRegexList(String input, List regexPatterns) { + for (Pattern pattern : regexPatterns) { + if (!validateRegex(input, pattern)) { + return false; + } + } + return true; + } + + public Boolean validateRegex(String input, Pattern pattern) { + Matcher matcher = pattern.matcher(input); + return !matcher.matches(); + } + + public Boolean validateStopWords(String input, Map> stopWordsIndices) { + for (Map.Entry entry : stopWordsIndices.entrySet()) { + if (!validateStopWordsSingleIndex(input, (String) entry.getKey(), (List) entry.getValue())) { + return false; + } + } + return true; + } + + public Boolean validateStopWordsSingleIndex(String input, String indexName, List fieldNames) { + SearchRequest searchRequest; + AtomicBoolean hitStopWords = new AtomicBoolean(false); + String queryBody; + Map documentMap = new HashMap<>(); + for (String field : fieldNames) { + documentMap.put(field, input); + } + Map queryBodyMap = Map + .of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap))); + CountDownLatch latch = new CountDownLatch(1); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + queryBody = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(queryBodyMap)); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody); + searchSourceBuilder.parseXContent(queryParser); + searchSourceBuilder.size(1); //Only need 1 doc returned, if hit. + searchRequest = new SearchRequest().source(searchSourceBuilder).indices(indexName); + context.restore(); + client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.wrap(r -> { + if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { + hitStopWords.set(true); + } + }, e -> { + log.error("Failed to search stop words index {}", indexName, e); + hitStopWords.set(true); + }), latch), () -> context.restore())); + } catch (Exception e) { + log.error("[validateStopWords] Searching stop words index failed.", e); + latch.countDown(); + hitStopWords.set(true); + } + + try { + latch.await(5, SECONDS); + } catch (InterruptedException e) { + log.error("[validateStopWords] Searching stop words index was timeout.", e); + throw new IllegalStateException(e); + } + return hitStopWords.get(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/model/StopWords.java b/common/src/main/java/org/opensearch/ml/common/model/StopWords.java new file mode 100644 index 0000000000..19307b398d --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/model/StopWords.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +@EqualsAndHashCode +@Getter +public class StopWords implements ToXContentObject { + public static final String INDEX_NAME_FIELD = "index_name"; + public static final String SOURCE_FIELDS_FIELD = "source_fields"; + + private String index; + private String[] sourceFields; + + @Builder(toBuilder = true) + public StopWords(String index, String[] sourceFields) { + this.index = index; + this.sourceFields = sourceFields; + } + + public StopWords(StreamInput input) throws IOException { + index = input.readString(); + sourceFields = input.readStringArray(); + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(index); + out.writeStringArray(sourceFields); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (index != null) { + builder.field(INDEX_NAME_FIELD, index); + } + if (sourceFields != null) { + builder.field(SOURCE_FIELDS_FIELD, sourceFields); + } + builder.endObject(); + return builder; + } + + public static StopWords parse(XContentParser parser) throws IOException { + String index = null; + String[] sourceFields = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case INDEX_NAME_FIELD: + index = parser.text(); + break; + case SOURCE_FIELDS_FIELD: + sourceFields = parser.list().toArray(new String[0]); + break; + default: + parser.skipChildren(); + break; + } + } + return StopWords.builder() + .index(index) + .sourceFields(sourceFields) + .build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java index 065dff69e0..15ce509756 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java @@ -8,6 +8,7 @@ import lombok.Data; import lombok.Builder; import lombok.Getter; +import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -15,6 +16,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.model.Guardrails; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; @@ -43,6 +45,9 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable { public static final String CONNECTOR_FIELD = "connector"; // optional public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; // passively set when sending update // request + public static final String GUARDRAILS_FIELD = "guardrails"; + + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS = Version.V_2_13_0; @Getter private String modelId; @@ -57,11 +62,12 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable { private String connectorId; private MLCreateConnectorInput connector; private Instant lastUpdateTime; + private Guardrails guardrails; @Builder(toBuilder = true) public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, Boolean isEnabled, MLRateLimiter rateLimiter, MLModelConfig modelConfig, - Connector updatedConnector, String connectorId, MLCreateConnectorInput connector, Instant lastUpdateTime) { + Connector updatedConnector, String connectorId, MLCreateConnectorInput connector, Instant lastUpdateTime, Guardrails guardrails) { this.modelId = modelId; this.description = description; this.version = version; @@ -74,9 +80,11 @@ public MLUpdateModelInput(String modelId, String description, String version, St this.connectorId = connectorId; this.connector = connector; this.lastUpdateTime = lastUpdateTime; + this.guardrails = guardrails; } public MLUpdateModelInput(StreamInput in) throws IOException { + Version streamInputVersion = in.getVersion(); modelId = in.readString(); description = in.readOptionalString(); version = in.readOptionalString(); @@ -97,6 +105,11 @@ public MLUpdateModelInput(StreamInput in) throws IOException { connector = new MLCreateConnectorInput(in); } lastUpdateTime = in.readOptionalInstant(); + if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS)) { + if (in.readBoolean()) { + this.guardrails = new Guardrails(in); + } + } } @Override @@ -136,6 +149,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (lastUpdateTime != null) { builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli()); } + if (guardrails != null) { + builder.field(GUARDRAILS_FIELD, guardrails); + } builder.endObject(); return builder; } @@ -174,12 +190,16 @@ public XContentBuilder toXContentForUpdateRequestDoc(XContentBuilder builder, Pa if (lastUpdateTime != null) { builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli()); } + if (guardrails != null) { + builder.field(GUARDRAILS_FIELD, guardrails); + } builder.endObject(); return builder; } @Override public void writeTo(StreamOutput out) throws IOException { + Version streamOutputVersion = out.getVersion(); out.writeString(modelId); out.writeOptionalString(description); out.writeOptionalString(version); @@ -212,6 +232,14 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeOptionalInstant(lastUpdateTime); + if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS)) { + if (guardrails != null) { + out.writeBoolean(true); + guardrails.writeTo(out); + } else { + out.writeBoolean(false); + } + } } public static MLUpdateModelInput parse(XContentParser parser) throws IOException { @@ -227,6 +255,7 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException String connectorId = null; MLCreateConnectorInput connector = null; Instant lastUpdateTime = null; + Guardrails guardrails = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -257,6 +286,9 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException case CONNECTOR_FIELD: connector = MLCreateConnectorInput.parse(parser, true); break; + case GUARDRAILS_FIELD: + guardrails = Guardrails.parse(parser); + break; default: parser.skipChildren(); break; @@ -265,6 +297,6 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException // Model ID can only be set through RestRequest. Model version can only be set // automatically. return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, isEnabled, rateLimiter, - modelConfig, updatedConnector, connectorId, connector, lastUpdateTime); + modelConfig, updatedConnector, connectorId, connector, lastUpdateTime, guardrails); } } \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index e448994eb5..16def08f86 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -18,6 +18,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.model.Guardrails; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.MLModelFormat; @@ -59,9 +60,11 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { public static final String BACKEND_ROLES_FIELD = "backend_roles"; public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles"; public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group"; + public static final String GUARDRAILS_FIELD = "guardrails"; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP = Version.V_2_11_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_AGENT_FRAMEWORK = Version.V_2_12_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS = Version.V_2_13_0; private FunctionName functionName; private String modelName; @@ -87,6 +90,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { private Boolean doesVersionCreateModelGroup; private Boolean isHidden; + private Guardrails guardrails; @Builder(toBuilder = true) public MLRegisterModelInput(FunctionName functionName, @@ -108,7 +112,8 @@ public MLRegisterModelInput(FunctionName functionName, Boolean addAllBackendRoles, AccessMode accessMode, Boolean doesVersionCreateModelGroup, - Boolean isHidden) { + Boolean isHidden, + Guardrails guardrails) { this.functionName = Objects.requireNonNullElse(functionName, FunctionName.TEXT_EMBEDDING); if (modelName == null) { throw new IllegalArgumentException("model name is null"); @@ -144,6 +149,7 @@ public MLRegisterModelInput(FunctionName functionName, this.accessMode = accessMode; this.doesVersionCreateModelGroup = doesVersionCreateModelGroup; this.isHidden = isHidden; + this.guardrails = guardrails; } public MLRegisterModelInput(StreamInput in) throws IOException { @@ -188,6 +194,11 @@ public MLRegisterModelInput(StreamInput in) throws IOException { } this.isHidden = in.readOptionalBoolean(); } + if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS)) { + if (in.readBoolean()) { + this.guardrails = new Guardrails(in); + } + } } @Override @@ -247,6 +258,14 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeOptionalBoolean(isHidden); } + if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS)) { + if (guardrails != null) { + out.writeBoolean(true); + guardrails.writeTo(out); + } else { + out.writeBoolean(false); + } + } } @Override @@ -306,6 +325,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (isHidden != null) { builder.field(MLModel.IS_HIDDEN_FIELD, isHidden); } + if (guardrails != null) { + builder.field(GUARDRAILS_FIELD, guardrails); + } builder.endObject(); return builder; } @@ -329,6 +351,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName AccessMode accessMode = null; Boolean doesVersionCreateModelGroup = null; Boolean isHidden = null; + Guardrails guardrails = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -392,6 +415,9 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName case DOES_VERSION_CREATE_MODEL_GROUP: doesVersionCreateModelGroup = parser.booleanValue(); break; + case GUARDRAILS_FIELD: + guardrails = Guardrails.parse(parser); + break; default: parser.skipChildren(); break; @@ -400,7 +426,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, isEnabled, rateLimiter, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, - isHidden); + isHidden, guardrails); } public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException { @@ -423,6 +449,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo Boolean addAllBackendRoles = null; Boolean doesVersionCreateModelGroup = null; Boolean isHidden = null; + Guardrails guardrails = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -493,6 +520,9 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo case MLModel.IS_HIDDEN_FIELD: isHidden = parser.booleanValue(); break; + case GUARDRAILS_FIELD: + guardrails = Guardrails.parse(parser); + break; default: parser.skipChildren(); break; @@ -500,6 +530,6 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo } return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, isEnabled, rateLimiter, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, - connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, isHidden); + connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, isHidden, guardrails); } } diff --git a/common/src/test/java/org/opensearch/ml/common/model/GuardrailTests.java b/common/src/test/java/org/opensearch/ml/common/model/GuardrailTests.java new file mode 100644 index 0000000000..b6b140d119 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/model/GuardrailTests.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchModule; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.*; + +public class GuardrailTests { + StopWords stopWords; + String[] regex; + + @Before + public void setUp() { + stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0])); + regex = List.of("regex1").toArray(new String[0]); + } + + @Test + public void writeTo() throws IOException { + Guardrail guardrail = new Guardrail(List.of(stopWords), regex); + BytesStreamOutput output = new BytesStreamOutput(); + guardrail.writeTo(output); + Guardrail guardrail1 = new Guardrail(output.bytes().streamInput()); + + Assert.assertArrayEquals(guardrail.getStopWords().toArray(), guardrail1.getStopWords().toArray()); + Assert.assertArrayEquals(guardrail.getRegex(), guardrail1.getRegex()); + } + + @Test + public void toXContent() throws IOException { + Guardrail guardrail = new Guardrail(List.of(stopWords), regex); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + guardrail.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert.assertEquals("{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}", content); + } + + @Test + public void parse() throws IOException { + String jsonStr = "{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + Guardrail guardrail = Guardrail.parse(parser); + + Assert.assertArrayEquals(guardrail.getStopWords().toArray(), List.of(stopWords).toArray()); + Assert.assertArrayEquals(guardrail.getRegex(), regex); + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java b/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java new file mode 100644 index 0000000000..43e1464214 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchModule; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.*; + +public class GuardrailsTests { + StopWords stopWords; + String[] regex; + Guardrail inputGuardrail; + Guardrail outputGuardrail; + + @Before + public void setUp() { + stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0])); + regex = List.of("regex1").toArray(new String[0]); + inputGuardrail = new Guardrail(List.of(stopWords), regex); + outputGuardrail = new Guardrail(List.of(stopWords), regex); + } + + @Test + public void writeTo() throws IOException { + Guardrails guardrails = new Guardrails("test_type", false, inputGuardrail, outputGuardrail); + BytesStreamOutput output = new BytesStreamOutput(); + guardrails.writeTo(output); + Guardrails guardrails1 = new Guardrails(output.bytes().streamInput()); + + Assert.assertEquals(guardrails.getType(), guardrails1.getType()); + Assert.assertEquals(guardrails.getEngDetectionEnabled(), guardrails1.getEngDetectionEnabled()); + Assert.assertEquals(guardrails.getInputGuardrail(), guardrails1.getInputGuardrail()); + Assert.assertEquals(guardrails.getOutputGuardrail(), guardrails1.getOutputGuardrail()); + } + + @Test + public void toXContent() throws IOException { + Guardrails guardrails = new Guardrails("test_type", false, inputGuardrail, outputGuardrail); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + guardrails.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert.assertEquals("{\"type\":\"test_type\"," + + "\"english_detection_enabled\":false," + + "\"input_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}," + + "\"output_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}}", + content); + } + + @Test + public void parse() throws IOException { + String jsonStr = "{\"type\":\"test_type\"," + + "\"english_detection_enabled\":false," + + "\"input_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}," + + "\"output_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + Guardrails guardrails = Guardrails.parse(parser); + + Assert.assertEquals(guardrails.getType(), "test_type"); + Assert.assertEquals(guardrails.getEngDetectionEnabled(), false); + Assert.assertEquals(guardrails.getInputGuardrail(), inputGuardrail); + Assert.assertEquals(guardrails.getOutputGuardrail(), outputGuardrail); + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java b/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java new file mode 100644 index 0000000000..413f2f6307 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java @@ -0,0 +1,230 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import org.apache.lucene.search.TotalHits; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.SearchModule; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.search.profile.SearchProfileShardResults; +import org.opensearch.search.suggest.Suggest; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.regex.Pattern; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +public class MLGuardTests { + + NamedXContentRegistry xContentRegistry; + @Mock + Client client; + @Mock + ThreadPool threadPool; + ThreadContext threadContext; + + StopWords stopWords; + String[] regex; + List regexPatterns; + Guardrail inputGuardrail; + Guardrail outputGuardrail; + Guardrails guardrails; + MLGuard mlGuard; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + xContentRegistry = new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + Settings settings = Settings.builder().build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + + stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0])); + regex = List.of("(.|\n)*stop words(.|\n)*").toArray(new String[0]); + regexPatterns = List.of(Pattern.compile("(.|\n)*stop words(.|\n)*")); + inputGuardrail = new Guardrail(List.of(stopWords), regex); + outputGuardrail = new Guardrail(List.of(stopWords), regex); + guardrails = new Guardrails("test_type", false, inputGuardrail, outputGuardrail); + mlGuard = new MLGuard(guardrails, xContentRegistry, client); + } + + @Test + public void validateInput() { + String input = "\n\nHuman:hello stop words.\n\nAssistant:"; + Boolean res = mlGuard.validate(input, 0); + + Assert.assertFalse(res); + } + + @Test + public void validateOutput() { + String input = "\n\nHuman:hello stop words.\n\nAssistant:"; + Boolean res = mlGuard.validate(input, 1); + + Assert.assertFalse(res); + } + + @Test + public void validateRegexListSuccess() { + String input = "\n\nHuman:hello good words.\n\nAssistant:"; + Boolean res = mlGuard.validateRegexList(input, regexPatterns); + + Assert.assertTrue(res); + } + + @Test + public void validateRegexListFailed() { + String input = "\n\nHuman:hello stop words.\n\nAssistant:"; + Boolean res = mlGuard.validateRegexList(input, regexPatterns); + + Assert.assertFalse(res); + } + + @Test + public void validateRegexSuccess() { + String input = "\n\nHuman:hello good words.\n\nAssistant:"; + Boolean res = mlGuard.validateRegex(input, regexPatterns.get(0)); + + Assert.assertTrue(res); + } + + @Test + public void validateRegexFailed() { + String input = "\n\nHuman:hello stop words.\n\nAssistant:"; + Boolean res = mlGuard.validateRegex(input, regexPatterns.get(0)); + + Assert.assertFalse(res); + } + + @Test + public void validateStopWords() throws IOException { + Map> stopWordsIndices = Map.of("test_index", List.of("test_field")); + SearchResponse searchResponse = createSearchResponse(1); + ActionFuture future = createSearchResponseFuture(searchResponse); + when(this.client.search(any())).thenReturn(future); + + Boolean res = mlGuard.validateStopWords("hello world", stopWordsIndices); + Assert.assertTrue(res); + } + + @Test + public void validateStopWordsSingleIndex() throws IOException { + SearchResponse searchResponse = createSearchResponse(1); + ActionFuture future = createSearchResponseFuture(searchResponse); + when(this.client.search(any())).thenReturn(future); + + Boolean res = mlGuard.validateStopWordsSingleIndex("hello world", "test_index", List.of("test_field")); + Assert.assertTrue(res); + } + + private SearchResponse createSearchResponse(int size) throws IOException { + XContentBuilder content = guardrails.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + SearchHit[] hits = new SearchHit[size]; + if (size > 0) { + hits[0] = new SearchHit(0).sourceRef(BytesReference.bytes(content)); + } + return new SearchResponse( + new InternalSearchResponse( + new SearchHits(hits, new TotalHits(size, TotalHits.Relation.EQUAL_TO), 1.0f), + InternalAggregations.EMPTY, + new Suggest(Collections.emptyList()), + new SearchProfileShardResults(Collections.emptyMap()), + false, + false, + 1 + ), + "", + 5, + 5, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + } + + private ActionFuture createSearchResponseFuture(SearchResponse searchResponse) { + return new ActionFuture<>() { + @Override + public SearchResponse actionGet() { + return searchResponse; + } + + @Override + public SearchResponse actionGet(String timeout) { + return searchResponse; + } + + @Override + public SearchResponse actionGet(long timeoutMillis) { + return searchResponse; + } + + @Override + public SearchResponse actionGet(long timeout, TimeUnit unit) { + return searchResponse; + } + + @Override + public SearchResponse actionGet(TimeValue timeout) { + return searchResponse; + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return false; + } + + @Override + public SearchResponse get() throws InterruptedException, ExecutionException { + return searchResponse; + } + + @Override + public SearchResponse get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + return searchResponse; + } + }; + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/model/StopWordsTests.java b/common/src/test/java/org/opensearch/ml/common/model/StopWordsTests.java new file mode 100644 index 0000000000..19764bb736 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/model/StopWordsTests.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchModule; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +public class StopWordsTests { + + @Test + public void writeTo() throws IOException { + StopWords stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0])); + BytesStreamOutput output = new BytesStreamOutput(); + stopWords.writeTo(output); + StopWords stopWords1 = new StopWords(output.bytes().streamInput()); + + Assert.assertEquals(stopWords.getIndex(), stopWords1.getIndex()); + Assert.assertArrayEquals(stopWords.getSourceFields(), stopWords1.getSourceFields()); + } + + @Test + public void toXContent() throws IOException { + StopWords stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0])); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + stopWords.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert.assertEquals("{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}", content); + } + + @Test + public void parse() throws IOException { + String jsonStr = "{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + StopWords stopWords = StopWords.parse(parser); + + Assert.assertEquals(stopWords.getIndex(), "test_index"); + Assert.assertArrayEquals(stopWords.getSourceFields(), List.of("test_field").toArray(new String[0])); + } +} \ No newline at end of file diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 49f651eadb..c933732674 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -28,6 +28,7 @@ import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.engine.annotation.ConnectorExecutor; import org.opensearch.script.ScriptService; @@ -64,6 +65,9 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor { @Setter @Getter private Client client; + @Setter + @Getter + private MLGuard mlGuard; public AwsConnectorExecutor(Connector connector, SdkHttpClient httpClient) { this.connector = (AwsConnector) connector; @@ -146,6 +150,9 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST); } String modelResponse = responseBuilder.toString(); + if (getMlGuard() != null && !getMlGuard().validate(modelResponse, 1)) { + throw new IllegalArgumentException("guardrails triggered for LLM output"); + } if (statusCode < 200 || statusCode >= 300) { throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode)); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 1c7b2eaf36..bea550ef4a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -32,6 +32,7 @@ import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.engine.annotation.ConnectorExecutor; import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; @@ -60,6 +61,9 @@ public class HttpJsonConnectorExecutor extends AbstractConnectorExecutor { @Setter @Getter private Client client; + @Setter + @Getter + private MLGuard mlGuard; private CloseableHttpClient httpClient; @@ -135,6 +139,9 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S return null; }); String modelResponse = responseRef.get(); + if (getMlGuard() != null && !getMlGuard().validate(modelResponse, 1)) { + throw new IllegalArgumentException("guardrails triggered for LLM output"); + } Integer statusCode = statusCodeRef.get(); if (statusCode < 200 || statusCode >= 300) { throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode)); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index bb3e13b24a..76637514ab 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -27,6 +27,7 @@ import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.script.ScriptService; @@ -87,6 +88,8 @@ default void setScriptService(ScriptService scriptService) {} Map getUserRateLimiterMap(); + MLGuard getMlGuard(); + Client getClient(); default void setClient(Client client) {} @@ -99,6 +102,8 @@ default void setRateLimiter(TokenBucket rateLimiter) {} default void setUserRateLimiterMap(Map userRateLimiterMap) {} + default void setMlGuard(MLGuard mlGuard) {} + default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List tensorOutputs) { Connector connector = getConnector(); @@ -137,6 +142,9 @@ && getUserRateLimiterMap().get(user.getName()) != null RestStatus.TOO_MANY_REQUESTS ); } else { + if (getMlGuard() != null && !getMlGuard().validate(payload, 0)) { + throw new IllegalArgumentException("guardrails triggered for user input"); + } invokeRemoteModel(mlInput, parameters, payload, tensorOutputs); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index add3906163..8774bcc40c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -16,6 +16,7 @@ import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.engine.MLEngineClassLoader; import org.opensearch.ml.engine.Predictable; @@ -37,6 +38,7 @@ public class RemoteModel implements Predictable { public static final String XCONTENT_REGISTRY = "xcontent_registry"; public static final String RATE_LIMITER = "rate_limiter"; public static final String USER_RATE_LIMITER_MAP = "user_rate_limiter_map"; + public static final String GUARDRAILS = "guardrails"; private RemoteConnectorExecutor connectorExecutor; @@ -90,6 +92,7 @@ public void initModel(MLModel model, Map params, Encryptor encry this.connectorExecutor.setXContentRegistry((NamedXContentRegistry) params.get(XCONTENT_REGISTRY)); this.connectorExecutor.setRateLimiter((TokenBucket) params.get(RATE_LIMITER)); this.connectorExecutor.setUserRateLimiterMap((Map) params.get(USER_RATE_LIMITER_MAP)); + this.connectorExecutor.setMlGuard((MLGuard) params.get(GUARDRAILS)); } catch (RuntimeException e) { log.error("Failed to init remote model.", e); throw e; diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java index 57e0361fae..3c9c38248c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java @@ -206,11 +206,13 @@ private void updateRemoteOrTextEmbeddingModel( String newConnectorId = Strings.hasLength(updateModelInput.getConnectorId()) ? updateModelInput.getConnectorId() : null; boolean isModelDeployed = isModelDeployed(mlModel.getModelState()); // This flag is used to decide if we need to re-deploy the predictor(model) when updating the model cache. - // If one of the internal connector, stand-alone connector id, model quota flag, as well as the model rate limiter needs update, we + // If one of the internal connector, stand-alone connector id, model quota flag, as well as the model rate limiter and guardrails + // need update, we // need to perform a re-deploy. boolean isPredictorUpdate = (updateModelInput.getConnector() != null) || (newConnectorId != null) - || !Objects.equals(updateModelInput.getIsEnabled(), mlModel.getIsEnabled()); + || !Objects.equals(updateModelInput.getIsEnabled(), mlModel.getIsEnabled()) + || (updateModelInput.getGuardrails() != null); if (MLRateLimiter.updateValidityPreCheck(mlModel.getRateLimiter(), updateModelInput.getRateLimiter())) { MLRateLimiter updatedRateLimiterConfig = MLRateLimiter.update(mlModel.getRateLimiter(), updateModelInput.getRateLimiter()); updateModelInput.setRateLimiter(updatedRateLimiterConfig); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java index 16ed1be826..06509c30ca 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java @@ -17,6 +17,7 @@ import org.opensearch.common.util.TokenBucket; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.engine.MLExecutable; import org.opensearch.ml.engine.Predictable; @@ -45,6 +46,7 @@ public class MLModelCache { private final Queue predictRequestDurationQueue; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Long memSizeEstimationCPU; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Long memSizeEstimationGPU; + private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) MLGuard mlGuard; // In rare case, this could be null, e.g. model info not synced up yet a predict request comes in. @Setter @@ -166,6 +168,7 @@ public void clear() { isModelEnabled = null; rateLimiter = null; userRateLimiterMap = null; + mlGuard = null; } public void addModelInferenceDuration(double duration, long maxRequestCount) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 99ccc9cce1..a71bd78fb9 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -21,6 +21,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.exception.MLLimitExceededException; +import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.engine.MLExecutable; @@ -160,6 +161,40 @@ public TokenBucket getUserRateLimiter(String modelId, String user) { return userRateLimiterMap.get(user); } + /** + * Set a ml guard + * + * @param modelId model id + * @param mlGuard mlGuard + */ + public synchronized void setMLGuard(String modelId, MLGuard mlGuard) { + log.debug("Setting ML guard {} for Model {}", mlGuard, modelId); + getExistingModelCache(modelId).setMlGuard(mlGuard); + } + + /** + * Get the current ML guard for the model. + * + * @param modelId model id + */ + public MLGuard getMLGuard(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache == null) { + return null; + } + return modelCache.getMlGuard(); + } + + /** + * Remove the ML guard from cache + * + * @param modelId model id + */ + public synchronized void removeMLGuard(String modelId) { + log.debug("Removing the ML guard from Model {}", modelId); + getExistingModelCache(modelId).setMlGuard(null); + } + /** * Set a quota flag to control if the model can still receive request * diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 24f7375934..72031145ab 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -26,6 +26,7 @@ import static org.opensearch.ml.engine.ModelHelper.MODEL_SIZE_IN_BYTES; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLIENT; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLUSTER_SERVICE; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.GUARDRAILS; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.RATE_LIMITER; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.SCRIPT_SERVICE; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.USER_RATE_LIMITER_MAP; @@ -106,6 +107,8 @@ import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.exception.MLValidationException; +import org.opensearch.ml.common.model.Guardrails; +import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; @@ -527,6 +530,7 @@ private void indexRemoteModel( .createdTime(now) .lastUpdateTime(now) .isHidden(registerModelInput.getIsHidden()) + .guardrails(registerModelInput.getGuardrails()) .build(); IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); @@ -591,6 +595,7 @@ void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, St .createdTime(now) .lastUpdateTime(now) .isHidden(registerModelInput.getIsHidden()) + .guardrails(registerModelInput.getGuardrails()) .build(); IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) { @@ -655,6 +660,7 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas .createdTime(now) .lastUpdateTime(now) .isHidden(registerModelInput.getIsHidden()) + .guardrails(registerModelInput.getGuardrails()) .build(); IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); if (functionName == FunctionName.METRICS_CORRELATION) { @@ -738,6 +744,7 @@ private void registerModel( .createdTime(now) .lastUpdateTime(now) .isHidden(registerModelInput.getIsHidden()) + .guardrails(registerModelInput.getGuardrails()) .build(); IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) { @@ -991,6 +998,7 @@ public void deployModel( } setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter()); + setupMLGuard(modelId, mlModel.getGuardrails()); deployControllerWithDeployingModel(mlModel, eligibleNodeCount); // check circuit breaker before deploying custom model chunks checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats); @@ -1052,6 +1060,7 @@ public void deployModel( private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCount, ActionListener wrappedListener) { String modelId = mlModel.getModelId(); setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter()); + setupMLGuard(modelId, mlModel.getGuardrails()); if (mlModel.getConnector() != null || FunctionName.REMOTE != mlModel.getAlgorithm()) { setupParamsAndPredictable(modelId, mlModel); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); @@ -1079,6 +1088,7 @@ private void setupParamsAndPredictable(String modelId, MLModel mlModel) { private Map setUpParameterMap(String modelId) { TokenBucket rateLimiter = getRateLimiter(modelId); Map userRateLimiterMap = getUserRateLimiterMap(modelId); + MLGuard mlGuard = getMLGuard(modelId); Map params = new HashMap<>(); params.put(ML_ENGINE, mlEngine); @@ -1089,21 +1099,24 @@ private Map setUpParameterMap(String modelId) { if (rateLimiter == null && userRateLimiterMap == null) { log.info("Setting up basic ML predictor parameters."); - return Collections.unmodifiableMap(params); } else if (rateLimiter != null && userRateLimiterMap == null) { params.put(RATE_LIMITER, rateLimiter); log.info("Setting up basic ML predictor parameters with model level throttling."); - return Collections.unmodifiableMap(params); } else if (rateLimiter == null) { params.put(USER_RATE_LIMITER_MAP, userRateLimiterMap); log.info("Setting up basic ML predictor parameters with user level throttling."); - return Collections.unmodifiableMap(params); } else { params.put(RATE_LIMITER, rateLimiter); params.put(USER_RATE_LIMITER_MAP, userRateLimiterMap); log.info("Setting up basic ML predictor parameters with both model and user level throttling."); - return Collections.unmodifiableMap(params); } + + if (mlGuard != null) { + params.put(GUARDRAILS, mlGuard); + log.info("Setting up ML guard parameter for ML predictor."); + } + + return Collections.unmodifiableMap(params); } private void handleDeployModelException(String modelId, FunctionName functionName, ActionListener listener, Exception e) { @@ -1125,6 +1138,7 @@ public synchronized void updateModelCache(String modelId, ActionListener int eligibleNodeCount = getWorkerNodes(modelId, mlModel.getAlgorithm()).length; modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled()); setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter()); + setupMLGuard(modelId, mlModel.getGuardrails()); if (mlModel.getAlgorithm() == FunctionName.REMOTE) { if (mlModel.getConnector() != null) { setupParamsAndPredictable(modelId, mlModel); @@ -1421,6 +1435,29 @@ public Map getUserRateLimiterMap(String modelId) { return modelCacheHelper.getUserRateLimiterMap(modelId); } + private void setupMLGuard(String modelId, Guardrails guardrails) { + if (guardrails != null) { + modelCacheHelper.setMLGuard(modelId, createMLGuard(guardrails, xContentRegistry, client)); + } else { + modelCacheHelper.removeMLGuard(modelId); + } + } + + private MLGuard createMLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client) { + + return new MLGuard(guardrails, xContentRegistry, client); + } + + /** + * Get ML guard with model id. + * + * @param modelId model id + * @return a ML guard + */ + public MLGuard getMLGuard(String modelId) { + return modelCacheHelper.getMLGuard(modelId); + } + /** * Get model from model index. *