diff --git a/CHANGELOG.md b/CHANGELOG.md index e4448dea02fde..41dea2ef96700 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Support searching from doc_value using termQueryCaseInsensitive/termQuery in flat_object/keyword field([#16974](https://github.com/opensearch-project/OpenSearch/pull/16974/)) - Added a new `time` field to replace the deprecated `getTime` field in `GetStats`. ([#17009](https://github.com/opensearch-project/OpenSearch/pull/17009)) - Improve performance of the bitmap filtering([#16936](https://github.com/opensearch-project/OpenSearch/pull/16936/)) +- Introduce Template query ([#16818](https://github.com/opensearch-project/OpenSearch/pull/16818)) ### Dependencies - Bump `com.google.cloud:google-cloud-core-http` from 2.23.0 to 2.47.0 ([#16504](https://github.com/opensearch-project/OpenSearch/pull/16504)) diff --git a/server/build.gradle b/server/build.gradle index 1b40fc980a818..873a423b1380f 100644 --- a/server/build.gradle +++ b/server/build.gradle @@ -70,7 +70,6 @@ dependencies { api project(":libs:opensearch-telemetry") api project(":libs:opensearch-task-commons") - compileOnly project(':libs:opensearch-plugin-classloader') testRuntimeOnly project(':libs:opensearch-plugin-classloader') diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index dfec2e1fda738..898174d60de76 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -476,7 +476,7 @@ private void executeRequest( } else { Rewriteable.rewriteAndFetch( sr.source(), - searchService.getRewriteContext(timeProvider::getAbsoluteStartMillis), + searchService.getRewriteContext(timeProvider::getAbsoluteStartMillis, searchRequest), rewriteListener ); } diff --git a/server/src/main/java/org/opensearch/index/query/BaseQueryRewriteContext.java b/server/src/main/java/org/opensearch/index/query/BaseQueryRewriteContext.java new file mode 100644 index 0000000000000..7cfaf9edb4709 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/query/BaseQueryRewriteContext.java @@ -0,0 +1,140 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.query; + +import org.opensearch.client.Client; +import org.opensearch.common.util.concurrent.CountDown; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.LongSupplier; + +/** + * BaseQueryRewriteContext is a base implementation of the QueryRewriteContext interface. + * It provides core functionality for query rewriting operations in OpenSearch. + * + * This class manages the context for query rewriting, including handling of asynchronous actions, + * access to content registries, and time-related operations. + */ +public class BaseQueryRewriteContext implements QueryRewriteContext { + private final NamedXContentRegistry xContentRegistry; + private final NamedWriteableRegistry writeableRegistry; + protected final Client client; + protected final LongSupplier nowInMillis; + private final List>> asyncActions = new ArrayList<>(); + private final boolean validate; + + public BaseQueryRewriteContext( + NamedXContentRegistry xContentRegistry, + NamedWriteableRegistry writeableRegistry, + Client client, + LongSupplier nowInMillis + ) { + this(xContentRegistry, writeableRegistry, client, nowInMillis, false); + } + + public BaseQueryRewriteContext( + NamedXContentRegistry xContentRegistry, + NamedWriteableRegistry writeableRegistry, + Client client, + LongSupplier nowInMillis, + boolean validate + ) { + + this.xContentRegistry = xContentRegistry; + this.writeableRegistry = writeableRegistry; + this.client = client; + this.nowInMillis = nowInMillis; + this.validate = validate; + } + + /** + * The registry used to build new {@link XContentParser}s. Contains registered named parsers needed to parse the query. + */ + public NamedXContentRegistry getXContentRegistry() { + return xContentRegistry; + } + + /** + * Returns the time in milliseconds that is shared across all resources involved. Even across shards and nodes. + */ + public long nowInMillis() { + return nowInMillis.getAsLong(); + } + + public NamedWriteableRegistry getWriteableRegistry() { + return writeableRegistry; + } + + /** + * Returns an instance of {@link QueryShardContext} if available of null otherwise + */ + public QueryShardContext convertToShardContext() { + return null; + } + + /** + * Registers an async action that must be executed before the next rewrite round in order to make progress. + * This should be used if a rewriteabel needs to fetch some external resources in order to be executed ie. a document + * from an index. + */ + public void registerAsyncAction(BiConsumer> asyncAction) { + asyncActions.add(asyncAction); + } + + /** + * Returns true if there are any registered async actions. + */ + public boolean hasAsyncActions() { + return asyncActions.isEmpty() == false; + } + + /** + * Executes all registered async actions and notifies the listener once it's done. The value that is passed to the listener is always + * null. The list of registered actions is cleared once this method returns. + */ + public void executeAsyncActions(ActionListener listener) { + if (asyncActions.isEmpty()) { + listener.onResponse(null); + return; + } + + CountDown countDown = new CountDown(asyncActions.size()); + ActionListener internalListener = new ActionListener() { + @Override + public void onResponse(Object o) { + if (countDown.countDown()) { + listener.onResponse(null); + } + } + + @Override + public void onFailure(Exception e) { + if (countDown.fastForward()) { + listener.onFailure(e); + } + } + }; + // make a copy to prevent concurrent modification exception + List>> biConsumers = new ArrayList<>(asyncActions); + asyncActions.clear(); + for (BiConsumer> action : biConsumers) { + action.accept(client, internalListener); + } + } + + public boolean validate() { + return validate; + } +} diff --git a/server/src/main/java/org/opensearch/index/query/QueryBuilders.java b/server/src/main/java/org/opensearch/index/query/QueryBuilders.java index 387d21830aa38..1debba73136b2 100644 --- a/server/src/main/java/org/opensearch/index/query/QueryBuilders.java +++ b/server/src/main/java/org/opensearch/index/query/QueryBuilders.java @@ -50,6 +50,7 @@ import java.io.IOException; import java.util.Collection; import java.util.List; +import java.util.Map; /** * Utility class to create search queries. @@ -780,4 +781,13 @@ public static GeoShapeQueryBuilder geoDisjointQuery(String name, String indexedS public static ExistsQueryBuilder existsQuery(String name) { return new ExistsQueryBuilder(name); } + + /** + * A query that contains a template with holder that should be resolved by search processors + * + * @param content The content of the template + */ + public static TemplateQueryBuilder templateQuery(Map content) { + return new TemplateQueryBuilder(content); + } } diff --git a/server/src/main/java/org/opensearch/index/query/QueryCoordinatorContext.java b/server/src/main/java/org/opensearch/index/query/QueryCoordinatorContext.java new file mode 100644 index 0000000000000..c99a952ee42e3 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/query/QueryCoordinatorContext.java @@ -0,0 +1,93 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.query; + +import org.opensearch.client.Client; +import org.opensearch.common.annotation.PublicApi; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.search.pipeline.PipelinedRequest; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiConsumer; + +/** + * The QueryCoordinatorContext class implements the QueryRewriteContext interface and provides + * additional functionality for coordinating query rewriting in OpenSearch. + * + * This class acts as a wrapper around a QueryRewriteContext instance and a PipelinedRequest, + * allowing access to both rewrite context methods and pass over search request information. + * + * @since 2.19.0 + */ +@PublicApi(since = "2.19.0") +public class QueryCoordinatorContext implements QueryRewriteContext { + private final QueryRewriteContext rewriteContext; + private final PipelinedRequest searchRequest; + + public QueryCoordinatorContext(QueryRewriteContext rewriteContext, PipelinedRequest searchRequest) { + this.rewriteContext = rewriteContext; + this.searchRequest = searchRequest; + } + + @Override + public NamedXContentRegistry getXContentRegistry() { + return rewriteContext.getXContentRegistry(); + } + + @Override + public long nowInMillis() { + return rewriteContext.nowInMillis(); + } + + @Override + public NamedWriteableRegistry getWriteableRegistry() { + return rewriteContext.getWriteableRegistry(); + } + + @Override + public QueryShardContext convertToShardContext() { + return rewriteContext.convertToShardContext(); + } + + @Override + public void registerAsyncAction(BiConsumer> asyncAction) { + rewriteContext.registerAsyncAction(asyncAction); + } + + @Override + public boolean hasAsyncActions() { + return rewriteContext.hasAsyncActions(); + } + + @Override + public void executeAsyncActions(ActionListener listener) { + rewriteContext.executeAsyncActions(listener); + } + + @Override + public boolean validate() { + return rewriteContext.validate(); + } + + @Override + public QueryCoordinatorContext convertToCoordinatorContext() { + return this; + } + + public Map getContextVariables() { + + // Read from pipeline context + Map contextVariables = new HashMap<>(searchRequest.getPipelineProcessingContext().getAttributes()); + + return contextVariables; + } +} diff --git a/server/src/main/java/org/opensearch/index/query/QueryRewriteContext.java b/server/src/main/java/org/opensearch/index/query/QueryRewriteContext.java index 15a6d0b5a774e..aec5914066ab5 100644 --- a/server/src/main/java/org/opensearch/index/query/QueryRewriteContext.java +++ b/server/src/main/java/org/opensearch/index/query/QueryRewriteContext.java @@ -33,16 +33,12 @@ import org.opensearch.client.Client; import org.opensearch.common.annotation.PublicApi; -import org.opensearch.common.util.concurrent.CountDown; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import java.util.ArrayList; -import java.util.List; import java.util.function.BiConsumer; -import java.util.function.LongSupplier; /** * Context object used to rewrite {@link QueryBuilder} instances into simplified version. @@ -50,60 +46,27 @@ * @opensearch.api */ @PublicApi(since = "1.0.0") -public class QueryRewriteContext { - private final NamedXContentRegistry xContentRegistry; - private final NamedWriteableRegistry writeableRegistry; - protected final Client client; - protected final LongSupplier nowInMillis; - private final List>> asyncActions = new ArrayList<>(); - private final boolean validate; - - public QueryRewriteContext( - NamedXContentRegistry xContentRegistry, - NamedWriteableRegistry writeableRegistry, - Client client, - LongSupplier nowInMillis - ) { - this(xContentRegistry, writeableRegistry, client, nowInMillis, false); - } - - public QueryRewriteContext( - NamedXContentRegistry xContentRegistry, - NamedWriteableRegistry writeableRegistry, - Client client, - LongSupplier nowInMillis, - boolean validate - ) { - - this.xContentRegistry = xContentRegistry; - this.writeableRegistry = writeableRegistry; - this.client = client; - this.nowInMillis = nowInMillis; - this.validate = validate; - } - +public interface QueryRewriteContext { /** * The registry used to build new {@link XContentParser}s. Contains registered named parsers needed to parse the query. */ - public NamedXContentRegistry getXContentRegistry() { - return xContentRegistry; - } + NamedXContentRegistry getXContentRegistry(); /** * Returns the time in milliseconds that is shared across all resources involved. Even across shards and nodes. */ - public long nowInMillis() { - return nowInMillis.getAsLong(); - } + long nowInMillis(); - public NamedWriteableRegistry getWriteableRegistry() { - return writeableRegistry; - } + NamedWriteableRegistry getWriteableRegistry(); /** * Returns an instance of {@link QueryShardContext} if available of null otherwise */ - public QueryShardContext convertToShardContext() { + default QueryShardContext convertToShardContext() { + return null; + } + + default QueryCoordinatorContext convertToCoordinatorContext() { return null; } @@ -112,51 +75,18 @@ public QueryShardContext convertToShardContext() { * This should be used if a rewriteabel needs to fetch some external resources in order to be executed ie. a document * from an index. */ - public void registerAsyncAction(BiConsumer> asyncAction) { - asyncActions.add(asyncAction); - } + void registerAsyncAction(BiConsumer> asyncAction); /** * Returns true if there are any registered async actions. */ - public boolean hasAsyncActions() { - return asyncActions.isEmpty() == false; - } + boolean hasAsyncActions(); /** * Executes all registered async actions and notifies the listener once it's done. The value that is passed to the listener is always * null. The list of registered actions is cleared once this method returns. */ - public void executeAsyncActions(ActionListener listener) { - if (asyncActions.isEmpty()) { - listener.onResponse(null); - } else { - CountDown countDown = new CountDown(asyncActions.size()); - ActionListener internalListener = new ActionListener() { - @Override - public void onResponse(Object o) { - if (countDown.countDown()) { - listener.onResponse(null); - } - } - - @Override - public void onFailure(Exception e) { - if (countDown.fastForward()) { - listener.onFailure(e); - } - } - }; - // make a copy to prevent concurrent modification exception - List>> biConsumers = new ArrayList<>(asyncActions); - asyncActions.clear(); - for (BiConsumer> action : biConsumers) { - action.accept(client, internalListener); - } - } - } + void executeAsyncActions(ActionListener listener); - public boolean validate() { - return validate; - } + boolean validate(); } diff --git a/server/src/main/java/org/opensearch/index/query/QueryShardContext.java b/server/src/main/java/org/opensearch/index/query/QueryShardContext.java index d026c5b7b7c57..69599a7b84d54 100644 --- a/server/src/main/java/org/opensearch/index/query/QueryShardContext.java +++ b/server/src/main/java/org/opensearch/index/query/QueryShardContext.java @@ -101,7 +101,7 @@ * @opensearch.api */ @PublicApi(since = "1.0.0") -public class QueryShardContext extends QueryRewriteContext { +public class QueryShardContext extends BaseQueryRewriteContext { private final ScriptService scriptService; private final IndexSettings indexSettings; diff --git a/server/src/main/java/org/opensearch/index/query/TemplateQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/TemplateQueryBuilder.java new file mode 100644 index 0000000000000..85d119ab704ec --- /dev/null +++ b/server/src/main/java/org/opensearch/index/query/TemplateQueryBuilder.java @@ -0,0 +1,198 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.query; + +import org.apache.lucene.search.Query; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +/** + * A query builder that constructs a query based on a template and context variables. + * This query is designed to be rewritten with variables from search processors. + */ + +public class TemplateQueryBuilder extends AbstractQueryBuilder { + public static final String NAME = "template"; + public static final String queryName = "template"; + private final Map content; + + /** + * Constructs a new TemplateQueryBuilder with the given content. + * + * @param content The template content as a map. + */ + public TemplateQueryBuilder(Map content) { + this.content = content; + } + + /** + * Creates a TemplateQueryBuilder from XContent. + * + * @param parser The XContentParser to read from. + * @return A new TemplateQueryBuilder instance. + * @throws IOException If there's an error parsing the content. + */ + public static TemplateQueryBuilder fromXContent(XContentParser parser) throws IOException { + return new TemplateQueryBuilder(parser.map()); + } + + /** + * Constructs a TemplateQueryBuilder from a stream input. + * + * @param in The StreamInput to read from. + * @throws IOException If there's an error reading from the stream. + */ + public TemplateQueryBuilder(StreamInput in) throws IOException { + super(in); + this.content = in.readMap(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeMap(content); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(NAME, content); + } + + @Override + protected Query doToQuery(QueryShardContext context) throws IOException { + throw new IllegalStateException( + "Template queries cannot be converted directly to a query. Template Query must be rewritten first during doRewrite." + ); + } + + @Override + protected boolean doEquals(TemplateQueryBuilder other) { + return Objects.equals(this.content, other.content); + } + + @Override + protected int doHashCode() { + return Objects.hash(content); + } + + @Override + public String getWriteableName() { + return NAME; + } + + /** + * Gets the content of this template query. + * + * @return The template content as a map. + */ + public Map getContent() { + return content; + } + + /** + * Rewrites the template query by substituting variables from the context. + * + * @param queryCoordinatorContext The context for query rewriting. + * @return A rewritten QueryBuilder. + * @throws IOException If there's an error during rewriting. + */ + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryCoordinatorContext) throws IOException { + // the queryRewrite is expected at QueryCoordinator level + if (!(queryCoordinatorContext instanceof QueryCoordinatorContext)) { + throw new IllegalStateException( + "Template Query must be rewritten at the coordinator node. Rewriting at shard level is not supported." + ); + } + + QueryCoordinatorContext convertedQueryCoordinateContext = (QueryCoordinatorContext) queryCoordinatorContext; + Map contextVariables = convertedQueryCoordinateContext.getContextVariables(); + String queryString; + + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.map(this.content); + queryString = builder.toString(); + } + + // Convert Map to Map with proper JSON escaping + Map variablesMap = null; + if (contextVariables != null) { + variablesMap = contextVariables.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> { + try { + return JsonXContent.contentBuilder().value(entry.getValue()).toString(); + } catch (IOException e) { + throw new RuntimeException("Error converting contextVariables to JSON string", e); + } + })); + } + String newQueryContent = replaceVariables(queryString, variablesMap); + + try { + XContentParser parser = XContentType.JSON.xContent() + .createParser(queryCoordinatorContext.getXContentRegistry(), LoggingDeprecationHandler.INSTANCE, newQueryContent); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + QueryBuilder newQueryBuilder = parseInnerQueryBuilder(parser); + + return newQueryBuilder; + + } catch (Exception e) { + throw new IllegalArgumentException("Failed to rewrite template query: " + newQueryContent, e); + } + } + + private String replaceVariables(String template, Map variables) { + if (template == null || template.equals("null")) { + throw new IllegalArgumentException("Template string cannot be null. A valid template must be provided."); + } + if (template.isEmpty() || template.equals("{}")) { + throw new IllegalArgumentException("Template string cannot be empty. A valid template must be provided."); + } + if (variables == null || variables.isEmpty()) { + return template; + } + + StringBuilder result = new StringBuilder(); + int start = 0; + while (true) { + int startVar = template.indexOf("\"${", start); + if (startVar == -1) { + result.append(template.substring(start)); + break; + } + result.append(template, start, startVar); + int endVar = template.indexOf("}\"", startVar); + if (endVar == -1) { + throw new IllegalArgumentException("Unclosed variable in template: " + template.substring(startVar)); + } + String varName = template.substring(startVar + 3, endVar); + String replacement = variables.get(varName); + if (replacement == null) { + throw new IllegalArgumentException("Variable not found: " + varName); + } + result.append(replacement); + start = endVar + 2; + } + return result.toString(); + } + +} diff --git a/server/src/main/java/org/opensearch/indices/IndicesService.java b/server/src/main/java/org/opensearch/indices/IndicesService.java index b9ef074065359..2738e8248add6 100644 --- a/server/src/main/java/org/opensearch/indices/IndicesService.java +++ b/server/src/main/java/org/opensearch/indices/IndicesService.java @@ -124,6 +124,7 @@ import org.opensearch.index.mapper.IdFieldMapper; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.merge.MergeStats; +import org.opensearch.index.query.BaseQueryRewriteContext; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.recovery.RecoveryStats; @@ -2004,7 +2005,7 @@ public QueryRewriteContext getValidationRewriteContext(LongSupplier nowInMillis) * Returns a new {@link QueryRewriteContext} with the given {@code now} provider */ private QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, boolean validate) { - return new QueryRewriteContext(xContentRegistry, namedWriteableRegistry, client, nowInMillis, validate); + return new BaseQueryRewriteContext(xContentRegistry, namedWriteableRegistry, client, nowInMillis, validate); } /** diff --git a/server/src/main/java/org/opensearch/search/SearchModule.java b/server/src/main/java/org/opensearch/search/SearchModule.java index b8d3a13e0df20..4711b64383ce9 100644 --- a/server/src/main/java/org/opensearch/search/SearchModule.java +++ b/server/src/main/java/org/opensearch/search/SearchModule.java @@ -86,6 +86,7 @@ import org.opensearch.index.query.SpanOrQueryBuilder; import org.opensearch.index.query.SpanTermQueryBuilder; import org.opensearch.index.query.SpanWithinQueryBuilder; +import org.opensearch.index.query.TemplateQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.query.TermsQueryBuilder; import org.opensearch.index.query.TermsSetQueryBuilder; @@ -1208,7 +1209,7 @@ private void registerQueryParsers(List plugins) { registerQuery( new QuerySpec<>(MatchBoolPrefixQueryBuilder.NAME, MatchBoolPrefixQueryBuilder::new, MatchBoolPrefixQueryBuilder::fromXContent) ); - + registerQuery(new QuerySpec<>(TemplateQueryBuilder.NAME, TemplateQueryBuilder::new, TemplateQueryBuilder::fromXContent)); if (ShapesAvailability.JTS_AVAILABLE && ShapesAvailability.SPATIAL4J_AVAILABLE) { registerQuery(new QuerySpec<>(GeoShapeQueryBuilder.NAME, GeoShapeQueryBuilder::new, GeoShapeQueryBuilder::fromXContent)); } diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index b20f8222d6b7a..42749a8cdbf0c 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -86,6 +86,7 @@ import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.MatchNoneQueryBuilder; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryCoordinatorContext; import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.Rewriteable; @@ -127,6 +128,7 @@ import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.lookup.SearchLookup; +import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.profile.Profilers; import org.opensearch.search.query.QueryPhase; import org.opensearch.search.query.QuerySearchRequest; @@ -1776,8 +1778,8 @@ private void rewriteAndFetchShardRequest(IndexShard shard, ShardSearchRequest re /** * Returns a new {@link QueryRewriteContext} with the given {@code now} provider */ - public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis) { - return indicesService.getRewriteContext(nowInMillis); + public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, PipelinedRequest searchRequest) { + return new QueryCoordinatorContext(indicesService.getRewriteContext(nowInMillis), searchRequest); } /** diff --git a/server/src/main/java/org/opensearch/search/pipeline/PipelineProcessingContext.java b/server/src/main/java/org/opensearch/search/pipeline/PipelineProcessingContext.java index 7e86c30ddbbd9..c7fad1363cf2f 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/PipelineProcessingContext.java +++ b/server/src/main/java/org/opensearch/search/pipeline/PipelineProcessingContext.java @@ -57,4 +57,8 @@ public void addProcessorExecutionDetail(ProcessorExecutionDetail detail) { public List getProcessorExecutionDetails() { return Collections.unmodifiableList(processorExecutionDetails); } + + public Map getAttributes() { + return attributes; + } } diff --git a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java index f5ce94946dd32..b35784aef5582 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java +++ b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java @@ -61,4 +61,8 @@ public void transformSearchPhaseResults( Pipeline getPipeline() { return pipeline; } + + public PipelineProcessingContext getPipelineProcessingContext() { + return requestContext; + } } diff --git a/server/src/test/java/org/opensearch/index/mapper/DateFieldTypeTests.java b/server/src/test/java/org/opensearch/index/mapper/DateFieldTypeTests.java index 297c0e3e356dd..377addfb26396 100644 --- a/server/src/test/java/org/opensearch/index/mapper/DateFieldTypeTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/DateFieldTypeTests.java @@ -75,6 +75,7 @@ import org.opensearch.index.mapper.DateFieldMapper.Resolution; import org.opensearch.index.mapper.MappedFieldType.Relation; import org.opensearch.index.mapper.ParseContext.Document; +import org.opensearch.index.query.BaseQueryRewriteContext; import org.opensearch.index.query.DateRangeIncludingNowQuery; import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; @@ -99,7 +100,7 @@ public class DateFieldTypeTests extends FieldTypeTestCase { private static final long nowInMillis = 0; public void testIsFieldWithinRangeEmptyReader() throws IOException { - QueryRewriteContext context = new QueryRewriteContext(xContentRegistry(), writableRegistry(), null, () -> nowInMillis); + QueryRewriteContext context = new BaseQueryRewriteContext(xContentRegistry(), writableRegistry(), null, () -> nowInMillis); IndexReader reader = new MultiReader(); DateFieldType ft = new DateFieldType("my_date"); assertEquals( @@ -136,7 +137,7 @@ public void isFieldWithinRangeTestCase(DateFieldType ft) throws IOException { doTestIsFieldWithinQuery(ft, reader, DateTimeZone.UTC, null); doTestIsFieldWithinQuery(ft, reader, DateTimeZone.UTC, alternateFormat); - QueryRewriteContext context = new QueryRewriteContext(xContentRegistry(), writableRegistry(), null, () -> nowInMillis); + QueryRewriteContext context = new BaseQueryRewriteContext(xContentRegistry(), writableRegistry(), null, () -> nowInMillis); // Fields with no value indexed. DateFieldType ft2 = new DateFieldType("my_date2"); @@ -148,7 +149,7 @@ public void isFieldWithinRangeTestCase(DateFieldType ft) throws IOException { private void doTestIsFieldWithinQuery(DateFieldType ft, DirectoryReader reader, DateTimeZone zone, DateMathParser alternateFormat) throws IOException { - QueryRewriteContext context = new QueryRewriteContext(xContentRegistry(), writableRegistry(), null, () -> nowInMillis); + QueryRewriteContext context = new BaseQueryRewriteContext(xContentRegistry(), writableRegistry(), null, () -> nowInMillis); assertEquals( Relation.INTERSECTS, ft.isFieldWithinQuery(reader, "2015-10-09", "2016-01-02", randomBoolean(), randomBoolean(), null, null, context) diff --git a/server/src/test/java/org/opensearch/index/query/RewriteableTests.java b/server/src/test/java/org/opensearch/index/query/RewriteableTests.java index 6385a57f9f370..6e58023ecc7e2 100644 --- a/server/src/test/java/org/opensearch/index/query/RewriteableTests.java +++ b/server/src/test/java/org/opensearch/index/query/RewriteableTests.java @@ -45,7 +45,7 @@ public class RewriteableTests extends OpenSearchTestCase { public void testRewrite() throws IOException { - QueryRewriteContext context = new QueryRewriteContext(null, null, null, null); + QueryRewriteContext context = new BaseQueryRewriteContext(null, null, null, null); TestRewriteable rewrite = Rewriteable.rewrite( new TestRewriteable(randomIntBetween(0, Rewriteable.MAX_REWRITE_ROUNDS)), context, @@ -65,7 +65,7 @@ public void testRewrite() throws IOException { } public void testRewriteAndFetch() throws ExecutionException, InterruptedException { - QueryRewriteContext context = new QueryRewriteContext(null, null, null, null); + BaseQueryRewriteContext context = new BaseQueryRewriteContext(null, null, null, null); PlainActionFuture future = new PlainActionFuture<>(); Rewriteable.rewriteAndFetch(new TestRewriteable(randomIntBetween(0, Rewriteable.MAX_REWRITE_ROUNDS), true), context, future); TestRewriteable rewrite = future.get(); @@ -83,7 +83,7 @@ public void testRewriteAndFetch() throws ExecutionException, InterruptedExceptio } public void testRewriteList() throws IOException { - QueryRewriteContext context = new QueryRewriteContext(null, null, null, null); + BaseQueryRewriteContext context = new BaseQueryRewriteContext(null, null, null, null); List rewriteableList = new ArrayList<>(); int numInstances = randomIntBetween(1, 10); rewriteableList.add(new TestRewriteable(randomIntBetween(1, Rewriteable.MAX_REWRITE_ROUNDS))); diff --git a/server/src/test/java/org/opensearch/index/query/TemplateQueryBuilderTests.java b/server/src/test/java/org/opensearch/index/query/TemplateQueryBuilderTests.java new file mode 100644 index 0000000000000..4ea01818ca32e --- /dev/null +++ b/server/src/test/java/org/opensearch/index/query/TemplateQueryBuilderTests.java @@ -0,0 +1,834 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.query; + +import org.opensearch.client.Client; +import org.opensearch.common.geo.GeoPoint; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.function.BiConsumer; + +import static org.opensearch.index.query.TemplateQueryBuilder.NAME; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TemplateQueryBuilderTests extends OpenSearchTestCase { + + /** + * Tests the fromXContent method of TemplateQueryBuilder. + * Verifies that a TemplateQueryBuilder can be correctly created from XContent. + */ + public void testFromXContent() throws IOException { + /* + { + "template": { + "term": { + "message": { + "value": "foo" + } + } + } + } + */ + Map template = new HashMap<>(); + Map term = new HashMap<>(); + Map message = new HashMap<>(); + + message.put("value", "foo"); + term.put("message", message); + template.put("term", term); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(template); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + TemplateQueryBuilder templateQueryBuilder = TemplateQueryBuilder.fromXContent(contentParser); + + assertEquals(NAME, templateQueryBuilder.getWriteableName()); + assertEquals(template, templateQueryBuilder.getContent()); + + SearchSourceBuilder source = new SearchSourceBuilder().query(templateQueryBuilder); + assertEquals(source.toString(), "{\"query\":{\"template\":{\"term\":{\"message\":{\"value\":\"foo\"}}}}}"); + } + + /** + * Tests the query source generation of TemplateQueryBuilder. + * Verifies that the correct query source is generated from a TemplateQueryBuilder. + */ + public void testQuerySource() { + + Map template = new HashMap<>(); + Map term = new HashMap<>(); + Map message = new HashMap<>(); + + message.put("value", "foo"); + term.put("message", message); + template.put("term", term); + QueryBuilder incomingQuery = new TemplateQueryBuilder(template); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + assertEquals(source.toString(), "{\"query\":{\"template\":{\"term\":{\"message\":{\"value\":\"foo\"}}}}}"); + } + + /** + * Tests parsing a TemplateQueryBuilder from a JSON string. + * Verifies that the parsed query matches the expected structure and can be serialized and deserialized. + */ + public void testFromJson() throws IOException { + String jsonString = "{\n" + + " \"geo_shape\": {\n" + + " \"location\": {\n" + + " \"shape\": {\n" + + " \"type\": \"Envelope\",\n" + + " \"coordinates\": \"${modelPredictionOutcome}\"\n" + + " },\n" + + " \"relation\": \"intersects\"\n" + + " },\n" + + " \"ignore_unmapped\": false,\n" + + " \"boost\": 42.0\n" + + " }\n" + + "}"; + + XContentParser parser = XContentType.JSON.xContent() + .createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, jsonString); + parser.nextToken(); + TemplateQueryBuilder parsed = TemplateQueryBuilder.fromXContent(parser); + + // Check if the parsed query is an instance of TemplateQueryBuilder + assertNotNull(parsed); + assertTrue(parsed instanceof TemplateQueryBuilder); + + // Check if the content of the parsed query matches the expected content + Map expectedContent = new HashMap<>(); + Map geoShape = new HashMap<>(); + Map location = new HashMap<>(); + Map shape = new HashMap<>(); + + shape.put("type", "Envelope"); + shape.put("coordinates", "${modelPredictionOutcome}"); + location.put("shape", shape); + location.put("relation", "intersects"); + geoShape.put("location", location); + geoShape.put("ignore_unmapped", false); + geoShape.put("boost", 42.0); + expectedContent.put("geo_shape", geoShape); + + Map actualContent = new HashMap<>(); + actualContent.put("template", expectedContent); + assertEquals(expectedContent, parsed.getContent()); + + // Test that the query can be serialized and deserialized + BytesStreamOutput out = new BytesStreamOutput(); + parsed.writeTo(out); + StreamInput in = out.bytes().streamInput(); + TemplateQueryBuilder deserializedQuery = new TemplateQueryBuilder(in); + assertEquals(parsed.getContent(), deserializedQuery.getContent()); + + // Test that the query can be converted to XContent + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + parsed.doXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + + Map expectedJson = new HashMap<>(); + Map template = new HashMap<>(); + template.put("geo_shape", geoShape); + expectedJson.put("template", template); + + XContentParser jsonParser = XContentType.JSON.xContent() + .createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, builder.toString()); + Map actualJson = jsonParser.map(); + + assertEquals(expectedJson, actualJson); + } + + /** + * Tests the constructor and getter methods of TemplateQueryBuilder. + * Verifies that the content and writeable name are correctly set and retrieved. + */ + public void testConstructorAndGetters() { + Map content = new HashMap<>(); + content.put("key", "value"); + TemplateQueryBuilder builder = new TemplateQueryBuilder(content); + + assertEquals(content, builder.getContent()); + assertEquals(NAME, builder.getWriteableName()); + } + + /** + * Tests the equals and hashCode methods of TemplateQueryBuilder. + * Verifies that two builders with the same content are equal and have the same hash code, + * while builders with different content are not equal and have different hash codes. + */ + public void testEqualsAndHashCode() { + Map content1 = new HashMap<>(); + content1.put("key", "value"); + TemplateQueryBuilder builder1 = new TemplateQueryBuilder(content1); + + Map content2 = new HashMap<>(); + content2.put("key", "value"); + TemplateQueryBuilder builder2 = new TemplateQueryBuilder(content2); + + Map content3 = new HashMap<>(); + content3.put("key", "different_value"); + TemplateQueryBuilder builder3 = new TemplateQueryBuilder(content3); + + assertTrue(builder1.equals(builder2)); + assertTrue(builder1.hashCode() == builder2.hashCode()); + assertFalse(builder1.equals(builder3)); + assertFalse(builder1.hashCode() == builder3.hashCode()); + } + + /** + * Tests the doToQuery method of TemplateQueryBuilder. + * Verifies that calling doToQuery throws an IllegalStateException. + */ + public void testDoToQuery() { + Map content = new HashMap<>(); + content.put("key", "value"); + TemplateQueryBuilder builder = new TemplateQueryBuilder(content); + + QueryShardContext mockContext = mock(QueryShardContext.class); + expectThrows(IllegalStateException.class, () -> builder.doToQuery(mockContext)); + } + + /** + * Tests the serialization and deserialization of TemplateQueryBuilder. + * Verifies that a builder can be written to a stream and read back correctly. + */ + public void testStreamRoundTrip() throws IOException { + Map content = new HashMap<>(); + content.put("key", "value"); + TemplateQueryBuilder original = new TemplateQueryBuilder(content); + + BytesStreamOutput out = new BytesStreamOutput(); + original.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + TemplateQueryBuilder deserialized = new TemplateQueryBuilder(in); + + assertEquals(original, deserialized); + } + + /** + * Tests the doRewrite method of TemplateQueryBuilder with a simple term query. + * Verifies that the template is correctly rewritten to a TermQueryBuilder. + */ + public void testDoRewrite() throws IOException { + + Map template = new HashMap<>(); + Map term = new HashMap<>(); + Map message = new HashMap<>(); + + message.put("value", "foo"); + term.put("message", message); + template.put("term", term); + TemplateQueryBuilder templateQueryBuilder = new TemplateQueryBuilder(template); + TermQueryBuilder termQueryBuilder = new TermQueryBuilder("message", "foo"); + + QueryCoordinatorContext queryRewriteContext = mockQueryRewriteContext(); + + Map contextVariables = new HashMap<>(); + when(queryRewriteContext.getContextVariables()).thenReturn(contextVariables); + + TermQueryBuilder newQuery = (TermQueryBuilder) templateQueryBuilder.doRewrite(queryRewriteContext); + + assertEquals(newQuery, termQueryBuilder); + assertEquals( + "{\n" + + " \"term\" : {\n" + + " \"message\" : {\n" + + " \"value\" : \"foo\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + newQuery.toString() + ); + } + + /** + * Tests the doRewrite method of TemplateQueryBuilder with a string variable. + * Verifies that the template is correctly rewritten with the variable substituted. + */ + public void testDoRewriteWithString() throws IOException { + + Map template = new HashMap<>(); + Map term = new HashMap<>(); + Map message = new HashMap<>(); + + message.put("value", "${response}"); + term.put("message", message); + template.put("term", term); + TemplateQueryBuilder templateQueryBuilder = new TemplateQueryBuilder(template); + TermQueryBuilder termQueryBuilder = new TermQueryBuilder("message", "foo"); + + QueryCoordinatorContext queryRewriteContext = mockQueryRewriteContext(); + + Map contextVariables = new HashMap<>(); + contextVariables.put("response", "foo"); + when(queryRewriteContext.getContextVariables()).thenReturn(contextVariables); + + TermQueryBuilder newQuery = (TermQueryBuilder) templateQueryBuilder.doRewrite(queryRewriteContext); + + assertEquals(newQuery, termQueryBuilder); + assertEquals( + "{\n" + + " \"term\" : {\n" + + " \"message\" : {\n" + + " \"value\" : \"foo\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + newQuery.toString() + ); + } + + /** + * Tests the doRewrite method of TemplateQueryBuilder with a list variable. + * Verifies that the template is correctly rewritten with the list variable substituted. + */ + public void testDoRewriteWithList() throws IOException { + ArrayList termsList = new ArrayList<>(); + termsList.add("foo"); + termsList.add("bar"); + + Map template = new HashMap<>(); + Map terms = new HashMap<>(); + + terms.put("message", "${response}"); + template.put("terms", terms); + TemplateQueryBuilder templateQueryBuilder = new TemplateQueryBuilder(template); + + TermsQueryBuilder termsQueryBuilder = new TermsQueryBuilder("message", termsList); + + QueryCoordinatorContext queryRewriteContext = mockQueryRewriteContext(); + + Map contextVariables = new HashMap<>(); + contextVariables.put("response", termsList); + when(queryRewriteContext.getContextVariables()).thenReturn(contextVariables); + NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry( + new SearchModule(Settings.EMPTY, List.of()).getNamedXContents() + ); + when(queryRewriteContext.getXContentRegistry()).thenReturn(TEST_XCONTENT_REGISTRY_FOR_QUERY); + TermsQueryBuilder newQuery = (TermsQueryBuilder) templateQueryBuilder.doRewrite(queryRewriteContext); + assertEquals(newQuery, termsQueryBuilder); + assertEquals( + "{\n" + + " \"terms\" : {\n" + + " \"message\" : [\n" + + " \"foo\",\n" + + " \"bar\"\n" + + " ],\n" + + " \"boost\" : 1.0\n" + + " }\n" + + "}", + newQuery.toString() + ); + } + + /** + * Tests the doRewrite method of TemplateQueryBuilder with a geo_distance query. + * Verifies that the template is correctly rewritten for a geo_distance query. + */ + public void testDoRewriteWithGeoDistanceQuery() throws IOException { + Map template = new HashMap<>(); + Map geoDistance = new HashMap<>(); + + geoDistance.put("distance", "12km"); + geoDistance.put("pin.location", "${geoPoint}"); + template.put("geo_distance", geoDistance); + + TemplateQueryBuilder templateQueryBuilder = new TemplateQueryBuilder(template); + + GeoPoint geoPoint = new GeoPoint(40, -70); + + QueryCoordinatorContext queryRewriteContext = mockQueryRewriteContext(); + Map contextVariables = new HashMap<>(); + contextVariables.put("geoPoint", geoPoint); + when(queryRewriteContext.getContextVariables()).thenReturn(contextVariables); + + GeoDistanceQueryBuilder expectedQuery = new GeoDistanceQueryBuilder("pin.location"); + expectedQuery.point(geoPoint).distance("12km"); + + QueryBuilder newQuery = templateQueryBuilder.doRewrite(queryRewriteContext); + assertEquals(expectedQuery, newQuery); + assertEquals( + "{\n" + + " \"geo_distance\" : {\n" + + " \"pin.location\" : [\n" + + " -70.0,\n" + + " 40.0\n" + + " ],\n" + + " \"distance\" : 12000.0,\n" + + " \"distance_type\" : \"arc\",\n" + + " \"validation_method\" : \"STRICT\",\n" + + " \"ignore_unmapped\" : false,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + "}", + newQuery.toString() + ); + } + + /** + * Tests the doRewrite method of TemplateQueryBuilder with a range query. + * Verifies that the template is correctly rewritten for a range query. + */ + public void testDoRewriteWithRangeQuery() throws IOException { + Map template = new HashMap<>(); + Map range = new HashMap<>(); + Map age = new HashMap<>(); + + age.put("gte", "${minAge}"); + age.put("lte", "${maxAge}"); + range.put("age", age); + template.put("range", range); + + TemplateQueryBuilder templateQueryBuilder = new TemplateQueryBuilder(template); + + QueryCoordinatorContext queryRewriteContext = mockQueryRewriteContext(); + Map contextVariables = new HashMap<>(); + contextVariables.put("minAge", 25); + contextVariables.put("maxAge", 35); + when(queryRewriteContext.getContextVariables()).thenReturn(contextVariables); + + RangeQueryBuilder expectedQuery = new RangeQueryBuilder("age"); + expectedQuery.gte(25).lte(35); + + QueryBuilder newQuery = templateQueryBuilder.doRewrite(queryRewriteContext); + assertEquals(expectedQuery, newQuery); + assertEquals( + "{\n" + + " \"range\" : {\n" + + " \"age\" : {\n" + + " \"from\" : 25,\n" + + " \"to\" : 35,\n" + + " \"include_lower\" : true,\n" + + " \"include_upper\" : true,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + newQuery.toString() + ); + } + + /** + * Tests the doRewrite method of TemplateQueryBuilder with a nested map variable. + * Verifies that the template is correctly rewritten with the nested map variable substituted. + */ + public void testDoRewriteWithNestedMap() throws IOException { + Map template = new HashMap<>(); + Map bool = new HashMap<>(); + List> must = new ArrayList<>(); + Map match = new HashMap<>(); + Map textEntry = new HashMap<>(); + + textEntry.put("text_entry", "${keyword}"); + match.put("match", textEntry); + must.add(match); + bool.put("must", must); + + List> should = new ArrayList<>(); + Map shouldMatch1 = new HashMap<>(); + Map shouldTextEntry1 = new HashMap<>(); + shouldTextEntry1.put("text_entry", "life"); + shouldMatch1.put("match", shouldTextEntry1); + should.add(shouldMatch1); + + Map shouldMatch2 = new HashMap<>(); + Map shouldTextEntry2 = new HashMap<>(); + shouldTextEntry2.put("text_entry", "grace"); + shouldMatch2.put("match", shouldTextEntry2); + should.add(shouldMatch2); + + bool.put("should", should); + bool.put("minimum_should_match", 1); + + Map filter = new HashMap<>(); + Map term = new HashMap<>(); + term.put("play_name", "Romeo and Juliet"); + filter.put("term", term); + bool.put("filter", filter); + + template.put("bool", bool); + + TemplateQueryBuilder templateQueryBuilder = new TemplateQueryBuilder(template); + + QueryCoordinatorContext queryRewriteContext = mockQueryRewriteContext(); + Map contextVariables = new HashMap<>(); + contextVariables.put("keyword", "love"); + when(queryRewriteContext.getContextVariables()).thenReturn(contextVariables); + + BoolQueryBuilder expectedQuery = new BoolQueryBuilder().must(new MatchQueryBuilder("text_entry", "love")) + .should(new MatchQueryBuilder("text_entry", "life")) + .should(new MatchQueryBuilder("text_entry", "grace")) + .filter(new TermQueryBuilder("play_name", "Romeo and Juliet")) + .minimumShouldMatch(1); + + QueryBuilder newQuery = templateQueryBuilder.doRewrite(queryRewriteContext); + assertEquals(expectedQuery, newQuery); + assertEquals( + "{\n" + + " \"bool\" : {\n" + + " \"must\" : [\n" + + " {\n" + + " \"match\" : {\n" + + " \"text_entry\" : {\n" + + " \"query\" : \"love\",\n" + + " \"operator\" : \"OR\",\n" + + " \"prefix_length\" : 0,\n" + + " \"max_expansions\" : 50,\n" + + " \"fuzzy_transpositions\" : true,\n" + + " \"lenient\" : false,\n" + + " \"zero_terms_query\" : \"NONE\",\n" + + " \"auto_generate_synonyms_phrase_query\" : true,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"filter\" : [\n" + + " {\n" + + " \"term\" : {\n" + + " \"play_name\" : {\n" + + " \"value\" : \"Romeo and Juliet\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"should\" : [\n" + + " {\n" + + " \"match\" : {\n" + + " \"text_entry\" : {\n" + + " \"query\" : \"life\",\n" + + " \"operator\" : \"OR\",\n" + + " \"prefix_length\" : 0,\n" + + " \"max_expansions\" : 50,\n" + + " \"fuzzy_transpositions\" : true,\n" + + " \"lenient\" : false,\n" + + " \"zero_terms_query\" : \"NONE\",\n" + + " \"auto_generate_synonyms_phrase_query\" : true,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"match\" : {\n" + + " \"text_entry\" : {\n" + + " \"query\" : \"grace\",\n" + + " \"operator\" : \"OR\",\n" + + " \"prefix_length\" : 0,\n" + + " \"max_expansions\" : 50,\n" + + " \"fuzzy_transpositions\" : true,\n" + + " \"lenient\" : false,\n" + + " \"zero_terms_query\" : \"NONE\",\n" + + " \"auto_generate_synonyms_phrase_query\" : true,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"adjust_pure_negative\" : true,\n" + + " \"minimum_should_match\" : \"1\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + "}", + newQuery.toString() + ); + } + + /** + * Tests the doRewrite method with an invalid query type. + * Verifies that an IOException is thrown when an invalid query type is used. + */ + public void testDoRewriteWithInvalidQueryType() throws IOException { + Map template = new HashMap<>(); + template.put("invalid_query_type", new HashMap<>()); + TemplateQueryBuilder templateQueryBuilder = new TemplateQueryBuilder(template); + + QueryCoordinatorContext queryRewriteContext = mockQueryRewriteContext(); + when(queryRewriteContext.getContextVariables()).thenReturn(new HashMap<>()); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> templateQueryBuilder.doRewrite(queryRewriteContext) + ); + assertTrue(exception.getMessage().contains("Failed to rewrite template query")); + } + + /** + * Tests the doRewrite method with a malformed JSON query. + * Verifies that an IOException is thrown when the query JSON is malformed. + */ + public void testDoRewriteWithMalformedJson() throws IOException { + Map template = new HashMap<>(); + template.put("malformed_json", "{ this is not valid JSON }"); + TemplateQueryBuilder templateQueryBuilder = new TemplateQueryBuilder(template); + + QueryCoordinatorContext queryRewriteContext = mockQueryRewriteContext(); + when(queryRewriteContext.getContextVariables()).thenReturn(new HashMap<>()); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> templateQueryBuilder.doRewrite(queryRewriteContext) + ); + assertTrue(exception.getMessage().contains("Failed to rewrite template query")); + } + + /** + * Tests the doRewrite method with an invalid matchall query. + * Verifies that an IOException is thrown when an invalid matchall query is used. + */ + public void testDoRewriteWithInvalidMatchAllQuery() throws IOException { + Map template = new HashMap<>(); + template.put("matchall_1", new HashMap<>()); + TemplateQueryBuilder templateQueryBuilder = new TemplateQueryBuilder(template); + + QueryCoordinatorContext queryRewriteContext = mockQueryRewriteContext(); + when(queryRewriteContext.getContextVariables()).thenReturn(new HashMap<>()); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> templateQueryBuilder.doRewrite(queryRewriteContext) + ); + assertTrue(exception.getMessage().contains("Failed to rewrite template query")); + } + + /** + * Tests the doRewrite method with a missing required field in a query. + * Verifies that an IOException is thrown when a required field is missing. + */ + public void testDoRewriteWithMissingRequiredField() throws IOException { + Map template = new HashMap<>(); + template.put("term", "value");// Missing the required field for term query + TemplateQueryBuilder templateQueryBuilder = new TemplateQueryBuilder(template); + + QueryCoordinatorContext queryRewriteContext = mockQueryRewriteContext(); + when(queryRewriteContext.getContextVariables()).thenReturn(new HashMap<>()); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> templateQueryBuilder.doRewrite(queryRewriteContext) + ); + assertTrue(exception.getMessage().contains("Failed to rewrite template query")); + } + + /** + * Tests the doRewrite method with a malformed variable substitution. + * Verifies that an IOException is thrown when a malformed variable is used. + */ + public void testDoRewriteWithMalformedVariableSubstitution() throws IOException { + + Map template = new HashMap<>(); + Map terms = new HashMap<>(); + + terms.put("message", "${response}"); + template.put("terms", terms); + TemplateQueryBuilder templateQueryBuilder = new TemplateQueryBuilder(template); + + QueryCoordinatorContext queryRewriteContext = mockQueryRewriteContext(); + + Map contextVariables = new HashMap<>(); + contextVariables.put("response", "should be a list but this is a string"); + + when(queryRewriteContext.getContextVariables()).thenReturn(contextVariables); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> templateQueryBuilder.doRewrite(queryRewriteContext) + ); + + assertTrue(exception.getMessage().contains("Failed to rewrite template query")); + } + + /** + * Tests the doRewrite method with a variable not found. + * Verifies that an IOException is thrown when a malformed variable is used. + */ + public void testDoRewriteWithNotFoundVariableSubstitution() throws IOException { + + Map template = new HashMap<>(); + Map term = new HashMap<>(); + Map message = new HashMap<>(); + + message.put("value", "${response}"); + term.put("message", message); + template.put("term", term); + TemplateQueryBuilder templateQueryBuilder = new TemplateQueryBuilder(template); + + QueryCoordinatorContext queryRewriteContext = mockQueryRewriteContext(); + + Map contextVariables = new HashMap<>(); + contextVariables.put("response1", "foo"); + when(queryRewriteContext.getContextVariables()).thenReturn(contextVariables); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> templateQueryBuilder.doRewrite(queryRewriteContext) + ); + assertTrue(exception.getMessage().contains("Variable not found")); + } + + /** + * Tests the doRewrite method of TemplateQueryBuilder with a missing bracket variable. + * Verifies that the exception is thrown + */ + public void testDoRewriteWithMissingBracketVariable() throws IOException { + + Map template = new HashMap<>(); + Map term = new HashMap<>(); + Map message = new HashMap<>(); + + message.put("value", "${response"); + term.put("message", message); + template.put("term", term); + TemplateQueryBuilder templateQueryBuilder = new TemplateQueryBuilder(template); + + QueryCoordinatorContext queryRewriteContext = mockQueryRewriteContext(); + + Map contextVariables = new HashMap<>(); + contextVariables.put("response", "foo"); + when(queryRewriteContext.getContextVariables()).thenReturn(contextVariables); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> templateQueryBuilder.doRewrite(queryRewriteContext) + ); + assertTrue(exception.getMessage().contains("Unclosed variable in template")); + } + + /** + * Tests the replaceVariables method when the template is null. + * Verifies that an IllegalArgumentException is thrown with the appropriate error message. + */ + + public void testReplaceVariablesWithNullTemplate() { + TemplateQueryBuilder templateQueryBuilder = new TemplateQueryBuilder((Map) null); + + QueryCoordinatorContext queryRewriteContext = mockQueryRewriteContext(); + Map contextVariables = new HashMap<>(); + contextVariables.put("response", "foo"); + when(queryRewriteContext.getContextVariables()).thenReturn(contextVariables); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> templateQueryBuilder.doRewrite(queryRewriteContext) + ); + assertEquals("Template string cannot be null. A valid template must be provided.", exception.getMessage()); + } + + /** + * Tests the replaceVariables method when the template is empty. + * Verifies that an IllegalArgumentException is thrown with the appropriate error message. + */ + + public void testReplaceVariablesWithEmptyTemplate() { + Map template = new HashMap<>(); + TemplateQueryBuilder templateQueryBuilder = new TemplateQueryBuilder(template); + + QueryCoordinatorContext queryRewriteContext = mockQueryRewriteContext(); + Map contextVariables = new HashMap<>(); + contextVariables.put("response", "foo"); + when(queryRewriteContext.getContextVariables()).thenReturn(contextVariables); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> templateQueryBuilder.doRewrite(queryRewriteContext) + ); + assertEquals("Template string cannot be empty. A valid template must be provided.", exception.getMessage()); + + } + + /** + * Tests the replaceVariables method when the variables map is null. + * Verifies that the method returns the original template unchanged, + * since a null variables map is treated as no replacement. + */ + public void testReplaceVariablesWithNullVariables() throws IOException { + + Map template = new HashMap<>(); + Map term = new HashMap<>(); + Map message = new HashMap<>(); + + message.put("value", "foo"); + term.put("message", message); + template.put("term", term); + TemplateQueryBuilder templateQueryBuilder = new TemplateQueryBuilder(template); + TermQueryBuilder termQueryBuilder = new TermQueryBuilder("message", "foo"); + + QueryCoordinatorContext queryRewriteContext = mockQueryRewriteContext(); + + when(queryRewriteContext.getContextVariables()).thenReturn(null); + + TermQueryBuilder newQuery = (TermQueryBuilder) templateQueryBuilder.doRewrite(queryRewriteContext); + + assertEquals(newQuery, termQueryBuilder); + assertEquals( + "{\n" + + " \"term\" : {\n" + + " \"message\" : {\n" + + " \"value\" : \"foo\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + newQuery.toString() + ); + } + + /** + * Helper method to create a mock QueryCoordinatorContext for testing. + */ + private QueryCoordinatorContext mockQueryRewriteContext() { + QueryCoordinatorContext queryRewriteContext = mock(QueryCoordinatorContext.class); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + doAnswer(invocation -> { + BiConsumer> biConsumer = invocation.getArgument(0); + biConsumer.accept( + null, + ActionListener.wrap( + response -> inProgressLatch.countDown(), + err -> fail("Failed to set query tokens supplier: " + err.getMessage()) + ) + ); + return null; + }).when(queryRewriteContext).registerAsyncAction(any()); + + NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry( + new SearchModule(Settings.EMPTY, List.of()).getNamedXContents() + ); + when(queryRewriteContext.getXContentRegistry()).thenReturn(TEST_XCONTENT_REGISTRY_FOR_QUERY); + + return queryRewriteContext; + } +} diff --git a/server/src/test/java/org/opensearch/search/SearchModuleTests.java b/server/src/test/java/org/opensearch/search/SearchModuleTests.java index 81b7ca8aef30b..6514e06fdf49a 100644 --- a/server/src/test/java/org/opensearch/search/SearchModuleTests.java +++ b/server/src/test/java/org/opensearch/search/SearchModuleTests.java @@ -615,7 +615,8 @@ public Optional create(IndexSettings indexSettin "terms_set", "wildcard", "wrapper", - "distance_feature" }; + "distance_feature", + "template" }; // add here deprecated queries to make sure we log a deprecation warnings when they are used private static final String[] DEPRECATED_QUERIES = new String[] { "common", "field_masking_span" }; diff --git a/server/src/test/java/org/opensearch/search/aggregations/AggregatorFactoriesTests.java b/server/src/test/java/org/opensearch/search/aggregations/AggregatorFactoriesTests.java index c930d27b068f8..a5724d3c34352 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/AggregatorFactoriesTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/AggregatorFactoriesTests.java @@ -45,6 +45,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.env.Environment; +import org.opensearch.index.query.BaseQueryRewriteContext; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; @@ -255,7 +256,7 @@ public void testRewriteAggregation() throws Exception { BucketScriptPipelineAggregationBuilder pipelineAgg = new BucketScriptPipelineAggregationBuilder("const", new Script("1")); AggregatorFactories.Builder builder = new AggregatorFactories.Builder().addAggregator(filterAggBuilder) .addPipelineAggregator(pipelineAgg); - AggregatorFactories.Builder rewritten = builder.rewrite(new QueryRewriteContext(xContentRegistry, null, null, () -> 0L)); + AggregatorFactories.Builder rewritten = builder.rewrite(new BaseQueryRewriteContext(xContentRegistry, null, null, () -> 0L)); assertNotSame(builder, rewritten); Collection aggregatorFactories = rewritten.getAggregatorFactories(); assertEquals(1, aggregatorFactories.size()); @@ -268,7 +269,9 @@ public void testRewriteAggregation() throws Exception { assertThat(rewrittenFilter, instanceOf(TermsQueryBuilder.class)); // Check that a further rewrite returns the same aggregation factories builder - AggregatorFactories.Builder secondRewritten = rewritten.rewrite(new QueryRewriteContext(xContentRegistry, null, null, () -> 0L)); + AggregatorFactories.Builder secondRewritten = rewritten.rewrite( + new BaseQueryRewriteContext(xContentRegistry, null, null, () -> 0L) + ); assertSame(rewritten, secondRewritten); } @@ -277,7 +280,7 @@ public void testRewritePipelineAggregationUnderAggregation() throws Exception { new RewrittenPipelineAggregationBuilder() ); AggregatorFactories.Builder builder = new AggregatorFactories.Builder().addAggregator(filterAggBuilder); - QueryRewriteContext context = new QueryRewriteContext(xContentRegistry, null, null, () -> 0L); + QueryRewriteContext context = new BaseQueryRewriteContext(xContentRegistry, null, null, () -> 0L); AggregatorFactories.Builder rewritten = builder.rewrite(context); CountDownLatch latch = new CountDownLatch(1); context.executeAsyncActions(new ActionListener() { @@ -304,7 +307,7 @@ public void testRewriteAggregationAtTopLevel() throws Exception { FilterAggregationBuilder filterAggBuilder = new FilterAggregationBuilder("titles", new MatchAllQueryBuilder()); AggregatorFactories.Builder builder = new AggregatorFactories.Builder().addAggregator(filterAggBuilder) .addPipelineAggregator(new RewrittenPipelineAggregationBuilder()); - QueryRewriteContext context = new QueryRewriteContext(xContentRegistry, null, null, () -> 0L); + QueryRewriteContext context = new BaseQueryRewriteContext(xContentRegistry, null, null, () -> 0L); AggregatorFactories.Builder rewritten = builder.rewrite(context); CountDownLatch latch = new CountDownLatch(1); context.executeAsyncActions(new ActionListener() { diff --git a/server/src/test/java/org/opensearch/search/aggregations/bucket/FiltersTests.java b/server/src/test/java/org/opensearch/search/aggregations/bucket/FiltersTests.java index 56f7f450dbdfb..770f18f781689 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/bucket/FiltersTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/bucket/FiltersTests.java @@ -36,12 +36,12 @@ import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.BaseQueryRewriteContext; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.MatchNoneQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.BaseAggregationTestCase; import org.opensearch.search.aggregations.bucket.filter.FiltersAggregationBuilder; @@ -147,12 +147,12 @@ public void testRewrite() throws IOException { // test non-keyed filter that doesn't rewrite AggregationBuilder original = new FiltersAggregationBuilder("my-agg", new MatchAllQueryBuilder()); original.setMetadata(Collections.singletonMap(randomAlphaOfLengthBetween(1, 20), randomAlphaOfLengthBetween(1, 20))); - AggregationBuilder rewritten = original.rewrite(new QueryRewriteContext(xContentRegistry(), null, null, () -> 0L)); + AggregationBuilder rewritten = original.rewrite(new BaseQueryRewriteContext(xContentRegistry(), null, null, () -> 0L)); assertSame(original, rewritten); // test non-keyed filter that does rewrite original = new FiltersAggregationBuilder("my-agg", new BoolQueryBuilder()); - rewritten = original.rewrite(new QueryRewriteContext(xContentRegistry(), null, null, () -> 0L)); + rewritten = original.rewrite(new BaseQueryRewriteContext(xContentRegistry(), null, null, () -> 0L)); assertNotSame(original, rewritten); assertThat(rewritten, instanceOf(FiltersAggregationBuilder.class)); assertEquals("my-agg", ((FiltersAggregationBuilder) rewritten).getName()); @@ -163,12 +163,12 @@ public void testRewrite() throws IOException { // test keyed filter that doesn't rewrite original = new FiltersAggregationBuilder("my-agg", new KeyedFilter("my-filter", new MatchAllQueryBuilder())); - rewritten = original.rewrite(new QueryRewriteContext(xContentRegistry(), null, null, () -> 0L)); + rewritten = original.rewrite(new BaseQueryRewriteContext(xContentRegistry(), null, null, () -> 0L)); assertSame(original, rewritten); // test non-keyed filter that does rewrite original = new FiltersAggregationBuilder("my-agg", new KeyedFilter("my-filter", new BoolQueryBuilder())); - rewritten = original.rewrite(new QueryRewriteContext(xContentRegistry(), null, null, () -> 0L)); + rewritten = original.rewrite(new BaseQueryRewriteContext(xContentRegistry(), null, null, () -> 0L)); assertNotSame(original, rewritten); assertThat(rewritten, instanceOf(FiltersAggregationBuilder.class)); assertEquals("my-agg", ((FiltersAggregationBuilder) rewritten).getName()); @@ -180,7 +180,7 @@ public void testRewrite() throws IOException { // test sub-agg filter that does rewrite original = new TermsAggregationBuilder("terms").userValueTypeHint(ValueType.BOOLEAN) .subAggregation(new FiltersAggregationBuilder("my-agg", new KeyedFilter("my-filter", new BoolQueryBuilder()))); - rewritten = original.rewrite(new QueryRewriteContext(xContentRegistry(), null, null, () -> 0L)); + rewritten = original.rewrite(new BaseQueryRewriteContext(xContentRegistry(), null, null, () -> 0L)); assertNotSame(original, rewritten); assertNotEquals(original, rewritten); assertThat(rewritten, instanceOf(TermsAggregationBuilder.class)); @@ -189,7 +189,7 @@ public void testRewrite() throws IOException { assertThat(subAgg, instanceOf(FiltersAggregationBuilder.class)); assertNotSame(original.getSubAggregations().iterator().next(), subAgg); assertEquals("my-agg", subAgg.getName()); - assertSame(rewritten, rewritten.rewrite(new QueryRewriteContext(xContentRegistry(), null, null, () -> 0L))); + assertSame(rewritten, rewritten.rewrite(new BaseQueryRewriteContext(xContentRegistry(), null, null, () -> 0L))); } public void testRewritePreservesOtherBucket() throws IOException { @@ -197,7 +197,7 @@ public void testRewritePreservesOtherBucket() throws IOException { originalFilters.otherBucket(randomBoolean()); originalFilters.otherBucketKey(randomAlphaOfLength(10)); - AggregationBuilder rewritten = originalFilters.rewrite(new QueryRewriteContext(xContentRegistry(), null, null, () -> 0L)); + AggregationBuilder rewritten = originalFilters.rewrite(new BaseQueryRewriteContext(xContentRegistry(), null, null, () -> 0L)); assertThat(rewritten, instanceOf(FiltersAggregationBuilder.class)); FiltersAggregationBuilder rewrittenFilters = (FiltersAggregationBuilder) rewritten; diff --git a/server/src/test/java/org/opensearch/search/builder/SearchSourceBuilderTests.java b/server/src/test/java/org/opensearch/search/builder/SearchSourceBuilderTests.java index 90962a5c613f1..4ee1ee61d9586 100644 --- a/server/src/test/java/org/opensearch/search/builder/SearchSourceBuilderTests.java +++ b/server/src/test/java/org/opensearch/search/builder/SearchSourceBuilderTests.java @@ -47,10 +47,10 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.BaseQueryRewriteContext; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchNoneQueryBuilder; import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.RandomQueryBuilder; import org.opensearch.index.query.Rewriteable; import org.opensearch.script.Script; @@ -737,7 +737,7 @@ private void assertIndicesBoostParseErrorMessage(String restContent, String expe private SearchSourceBuilder rewrite(SearchSourceBuilder searchSourceBuilder) throws IOException { return Rewriteable.rewrite( searchSourceBuilder, - new QueryRewriteContext(xContentRegistry(), writableRegistry(), null, Long.valueOf(1)::longValue) + new BaseQueryRewriteContext(xContentRegistry(), writableRegistry(), null, Long.valueOf(1)::longValue) ); } }