diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java index 89cca22e1a..c4fb8a0a5b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java @@ -12,9 +12,6 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.bulk.BulkRequest; -import org.opensearch.action.bulk.BulkResponse; -import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.update.UpdateRequest; @@ -28,8 +25,12 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.ml.common.conversation.ActionConstants; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.common.conversation.Interaction; @@ -221,22 +222,10 @@ public void getTraces(String parentInteractionId, ActionListener> listener) { SearchRequest searchRequest = Requests.searchRequest(indexName); - - // Build the query - BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); - - // Add the ExistsQueryBuilder for checking null values - ExistsQueryBuilder existsQueryBuilder = QueryBuilders.existsQuery(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD); - boolQueryBuilder.must(existsQueryBuilder); - - // Add the TermQueryBuilder for another field - TermQueryBuilder termQueryBuilder = QueryBuilders - .termQuery(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parentInteractionId); - boolQueryBuilder.must(termQueryBuilder); - + QueryBuilder traceQueryBuilder = buildTraceQueryBuilder(parentInteractionId); // Set the query to the search source SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(boolQueryBuilder); + searchSourceBuilder.query(traceQueryBuilder); searchRequest.source(searchSourceBuilder); searchRequest.source().sort(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, SortOrder.ASC); @@ -260,6 +249,21 @@ void innerGetTraces(String parentInteractionId, ActionListener } } + private QueryBuilder buildTraceQueryBuilder(String parentInteractionId) { + // Build the query + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); + + // Add the ExistsQueryBuilder for checking null values + ExistsQueryBuilder existsQueryBuilder = QueryBuilders.existsQuery(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD); + boolQueryBuilder.must(existsQueryBuilder); + + // Add the TermQueryBuilder for another field + TermQueryBuilder termQueryBuilder = QueryBuilders + .termQuery(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parentInteractionId); + boolQueryBuilder.must(termQueryBuilder); + return boolQueryBuilder; + } + /** * Get the interactions associate with this conversation, sorted by recency * @param interactionId the parent interaction id whose traces to get @@ -296,24 +300,23 @@ public void updateInteraction(String interactionId, Map updateCo * @param listener callback for delete result */ public void deleteInteraction(String interactionId, ActionListener listener) { - BulkRequest bulkRequest = new BulkRequest(indexName); - bulkRequest.add(new DeleteRequest(indexName, interactionId)); + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); + // interaction itself + boolQueryBuilder.should(QueryBuilders.idsQuery().addIds(interactionId)); + // interaction trace + boolQueryBuilder.should(buildTraceQueryBuilder(interactionId)); - innerGetTraces(interactionId, ActionListener.wrap(traces -> { - traces.forEach(trace -> bulkRequest.add(new DeleteRequest(indexName, trace.getId()))); + DeleteByQueryRequest deleteByQueryRequest = new DeleteByQueryRequest(indexName); + deleteByQueryRequest.setQuery(boolQueryBuilder); - innerDeleteInteraction(bulkRequest, interactionId, listener); - }, e -> { - // delete interaction only if we can't get trace - innerDeleteInteraction(bulkRequest, interactionId, listener); - })); + innerDeleteInteraction(deleteByQueryRequest, interactionId, listener); } @VisibleForTesting - void innerDeleteInteraction(BulkRequest bulkRequest, String interactionId, ActionListener listener) { + void innerDeleteInteraction(DeleteByQueryRequest deleteByQueryRequest, String interactionId, ActionListener listener) { try (ThreadContext.StoredContext ignored = client.threadPool().getThreadContext().stashContext()) { - ActionListener al = ActionListener.wrap(bulkResponse -> { - if (bulkResponse != null && bulkResponse.hasFailures()) { + ActionListener al = ActionListener.wrap(bulkResponse -> { + if (bulkResponse != null && (!bulkResponse.getBulkFailures().isEmpty() || !bulkResponse.getSearchFailures().isEmpty())) { log.info("Failed to delete the interaction with ID: {}", interactionId); listener.onResponse(false); return; @@ -325,7 +328,7 @@ void innerDeleteInteraction(BulkRequest bulkRequest, String interactionId, Actio listener.onFailure(exception); }); // bulk delete interaction and its trace - client.bulk(bulkRequest, al); + client.execute(DeleteByQueryAction.INSTANCE, deleteByQueryRequest, al); } catch (Exception e) { log.error("Failed to delete interaction with ID {}. Details {}:", interactionId, e); listener.onFailure(e);