Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different embedding types of model response #1007

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
### Features
### Enhancements
- Support different embedding types in model's response ([#1007](https://github.com/opensearch-project/neural-search/pull/1007))
### Bug Fixes
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ public class VectorUtil {
* @param vectorAsList {@link List} of {@link Float}'s representing the vector
* @return array of floats produced from input list
*/
public static float[] vectorAsListToArray(List<Float> vectorAsList) {
public static float[] vectorAsListToArray(List<Number> vectorAsList) {
float[] vector = new float[vectorAsList.size()];
for (int i = 0; i < vectorAsList.size(); i++) {
vector[i] = vectorAsList.get(i);
vector[i] = vectorAsList.get(i).floatValue();
}
return vector;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class MLCommonsClientAccessor {
public void inferenceSentence(
@NonNull final String modelId,
@NonNull final String inputText,
@NonNull final ActionListener<List<Float>> listener
@NonNull final ActionListener<List<Number>> listener
) {
inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), ActionListener.wrap(response -> {
if (response.size() != 1) {
Expand Down Expand Up @@ -82,7 +82,7 @@ public void inferenceSentence(
public void inferenceSentences(
@NonNull final String modelId,
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<List<Float>>> listener
@NonNull final ActionListener<List<List<Number>>> listener
) {
inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, listener);
}
Expand All @@ -103,7 +103,7 @@ public void inferenceSentences(
@NonNull final List<String> targetResponseFilters,
@NonNull final String modelId,
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<List<Float>>> listener
@NonNull final ActionListener<List<List<Number>>> listener
) {
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener);
}
Expand All @@ -128,7 +128,7 @@ public void inferenceSentencesWithMapResult(
public void inferenceSentences(
@NonNull final String modelId,
@NonNull final Map<String, String> inputObjects,
@NonNull final ActionListener<List<Float>> listener
@NonNull final ActionListener<List<Number>> listener
) {
retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener);
}
Expand Down Expand Up @@ -177,11 +177,11 @@ private void retryableInferenceSentencesWithVectorResult(
final String modelId,
final List<String> inputText,
final int retryTime,
final ActionListener<List<List<Float>>> listener
final ActionListener<List<List<Number>>> listener
) {
MLInput mlInput = createMLTextInput(targetResponseFilters, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
final List<List<Number>> vector = buildVectorFromResponse(mlOutput);
listener.onResponse(vector);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
Expand All @@ -202,7 +202,8 @@ private void retryableInferenceSimilarityWithVectorResult(
) {
MLInput mlInput = createMLTextPairsInput(queryText, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
final List<List<Float>> tensors = buildVectorFromResponse(mlOutput);
final List<Float> scores = tensors.stream().map(v -> v.get(0)).collect(Collectors.toList());
listener.onResponse(scores);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
Expand All @@ -224,14 +225,14 @@ private MLInput createMLTextPairsInput(final String query, final List<String> in
return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset);
}

private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
final List<List<Float>> vector = new ArrayList<>();
private <T extends Number> List<List<T>> buildVectorFromResponse(MLOutput mlOutput) {
final List<List<T>> vector = new ArrayList<>();
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
for (final ModelTensors tensors : tensorOutputList) {
final List<ModelTensor> tensorsList = tensors.getMlModelTensors();
for (final ModelTensor tensor : tensorsList) {
vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList()));
vector.add(Arrays.stream(tensor.getData()).map(value -> (T) value).collect(Collectors.toList()));
}
}
return vector;
Expand All @@ -255,8 +256,8 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
return resultMaps;
}

private List<Float> buildSingleVectorFromResponse(final MLOutput mlOutput) {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
private <T extends Number> List<T> buildSingleVectorFromResponse(final MLOutput mlOutput) {
final List<List<T>> vector = buildVectorFromResponse(mlOutput);
return vector.isEmpty() ? new ArrayList<>() : vector.get(0);
}

Expand All @@ -265,11 +266,11 @@ private void retryableInferenceSentencesWithSingleVectorResult(
final String modelId,
final Map<String, String> inputObjects,
final int retryTime,
final ActionListener<List<Float>> listener
final ActionListener<List<Number>> listener
) {
MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> vector = buildSingleVectorFromResponse(mlOutput);
final List<Number> vector = buildSingleVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence is : {} ", vector);
listener.onResponse(vector);
}, e -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer<Ingest

}

private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List<Float> vectors) {
private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List<Number> vectors) {
Objects.requireNonNull(vectors, "embedding failed, inference returns null result!");
log.debug("Text embedding result fetched, starting build vector output!");
Map<String, Object> textEmbeddingResult = buildTextEmbeddingResult(this.embedding, vectors);
Expand Down Expand Up @@ -164,7 +164,7 @@ Map<String, String> buildMapWithKnnKeyAndOriginalValue(final IngestDocument inge

@SuppressWarnings({ "unchecked" })
@VisibleForTesting
Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Float> modelTensorList) {
Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Number> modelTensorList) {
Map<String, Object> result = new LinkedHashMap<>();
result.put(knnKey, modelTensorList);
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
public class VectorUtilTests extends OpenSearchTestCase {

public void testVectorAsListToArray() {
List<Float> vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f);
List<Number> vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f);
float[] vectorAsArray_withThreeElements = VectorUtil.vectorAsListToArray(vectorAsList_withThreeElements);

assertEquals(vectorAsList_withThreeElements.size(), vectorAsArray_withThreeElements.length);
for (int i = 0; i < vectorAsList_withThreeElements.size(); i++) {
assertEquals(vectorAsList_withThreeElements.get(i), vectorAsArray_withThreeElements[i], 0.0f);
assertEquals(vectorAsList_withThreeElements.get(i).floatValue(), vectorAsArray_withThreeElements[i], 0.0f);
}

List<Float> vectorAsList_withNoElements = Collections.emptyList();
List<Number> vectorAsList_withNoElements = Collections.emptyList();
float[] vectorAsArray_withNoElements = VectorUtil.vectorAsListToArray(vectorAsList_withNoElements);
assertEquals(0, vectorAsArray_withNoElements.length);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@
public class MLCommonsClientAccessorTests extends OpenSearchTestCase {

@Mock
private ActionListener<List<List<Float>>> resultListener;
private ActionListener<List<List<Number>>> resultListener;

@Mock
private ActionListener<List<Float>> singleSentenceResultListener;
private ActionListener<List<Number>> singleSentenceResultListener;

@Mock
private ActionListener<List<Float>> similarityResultListener;

@Mock
private MachineLearningNodeClient client;
Expand All @@ -53,7 +56,7 @@ public void setup() {
}

public void testInferenceSentence_whenValidInput_thenSuccess() {
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
final List<Number> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Expand All @@ -69,7 +72,7 @@ public void testInferenceSentence_whenValidInput_thenSuccess() {
}

public void testInferenceSentences_whenValidInputThenSuccess() {
final List<List<Float>> vectorList = new ArrayList<>();
final List<List<Number>> vectorList = new ArrayList<>();
vectorList.add(Arrays.asList(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
Expand All @@ -85,7 +88,7 @@ public void testInferenceSentences_whenValidInputThenSuccess() {
}

public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() {
final List<List<Float>> vectorList = new ArrayList<>();
final List<List<Number>> vectorList = new ArrayList<>();
vectorList.add(Collections.emptyList());
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
Expand Down Expand Up @@ -278,7 +281,7 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa
}

public void testInferenceMultimodal_whenValidInput_thenSuccess() {
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
final List<Number> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Expand Down Expand Up @@ -337,13 +340,13 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() {
TestCommonConstants.MODEL_ID,
"is it sunny",
List.of("it is sunny today", "roses are red"),
singleSentenceResultListener
similarityResultListener
);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onResponse(vector);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
Mockito.verify(similarityResultListener).onResponse(vector);
Mockito.verifyNoMoreInteractions(similarityResultListener);
}

public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
Expand All @@ -358,13 +361,13 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
TestCommonConstants.MODEL_ID,
"is it sunny",
List.of("it is sunny today", "roses are red"),
singleSentenceResultListener
similarityResultListener
);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(exception);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
Mockito.verify(similarityResultListener).onFailure(exception);
Mockito.verifyNoMoreInteractions(similarityResultListener);
}

public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTimes() {
Expand All @@ -382,12 +385,12 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTi
TestCommonConstants.MODEL_ID,
"is it sunny",
List.of("it is sunny today", "roses are red"),
singleSentenceResultListener
similarityResultListener
);

Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);
Mockito.verify(similarityResultListener).onFailure(nodeNodeConnectedException);
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,10 +646,10 @@ public void testHashAndEquals() {
@SneakyThrows
public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME).queryText(QUERY_TEXT).modelId(MODEL_ID).k(K);
List<Float> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
List<Number> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
doAnswer(invocation -> {
ActionListener<List<Float>> listener = invocation.getArgument(2);
ActionListener<List<Number>> listener = invocation.getArgument(2);
listener.onResponse(expectedVector);
return null;
}).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any());
Expand Down Expand Up @@ -682,10 +682,10 @@ public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSe
.queryImage(IMAGE_TEXT)
.modelId(MODEL_ID)
.k(K);
List<Float> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
List<Number> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
doAnswer(invocation -> {
ActionListener<List<Float>> listener = invocation.getArgument(2);
ActionListener<List<Number>> listener = invocation.getArgument(2);
listener.onResponse(expectedVector);
return null;
}).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ protected float[] runInference(final String modelId, final String queryText) {
List<Object> output = (List<Object>) result.get("output");
assertEquals(1, output.size());
Map<String, Object> map = (Map<String, Object>) output.get(0);
List<Float> data = ((List<Double>) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList());
List<Number> data = ((List<Double>) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList());
return vectorAsListToArray(data);
}

Expand Down
Loading