diff --git a/common/src/main/java/org/opensearch/ml/common/FunctionName.java b/common/src/main/java/org/opensearch/ml/common/FunctionName.java index a2c900f6cc..1113aaf59d 100644 --- a/common/src/main/java/org/opensearch/ml/common/FunctionName.java +++ b/common/src/main/java/org/opensearch/ml/common/FunctionName.java @@ -28,6 +28,7 @@ public enum FunctionName { SPARSE_ENCODING, SPARSE_TOKENIZE, TEXT_SIMILARITY, + QUESTION_ANSWERING, AGENT; public static FunctionName from(String value) { @@ -42,7 +43,8 @@ public static FunctionName from(String value) { TEXT_EMBEDDING, TEXT_SIMILARITY, SPARSE_ENCODING, - SPARSE_TOKENIZE + SPARSE_TOKENIZE, + QUESTION_ANSWERING )); /** diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index c54fc1fedf..479cd09a73 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -21,6 +21,7 @@ import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.model.QuestionAnsweringModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; @@ -219,6 +220,8 @@ public MLModel(StreamInput input) throws IOException { if (input.readBoolean()) { if (algorithm.equals(FunctionName.METRICS_CORRELATION)) { modelConfig = new MetricsCorrelationModelConfig(input); + } else if (algorithm.equals(FunctionName.QUESTION_ANSWERING)) { + modelConfig = new QuestionAnsweringModelConfig(input); } else { modelConfig = new TextEmbeddingModelConfig(input); } @@ -527,6 +530,8 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws case MODEL_CONFIG_FIELD: if (FunctionName.METRICS_CORRELATION.name().equals(algorithmName)) { modelConfig = MetricsCorrelationModelConfig.parse(parser); + } else if (FunctionName.QUESTION_ANSWERING.name().equals(algorithmName)) { + modelConfig = QuestionAnsweringModelConfig.parse(parser); } else { modelConfig = TextEmbeddingModelConfig.parse(parser); } diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataType.java b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataType.java index 95ec709f0d..5432192f0a 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataType.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataType.java @@ -13,5 +13,6 @@ public enum MLInputDataType { DATA_FRAME, TEXT_DOCS, REMOTE, - TEXT_SIMILARITY + TEXT_SIMILARITY, + QUESTION_ANSWERING } diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/QuestionAnsweringInputDataSet.java b/common/src/main/java/org/opensearch/ml/common/dataset/QuestionAnsweringInputDataSet.java new file mode 100644 index 0000000000..204d7df149 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/dataset/QuestionAnsweringInputDataSet.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.dataset; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.experimental.FieldDefaults; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.annotation.InputDataSet; + +import java.io.IOException; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@InputDataSet(MLInputDataType.QUESTION_ANSWERING) +public class QuestionAnsweringInputDataSet extends MLInputDataset { + + String question; + + String context; + + @Builder(toBuilder = true) + public QuestionAnsweringInputDataSet(String question, String context) { + super(MLInputDataType.QUESTION_ANSWERING); + if(question == null) { + throw new IllegalArgumentException("Question is not provided"); + } + if(context == null) { + throw new IllegalArgumentException("Context is not provided"); + } + this.question = question; + this.context = context; + } + + public QuestionAnsweringInputDataSet(StreamInput in) throws IOException { + super(MLInputDataType.QUESTION_ANSWERING); + this.question = in.readString(); + this.context = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(question); + out.writeString(context); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java index f2d74bf8c9..ed28cdfc1f 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java @@ -17,6 +17,7 @@ import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DefaultDataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.output.model.ModelResultFilter; import org.opensearch.ml.common.dataset.MLInputDataset; @@ -63,6 +64,12 @@ public class MLInput implements Input { public static final String QUERY_TEXT_FIELD = "query_text"; public static final String PARAMETERS_FIELD = "parameters"; + // Input question in question answering model + public static final String QUESTION_FIELD = "question"; + + // Input context in question answering model + public static final String CONTEXT_FIELD = "context"; + // Algorithm name protected FunctionName algorithm; // ML algorithm parameters @@ -178,6 +185,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endArray(); } break; + case QUESTION_ANSWERING: + QuestionAnsweringInputDataSet qaInputDataSet = (QuestionAnsweringInputDataSet) this.inputDataset; + String question = qaInputDataSet.getQuestion(); + String context = qaInputDataSet.getContext(); + builder.field(QUESTION_FIELD, question); + builder.field(CONTEXT_FIELD, context); + break; case REMOTE: RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet) this.inputDataset; Map parameters = remoteInferenceInputDataSet.getParameters(); @@ -213,6 +227,8 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws List targetResponsePositions = new ArrayList<>(); List textDocs = new ArrayList<>(); String queryText = null; + String question = null; + String context = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -263,6 +279,12 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws case QUERY_TEXT_FIELD: queryText = parser.text(); break; + case QUESTION_FIELD: + question = parser.text(); + break; + case CONTEXT_FIELD: + context = parser.text(); + break; default: parser.skipChildren(); break; @@ -272,9 +294,10 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.SPARSE_ENCODING || algorithm == FunctionName.SPARSE_TOKENIZE) { ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions); inputDataSet = new TextDocsInputDataSet(textDocs, filter); - } - if (algorithm == FunctionName.TEXT_SIMILARITY) { + } else if (algorithm == FunctionName.TEXT_SIMILARITY) { inputDataSet = new TextSimilarityInputDataSet(queryText, textDocs); + } else if (algorithm == FunctionName.QUESTION_ANSWERING) { + inputDataSet = new QuestionAnsweringInputDataSet(question, context); } return new MLInput(algorithm, mlParameters, searchSourceBuilder, sourceIndices, dataFrame, inputDataSet); } diff --git a/common/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnsweringMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnsweringMLInput.java new file mode 100644 index 0000000000..2b69d2c345 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnsweringMLInput.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.input.nlp; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.io.IOException; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + + +/** + * MLInput which supports a question answering algorithm + * Inputs are question and context. Output is the answer + */ +@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.QUESTION_ANSWERING}) +public class QuestionAnsweringMLInput extends MLInput { + + public QuestionAnsweringMLInput(FunctionName algorithm, MLInputDataset dataset) { + super(algorithm, null, dataset); + } + + public QuestionAnsweringMLInput(StreamInput in) throws IOException { + super(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ALGORITHM_FIELD, algorithm.name()); + if(parameters != null) { + builder.field(ML_PARAMETERS_FIELD, parameters); + } + if(inputDataset != null) { + QuestionAnsweringInputDataSet ds = (QuestionAnsweringInputDataSet) this.inputDataset; + String question = ds.getQuestion(); + String context = ds.getContext(); + builder.field(QUESTION_FIELD, question); + builder.field(CONTEXT_FIELD, context); + } + builder.endObject(); + return builder; + } + + public QuestionAnsweringMLInput(XContentParser parser, FunctionName functionName) throws IOException { + super(); + this.algorithm = functionName; + String question = null; + String context = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case QUESTION_FIELD: + question = parser.text(); + break; + case CONTEXT_FIELD: + context = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + if(question == null) { + throw new IllegalArgumentException("Question is not provided"); + } + if(context == null) { + throw new IllegalArgumentException("Context is not provided"); + } + inputDataset = new QuestionAnsweringInputDataSet(question, context); + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInput.java index 0c4d9f9a7b..e86fe1df22 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInput.java @@ -99,6 +99,7 @@ public TextSimilarityMLInput(XContentParser parser, FunctionName functionName) t break; case QUERY_TEXT_FIELD: queryText = parser.text(); + break; default: parser.skipChildren(); break; diff --git a/common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java new file mode 100644 index 0000000000..7b01f847a2 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; + +import java.io.IOException; +import java.util.Locale; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +@Setter +@Getter +public class QuestionAnsweringModelConfig extends MLModelConfig { + public static final String PARSE_FIELD_NAME = FunctionName.QUESTION_ANSWERING.name(); + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + QuestionAnsweringModelConfig.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) + ); + public static final String FRAMEWORK_TYPE_FIELD = "framework_type"; + public static final String NORMALIZE_RESULT_FIELD = "normalize_result"; + public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length"; + + private final FrameworkType frameworkType; + private final boolean normalizeResult; + private final Integer modelMaxLength; + + @Builder(toBuilder = true) + public QuestionAnsweringModelConfig(String modelType, FrameworkType frameworkType, String allConfig, boolean normalizeResult, Integer modelMaxLength) { + super(modelType, allConfig); + if (frameworkType == null) { + throw new IllegalArgumentException("framework type is null"); + } + this.frameworkType = frameworkType; + this.normalizeResult = normalizeResult; + this.modelMaxLength = modelMaxLength; + } + + public static QuestionAnsweringModelConfig parse(XContentParser parser) throws IOException { + String modelType = null; + FrameworkType frameworkType = null; + String allConfig = null; + boolean normalizeResult = false; + Integer modelMaxLength = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MODEL_TYPE_FIELD: + modelType = parser.text(); + break; + case FRAMEWORK_TYPE_FIELD: + frameworkType = FrameworkType.from(parser.text().toUpperCase(Locale.ROOT)); + break; + case ALL_CONFIG_FIELD: + allConfig = parser.text(); + break; + case NORMALIZE_RESULT_FIELD: + normalizeResult = parser.booleanValue(); + break; + case MODEL_MAX_LENGTH_FIELD: + modelMaxLength = parser.intValue(); + break; + default: + parser.skipChildren(); + break; + } + } + return new QuestionAnsweringModelConfig(modelType, frameworkType, allConfig, normalizeResult, modelMaxLength); + } + + @Override + public String getWriteableName() { + return PARSE_FIELD_NAME; + } + + public QuestionAnsweringModelConfig(StreamInput in) throws IOException{ + super(in); + frameworkType = in.readEnum(FrameworkType.class); + normalizeResult = in.readBoolean(); + modelMaxLength = in.readOptionalInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeEnum(frameworkType); + out.writeBoolean(normalizeResult); + out.writeOptionalInt(modelMaxLength); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (modelType != null) { + builder.field(MODEL_TYPE_FIELD, modelType); + } + if (frameworkType != null) { + builder.field(FRAMEWORK_TYPE_FIELD, frameworkType); + } + if (allConfig != null) { + builder.field(ALL_CONFIG_FIELD, allConfig); + } + if (modelMaxLength != null) { + builder.field(MODEL_MAX_LENGTH_FIELD, modelMaxLength); + } + if (normalizeResult) { + builder.field(NORMALIZE_RESULT_FIELD, normalizeResult); + } + builder.endObject(); + return builder; + } + public enum FrameworkType { + HUGGINGFACE_TRANSFORMERS, + SENTENCE_TRANSFORMERS, + HUGGINGFACE_TRANSFORMERS_NEURON; + + public static FrameworkType from(String value) { + try { + return FrameworkType.valueOf(value.toUpperCase(Locale.ROOT)); + } catch (Exception e) { + throw new IllegalArgumentException("Wrong framework type"); + } + } + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java index d90294536c..2239f1b8d4 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java @@ -63,6 +63,11 @@ public ModelTensor(String name, Number[] data, long[] shape, MLResultDataType da this.dataAsMap = dataAsMap; } + public ModelTensor(String name, String result) { + this.name = name; + this.result = result; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.startObject(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index 1c1ea40c6b..4b3b3cfb0f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -23,6 +23,7 @@ import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; +import org.opensearch.ml.common.model.QuestionAnsweringModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import java.io.IOException; @@ -166,6 +167,8 @@ public MLRegisterModelInput(StreamInput in) throws IOException { if (in.readBoolean()) { if (this.functionName.equals(FunctionName.METRICS_CORRELATION)) { this.modelConfig = new MetricsCorrelationModelConfig(in); + } else if (this.functionName.equals(FunctionName.QUESTION_ANSWERING)) { + this.modelConfig = new QuestionAnsweringModelConfig(in); } else { this.modelConfig = new TextEmbeddingModelConfig(in); } @@ -382,7 +385,11 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName modelFormat = MLModelFormat.from(parser.text().toUpperCase(Locale.ROOT)); break; case MODEL_CONFIG_FIELD: - modelConfig = TextEmbeddingModelConfig.parse(parser); + if (FunctionName.QUESTION_ANSWERING.equals(functionName)) { + modelConfig = QuestionAnsweringModelConfig.parse(parser); + } else { + modelConfig = TextEmbeddingModelConfig.parse(parser); + } break; case CONNECTOR_FIELD: connector = createConnector(parser); @@ -493,7 +500,11 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo modelFormat = MLModelFormat.from(parser.text().toUpperCase(Locale.ROOT)); break; case MODEL_CONFIG_FIELD: - modelConfig = TextEmbeddingModelConfig.parse(parser); + if (FunctionName.QUESTION_ANSWERING.equals(functionName)) { + modelConfig = QuestionAnsweringModelConfig.parse(parser); + } else { + modelConfig = TextEmbeddingModelConfig.parse(parser); + } break; case MODEL_NODE_IDS_FIELD: ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java index 0b54a26f5c..e7ab3b7091 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java @@ -22,6 +22,7 @@ import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.model.QuestionAnsweringModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import java.io.IOException; @@ -144,7 +145,11 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException { this.modelContentSizeInBytes = in.readOptionalLong(); this.modelContentHashValue = in.readString(); if (in.readBoolean()) { - modelConfig = new TextEmbeddingModelConfig(in); + if (this.functionName.equals(FunctionName.QUESTION_ANSWERING)) { + this.modelConfig = new QuestionAnsweringModelConfig(in); + } else { + this.modelConfig = new TextEmbeddingModelConfig(in); + } } this.totalChunks = in.readInt(); this.backendRoles = in.readOptionalStringList(); @@ -329,7 +334,11 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc modelContentHashValue = parser.text(); break; case MODEL_CONFIG_FIELD: - modelConfig = TextEmbeddingModelConfig.parse(parser); + if (FunctionName.QUESTION_ANSWERING.equals(functionName)) { + modelConfig = QuestionAnsweringModelConfig.parse(parser); + } else { + modelConfig = TextEmbeddingModelConfig.parse(parser); + } break; case TOTAL_CHUNKS_FIELD: totalChunks = parser.intValue(false); diff --git a/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java b/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java index 533b525cfa..e25bf1fad3 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java @@ -171,8 +171,8 @@ private void testClassLoader_MLInput_DlModel(FunctionName functionName) throws I @Test public void testClassLoader_MLInput() throws IOException { testClassLoader_MLInput_DlModel(FunctionName.TEXT_EMBEDDING); - testClassLoader_MLInput_DlModel(FunctionName.SPARSE_TOKENIZE); testClassLoader_MLInput_DlModel(FunctionName.SPARSE_ENCODING); + testClassLoader_MLInput_DlModel(FunctionName.SPARSE_TOKENIZE); } @Test(expected = IllegalArgumentException.class) diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/QuestionAnsweringInputDatasetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/QuestionAnsweringInputDatasetTest.java new file mode 100644 index 0000000000..f332f18db5 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/dataset/QuestionAnsweringInputDatasetTest.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.dataset; + +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; +import java.util.List; + +import static org.junit.Assert.assertThrows; + +public class QuestionAnsweringInputDatasetTest { + + @Test + public void testStreaming() throws IOException { + String question = "What color is apple"; + String context = "I like Apples. They are red"; + QuestionAnsweringInputDataSet dataset = QuestionAnsweringInputDataSet.builder().question(question).context(context).build(); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + dataset.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + QuestionAnsweringInputDataSet newDs = (QuestionAnsweringInputDataSet) MLInputDataset.fromStream(in); + assert (question.equals("What color is apple")); + assert (context.equals("I like Apples. They are red")); + } + + @Test + public void noContext_ThenFail() { + String question = "What color is apple"; + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> QuestionAnsweringInputDataSet.builder().question(question).build()); + assert (e.getMessage().equals("Context is not provided")); + } + + @Test + public void noQuestion_ThenFail() { + String context = "I like Apples. They are red"; + assertThrows(IllegalArgumentException.class, + () -> QuestionAnsweringInputDataSet.builder().context(context).build()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/input/nlp/QuestionAnsweringMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/nlp/QuestionAnsweringMLInputTest.java new file mode 100644 index 0000000000..dd91f4023f --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/input/nlp/QuestionAnsweringMLInputTest.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.input.nlp; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.search.SearchModule; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertThrows; + +public class QuestionAnsweringMLInputTest { + + MLInput input; + + private final FunctionName algorithm = FunctionName.QUESTION_ANSWERING; + + @Before + public void setup() { + String question = "What color is apple"; + String context = "I like Apples. They are red"; + MLInputDataset dataset = QuestionAnsweringInputDataSet.builder().question(question).context(context).build(); + input = new QuestionAnsweringMLInput(algorithm, dataset); + } + + @Test + public void testXContent_IsInternallyConsistent() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = builder.toString(); + XContentParser parser = XContentType.JSON.xContent() + .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + + MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name()); + assert (parsedInput instanceof QuestionAnsweringMLInput); + QuestionAnsweringMLInput parsedQAMLI = (QuestionAnsweringMLInput) parsedInput; + String question = ((QuestionAnsweringInputDataSet) parsedQAMLI.getInputDataset()).getQuestion(); + String context = ((QuestionAnsweringInputDataSet) parsedQAMLI.getInputDataset()).getContext(); + assert (question.equals("What color is apple")); + assert (context.equals("I like Apples. They are red")); + } + + @Test + public void testXContent_String() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = builder.toString(); + assert (jsonStr.equals("{\"algorithm\":\"QUESTION_ANSWERING\",\"question\":\"What color is apple\",\"context\":\"I like Apples. They are red\"}")); + } + + @Test + public void testParseJson() throws IOException { + String json = "{\"algorithm\":\"QUESTION_ANSWERING\",\"question\":\"What color is apple\",\"context\":\"I like Apples. They are red\"}"; + XContentParser parser = XContentType.JSON.xContent() + .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, json); + parser.nextToken(); + + MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name()); + assert (parsedInput instanceof QuestionAnsweringMLInput); + QuestionAnsweringMLInput parsedQAMLI = (QuestionAnsweringMLInput) parsedInput; + String question = ((QuestionAnsweringInputDataSet) parsedQAMLI.getInputDataset()).getQuestion(); + String context = ((QuestionAnsweringInputDataSet) parsedQAMLI.getInputDataset()).getContext(); + assert (question.equals("What color is apple")); + assert (context.equals("I like Apples. They are red")); + } + + @Test + public void testStreaming() throws IOException { + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + input.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + QuestionAnsweringMLInput newInput = new QuestionAnsweringMLInput(in); + String newQuestion = ((QuestionAnsweringInputDataSet) newInput.getInputDataset()).getQuestion(); + String oldQuestion = ((QuestionAnsweringInputDataSet) input.getInputDataset()).getQuestion(); + assert (newQuestion.equals(oldQuestion)); + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfigTests.java new file mode 100644 index 0000000000..5136c187b7 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfigTests.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; + +import java.io.IOException; +import java.util.function.Function; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +public class QuestionAnsweringModelConfigTests { + + QuestionAnsweringModelConfig config; + Function function; + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() { + config = QuestionAnsweringModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .normalizeResult(false) + .frameworkType(QuestionAnsweringModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .build(); + function = parser -> { + try { + return QuestionAnsweringModelConfig.parse(parser); + } catch (IOException e) { + throw new RuntimeException("Failed to parse QuestionAnsweringModelConfig", e); + } + }; + } + + @Test + public void toXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + config.toXContent(builder, EMPTY_PARAMS); + String configContent = TestHelper.xContentBuilderToString(builder); + assertEquals("{\"model_type\":\"testModelType\",\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent); + } + + @Test + public void nullFields_ModelType() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("model type is null"); + config = QuestionAnsweringModelConfig.builder() + .build(); + } + + @Test + public void nullFields_FrameworkType() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("framework type is null"); + config = QuestionAnsweringModelConfig.builder() + .modelType("testModelType") + .build(); + } + + @Test + public void parse() throws IOException { + String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"normalize_result\":false,\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}"; + TestHelper.testParseFromString(config, content, function); + } + + @Test + public void frameworkType_wrongValue() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Wrong framework type"); + QuestionAnsweringModelConfig.FrameworkType.from("test_wrong_value"); + } + + @Test + public void readInputStream_Success() throws IOException { + readInputStream(config); + } + + public void readInputStream(QuestionAnsweringModelConfig config) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + config.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + QuestionAnsweringModelConfig parsedConfig = new QuestionAnsweringModelConfig(streamInput); + assertEquals(config.getModelType(), parsedConfig.getModelType()); + assertEquals(config.getAllConfig(), parsedConfig.getAllConfig()); + assertEquals(config.getFrameworkType(), parsedConfig.getFrameworkType()); + assertEquals(config.getWriteableName(), parsedConfig.getWriteableName()); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java index 64867554cb..50c514599e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java @@ -29,6 +29,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.QuestionAnsweringModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; @@ -62,6 +63,7 @@ public void downloadPrebuiltModelConfig( ActionListener listener ) { String modelName = registerModelInput.getModelName(); + FunctionName algorithm = registerModelInput.getFunctionName(); String version = registerModelInput.getVersion(); MLModelFormat modelFormat = registerModelInput.getModelFormat(); Boolean isHidden = registerModelInput.getIsHidden(); @@ -90,6 +92,10 @@ public void downloadPrebuiltModelConfig( MLRegisterModelInput.MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder(); + String functionName = config.containsKey("function_name") + ? (String) config.get("function_name") + : (String) config.get("model_task_type"); + builder .modelName(modelName) .version(version) @@ -98,7 +104,7 @@ public void downloadPrebuiltModelConfig( .modelNodeIds(modelNodeIds) .isHidden(isHidden) .modelGroupId(modelGroupId) - .functionName(FunctionName.from((String) config.get("model_task_type"))); + .functionName(FunctionName.from((functionName))); config.entrySet().forEach(entry -> { switch (entry.getKey().toString()) { @@ -106,47 +112,74 @@ public void downloadPrebuiltModelConfig( builder.modelFormat(MLModelFormat.from(entry.getValue().toString())); break; case MLRegisterModelInput.MODEL_CONFIG_FIELD: - TextEmbeddingModelConfig.TextEmbeddingModelConfigBuilder configBuilder = TextEmbeddingModelConfig.builder(); - Map configMap = (Map) entry.getValue(); - for (Map.Entry configEntry : configMap.entrySet()) { - switch (configEntry.getKey().toString()) { - case MLModelConfig.MODEL_TYPE_FIELD: - configBuilder.modelType(configEntry.getValue().toString()); - break; - case MLModelConfig.ALL_CONFIG_FIELD: - configBuilder.allConfig(configEntry.getValue().toString()); - break; - case TextEmbeddingModelConfig.EMBEDDING_DIMENSION_FIELD: - configBuilder.embeddingDimension(((Double) configEntry.getValue()).intValue()); - break; - case TextEmbeddingModelConfig.FRAMEWORK_TYPE_FIELD: - configBuilder - .frameworkType(TextEmbeddingModelConfig.FrameworkType.from(configEntry.getValue().toString())); - break; - case TextEmbeddingModelConfig.POOLING_MODE_FIELD: - configBuilder - .poolingMode( - TextEmbeddingModelConfig.PoolingMode - .from(configEntry.getValue().toString().toUpperCase(Locale.ROOT)) - ); - break; - case TextEmbeddingModelConfig.NORMALIZE_RESULT_FIELD: - configBuilder.normalizeResult(Boolean.parseBoolean(configEntry.getValue().toString())); - break; - case TextEmbeddingModelConfig.MODEL_MAX_LENGTH_FIELD: - configBuilder.modelMaxLength(((Double) configEntry.getValue()).intValue()); - break; - case TextEmbeddingModelConfig.QUERY_PREFIX: - configBuilder.queryPrefix(configEntry.getValue().toString()); - break; - case TextEmbeddingModelConfig.PASSAGE_PREFIX: - configBuilder.passagePrefix(configEntry.getValue().toString()); - break; - default: - break; + if (FunctionName.QUESTION_ANSWERING.equals(algorithm)) { + QuestionAnsweringModelConfig.QuestionAnsweringModelConfigBuilder configBuilder = + QuestionAnsweringModelConfig.builder(); + Map configMap = (Map) entry.getValue(); + for (Map.Entry configEntry : configMap.entrySet()) { + switch (configEntry.getKey().toString()) { + case MLModelConfig.MODEL_TYPE_FIELD: + configBuilder.modelType(configEntry.getValue().toString()); + break; + case MLModelConfig.ALL_CONFIG_FIELD: + configBuilder.allConfig(configEntry.getValue().toString()); + break; + case QuestionAnsweringModelConfig.FRAMEWORK_TYPE_FIELD: + configBuilder + .frameworkType( + QuestionAnsweringModelConfig.FrameworkType.from(configEntry.getValue().toString()) + ); + break; + default: + break; + } + } + builder.modelConfig(configBuilder.build()); + } else { + TextEmbeddingModelConfig.TextEmbeddingModelConfigBuilder configBuilder = TextEmbeddingModelConfig.builder(); + Map configMap = (Map) entry.getValue(); + for (Map.Entry configEntry : configMap.entrySet()) { + switch (configEntry.getKey().toString()) { + case MLModelConfig.MODEL_TYPE_FIELD: + configBuilder.modelType(configEntry.getValue().toString()); + break; + case MLModelConfig.ALL_CONFIG_FIELD: + configBuilder.allConfig(configEntry.getValue().toString()); + break; + case TextEmbeddingModelConfig.EMBEDDING_DIMENSION_FIELD: + configBuilder.embeddingDimension(((Double) configEntry.getValue()).intValue()); + break; + case TextEmbeddingModelConfig.FRAMEWORK_TYPE_FIELD: + configBuilder + .frameworkType( + TextEmbeddingModelConfig.FrameworkType.from(configEntry.getValue().toString()) + ); + break; + case TextEmbeddingModelConfig.POOLING_MODE_FIELD: + configBuilder + .poolingMode( + TextEmbeddingModelConfig.PoolingMode + .from(configEntry.getValue().toString().toUpperCase(Locale.ROOT)) + ); + break; + case TextEmbeddingModelConfig.NORMALIZE_RESULT_FIELD: + configBuilder.normalizeResult(Boolean.parseBoolean(configEntry.getValue().toString())); + break; + case TextEmbeddingModelConfig.MODEL_MAX_LENGTH_FIELD: + configBuilder.modelMaxLength(((Double) configEntry.getValue()).intValue()); + break; + case TextEmbeddingModelConfig.QUERY_PREFIX: + configBuilder.queryPrefix(configEntry.getValue().toString()); + break; + case TextEmbeddingModelConfig.PASSAGE_PREFIX: + configBuilder.passagePrefix(configEntry.getValue().toString()); + break; + default: + break; + } } + builder.modelConfig(configBuilder.build()); } - builder.modelConfig(configBuilder.build()); break; case MLRegisterModelInput.MODEL_CONTENT_HASH_VALUE_FIELD: builder.hashValue(entry.getValue().toString()); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModel.java new file mode 100644 index 0000000000..80d9f1a61d --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModel.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.question_answering; + +import static org.opensearch.ml.engine.ModelHelper.*; + +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.engine.algorithms.DLModel; +import org.opensearch.ml.engine.annotation.Function; + +import ai.djl.inference.Predictor; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@Function(FunctionName.QUESTION_ANSWERING) +public class QuestionAnsweringModel extends DLModel { + + @Override + public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfig) throws TranslateException { + String question = "How is the weather?"; + String context = "The weather is nice, it is beautiful day."; + Input input = new Input(); + input.add(question); + input.add(context); + + // First request takes longer time. Predict once to warm up model. + predictor.predict(input); + } + + @Override + public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException { + MLInputDataset inputDataSet = mlInput.getInputDataset(); + List tensorOutputs = new ArrayList<>(); + Output output; + QuestionAnsweringInputDataSet qaInputDataSet = (QuestionAnsweringInputDataSet) inputDataSet; + String question = qaInputDataSet.getQuestion(); + String context = qaInputDataSet.getContext(); + Input input = new Input(); + input.add(question); + input.add(context); + output = getPredictor().predict(input); + tensorOutputs.add(parseModelTensorOutput(output, null)); + return new ModelTensorOutput(tensorOutputs); + } + + @Override + public Translator getTranslator(String engine, MLModelConfig modelConfig) throws IllegalArgumentException { + return new QuestionAnsweringTranslator(); + } + + @Override + public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig) { + return null; + } + +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringTranslator.java new file mode 100644 index 0000000000..cd3684d717 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringTranslator.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.question_answering; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.engine.algorithms.SentenceTransformerTranslator; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.translate.TranslatorContext; + +public class QuestionAnsweringTranslator extends SentenceTransformerTranslator { + private List tokens; + + @Override + public NDList processInput(TranslatorContext ctx, Input input) { + NDManager manager = ctx.getNDManager(); + String question = input.getAsString(0); + String context = input.getAsString(1); + NDList ndList = new NDList(); + + Encoding encodings = tokenizer.encode(question, context); + tokens = Arrays.asList(encodings.getTokens()); + ctx.setAttachment("encoding", encodings); + long[] indices = encodings.getIds(); + long[] attentionMask = encodings.getAttentionMask(); + + NDArray indicesArray = manager.create(indices); + indicesArray.setName("input_ids"); + + NDArray attentionMaskArray = manager.create(attentionMask); + attentionMaskArray.setName("attention_mask"); + + ndList.add(indicesArray); + ndList.add(attentionMaskArray); + return ndList; + } + + @Override + public Output processOutput(TranslatorContext ctx, NDList list) { + Output output = new Output(200, "OK"); + + List outputs = new ArrayList<>(); + + NDArray startLogits = list.get(0); + NDArray endLogits = list.get(1); + int startIdx = (int) startLogits.argMax().getLong(); + int endIdx = (int) endLogits.argMax().getLong(); + if (startIdx >= endIdx) { + int tmp = startIdx; + startIdx = endIdx; + endIdx = tmp; + } + String answer = tokenizer.buildSentence(tokens.subList(startIdx, endIdx + 1)); + + outputs.add(new ModelTensor(null, answer)); + + ModelTensors modelTensorOutput = new ModelTensors(outputs); + output.add(modelTensorOutput.toBytes()); + return output; + } + +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModelTest.java new file mode 100644 index 0000000000..0999ef9a9e --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModelTest.java @@ -0,0 +1,285 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.engine.algorithms.question_answering; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.engine.algorithms.DLModel.*; + +import java.io.File; +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.engine.MLEngine; +import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.engine.utils.FileUtils; + +import ai.djl.Model; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.translate.TranslatorContext; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class QuestionAnsweringModelTest { + + private File modelZipFile; + private MLModel model; + private ModelHelper modelHelper; + private Map params; + private QuestionAnsweringModel questionAnsweringModel; + private Path mlCachePath; + private QuestionAnsweringInputDataSet inputDataSet; + private MLEngine mlEngine; + private Encryptor encryptor; + + @Before + public void setUp() throws URISyntaxException { + mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID()); + encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + mlEngine = new MLEngine(mlCachePath, encryptor); + model = MLModel + .builder() + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .name("test_model_name") + .modelId("test_model_id") + .algorithm(FunctionName.QUESTION_ANSWERING) + .version("1.0.0") + .modelState(MLModelState.TRAINED) + .build(); + modelHelper = new ModelHelper(mlEngine); + params = new HashMap<>(); + modelZipFile = new File(getClass().getResource("question_answering_pt.zip").toURI()); + params.put(MODEL_ZIP_FILE, modelZipFile); + params.put(MODEL_HELPER, modelHelper); + params.put(ML_ENGINE, mlEngine); + questionAnsweringModel = new QuestionAnsweringModel(); + + inputDataSet = QuestionAnsweringInputDataSet.builder().question("What color is apple").context("Apples are red").build(); + } + + @Test + public void test_QuestionAnswering_ProcessInput_ProcessOutput() throws URISyntaxException, IOException { + QuestionAnsweringTranslator questionAnsweringTranslator = new QuestionAnsweringTranslator(); + TranslatorContext translatorContext = mock(TranslatorContext.class); + Model mlModel = mock(Model.class); + when(translatorContext.getModel()).thenReturn(mlModel); + when(mlModel.getModelPath()).thenReturn(Paths.get(getClass().getResource("../tokenize/tokenizer.json").toURI()).getParent()); + questionAnsweringTranslator.prepare(translatorContext); + + NDManager manager = mock(NDManager.class); + when(translatorContext.getNDManager()).thenReturn(manager); + Input input = mock(Input.class); + String question = "What color is apple"; + String context = "Apples are red"; + when(input.getAsString(0)).thenReturn(question); + when(input.getAsString(1)).thenReturn(context); + NDArray indiceNdArray = mock(NDArray.class); + when(indiceNdArray.toLongArray()).thenReturn(new long[] { 102l, 101l }); + when(manager.create((long[]) any())).thenReturn(indiceNdArray); + doNothing().when(indiceNdArray).setName(any()); + NDList outputList = questionAnsweringTranslator.processInput(translatorContext, input); + assertEquals(2, outputList.size()); + Iterator iterator = outputList.iterator(); + while (iterator.hasNext()) { + NDArray ndArray = iterator.next(); + long[] output = ndArray.toLongArray(); + assertEquals(2, output.length); + } + + NDArray startLogits = mock(NDArray.class); + NDArray endLogits = mock(NDArray.class); + when(startLogits.argMax()).thenReturn(startLogits); + when(startLogits.getLong()).thenReturn(3L); + when(endLogits.argMax()).thenReturn(endLogits); + when(endLogits.getLong()).thenReturn(7L); + + List ndArrayList = new ArrayList<>(); + ndArrayList.add(startLogits); + ndArrayList.add(endLogits); + NDList ndList = new NDList(ndArrayList); + + Output output = questionAnsweringTranslator.processOutput(translatorContext, ndList); + assertNotNull(output); + byte[] bytes = output.getData().getAsBytes(); + ModelTensors tensorOutput = ModelTensors.fromBytes(bytes); + List modelTensorsList = tensorOutput.getMlModelTensors(); + assertEquals(1, modelTensorsList.size()); + } + + @Test + public void initModel_predict_TorchScript_QuestionAnswering() throws URISyntaxException { + questionAnsweringModel.initModel(model, params, encryptor); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.QUESTION_ANSWERING).inputDataset(inputDataSet).build(); + ModelTensorOutput output = (ModelTensorOutput) questionAnsweringModel.predict(mlInput); + List mlModelOutputs = output.getMlModelOutputs(); + assertEquals(1, mlModelOutputs.size()); + for (int i = 0; i < mlModelOutputs.size(); i++) { + ModelTensors tensors = mlModelOutputs.get(i); + List mlModelTensors = tensors.getMlModelTensors(); + assertEquals(1, mlModelTensors.size()); + } + questionAnsweringModel.close(); + } + + // ONNX is working fine but the model is too big to upload to git. Trying to find small models @Test + @Test + public void initModel_predict_ONNX_QuestionAnswering() throws URISyntaxException { + model = MLModel + .builder() + .modelFormat(MLModelFormat.ONNX) + .name("test_model_name") + .modelId("test_model_id") + .algorithm(FunctionName.TEXT_SIMILARITY) + .version("1.0.0") + .modelState(MLModelState.TRAINED) + .build(); + modelZipFile = new File(getClass().getResource("question_answering_onnx.zip").toURI()); + params.put(MODEL_ZIP_FILE, modelZipFile); + + questionAnsweringModel.initModel(model, params, encryptor); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build(); + ModelTensorOutput output = (ModelTensorOutput) questionAnsweringModel.predict(mlInput); + List mlModelOutputs = output.getMlModelOutputs(); + assertEquals(1, mlModelOutputs.size()); + for (int i = 1; i < mlModelOutputs.size(); i++) { + ModelTensors tensors = mlModelOutputs.get(i); + List mlModelTensors = tensors.getMlModelTensors(); + assertEquals(1, mlModelTensors.size()); + } + questionAnsweringModel.close(); + } + + @Test + public void initModel_NullModelHelper() throws URISyntaxException { + Map params = new HashMap<>(); + params.put(MODEL_ZIP_FILE, new File(getClass().getResource("question_answering_pt.zip").toURI())); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> questionAnsweringModel.initModel(model, params, encryptor) + ); + assert (e.getMessage().equals("model helper is null")); + } + + @Test + public void initModel_NullMLEngine() throws URISyntaxException { + Map params = new HashMap<>(); + params.put(MODEL_ZIP_FILE, new File(getClass().getResource("question_answering_pt.zip").toURI())); + params.put(MODEL_HELPER, modelHelper); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> questionAnsweringModel.initModel(model, params, encryptor) + ); + assert (e.getMessage().equals("ML engine is null")); + } + + @Test + public void initModel_NullModelId() { + model.setModelId(null); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> questionAnsweringModel.initModel(model, params, encryptor) + ); + assert (e.getMessage().equals("model id is null")); + } + + @Test + public void initModel_WrongModelFile() throws URISyntaxException { + Map params = new HashMap<>(); + params.put(MODEL_HELPER, modelHelper); + params.put(MODEL_ZIP_FILE, new File(getClass().getResource("../text_embedding/wrong_zip_with_2_pt_file.zip").toURI())); + params.put(ML_ENGINE, mlEngine); + MLException e = assertThrows(MLException.class, () -> questionAnsweringModel.initModel(model, params, encryptor)); + Throwable rootCause = e.getCause(); + assert (rootCause instanceof IllegalArgumentException); + assert (rootCause.getMessage().equals("found multiple models")); + } + + @Test + public void initModel_WrongFunctionName() { + MLModel mlModel = model.toBuilder().algorithm(FunctionName.KMEANS).build(); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> questionAnsweringModel.initModel(mlModel, params, encryptor) + ); + assert (e.getMessage().equals("wrong function name")); + } + + @Test + public void predict_NullModelHelper() { + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> questionAnsweringModel + .predict(MLInput.builder().algorithm(FunctionName.QUESTION_ANSWERING).inputDataset(inputDataSet).build()) + ); + assert (e.getMessage().equals("model not deployed")); + } + + @Test + public void predict_NullModelId() { + model.setModelId(null); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> questionAnsweringModel.initModel(model, params, encryptor) + ); + assert (e.getMessage().equals("model id is null")); + IllegalArgumentException e2 = assertThrows( + IllegalArgumentException.class, + () -> questionAnsweringModel + .predict(MLInput.builder().algorithm(FunctionName.QUESTION_ANSWERING).inputDataset(inputDataSet).build()) + ); + assert (e2.getMessage().equals("model not deployed")); + } + + @Test + public void predict_AfterModelClosed() { + questionAnsweringModel.initModel(model, params, encryptor); + questionAnsweringModel.close(); + MLException e = assertThrows( + MLException.class, + () -> questionAnsweringModel + .predict(MLInput.builder().algorithm(FunctionName.QUESTION_ANSWERING).inputDataset(inputDataSet).build()) + ); + log.info(e.getMessage()); + assert (e.getMessage().startsWith("Failed to inference QUESTION_ANSWERING")); + } + + @After + public void tearDown() { + FileUtils.deleteFileQuietly(mlCachePath); + } + +} diff --git a/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/question_answering/question_answering_onnx.zip b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/question_answering/question_answering_onnx.zip new file mode 100644 index 0000000000..59e2b9eb41 Binary files /dev/null and b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/question_answering/question_answering_onnx.zip differ diff --git a/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/question_answering/question_answering_pt.zip b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/question_answering/question_answering_pt.zip new file mode 100644 index 0000000000..38cdde960d Binary files /dev/null and b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/question_answering/question_answering_pt.zip differ