Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Nov 29, 2023
1 parent cc3bef8 commit 9a0c209
Show file tree
Hide file tree
Showing 5 changed files with 398 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,6 @@ public Collection<Query> getSubQueries() {
return Collections.unmodifiableCollection(subQueries);
}

public void addSubQuery(final Query query) {
Objects.requireNonNull(subQueries, "collection of queries must not be null");
subQueries.add(query);
}

/**
* Create the Weight used to score this query
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public boolean searchWith(
query = extractHybridQuery(searchContext, query);
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}
validateHybridQuery(query);
validateQuery(query);
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}

Expand Down Expand Up @@ -112,7 +112,7 @@ && mightBeWrappedHybridQuery(query)
return query;
}

private void validateHybridQuery(final Query query) {
private void validateQuery(final Query query) {
if (query instanceof BooleanQuery) {
List<BooleanClause> booleanClauses = ((BooleanQuery) query).clauses();
for (BooleanClause booleanClause : booleanClauses) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,18 @@ protected void addKnnDoc(String index, String docId, List<String> vectorFieldNam
addKnnDoc(index, docId, vectorFieldNames, vectors, Collections.emptyList(), Collections.emptyList());
}

@SneakyThrows
protected void addKnnDoc(
String index,
String docId,
List<String> vectorFieldNames,
List<Object[]> vectors,
List<String> textFieldNames,
List<String> texts
) {
addKnnDoc(index, docId, vectorFieldNames, vectors, textFieldNames, texts, Collections.emptyList(), Collections.emptyList());
}

/**
* Add a set of knn vectors and text to an index
*
Expand All @@ -422,6 +434,8 @@ protected void addKnnDoc(String index, String docId, List<String> vectorFieldNam
* @param vectors List of vectors corresponding to those fields
* @param textFieldNames List of text fields to be added
* @param texts List of text corresponding to those fields
* @param nestedFieldNames List of nested fields to be added
* @param nestedFields List of fields and values corresponding to those fields
*/
@SneakyThrows
protected void addKnnDoc(
Expand All @@ -430,7 +444,9 @@ protected void addKnnDoc(
List<String> vectorFieldNames,
List<Object[]> vectors,
List<String> textFieldNames,
List<String> texts
List<String> texts,
List<String> nestedFieldNames,
List<Map<String, String>> nestedFields
) {
Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true");
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
Expand All @@ -441,6 +457,16 @@ protected void addKnnDoc(
for (int i = 0; i < textFieldNames.size(); i++) {
builder.field(textFieldNames.get(i), texts.get(i));
}

for (int i = 0; i < nestedFieldNames.size(); i++) {
builder.field(nestedFieldNames.get(i));
builder.startObject();
Map<String, String> nestedValues = nestedFields.get(i);
for (Map.Entry<String, String> entry : nestedValues.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
}
builder.endObject();

request.setJsonEntity(builder.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.opensearch.index.query.QueryBuilders.matchQuery;
import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION;
import static org.opensearch.neuralsearch.TestUtils.createRandomVector;

Expand All @@ -21,11 +22,13 @@

import lombok.SneakyThrows;

import org.apache.lucene.search.join.ScoreMode;
import org.junit.After;
import org.junit.Before;
import org.opensearch.client.ResponseException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.NestedQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.index.SpaceType;
Expand All @@ -51,9 +54,14 @@ public class HybridQueryIT extends BaseNeuralSearchIT {
private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1";
private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2";
private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1";
private static final String TEST_NESTED_TYPE_FIELD_NAME_1 = "user";

private static final int TEST_DIMENSION = 768;
private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2;
private static final String NESTED_FIELD_1 = "firstname";
private static final String NESTED_FIELD_2 = "lastname";
private static final String NESTED_FIELD_1_VALUE = "john";
private static final String NESTED_FIELD_2_VALUE = "black";
private final float[] testVector1 = createRandomVector(TEST_DIMENSION);
private final float[] testVector2 = createRandomVector(TEST_DIMENSION);
private final float[] testVector3 = createRandomVector(TEST_DIMENSION);
Expand Down Expand Up @@ -271,6 +279,39 @@ public void testIndexWithNestedFields_whenHybridQuery_thenSuccess() {
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

@SneakyThrows
public void testIndexWithNestedFields_whenHybridQueryIncludesNested_thenSuccess() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD);

TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT);
NestedQueryBuilder nestedQueryBuilder = QueryBuilders.nestedQuery(
TEST_NESTED_TYPE_FIELD_NAME_1,
matchQuery(TEST_NESTED_TYPE_FIELD_NAME_1 + "." + NESTED_FIELD_1, NESTED_FIELD_1_VALUE),
ScoreMode.Total
);
HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder();
hybridQueryBuilderOnlyTerm.add(termQueryBuilder);
hybridQueryBuilderOnlyTerm.add(nestedQueryBuilder);

Map<String, Object> searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD,
hybridQueryBuilderOnlyTerm,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertEquals(1, getHitCount(searchResponseAsMap));
assertTrue(getMaxScore(searchResponseAsMap).isPresent());
assertTrue(getMaxScore(searchResponseAsMap).get() > 0);

Map<String, Object> total = getTotalHits(searchResponseAsMap);
assertNotNull(total.get("value"));
assertEquals(1, total.get("value"));
assertNotNull(total.get("relation"));
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

@SneakyThrows
private void initializeIndexIfNotExist(String indexName) throws IOException {
if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) {
Expand Down Expand Up @@ -344,13 +385,23 @@ private void initializeIndexIfNotExist(String indexName) throws IOException {
indexName,
buildIndexConfiguration(
Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)),
List.of("user"),
List.of(TEST_NESTED_TYPE_FIELD_NAME_1),
1
),
""
);

addDocsToIndex(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD);
addKnnDoc(
TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD,
"4",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector1).toArray()),
List.of(),
List.of(),
List.of(TEST_NESTED_TYPE_FIELD_NAME_1),
List.of(Map.of(NESTED_FIELD_1, NESTED_FIELD_1_VALUE, NESTED_FIELD_2, NESTED_FIELD_2_VALUE))
);
}
}

Expand Down
Loading

0 comments on commit 9a0c209

Please sign in to comment.