Skip to content

Commit

Permalink
refactor code to remove duplicate
Browse files Browse the repository at this point in the history
Signed-off-by: will-hwang <[email protected]>
  • Loading branch information
will-hwang committed Jan 2, 2025
1 parent e4835f6 commit d8b7fe5
Showing 1 changed file with 23 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.opensearch.core.ParseField;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
Expand Down Expand Up @@ -201,8 +200,12 @@ public Builder rescoreContext(RescoreContext rescoreContext) {
}

public NeuralQueryBuilder build() {
validate();
int k = this.k == null ? 0 : this.k;
validateQueryParameters(fieldName, queryText, queryImage);
int k = this.k == null ? DEFAULT_K : this.k;
boolean queryTypeIsProvided = validateKNNQueryType(k, maxDistance, minScore);
if (queryTypeIsProvided == false) {
k = DEFAULT_K;
}
return new NeuralQueryBuilder(
fieldName,
queryText,
Expand All @@ -219,14 +222,6 @@ public NeuralQueryBuilder build() {
).boost(boost).queryName(queryName);
}

private void validate() {
if (Strings.isNullOrEmpty(fieldName)) {
throw new IllegalArgumentException("Field name must be provided for neural query");
}
if (Strings.isNullOrEmpty(queryText) && Strings.isNullOrEmpty(queryImage)) {
throw new IllegalArgumentException("Either query text or image text must be provided for neural query");
}
}
}

public static NeuralQueryBuilder.Builder builder() {
Expand Down Expand Up @@ -371,15 +366,16 @@ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOEx
+ "]"
);
}
if (StringUtils.isBlank(neuralQueryBuilder.queryText()) && StringUtils.isBlank(neuralQueryBuilder.queryImage())) {
throw new IllegalArgumentException("Either query text or image text must be provided for neural query");
}
requireValue(neuralQueryBuilder.fieldName(), "Field name must be provided for neural query");
validateQueryParameters(neuralQueryBuilder.fieldName(), neuralQueryBuilder.queryText(), neuralQueryBuilder.queryImage());
if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query");
}

boolean queryTypeIsProvided = validateKNNQueryType(neuralQueryBuilder);
boolean queryTypeIsProvided = validateKNNQueryType(
neuralQueryBuilder.k(),
neuralQueryBuilder.maxDistance(),
neuralQueryBuilder.minScore()
);
if (queryTypeIsProvided == false) {
neuralQueryBuilder.k(DEFAULT_K);
}
Expand Down Expand Up @@ -522,15 +518,22 @@ public String getWriteableName() {
return NAME;
}

private static boolean validateKNNQueryType(NeuralQueryBuilder neuralQueryBuilder) {
private static void validateQueryParameters(String fieldName, String queryText, String queryImage) {
if (StringUtils.isBlank(queryText) && StringUtils.isBlank(queryImage)) {
throw new IllegalArgumentException("Either query text or image text must be provided for neural query");
}
requireValue(fieldName, "Field name must be provided for neural query");
}

private static boolean validateKNNQueryType(Integer k, Float maxDistance, Float minScore) {
int queryCount = 0;
if (neuralQueryBuilder.k() != null) {
if (k != null) {
queryCount++;
}
if (neuralQueryBuilder.maxDistance() != null) {
if (maxDistance != null) {
queryCount++;
}
if (neuralQueryBuilder.minScore() != null) {
if (minScore != null) {
queryCount++;
}
if (queryCount > 1) {
Expand Down

0 comments on commit d8b7fe5

Please sign in to comment.