Skip to content

Commit

Permalink
create custom builder class to enforce valid neural query builder ins…
Browse files Browse the repository at this point in the history
…tantiation
  • Loading branch information
will-hwang committed Dec 31, 2024
1 parent b98d15e commit 344c2ef
Show file tree
Hide file tree
Showing 14 changed files with 452 additions and 440 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013))
- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988))
- Support empty string for fields in text embedding processor ([#1041](https://github.com/opensearch-project/neural-search/pull/1041))
- Support for builder constructor in Neural Query Builder ([#1047](https://github.com/opensearch-project/neural-search/pull/1047))
### Bug Fixes
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,12 @@ private HybridQueryBuilder getQueryBuilder(
final Map<String, ?> methodParameters,
final RescoreContext rescoreContext
) {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName("passage_embedding");
neuralQueryBuilder.modelId(modelId);
neuralQueryBuilder.queryText(QUERY);
neuralQueryBuilder.k(5);
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder()
.fieldName("passage_embedding")
.modelId(modelId)
.queryText(QUERY)
.k(5)
.build();
if (expandNestedDocs != null) {
neuralQueryBuilder.expandNested(expandNestedDocs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.Objects;
import java.util.function.Supplier;

import lombok.Builder;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
Expand All @@ -32,6 +31,7 @@
import org.opensearch.core.ParseField;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
Expand Down Expand Up @@ -68,9 +68,8 @@
@Getter
@Setter
@Accessors(chain = true, fluent = true)
@Builder(toBuilder = true)
@NoArgsConstructor
@AllArgsConstructor
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder> implements ModelInferenceQueryBuilder {

public static final String NAME = "neural";
Expand Down Expand Up @@ -110,6 +109,130 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
private Map<String, ?> methodParameters;
private RescoreContext rescoreContext;

/**
* A custom builder class to enforce valid Neural Query Builder instance by validating the required fields are initialized
*/
public static class Builder {
private String fieldName;
private String queryText;
private String queryImage;
private String modelId;
private Integer k = null;
private Float maxDistance = null;
private Float minScore = null;
private Boolean expandNested;
private Supplier<float[]> vectorSupplier;
private QueryBuilder filter;
private Map<String, ?> methodParameters;
private RescoreContext rescoreContext;
private String queryName;
private float boost = DEFAULT_BOOST;

public Builder() {}

public Builder fieldName(String fieldName) {
this.fieldName = fieldName;
return this;
}

public Builder queryText(String queryText) {
this.queryText = queryText;
return this;
}

public Builder queryImage(String queryImage) {
this.queryImage = queryImage;
return this;
}

public Builder modelId(String modelId) {
this.modelId = modelId;
return this;
}

public Builder k(Integer k) {
this.k = k;
return this;
}

public Builder maxDistance(Float maxDistance) {
this.maxDistance = maxDistance;
return this;
}

public Builder minScore(Float minScore) {
this.minScore = minScore;
return this;
}

public Builder expandNested(Boolean expandNested) {
this.expandNested = expandNested;
return this;
}

public Builder vectorSupplier(Supplier<float[]> vectorSupplier) {
this.vectorSupplier = vectorSupplier;
return this;
}

public Builder filter(QueryBuilder filter) {
this.filter = filter;
return this;
}

public Builder methodParameters(Map<String, ?> methodParameters) {
this.methodParameters = methodParameters;
return this;
}

public Builder queryName(String queryName) {
this.queryName = queryName;
return this;
}

public Builder boost(float boost) {
this.boost = boost;
return this;
}

public Builder rescoreContext(RescoreContext rescoreContext) {
this.rescoreContext = rescoreContext;
return this;
}

public NeuralQueryBuilder build() {
validate();
int k = this.k == null ? 0 : this.k;
return new NeuralQueryBuilder(
fieldName,
queryText,
queryImage,
modelId,
k,
maxDistance,
minScore,
expandNested,
vectorSupplier,
filter,
methodParameters,
rescoreContext
).boost(boost).queryName(queryName);
}

private void validate() {
if (Strings.isNullOrEmpty(fieldName)) {
throw new IllegalArgumentException("Field name must be provided for neural query");
}
if (Strings.isNullOrEmpty(queryText) && Strings.isNullOrEmpty(queryImage)) {
throw new IllegalArgumentException("Either query text or image text must be provided for neural query");
}
}
}

public static NeuralQueryBuilder.Builder builder() {
return new NeuralQueryBuilder.Builder();
}

/**
* Constructor from stream input
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ public void testNeuralQueryEnricherProcessor_whenNoModelIdPassed_thenSuccess() {
createSearchRequestProcessor(modelId, search_pipeline);
createPipelineProcessor(modelId, ingest_pipeline, ProcessorType.TEXT_EMBEDDING);
updateIndexSettings(index, Settings.builder().put("index.search.default_pipeline", search_pipeline));
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName(TEST_KNN_VECTOR_FIELD_NAME_1);
neuralQueryBuilder.queryText("Hello World");
neuralQueryBuilder.k(1);
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder()
.fieldName(TEST_KNN_VECTOR_FIELD_NAME_1)
.queryText("Hello World")
.k(1)
.build();
Map<String, Object> response = search(index, neuralQueryBuilder, 2);
assertFalse(response.isEmpty());
} finally {
Expand Down Expand Up @@ -112,10 +113,11 @@ public void testNeuralQueryEnricherProcessor_whenHybridQueryBuilderAndNoModelIdP
createSearchRequestProcessor(modelId, search_pipeline);
createPipelineProcessor(modelId, ingest_pipeline, ProcessorType.TEXT_EMBEDDING);
updateIndexSettings(index, Settings.builder().put("index.search.default_pipeline", search_pipeline));
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName(TEST_KNN_VECTOR_FIELD_NAME_1);
neuralQueryBuilder.queryText("Hello World");
neuralQueryBuilder.k(1);
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder()
.fieldName(TEST_KNN_VECTOR_FIELD_NAME_1)
.queryText("Hello World")
.k(1)
.build();
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(neuralQueryBuilder);
Map<String, Object> response = search(index, hybridQueryBuilder, 2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public void testFactory_whenModelIdIsNotString_thenFail() {

public void testProcessRequest_whenVisitingQueryBuilder_thenSuccess() throws Exception {
NeuralQueryEnricherProcessor.Factory factory = new NeuralQueryEnricherProcessor.Factory();
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder().fieldName("field_name").queryText("query_text").build();
SearchRequest searchRequest = new SearchRequest();
searchRequest.source(new SearchSourceBuilder().query(neuralQueryBuilder));
NeuralQueryEnricherProcessor processor = createTestProcessor(factory);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,13 @@ public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() {
modelId = prepareModel();
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);

NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_DOC_TEXT1,
"",
modelId,
5,
null,
null,
null,
null,
null,
null,
null
);
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder()
.fieldName(TEST_KNN_VECTOR_FIELD_NAME_1)
.queryText(TEST_DOC_TEXT1)
.modelId(modelId)
.k(5)
.build();

TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
Expand Down Expand Up @@ -140,20 +133,13 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu
modelId = prepareModel();
createSearchPipelineWithDefaultResultsPostProcessor(SEARCH_PIPELINE);

NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_DOC_TEXT1,
"",
modelId,
5,
null,
null,
null,
null,
null,
null,
null
);
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder()
.fieldName(TEST_KNN_VECTOR_FIELD_NAME_1)
.queryText(TEST_DOC_TEXT1)
.modelId(modelId)
.k(5)
.build();

TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
Expand Down Expand Up @@ -182,20 +168,13 @@ public void testQueryMatches_whenMultipleShards_thenSuccessful() {
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);
int totalExpectedDocQty = 6;

NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_DOC_TEXT1,
"",
modelId,
6,
null,
null,
null,
null,
null,
null,
null
);
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder()
.fieldName(TEST_KNN_VECTOR_FIELD_NAME_1)
.queryText(TEST_DOC_TEXT1)
.modelId(modelId)
.k(6)
.build();

TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
Expand Down
Loading

0 comments on commit 344c2ef

Please sign in to comment.