Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] ByField Rerank Improvements #1117

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getScoreFromSourceMap;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getValueFromSource;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.isNumeric;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.mappingExistsInSource;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.removeTargetFieldFromSource;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.validateRerankCriteria;
Expand Down Expand Up @@ -113,7 +114,6 @@ public void rescoreSearchResponse(
final ActionListener<List<Float>> listener
) {
SearchHit[] searchHits = response.getHits().getHits();

SearchHitValidator searchHitValidator = this::byFieldSearchHitValidator;

if (!validateRerankCriteria(searchHits, searchHitValidator, listener)) {
Expand Down Expand Up @@ -162,26 +162,41 @@ public void rescoreSearchResponse(
*/
public void byFieldSearchHitValidator(final SearchHit hit) {
if (!hit.hasSource()) {
log.error(String.format(Locale.ROOT, "There is no source field to be able to perform rerank on hit [%d]", hit.docId()));
log.error(String.format(Locale.ROOT, "There is no source field to be able to perform rerank on hit [%s]", hit.getId()));
throw new IllegalArgumentException(
String.format(Locale.ROOT, "There is no source field to be able to perform rerank on hit [%d]", hit.docId())
String.format(Locale.ROOT, "There is no source field to be able to perform rerank on hit [%s]", hit.getId())
);
}

Map<String, Object> sourceMap = hit.getSourceAsMap();
if (!mappingExistsInSource(sourceMap, targetField)) {
log.error(String.format(Locale.ROOT, "The field to rerank [%s] is not found at hit [%d]", targetField, hit.docId()));
log.error(String.format(Locale.ROOT, "The field to rerank [%s] is not found at hit [%s]", targetField, hit.getId()));

throw new IllegalArgumentException(String.format(Locale.ROOT, "The field to rerank by is not found at hit [%d]", hit.docId()));
throw new IllegalArgumentException(String.format(Locale.ROOT, "The field to rerank by is not found at hit [%s]", hit.getId()));
}

Optional<Object> val = getValueFromSource(sourceMap, targetField);

if (!(val.get() instanceof Number)) {
log.error(String.format(Locale.ROOT, "The field mapping to rerank [%s: %s] is not Numerical", targetField, val.orElse(null)));
if (!(isNumeric(val.get()))) {
// Strictly get the type of value removing the prefix of getClass() having a value is guaranteed so no NPE check
String typeOfMapping = val.get().getClass().getSimpleName();
log.error(
String.format(
Locale.ROOT,
"The field mapping to rerank [%s: %s] is not Numerical, instead of type [%s]",
targetField,
val.orElse(null),
typeOfMapping
)
);

throw new IllegalArgumentException(
String.format(Locale.ROOT, "The field mapping to rerank by [%s] is not Numerical", val.orElse(null))
String.format(
Locale.ROOT,
"The field mapping to rerank by [%s] is not Numerical, instead of type [%s]",
val.orElse(null),
typeOfMapping
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public interface SearchHitValidator {
* for each SearchHit follows the correct form as specified by the validator.
* When just one of the conditions fail (as specified by the validator) the exception will be thrown to the listener.
* @param searchHits from the SearchResponse
* @param validator The given validator used to check every search hit being correct
* @param listener returns an error to the listener in case on of the conditions fail
* @return The status indicating that the SearchHits are in correct form to perform the Rerank
*/
Expand Down Expand Up @@ -77,6 +78,9 @@ public static boolean validateRerankCriteria(
*/
public static float getScoreFromSourceMap(final Map<String, Object> sourceAsMap, final String targetField) {
Object val = getValueFromSource(sourceAsMap, targetField).get();
if (val instanceof String) {
return Float.parseFloat((String) val);
}
return ((Number) val).floatValue();
}

Expand Down Expand Up @@ -180,4 +184,29 @@ public static boolean mappingExistsInSource(final Map<String, Object> sourceAsMa
return getValueFromSource(sourceAsMap, pathToValue).isPresent();
}

/**
* @param value Any value to be determined to be numerical
* @return whether the value can be turned into a number
*/
public static boolean isNumeric(Object value) {
if (value == null) {
return false;
}

if (value instanceof Number) {
return true;
}

if (value instanceof String) {
String string = (String) value;
try {
Double.parseDouble(string);
return true;
} catch (NumberFormatException e) {
return false;
}
}

return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,90 @@ public void testRerank_keepsTargetFieldAndHasNoPreviousScore_WhenByFieldHasDefau
}
}

public void testRerank_reranksHits_WhenTargetFieldIsNumericalString() throws IOException {
String targetField = "ml.info.score";
setUpValidSearchResultsWithNestedTargetValueWithNumericalString();
List<Map.Entry<Integer, Float>> sortedScoresDescending = sampleIndexMLScorePairs.stream()
.sorted(Map.Entry.<Integer, Float>comparingByValue().reversed())
.toList();

Map<String, Object> config = new HashMap<>(
Map.of(RerankType.BY_FIELD.getLabel(), new HashMap<>(Map.of(ByFieldRerankProcessor.TARGET_FIELD, targetField)))
);
processor = (ByFieldRerankProcessor) factory.create(
Map.of(),
"rerank processor",
"processor for 2nd level reranking based on provided field, This will check a nested field and numerical string",
false,
config,
pipelineContext
);
ActionListener<SearchResponse> listener = mock(ActionListener.class);
processor.rerank(response, Map.of(), listener);
ArgumentCaptor<SearchResponse> argCaptor = ArgumentCaptor.forClass(SearchResponse.class);

verify(listener, times(1)).onResponse(argCaptor.capture());
SearchResponse searchResponse = argCaptor.getValue();

assertEquals(sampleIndexMLScorePairs.size(), searchResponse.getHits().getHits().length);
assertEquals(sortedScoresDescending.getFirst().getValue(), searchResponse.getHits().getMaxScore(), 0.0001);

for (int i = 0; i < sortedScoresDescending.size(); i++) {
int docId = sortedScoresDescending.get(i).getKey();
float ml_score = sortedScoresDescending.get(i).getValue();
assertEquals(docId, searchResponse.getHits().getAt(i).docId());
assertEquals(ml_score, searchResponse.getHits().getAt(i).getScore(), 0.001);

// Test that the path to targetField is valid
Map<String, Object> currentMap = searchResponse.getHits().getAt(i).getSourceAsMap();
String[] keys = targetField.split("\\.");
String lastKey = keys[keys.length - 1];
for (int keyIndex = 0; keyIndex < keys.length - 1; keyIndex++) {
String key = keys[keyIndex];
assertTrue("The key:" + key + "does not exist in" + currentMap, currentMap.containsKey(key));
currentMap = (Map<String, Object>) currentMap.get(key);
}
assertTrue("The key:" + lastKey + "does not exist in" + currentMap, currentMap.containsKey(lastKey));

}
}

/**
* Setups a search response that has a target field with a numerical string for example "3.2"
* Which can be used by the processor to rerank documents.
*/
private void setUpValidSearchResultsWithNestedTargetValueWithNumericalString() {
SearchHit[] hits = new SearchHit[sampleIndexMLScorePairs.size()];

String templateString = """
{
"my_field" : "%s",
"ml": {
"info" : {
"score": "%s"
}
}
}
""".replace("\n", "");

for (int i = 0; i < sampleIndexMLScorePairs.size(); i++) {
int docId = sampleIndexMLScorePairs.get(i).getKey();
String mlScore = sampleIndexMLScorePairs.get(i).getValue() + "";

String sourceMap = templateString.formatted(i, mlScore);

hits[i] = new SearchHit(docId, docId + "", Collections.emptyMap(), Collections.emptyMap());
hits[i].sourceRef(new BytesArray(sourceMap));
hits[i].score(1);
}

TotalHits totalHits = new TotalHits(sampleIndexMLScorePairs.size(), TotalHits.Relation.EQUAL_TO);

SearchHits searchHits = new SearchHits(hits, totalHits, 1.0f);
SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
response = new SearchResponse(internal, null, 1, 1, 0, 1, new ShardSearchFailure[0], new SearchResponse.Clusters(1, 1, 0), null);
}

/**
* Creates a searchResponse where the value to reRank by is Nested.
* The location where the target is within a map of size 1 meaning after
Expand Down Expand Up @@ -891,7 +975,7 @@ public void testRerank_throwsExceptionOnNoSource_WhenSearchResponseHasNoSourceMa
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(argumentCaptor.capture());

assertEquals("There is no source field to be able to perform rerank on hit [" + 1 + "]", argumentCaptor.getValue().getMessage());
assertEquals("There is no source field to be able to perform rerank on hit [" + 2 + "]", argumentCaptor.getValue().getMessage());
assert (argumentCaptor.getValue() instanceof IllegalArgumentException);
}

Expand Down Expand Up @@ -929,7 +1013,7 @@ public void testRerank_throwsExceptionOnMappingNotExistingInSource_WhenSearchRes
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(argumentCaptor.capture());

assertEquals("The field to rerank by is not found at hit [" + 1 + "]", argumentCaptor.getValue().getMessage());
assertEquals("The field to rerank by is not found at hit [" + 2 + "]", argumentCaptor.getValue().getMessage());
assert (argumentCaptor.getValue() instanceof IllegalArgumentException);
}

Expand Down Expand Up @@ -969,7 +1053,7 @@ public void testRerank_throwsExceptionOnHavingEmptyMapping_WhenTargetFieldHasNul
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(argumentCaptor.capture());

assertEquals("The field to rerank by is not found at hit [" + 1 + "]", argumentCaptor.getValue().getMessage());
assertEquals("The field to rerank by is not found at hit [" + 2 + "]", argumentCaptor.getValue().getMessage());
assert (argumentCaptor.getValue() instanceof IllegalArgumentException);
}

Expand Down Expand Up @@ -1007,7 +1091,10 @@ public void testRerank_throwsExceptionOnHavingNonNumericValue_WhenTargetFieldHas
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(argumentCaptor.capture());

assertEquals("The field mapping to rerank by [hello world] is not Numerical", argumentCaptor.getValue().getMessage());
assertEquals(
"The field mapping to rerank by [hello world] is not Numerical, instead of type [String]",
argumentCaptor.getValue().getMessage()
);
assert (argumentCaptor.getValue() instanceof IllegalArgumentException);

}
Expand Down
Loading