Skip to content

Commit

Permalink
Correct NeuralQueryBuilder doEquals() and doHashCode().
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Zhang <[email protected]>
  • Loading branch information
bzhangam committed Jan 8, 2025
1 parent fea0a7f commit fb35c1b
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 133 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))
- Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043))
- Update NeuralQueryBuilder doEquals() and doHashCode() to cater the missing parameters information ([#1045](https://github.com/opensearch-project/neural-search/pull/1045)).
- Fix bug where embedding is missing when ingested document has "." in field name, and mismatches fieldMap config ([#1062](https://github.com/opensearch-project/neural-search/pull/1062))
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu

private final List<QueryBuilder> queries = new ArrayList<>();

private String fieldName;

static final int MAX_NUMBER_OF_SUB_QUERIES = 5;

public HybridQueryBuilder(StreamInput in) throws IOException {
Expand Down Expand Up @@ -255,7 +253,6 @@ protected boolean doEquals(HybridQueryBuilder obj) {
return false;
}
EqualsBuilder equalsBuilder = new EqualsBuilder();
equalsBuilder.append(fieldName, obj.fieldName);
equalsBuilder.append(queries, obj.queries);
return equalsBuilder.isEquals();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_TEXT;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.lucene.search.Query;
import org.opensearch.common.SetOnce;
import org.opensearch.core.ParseField;
Expand Down Expand Up @@ -236,7 +236,8 @@ public static NeuralQueryBuilder.Builder builder() {
public NeuralQueryBuilder(StreamInput in) throws IOException {
super(in);
this.fieldName = in.readString();
this.queryText = in.readString();
this.queryText = in.readOptionalString();
this.queryImage = in.readOptionalString();
// If cluster version is on or after 2.11 then default model Id support is enabled
if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
this.modelId = in.readOptionalString();
Expand Down Expand Up @@ -265,7 +266,8 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(this.fieldName);
out.writeString(this.queryText);
out.writeOptionalString(this.queryText);
out.writeOptionalString(this.queryImage);
// If cluster version is on or after 2.11 then default model Id support is enabled
if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
out.writeOptionalString(this.modelId);
Expand All @@ -285,6 +287,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if (isClusterOnOrAfterMinReqVersion(EXPAND_NESTED_FIELD.getPreferredName())) {
out.writeOptionalBoolean(this.expandNested);
}

if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
MethodParametersParser.streamOutput(out, methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
}
Expand All @@ -295,7 +298,12 @@ protected void doWriteTo(StreamOutput out) throws IOException {
protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
xContentBuilder.startObject(NAME);
xContentBuilder.startObject(fieldName);
xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText);
if (Objects.nonNull(queryText)) {
xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText);
}
if (Objects.nonNull(queryImage)) {
xContentBuilder.field(QUERY_IMAGE_FIELD.getPreferredName(), queryImage);
}
if (Objects.nonNull(modelId)) {
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
}
Expand Down Expand Up @@ -501,15 +509,39 @@ protected boolean doEquals(NeuralQueryBuilder obj) {
EqualsBuilder equalsBuilder = new EqualsBuilder();
equalsBuilder.append(fieldName, obj.fieldName);
equalsBuilder.append(queryText, obj.queryText);
equalsBuilder.append(queryImage, obj.queryImage);
equalsBuilder.append(modelId, obj.modelId);
equalsBuilder.append(k, obj.k);
equalsBuilder.append(maxDistance, obj.maxDistance);
equalsBuilder.append(minScore, obj.minScore);
equalsBuilder.append(expandNested, obj.expandNested);
equalsBuilder.append(getVector(vectorSupplier), getVector(obj.vectorSupplier));
equalsBuilder.append(filter, obj.filter);
equalsBuilder.append(methodParameters, obj.methodParameters);
equalsBuilder.append(rescoreContext, obj.rescoreContext);
return equalsBuilder.isEquals();
}

@Override
protected int doHashCode() {
return new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(k).toHashCode();
return Objects.hash(
fieldName,
queryText,
queryImage,
modelId,
k,
maxDistance,
minScore,
expandNested,
Arrays.hashCode(getVector(vectorSupplier)),
filter,
methodParameters,
rescoreContext
);
}

private float[] getVector(final Supplier<float[]> vectorSupplier) {
return Objects.isNull(vectorSupplier) ? null : vectorSupplier.get();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,6 @@ public void testStreams_whenWrittingToStream_thenSuccessful() {
.queryText(QUERY_TEXT)
.modelId(MODEL_ID)
.k(K)
.vectorSupplier(TEST_VECTOR_SUPPLIER)
.build();

original.add(neuralQueryBuilder);
Expand Down
Loading

0 comments on commit fb35c1b

Please sign in to comment.