diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessor.java index 28bf7866f..bd13b78b1 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessor.java @@ -23,6 +23,7 @@ import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getScoreFromSourceMap; import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getValueFromSource; +import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.isNumeric; import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.mappingExistsInSource; import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.removeTargetFieldFromSource; import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.validateRerankCriteria; @@ -113,7 +114,6 @@ public void rescoreSearchResponse( final ActionListener> listener ) { SearchHit[] searchHits = response.getHits().getHits(); - SearchHitValidator searchHitValidator = this::byFieldSearchHitValidator; if (!validateRerankCriteria(searchHits, searchHitValidator, listener)) { @@ -162,26 +162,41 @@ public void rescoreSearchResponse( */ public void byFieldSearchHitValidator(final SearchHit hit) { if (!hit.hasSource()) { - log.error(String.format(Locale.ROOT, "There is no source field to be able to perform rerank on hit [%d]", hit.docId())); + log.error(String.format(Locale.ROOT, "There is no source field to be able to perform rerank on hit [%s]", hit.getId())); throw new IllegalArgumentException( - String.format(Locale.ROOT, "There is no source field to be able to perform rerank on hit [%d]", hit.docId()) + String.format(Locale.ROOT, "There is no source field to be able to perform rerank on hit [%s]", hit.getId()) ); } Map sourceMap = hit.getSourceAsMap(); if (!mappingExistsInSource(sourceMap, targetField)) { - log.error(String.format(Locale.ROOT, "The field to rerank [%s] is not found at hit [%d]", targetField, hit.docId())); + log.error(String.format(Locale.ROOT, "The field to rerank [%s] is not found at hit [%s]", targetField, hit.getId())); - throw new IllegalArgumentException(String.format(Locale.ROOT, "The field to rerank by is not found at hit [%d]", hit.docId())); + throw new IllegalArgumentException(String.format(Locale.ROOT, "The field to rerank by is not found at hit [%s]", hit.getId())); } Optional val = getValueFromSource(sourceMap, targetField); - if (!(val.get() instanceof Number)) { - log.error(String.format(Locale.ROOT, "The field mapping to rerank [%s: %s] is not Numerical", targetField, val.orElse(null))); + if (!(isNumeric(val.get()))) { + // Strictly get the type of value removing the prefix of getClass() having a value is guaranteed so no NPE check + String typeOfMapping = val.get().getClass().getSimpleName(); + log.error( + String.format( + Locale.ROOT, + "The field mapping to rerank [%s: %s] is not Numerical, instead of type [%s]", + targetField, + val.orElse(null), + typeOfMapping + ) + ); throw new IllegalArgumentException( - String.format(Locale.ROOT, "The field mapping to rerank by [%s] is not Numerical", val.orElse(null)) + String.format( + Locale.ROOT, + "The field mapping to rerank by [%s] is not Numerical, instead of type [%s]", + val.orElse(null), + typeOfMapping + ) ); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java index a6a377843..d799f323f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java @@ -47,6 +47,7 @@ public interface SearchHitValidator { * for each SearchHit follows the correct form as specified by the validator. * When just one of the conditions fail (as specified by the validator) the exception will be thrown to the listener. * @param searchHits from the SearchResponse + * @param validator The given validator used to check every search hit being correct * @param listener returns an error to the listener in case on of the conditions fail * @return The status indicating that the SearchHits are in correct form to perform the Rerank */ @@ -77,6 +78,9 @@ public static boolean validateRerankCriteria( */ public static float getScoreFromSourceMap(final Map sourceAsMap, final String targetField) { Object val = getValueFromSource(sourceAsMap, targetField).get(); + if (val instanceof String) { + return Float.parseFloat((String) val); + } return ((Number) val).floatValue(); } @@ -180,4 +184,29 @@ public static boolean mappingExistsInSource(final Map sourceAsMa return getValueFromSource(sourceAsMap, pathToValue).isPresent(); } + /** + * @param value Any value to be determined to be numerical + * @return whether the value can be turned into a number + */ + public static boolean isNumeric(Object value) { + if (value == null) { + return false; + } + + if (value instanceof Number) { + return true; + } + + if (value instanceof String) { + String string = (String) value; + try { + Double.parseDouble(string); + return true; + } catch (NumberFormatException e) { + return false; + } + } + + return false; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessorTests.java index 31dda262f..f38d5834e 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessorTests.java @@ -757,6 +757,90 @@ public void testRerank_keepsTargetFieldAndHasNoPreviousScore_WhenByFieldHasDefau } } + public void testRerank_reranksHits_WhenTargetFieldIsNumericalString() throws IOException { + String targetField = "ml.info.score"; + setUpValidSearchResultsWithNestedTargetValueWithNumericalString(); + List> sortedScoresDescending = sampleIndexMLScorePairs.stream() + .sorted(Map.Entry.comparingByValue().reversed()) + .toList(); + + Map config = new HashMap<>( + Map.of(RerankType.BY_FIELD.getLabel(), new HashMap<>(Map.of(ByFieldRerankProcessor.TARGET_FIELD, targetField))) + ); + processor = (ByFieldRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for 2nd level reranking based on provided field, This will check a nested field and numerical string", + false, + config, + pipelineContext + ); + ActionListener listener = mock(ActionListener.class); + processor.rerank(response, Map.of(), listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + + verify(listener, times(1)).onResponse(argCaptor.capture()); + SearchResponse searchResponse = argCaptor.getValue(); + + assertEquals(sampleIndexMLScorePairs.size(), searchResponse.getHits().getHits().length); + assertEquals(sortedScoresDescending.getFirst().getValue(), searchResponse.getHits().getMaxScore(), 0.0001); + + for (int i = 0; i < sortedScoresDescending.size(); i++) { + int docId = sortedScoresDescending.get(i).getKey(); + float ml_score = sortedScoresDescending.get(i).getValue(); + assertEquals(docId, searchResponse.getHits().getAt(i).docId()); + assertEquals(ml_score, searchResponse.getHits().getAt(i).getScore(), 0.001); + + // Test that the path to targetField is valid + Map currentMap = searchResponse.getHits().getAt(i).getSourceAsMap(); + String[] keys = targetField.split("\\."); + String lastKey = keys[keys.length - 1]; + for (int keyIndex = 0; keyIndex < keys.length - 1; keyIndex++) { + String key = keys[keyIndex]; + assertTrue("The key:" + key + "does not exist in" + currentMap, currentMap.containsKey(key)); + currentMap = (Map) currentMap.get(key); + } + assertTrue("The key:" + lastKey + "does not exist in" + currentMap, currentMap.containsKey(lastKey)); + + } + } + + /** + * Setups a search response that has a target field with a numerical string for example "3.2" + * Which can be used by the processor to rerank documents. + */ + private void setUpValidSearchResultsWithNestedTargetValueWithNumericalString() { + SearchHit[] hits = new SearchHit[sampleIndexMLScorePairs.size()]; + + String templateString = """ + { + "my_field" : "%s", + "ml": { + "info" : { + "score": "%s" + } + } + } + """.replace("\n", ""); + + for (int i = 0; i < sampleIndexMLScorePairs.size(); i++) { + int docId = sampleIndexMLScorePairs.get(i).getKey(); + String mlScore = sampleIndexMLScorePairs.get(i).getValue() + ""; + + String sourceMap = templateString.formatted(i, mlScore); + + hits[i] = new SearchHit(docId, docId + "", Collections.emptyMap(), Collections.emptyMap()); + hits[i].sourceRef(new BytesArray(sourceMap)); + hits[i].score(1); + } + + TotalHits totalHits = new TotalHits(sampleIndexMLScorePairs.size(), TotalHits.Relation.EQUAL_TO); + + SearchHits searchHits = new SearchHits(hits, totalHits, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + response = new SearchResponse(internal, null, 1, 1, 0, 1, new ShardSearchFailure[0], new SearchResponse.Clusters(1, 1, 0), null); + } + /** * Creates a searchResponse where the value to reRank by is Nested. * The location where the target is within a map of size 1 meaning after @@ -891,7 +975,7 @@ public void testRerank_throwsExceptionOnNoSource_WhenSearchResponseHasNoSourceMa ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(argumentCaptor.capture()); - assertEquals("There is no source field to be able to perform rerank on hit [" + 1 + "]", argumentCaptor.getValue().getMessage()); + assertEquals("There is no source field to be able to perform rerank on hit [" + 2 + "]", argumentCaptor.getValue().getMessage()); assert (argumentCaptor.getValue() instanceof IllegalArgumentException); } @@ -929,7 +1013,7 @@ public void testRerank_throwsExceptionOnMappingNotExistingInSource_WhenSearchRes ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(argumentCaptor.capture()); - assertEquals("The field to rerank by is not found at hit [" + 1 + "]", argumentCaptor.getValue().getMessage()); + assertEquals("The field to rerank by is not found at hit [" + 2 + "]", argumentCaptor.getValue().getMessage()); assert (argumentCaptor.getValue() instanceof IllegalArgumentException); } @@ -969,7 +1053,7 @@ public void testRerank_throwsExceptionOnHavingEmptyMapping_WhenTargetFieldHasNul ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(argumentCaptor.capture()); - assertEquals("The field to rerank by is not found at hit [" + 1 + "]", argumentCaptor.getValue().getMessage()); + assertEquals("The field to rerank by is not found at hit [" + 2 + "]", argumentCaptor.getValue().getMessage()); assert (argumentCaptor.getValue() instanceof IllegalArgumentException); } @@ -1007,7 +1091,10 @@ public void testRerank_throwsExceptionOnHavingNonNumericValue_WhenTargetFieldHas ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(argumentCaptor.capture()); - assertEquals("The field mapping to rerank by [hello world] is not Numerical", argumentCaptor.getValue().getMessage()); + assertEquals( + "The field mapping to rerank by [hello world] is not Numerical, instead of type [String]", + argumentCaptor.getValue().getMessage() + ); assert (argumentCaptor.getValue() instanceof IllegalArgumentException); }