Skip to content

Commit

Permalink
Fix script score queries not getting cached
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Jan 4, 2024
1 parent f519fe2 commit a65a430
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 70 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305)
* Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318)
* Properly designate model state for actively training models when nodes crash or leave cluster [#1317](https://github.com/opensearch-project/k-NN/pull/1317)
* Fix script score queries not getting cached [#1367](https://github.com/opensearch-project/k-NN/pull/1367)
### Infrastructure
* Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289)
* Refactor security testing to install from individual components [#1307](https://github.com/opensearch-project/k-NN/pull/1307)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.knn.index.SpaceType;
import org.opensearch.core.rest.RestStatus;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -107,7 +108,7 @@ private void validateKNNInnerProductScriptScoreSearch(String testIndex, String t
params.put(QUERY_VALUE, queryVector);
params.put(METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue());

Request request = constructKNNScriptQueryRequest(testIndex, qb, params, k);
Request request = constructKNNScriptQueryRequest(testIndex, qb, params, k, Collections.emptyMap());
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,64 +6,21 @@
package org.opensearch.knn.plugin.script;

import org.apache.lucene.search.IndexSearcher;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.apache.lucene.index.LeafReaderContext;
import org.opensearch.script.ScoreScript;
import org.opensearch.script.ScriptFactory;
import org.opensearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.Map;

public class KNNScoreScriptFactory implements ScoreScript.LeafFactory {
private final Map<String, Object> params;
private final SearchLookup lookup;
private String similaritySpace;
private String field;
private Object query;
private KNNScoringSpace knnScoringSpace;

private IndexSearcher searcher;

public KNNScoreScriptFactory(Map<String, Object> params, SearchLookup lookup, IndexSearcher searcher) {
KNNCounter.SCRIPT_QUERY_REQUESTS.increment();
this.params = params;
this.lookup = lookup;
this.field = getValue(params, "field").toString();
this.similaritySpace = getValue(params, "space_type").toString();
this.query = getValue(params, "query_value");
this.searcher = searcher;

this.knnScoringSpace = KNNScoringSpaceFactory.create(
this.similaritySpace,
this.query,
lookup.doc().mapperService().fieldType(this.field)
);
}

private Object getValue(Map<String, Object> params, String fieldName) {
final Object value = params.get(fieldName);
if (value != null) return value;

KNNCounter.SCRIPT_QUERY_ERRORS.increment();
throw new IllegalArgumentException("Missing parameter [" + fieldName + "]");
}

public class KNNScoreScriptFactory implements ScoreScript.Factory, ScriptFactory {
@Override
public boolean needs_score() {
return false;
public boolean isResultDeterministic() {
// This implies the results are cacheable
return true;
}

/**
* For each segment, supply the KNNScoreScript that should be used to re-score the documents returned from the
* query. Because the method to score the documents was set during factory construction, the scripts are agnostic of
* the similarity space. The KNNScoringSpace will return the correct script, given the query, the field type, and
* the similarity space.
*
* @param ctx LeafReaderContext for the segment
* @return ScoreScript to be executed
*/
@Override
public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
return knnScoringSpace.getScoreScript(params, field, lookup, ctx, this.searcher);
public ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher) {
return new KNNScoreScriptLeafFactory(params, lookup, indexSearcher);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.plugin.script;

import java.io.IOException;
import java.util.Locale;
import java.util.Map;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.script.ScoreScript;
import org.opensearch.search.lookup.SearchLookup;

/*
* A factory that creates KNNScoreScriptLeafFactory objects. The factory is responsible for parsing the parameters
* passed in the query and creating the KNNScoreScriptLeafFactory object.
*/
public class KNNScoreScriptLeafFactory implements ScoreScript.LeafFactory {
private final Map<String, Object> params;
private final SearchLookup lookup;
private final String similaritySpace;
private final String field;
private final Object query;
private final KNNScoringSpace knnScoringSpace;
private final IndexSearcher searcher;

public KNNScoreScriptLeafFactory(Map<String, Object> params, SearchLookup lookup, IndexSearcher searcher) {
KNNCounter.SCRIPT_QUERY_REQUESTS.increment();
this.params = params;
this.lookup = lookup;
this.field = getValue(params, "field").toString();
this.similaritySpace = getValue(params, "space_type").toString();
this.query = getValue(params, "query_value");
this.searcher = searcher;

this.knnScoringSpace = KNNScoringSpaceFactory.create(
this.similaritySpace,
this.query,
lookup.doc().mapperService().fieldType(this.field)
);
}

private Object getValue(Map<String, Object> params, String fieldName) {
final Object value = params.get(fieldName);
if (value != null) return value;

KNNCounter.SCRIPT_QUERY_ERRORS.increment();
throw new IllegalArgumentException(String.format(Locale.ROOT, "Missing parameter [%s]", fieldName));
}

@Override
public boolean needs_score() {
return false;
}

/**
* For each segment, supply the KNNScoreScript that should be used to re-score the documents returned from the
* query. Because the method to score the documents was set during factory construction, the scripts are agnostic of
* the similarity space. The KNNScoringSpace will return the correct script, given the query, the field type, and
* the similarity space.
*
* @param ctx LeafReaderContext for the segment
* @return ScoreScript to be executed
*/
@Override
public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
return knnScoringSpace.getScoreScript(params, field, lookup, ctx, this.searcher);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public <FactoryType> FactoryType compile(String name, String code, ScriptContext
KNNCounter.SCRIPT_COMPILATION_ERRORS.increment();
throw new IllegalArgumentException("Unknown script name " + code);
}
ScoreScript.Factory factory = KNNScoreScriptFactory::new;
ScoreScript.Factory factory = new KNNScoreScriptFactory();
return context.factoryClazz.cast(factory);
}

Expand Down
6 changes: 4 additions & 2 deletions src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ public void testL2PainlessScriptingWithByteVectorDataType() throws Exception {
Collections.emptyMap(),
Script.DEFAULT_SCRIPT_LANG,
source,
4
4,
Collections.emptyMap()
);

Response response = client().performRequest(request);
Expand All @@ -370,7 +371,8 @@ public void testL2PainlessScriptingWithFloatVectorDataType() throws Exception {
Collections.emptyMap(),
Script.DEFAULT_SCRIPT_LANG,
source,
4
4,
Collections.emptyMap()
);

Response response = client().performRequest(request);
Expand Down
117 changes: 113 additions & 4 deletions src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.plugin.script;

import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.index.SpaceType;
Expand All @@ -25,9 +26,11 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.containsString;
Expand Down Expand Up @@ -449,7 +452,7 @@ public void testHammingScriptScore_Long() throws Exception {
params1.put("field", FIELD_NAME);
params1.put("query_value", queryValue1);
params1.put("space_type", SpaceType.HAMMING_BIT.getValue());
Request request1 = constructKNNScriptQueryRequest(INDEX_NAME, qb1, params1, 4);
Request request1 = constructKNNScriptQueryRequest(INDEX_NAME, qb1, params1, 4, Collections.emptyMap());
Response response1 = client().performRequest(request1);
assertEquals(request1.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response1.getStatusLine().getStatusCode()));

Expand Down Expand Up @@ -487,7 +490,7 @@ public void testHammingScriptScore_Long() throws Exception {
params2.put("field", FIELD_NAME);
params2.put("query_value", queryValue2);
params2.put("space_type", SpaceType.HAMMING_BIT.getValue());
Request request2 = constructKNNScriptQueryRequest(INDEX_NAME, qb2, params2, 4);
Request request2 = constructKNNScriptQueryRequest(INDEX_NAME, qb2, params2, 4, Collections.emptyMap());
Response response2 = client().performRequest(request2);
assertEquals(request2.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response2.getStatusLine().getStatusCode()));

Expand Down Expand Up @@ -555,7 +558,7 @@ public void testHammingScriptScore_Base64() throws Exception {
params1.put("field", FIELD_NAME);
params1.put("query_value", queryValue1);
params1.put("space_type", SpaceType.HAMMING_BIT.getValue());
Request request1 = constructKNNScriptQueryRequest(INDEX_NAME, qb1, params1, 4);
Request request1 = constructKNNScriptQueryRequest(INDEX_NAME, qb1, params1, 4, Collections.emptyMap());
Response response1 = client().performRequest(request1);
assertEquals(request1.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response1.getStatusLine().getStatusCode()));

Expand Down Expand Up @@ -593,7 +596,7 @@ public void testHammingScriptScore_Base64() throws Exception {
params2.put("field", FIELD_NAME);
params2.put("query_value", queryValue2);
params2.put("space_type", SpaceType.HAMMING_BIT.getValue());
Request request2 = constructKNNScriptQueryRequest(INDEX_NAME, qb2, params2, 4);
Request request2 = constructKNNScriptQueryRequest(INDEX_NAME, qb2, params2, 4, Collections.emptyMap());
Response response2 = client().performRequest(request2);
assertEquals(request2.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response2.getStatusLine().getStatusCode()));

Expand Down Expand Up @@ -673,4 +676,110 @@ public void testKNNInnerProdScriptScore() throws Exception {
assertEquals("4", results.get(2).getDocId());
assertEquals("1", results.get(3).getDocId());
}

public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception {
/*
* Create knn index and populate data
*/
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
Float[] f1 = { 6.0f, 6.0f };
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1);

Float[] f2 = { 2.0f, 2.0f };
addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2);

Float[] f3 = { 4.0f, 4.0f };
addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3);

Float[] f4 = { 3.0f, 3.0f };
addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4);

/**
* Construct Search Request
*/
QueryBuilder qb = new MatchAllQueryBuilder();
Map<String, Object> scriptParams = new HashMap<>();
/*
* params": {
* "field": "my_dense_vector",
* "vector": [2.0, 2.0]
* }
*/
float[] queryVector = { 1.0f, 1.0f };
scriptParams.put("field", FIELD_NAME);
scriptParams.put("query_value", queryVector);
scriptParams.put("space_type", SpaceType.L2.getValue());
Map<String, Object> searchParams = new HashMap<>();
searchParams.put("request_cache", true);

// first request with request cache enabled
Request firstScriptQueryRequest = constructKNNScriptQueryRequest(INDEX_NAME, qb, scriptParams, 4, searchParams);
Response firstScriptQueryResponse = client().performRequest(firstScriptQueryRequest);
assertEquals(
firstScriptQueryRequest.getEndpoint() + ": failed",
RestStatus.OK,
RestStatus.fromCode(firstScriptQueryResponse.getStatusLine().getStatusCode())
);

List<KNNResult> results = parseSearchResponse(EntityUtils.toString(firstScriptQueryResponse.getEntity()), FIELD_NAME);
List<String> expectedDocids = Arrays.asList("2", "4", "3", "1");

List<String> actualDocids = new ArrayList<>();
for (KNNResult result : results) {
actualDocids.add(result.getDocId());
}

assertEquals(4, results.size());
assertEquals(expectedDocids, actualDocids);

// assert that the request cache was hit missed at first request
Request firstStatsRequest = new Request("GET", "/" + INDEX_NAME + "/_stats");
Response firstStatsResponse = client().performRequest(firstStatsRequest);
assertEquals(
firstStatsRequest.getEndpoint() + ": failed",
RestStatus.OK,
RestStatus.fromCode(firstStatsResponse.getStatusLine().getStatusCode())
);
String firstStatsResponseBody = EntityUtils.toString(firstStatsResponse.getEntity());
Map<String, Object> firstQueryCacheMap = Optional.ofNullable(
createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), firstStatsResponseBody).map()
)
.map(r -> (Map<String, Object>) r.get("indices"))
.map(i -> (Map<String, Object>) i.get(INDEX_NAME))
.map(ind -> (Map<String, Object>) ind.get("total"))
.map(t -> (Map<String, Object>) t.get("request_cache"))
.orElseThrow(() -> new IllegalStateException("Query Cache Map not found"));
// assert that the request cache was hit missed at first request
assertEquals(1, firstQueryCacheMap.get("miss_count"));
assertEquals(0, firstQueryCacheMap.get("hit_count"));

// second request with request cache enabled
Request secondScriptQueryRequest = constructKNNScriptQueryRequest(INDEX_NAME, qb, scriptParams, 4, searchParams);
Response secondScriptQueryResponse = client().performRequest(secondScriptQueryRequest);
assertEquals(
firstScriptQueryRequest.getEndpoint() + ": failed",
RestStatus.OK,
RestStatus.fromCode(secondScriptQueryResponse.getStatusLine().getStatusCode())
);

Request secondStatsRequest = new Request("GET", "/" + INDEX_NAME + "/_stats");
Response secondStatsResponse = client().performRequest(secondStatsRequest);
assertEquals(
secondStatsRequest.getEndpoint() + ": failed",
RestStatus.OK,
RestStatus.fromCode(secondStatsResponse.getStatusLine().getStatusCode())
);
String secondStatsResponseBody = EntityUtils.toString(secondStatsResponse.getEntity());
Map<String, Object> secondQueryCacheMap = Optional.ofNullable(
createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), secondStatsResponseBody).map()
)
.map(r -> (Map<String, Object>) r.get("indices"))
.map(i -> (Map<String, Object>) i.get(INDEX_NAME))
.map(ind -> (Map<String, Object>) ind.get("total"))
.map(t -> (Map<String, Object>) t.get("request_cache"))
.orElseThrow(() -> new IllegalStateException("Query Cache Map not found"));
assertEquals(1, secondQueryCacheMap.get("miss_count"));
// assert that the request cache was hit at second request
assertEquals(1, secondQueryCacheMap.get("hit_count"));
}
}
Loading

0 comments on commit a65a430

Please sign in to comment.