diff --git a/client/build.gradle b/client/build.gradle index 8bca27d81b..06af14434a 100644 --- a/client/build.gradle +++ b/client/build.gradle @@ -15,7 +15,7 @@ plugins { dependencies { implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow') - implementation project(':opensearch-ml-common') + implementation project(path: ":${rootProject.name}-common", configuration: 'shadow') compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" testImplementation group: 'junit', name: 'junit', version: '4.13.2' testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0' @@ -122,4 +122,5 @@ publishing { } } - +compileJava.dependsOn(':opensearch-ml-common:shadowJar') +delombok.dependsOn(':opensearch-ml-common:shadowJar') diff --git a/common/build.gradle b/common/build.gradle index c8be258615..12a8bf6269 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -6,8 +6,11 @@ //TODO: cleanup gradle config file, some overlap plugins { id 'java' + id 'com.github.johnrengelman.shadow' id 'jacoco' id "io.freefair.lombok" + id 'maven-publish' + id 'signing' } dependencies { @@ -21,6 +24,15 @@ dependencies { compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' compileOnly group: 'org.json', name: 'json', version: '20231013' + + implementation('com.google.guava:guava:32.1.2-jre') { + exclude group: 'com.google.guava', module: 'failureaccess' + exclude group: 'com.google.code.findbugs', module: 'jsr305' + exclude group: 'org.checkerframework', module: 'checker-qual' + exclude group: 'com.google.errorprone', module: 'error_prone_annotations' + exclude group: 'com.google.j2objc', module: 'j2objc-annotations' + exclude group: 'com.google.guava', module: 'listenablefuture' + } } lombok { @@ -53,3 +65,75 @@ jacocoTestCoverageVerification { dependsOn jacocoTestReport } check.dependsOn jacocoTestCoverageVerification + +shadowJar { + destinationDirectory = file("${project.buildDir}/distributions") + archiveClassifier.set(null) + exclude 'META-INF/maven/com.google.guava/**' + exclude 'com/google/thirdparty/**' + relocate 'com.google.common', 'org.opensearch.ml.repackage.com.google.common' // dependency of cron-utils +} + +jar { + enabled false +} + +task sourcesJar(type: Jar) { + archiveClassifier.set 'sources' + from sourceSets.main.allJava +} + +task javadocJar(type: Jar) { + archiveClassifier.set 'javadoc' + from javadoc.destinationDir + dependsOn javadoc +} + +publishing { + repositories { + maven { + name = 'staging' + url = "${rootProject.buildDir}/local-staging-repo" + } + maven { + name = "Snapshots" // optional target repository name + url = "https://aws.oss.sonatype.org/content/repositories/snapshots" + credentials { + username "$System.env.SONATYPE_USERNAME" + password "$System.env.SONATYPE_PASSWORD" + } + } + } + publications { + shadow(MavenPublication) { publication -> + project.shadow.component(publication) + artifact sourcesJar + artifact javadocJar + + pom { + name = "OpenSearch ML Commons Comm" + packaging = "jar" + url = "https://github.com/opensearch-project/ml-commons" + description = "OpenSearch ML Common" + scm { + connection = "scm:git@github.com:opensearch-project/ml-commons.git" + developerConnection = "scm:git@github.com:opensearch-project/ml-commons.git" + url = "git@github.com:opensearch-project/ml-commons.git" + } + licenses { + license { + name = "The Apache License, Version 2.0" + url = "http://www.apache.org/licenses/LICENSE-2.0.txt" + } + } + developers { + developer { + name = "OpenSearch" + url = "https://github.com/opensearch-project/ml-commons" + } + } + } + } + } +} +publishShadowPublicationToMavenLocal.mustRunAfter shadowJar diff --git a/memory/build.gradle b/memory/build.gradle index 83eb747c7e..eb9763b272 100644 --- a/memory/build.gradle +++ b/memory/build.gradle @@ -24,7 +24,7 @@ plugins { } dependencies { - implementation project(":opensearch-ml-common") + implementation project(path: ":${rootProject.name}-common", configuration: 'shadow') implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: '5.2.1' implementation "org.opensearch:common-utils:${common_utils_version}" diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 049d142fd6..99577b8229 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -17,10 +17,10 @@ repositories { } dependencies { - compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow') - implementation project(':opensearch-ml-common') + implementation project(path: ":${rootProject.name}-common", configuration: 'shadow') implementation project(':opensearch-ml-memory') + compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}" testImplementation "org.opensearch.test:framework:${opensearch_version}" implementation "org.opensearch:common-utils:${common_utils_version}" @@ -103,6 +103,7 @@ jacocoTestCoverageVerification { dependsOn jacocoTestReport } check.dependsOn jacocoTestCoverageVerification +compileJava.dependsOn(':opensearch-ml-common:shadowJar') spotless { java { diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java similarity index 67% rename from plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java rename to ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java index b81682f07e..671f4e548a 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java @@ -3,14 +3,23 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.indices; +package org.opensearch.ml.engine.indices; +import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX_SCHEMA_VERSION; import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX_SCHEMA_VERSION; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_SCHEMA_VERSION; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX_SCHEMA_VERSION; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX_SCHEMA_VERSION; @@ -26,7 +35,10 @@ public enum MLIndex { MODEL(ML_MODEL_INDEX, false, ML_MODEL_INDEX_MAPPING, ML_MODEL_INDEX_SCHEMA_VERSION), TASK(ML_TASK_INDEX, false, ML_TASK_INDEX_MAPPING, ML_TASK_INDEX_SCHEMA_VERSION), CONNECTOR(ML_CONNECTOR_INDEX, false, ML_CONNECTOR_INDEX_MAPPING, ML_CONNECTOR_SCHEMA_VERSION), - CONFIG(ML_CONFIG_INDEX, false, ML_CONFIG_INDEX_MAPPING, ML_CONFIG_INDEX_SCHEMA_VERSION); + CONFIG(ML_CONFIG_INDEX, false, ML_CONFIG_INDEX_MAPPING, ML_CONFIG_INDEX_SCHEMA_VERSION), + AGENT(ML_AGENT_INDEX, false, ML_AGENT_INDEX_MAPPING, ML_AGENT_INDEX_SCHEMA_VERSION), + MEMORY_META(ML_MEMORY_META_INDEX, false, ML_MEMORY_META_INDEX_MAPPING, ML_MEMORY_META_INDEX_SCHEMA_VERSION), + MEMORY_MESSAGE(ML_MEMORY_MESSAGE_INDEX, false, ML_MEMORY_MESSAGE_INDEX_MAPPING, ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION); private final String indexName; // whether we use an alias for the index diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java similarity index 94% rename from plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java rename to ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java index d278fa6415..ca5f88be78 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.indices; +package org.opensearch.ml.engine.indices; import static org.opensearch.ml.common.CommonValue.META; import static org.opensearch.ml.common.CommonValue.SCHEMA_VERSION_FIELD; @@ -62,10 +62,22 @@ public void initMLConnectorIndex(ActionListener listener) { initMLIndexIfAbsent(MLIndex.CONNECTOR, listener); } + public void initMemoryMetaIndex(ActionListener listener) { + initMLIndexIfAbsent(MLIndex.MEMORY_META, listener); + } + + public void initMemoryMessageIndex(ActionListener listener) { + initMLIndexIfAbsent(MLIndex.MEMORY_MESSAGE, listener); + } + public void initMLConfigIndex(ActionListener listener) { initMLIndexIfAbsent(MLIndex.CONFIG, listener); } + public void initMLAgentIndex(ActionListener listener) { + initMLIndexIfAbsent(MLIndex.AGENT, listener); + } + public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) { String indexName = index.getIndexName(); String mapping = index.getMapping(); diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLInputDatasetHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLInputDatasetHandler.java similarity index 83% rename from plugin/src/main/java/org/opensearch/ml/indices/MLInputDatasetHandler.java rename to ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLInputDatasetHandler.java index 1dcf6bdf77..452f836357 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLInputDatasetHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLInputDatasetHandler.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.indices; +package org.opensearch.ml.engine.indices; import java.util.ArrayList; import java.util.List; @@ -35,19 +35,6 @@ public class MLInputDatasetHandler { Client client; - // /** - // * Retrieve DataFrame from DataFrameInputDataset - // * @param mlInputDataset MLInputDataset - // * @return DataFrame - // */ - // public DataFrame parseDataFrameInput(MLInputDataset mlInputDataset) { - // if (!mlInputDataset.getInputDataType().equals(MLInputDataType.DATA_FRAME)) { - // throw new IllegalArgumentException("Input dataset is not DATA_FRAME type."); - // } - // DataFrameInputDataset inputDataset = (DataFrameInputDataset) mlInputDataset; - // return inputDataset.getDataFrame(); - // } - /** * Create DataFrame based on given search query * @param mlInputDataset MLInputDataset diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/BaseMessage.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/BaseMessage.java new file mode 100644 index 0000000000..05b3185a34 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/BaseMessage.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import java.io.IOException; + +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.spi.memory.Message; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +public class BaseMessage implements Message, ToXContentObject { + + @Getter + @Setter + protected String type; + @Getter + @Setter + protected String content; + + @Builder + public BaseMessage(String type, String content) { + this.type = type; + this.content = content; + } + + @Override + public String toString() { + return type + ": " + content; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("type", type); + builder.field("content", content); + builder.endObject(); + return builder; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java new file mode 100644 index 0000000000..8dcbe050bb --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java @@ -0,0 +1,207 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX; + +import java.util.Map; + +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.spi.memory.Memory; +import org.opensearch.ml.common.spi.memory.Message; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; + +import lombok.Getter; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@Getter +public class ConversationIndexMemory implements Memory { + public static final String TYPE = "conversation_index"; + public static final String CONVERSATION_ID = "conversation_id"; + public static final String FINAL_ANSWER = "final_answer"; + public static final String CREATED_TIME = "created_time"; + public static final String MEMORY_NAME = "memory_name"; + public static final String MEMORY_ID = "memory_id"; + public static final String APP_TYPE = "app_type"; + public static int LAST_N_INTERACTIONS = 10; + protected String memoryMetaIndexName; + protected String memoryMessageIndexName; + protected String conversationId; + protected boolean retrieveFinalAnswer = true; + protected final Client client; + private final MLIndicesHandler mlIndicesHandler; + private MLMemoryManager memoryManager; + + public ConversationIndexMemory( + Client client, + MLIndicesHandler mlIndicesHandler, + String memoryMetaIndexName, + String memoryMessageIndexName, + String conversationId, + MLMemoryManager memoryManager + ) { + this.client = client; + this.mlIndicesHandler = mlIndicesHandler; + this.memoryMetaIndexName = memoryMetaIndexName; + this.memoryMessageIndexName = memoryMessageIndexName; + this.conversationId = conversationId; + this.memoryManager = memoryManager; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public void save(String id, Message message) { + this.save(id, message, ActionListener.wrap(r -> { log.info("saved message into {} memory, session id: {}", TYPE, id); }, e -> { + log.error("Failed to save message to memory", e); + })); + } + + @Override + public void save(String id, Message message, ActionListener listener) { + mlIndicesHandler.initMemoryMessageIndex(ActionListener.wrap(created -> { + if (created) { + IndexRequest indexRequest = new IndexRequest(memoryMessageIndexName); + ConversationIndexMessage conversationIndexMessage = (ConversationIndexMessage) message; + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + conversationIndexMessage.toXContent(builder, ToXContent.EMPTY_PARAMS); + indexRequest.source(builder); + client.index(indexRequest, listener); + } else { + listener.onFailure(new RuntimeException("Failed to create memory message index")); + } + }, e -> { listener.onFailure(new RuntimeException("Failed to create memory message index", e)); })); + } + + public void save(Message message, String parentId, Integer traceNum, String action) { + this.save(message, parentId, traceNum, action, ActionListener.wrap(r -> { + log + .info( + "saved message into memory {}, parent id: {}, trace number: {}, interaction id: {}", + conversationId, + parentId, + traceNum, + r.getId() + ); + }, e -> { log.error("Failed to save interaction", e); })); + } + + public void save(Message message, String parentId, Integer traceNum, String action, ActionListener listener) { + ConversationIndexMessage msg = (ConversationIndexMessage) message; + memoryManager + .createInteraction(conversationId, msg.getQuestion(), null, msg.getResponse(), action, null, parentId, traceNum, listener); + } + + @Override + public void getMessages(String id, ActionListener listener) { + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(memoryMessageIndexName); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.size(10000); + QueryBuilder sessionIdQueryBuilder = new TermQueryBuilder(CONVERSATION_ID, id); + + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.must(sessionIdQueryBuilder); + + if (retrieveFinalAnswer) { + QueryBuilder finalAnswerQueryBuilder = new TermQueryBuilder(FINAL_ANSWER, true); + boolQueryBuilder.must(finalAnswerQueryBuilder); + } + + sourceBuilder.query(boolQueryBuilder); + sourceBuilder.sort(CREATED_TIME, SortOrder.ASC); + searchRequest.source(sourceBuilder); + client.search(searchRequest, listener); + } + + public void getMessages(ActionListener listener) { + memoryManager.getFinalInteractions(conversationId, LAST_N_INTERACTIONS, listener); + } + + @Override + public void clear() { + throw new RuntimeException("clear method is not supported in ConversationIndexMemory"); + } + + @Override + public void remove(String id) { + throw new RuntimeException("remove method is not supported in ConversationIndexMemory"); + } + + public static class Factory implements Memory.Factory { + private Client client; + private MLIndicesHandler mlIndicesHandler; + private String memoryMetaIndexName = ML_MEMORY_META_INDEX; + private String memoryMessageIndexName = ML_MEMORY_MESSAGE_INDEX; + private MLMemoryManager memoryManager; + + public void init(Client client, MLIndicesHandler mlIndicesHandler, MLMemoryManager memoryManager) { + this.client = client; + this.mlIndicesHandler = mlIndicesHandler; + this.memoryManager = memoryManager; + } + + @Override + public void create(Map map, ActionListener listener) { + if (map == null || map.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Invalid input parameter for creating ConversationIndexMemory")); + return; + } + + String memoryId = (String) map.get(MEMORY_ID); + String name = (String) map.get(MEMORY_NAME); + String appType = (String) map.get(APP_TYPE); + create(name, memoryId, appType, listener); + } + + public void create(String name, String memoryId, String appType, ActionListener listener) { + if (Strings.isEmpty(memoryId)) { + memoryManager.createConversation(name, appType, ActionListener.wrap(r -> { + create(r.getId(), listener); + log.debug("Created conversation on memory layer, conversation id: {}", r.getId()); + }, e -> { + log.error("Failed to save interaction", e); + listener.onFailure(e); + })); + } else { + create(memoryId, listener); + } + } + + public void create(String memoryId, ActionListener listener) { + listener + .onResponse( + new ConversationIndexMemory( + client, + mlIndicesHandler, + memoryMetaIndexName, + memoryMessageIndexName, + memoryId, + memoryManager + ) + ); + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMessage.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMessage.java new file mode 100644 index 0000000000..2a084ee9b9 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMessage.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Builder; +import lombok.Data; + +@Data +public class ConversationIndexMessage extends BaseMessage { + + private String sessionId; + private String question; + private String response; + private Boolean finalAnswer; + private Instant createdTime; + + @Builder(builderMethodName = "conversationIndexMessageBuilder") + public ConversationIndexMessage(String type, String sessionId, String question, String response, boolean finalAnswer) { + super(type, response); + this.sessionId = sessionId; + this.question = question; + this.response = response; + this.finalAnswer = finalAnswer; + this.createdTime = Instant.now(); + } + + @Override + public String toString() { + return "Human:" + question + "\nAI:" + response; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (sessionId != null) { + builder.field("session_id", sessionId); + } + if (question != null) { + builder.field("question", question); + } + if (response != null) { + builder.field("response", response); + } + if (finalAnswer != null) { + builder.field("final_answer", finalAnswer); + } + builder.field("created_time", createdTime); + builder.endObject(); + return builder; + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java new file mode 100644 index 0000000000..5ca7e2d31a --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java @@ -0,0 +1,194 @@ +package org.opensearch.ml.engine.indices; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.META; +import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX; +import static org.opensearch.ml.common.CommonValue.SCHEMA_VERSION_FIELD; + +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; + +public class MLIndicesHandlerTest { + + @Mock + Client client; + + @Mock + AdminClient adminClient; + + @Mock + IndicesAdminClient indicesAdminClient; + + @Mock + ClusterService clusterService; + + @Mock + ClusterState clusterState; + + @Mock + Metadata metadata; + + @Mock + IndexMetadata indexMetadata; + + @Mock + MappingMetadata mappingMetadata; + + @Mock + private ThreadPool threadPool; + + Settings settings; + ThreadContext threadContext; + MLIndicesHandler indicesHandler; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + doNothing().when(client).execute(any(), any(), any()); + doNothing().when(client).update(any(), any()); + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + doNothing().when(indicesAdminClient).create(any(), any()); + doNothing().when(indicesAdminClient).refresh(any(), any()); + doNothing().when(indicesAdminClient).putMapping(any(), any()); + doNothing().when(indicesAdminClient).updateSettings(any(), any()); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(clusterState.getMetadata()).thenReturn(metadata); + when(metadata.hasIndex(anyString())).thenReturn(true); + when(metadata.indices()).thenReturn(Map.of(ML_AGENT_INDEX, indexMetadata, ML_MEMORY_META_INDEX, indexMetadata)); + when(indexMetadata.mapping()).thenReturn(mappingMetadata); + when(mappingMetadata.getSourceAsMap()).thenReturn(Map.of(META, Map.of(SCHEMA_VERSION_FIELD, Integer.valueOf(1)))); + settings = Settings.builder().put("test_key", 10).build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + indicesHandler = new MLIndicesHandler(clusterService, client); + } + + @Test + public void initMemoryMetaIndex() { + ActionListener listener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(new AcknowledgedResponse(true)); + return null; + }).when(indicesAdminClient).putMapping(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + indicesHandler.initMemoryMetaIndex(listener); + + verify(listener).onResponse(argumentCaptor.capture()); + assertEquals(true, argumentCaptor.getValue()); + } + + @Test + public void initMemoryMetaIndexNoIndex() { + ActionListener listener = mock(ActionListener.class); + when(metadata.hasIndex(anyString())).thenReturn(false); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(new CreateIndexResponse(true, true, ML_MEMORY_META_INDEX)); + return null; + }).when(indicesAdminClient).create(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + indicesHandler.initMemoryMetaIndex(listener); + + verify(indicesAdminClient).create(isA(CreateIndexRequest.class), any()); + verify(listener).onResponse(argumentCaptor.capture()); + assertEquals(true, argumentCaptor.getValue()); + } + + @Test + public void initMemoryMessageIndex() { + ActionListener listener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(new AcknowledgedResponse(true)); + return null; + }).when(indicesAdminClient).putMapping(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + indicesHandler.initMemoryMessageIndex(listener); + + verify(listener).onResponse(argumentCaptor.capture()); + assertEquals(true, argumentCaptor.getValue()); + } + + @Test + public void initMemoryMessageIndexNoIndex() { + ActionListener listener = mock(ActionListener.class); + when(metadata.hasIndex(anyString())).thenReturn(false); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(new CreateIndexResponse(true, true, ML_MEMORY_MESSAGE_INDEX)); + return null; + }).when(indicesAdminClient).create(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + indicesHandler.initMemoryMessageIndex(listener); + + verify(indicesAdminClient).create(isA(CreateIndexRequest.class), any()); + verify(listener).onResponse(argumentCaptor.capture()); + assertEquals(true, argumentCaptor.getValue()); + } + + @Test + public void initMLAgentIndex() { + ActionListener listener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(new AcknowledgedResponse(true)); + return null; + }).when(indicesAdminClient).putMapping(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + indicesHandler.initMLAgentIndex(listener); + + verify(listener).onResponse(argumentCaptor.capture()); + assertEquals(true, argumentCaptor.getValue()); + } + + @Test + public void initMLAgentIndexNoIndex() { + ActionListener listener = mock(ActionListener.class); + when(metadata.hasIndex(anyString())).thenReturn(false); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(new CreateIndexResponse(true, true, ML_AGENT_INDEX)); + return null; + }).when(indicesAdminClient).create(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + indicesHandler.initMLAgentIndex(listener); + + verify(indicesAdminClient).create(isA(CreateIndexRequest.class), any()); + verify(listener).onResponse(argumentCaptor.capture()); + assertEquals(true, argumentCaptor.getValue()); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/BaseMessageTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/BaseMessageTest.java new file mode 100644 index 0000000000..b66fd502ed --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/BaseMessageTest.java @@ -0,0 +1,29 @@ +package org.opensearch.ml.engine.memory; + +import java.io.IOException; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +public class BaseMessageTest { + + @Test + public void testToString() { + BaseMessage message = new BaseMessage("test", "test"); + Assert.assertEquals("test: test", message.toString()); + } + + @Test + public void toXContent() throws IOException { + BaseMessage baseMessage = new BaseMessage("test", "test"); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + baseMessage.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = BytesReference.bytes(builder).utf8ToString(); + + Assert.assertEquals("{\"type\":\"test\",\"content\":\"test\"}", content); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMemoryTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMemoryTest.java new file mode 100644 index 0000000000..e186521d9c --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMemoryTest.java @@ -0,0 +1,248 @@ +package org.opensearch.ml.engine.memory; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_ID; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_NAME; + +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; + +public class ConversationIndexMemoryTest { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Mock + Client client; + + @Mock + MLIndicesHandler indicesHandler; + + @Mock + MLMemoryManager memoryManager; + + ConversationIndexMemory indexMemory; + ConversationIndexMemory.Factory memoryFactory; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + indexMemory = new ConversationIndexMemory(client, indicesHandler, "test", "test", "test", memoryManager); + doNothing().when(client).index(any(), any()); + doNothing().when(client).search(any(), any()); + doNothing().when(client).get(any(), any()); + doNothing().when(memoryManager).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + doNothing().when(memoryManager).getFinalInteractions(any(), anyInt(), any()); + doNothing().when(memoryManager).createConversation(any(), any(), any()); + doNothing().when(indicesHandler).initMemoryMetaIndex(any()); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException("test failure")); + return null; + }).when(indicesHandler).initMemoryMessageIndex(any()); + memoryFactory = new ConversationIndexMemory.Factory(); + memoryFactory.init(client, indicesHandler, memoryManager); + } + + @Test + public void getType() { + Assert.assertEquals(indexMemory.getType(), ConversationIndexMemory.TYPE); + } + + @Test + public void save() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(true); + return null; + }).when(indicesHandler).initMemoryMessageIndex(any()); + indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false)); + + verify(indicesHandler).initMemoryMessageIndex(any()); + } + + @Test + public void save4() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException()); + return null; + }).when(indicesHandler).initMemoryMessageIndex(any()); + indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false)); + + verify(indicesHandler).initMemoryMessageIndex(any()); + } + + @Test + public void save1() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(8); + listener.onResponse(new CreateInteractionResponse("interaction_id")); + return null; + }).when(memoryManager).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + indexMemory.save(new ConversationIndexMessage("test", "123", "question", "response", false), "parent_id", 0, "action"); + + verify(memoryManager).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + } + + @Test + public void save6() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(8); + listener.onFailure(new RuntimeException()); + return null; + }).when(memoryManager).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + indexMemory.save(new ConversationIndexMessage("test", "123", "question", "response", false), "parent_id", 0, "action"); + + verify(memoryManager).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + } + + @Test + public void save2() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(Boolean.TRUE); + return null; + }).when(indicesHandler).initMemoryMessageIndex(any()); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new IndexResponse(new ShardId("test", "test", 1), "test", 1l, 1l, 1l, true)); + return null; + }).when(client).index(any(), any()); + ActionListener actionListener = mock(ActionListener.class); + indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); + + verify(actionListener).onResponse(isA(IndexResponse.class)); + } + + @Test + public void save3() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException()); + return null; + }).when(indicesHandler).initMemoryMessageIndex(any()); + ActionListener actionListener = mock(ActionListener.class); + indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); + + verify(actionListener).onFailure(isA(RuntimeException.class)); + } + + @Test + public void save5() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(Boolean.FALSE); + return null; + }).when(indicesHandler).initMemoryMessageIndex(any()); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new IndexResponse(new ShardId("test", "test", 1), "test", 1l, 1l, 1l, true)); + return null; + }).when(client).index(any(), any()); + ActionListener actionListener = mock(ActionListener.class); + indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); + + verify(actionListener).onFailure(isA(RuntimeException.class)); + } + + @Test + public void getMessages() { + ActionListener listener = mock(ActionListener.class); + indexMemory.getMessages("test_id", listener); + } + + @Test + public void getMessages1() { + ActionListener listener = mock(ActionListener.class); + indexMemory.getMessages(listener); + } + + @Test + public void clear() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("clear method is not supported in ConversationIndexMemory"); + indexMemory.clear(); + } + + @Test + public void remove() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("remove method is not supported in ConversationIndexMemory"); + indexMemory.remove("test_id"); + } + + @Test + public void factory_create_emptyMap() { + ActionListener listener = mock(ActionListener.class); + memoryFactory.create(Map.of(), listener); + + verify(listener).onFailure(isA(IllegalArgumentException.class)); + } + + @Test + public void factory_create() { + ActionListener listener = mock(ActionListener.class); + memoryFactory.create(Map.of(MEMORY_ID, "123", MEMORY_NAME, "name", APP_TYPE, "app"), listener); + + verify(listener).onResponse(isA(ConversationIndexMemory.class)); + } + + @Test + public void factory_create_only_memory_id() { + ActionListener listener = mock(ActionListener.class); + memoryFactory.create(Map.of(MEMORY_ID, "123"), listener); + + verify(listener).onResponse(isA(ConversationIndexMemory.class)); + } + + @Test + public void factory_create_empty_memory_id() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(new CreateConversationResponse("interaction_id")); + return null; + }).when(memoryManager).createConversation(any(), any(), any()); + ActionListener listener = mock(ActionListener.class); + memoryFactory.create(Map.of(MEMORY_NAME, "name", APP_TYPE, "app"), listener); + + verify(listener).onResponse(isA(ConversationIndexMemory.class)); + verify(memoryManager).createConversation(any(), any(), any()); + } + + @Test + public void factory_create_empty_memory_id_failure() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException()); + return null; + }).when(memoryManager).createConversation(any(), any(), any()); + ActionListener listener = mock(ActionListener.class); + memoryFactory.create(Map.of(MEMORY_NAME, "name", APP_TYPE, "app"), listener); + + verify(listener).onFailure(isA(RuntimeException.class)); + verify(memoryManager).createConversation(any(), any(), any()); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMessageTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMessageTest.java new file mode 100644 index 0000000000..9e91695a5b --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMessageTest.java @@ -0,0 +1,48 @@ +package org.opensearch.ml.engine.memory; + +import java.io.IOException; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +public class ConversationIndexMessageTest { + + ConversationIndexMessage message; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + message = ConversationIndexMessage + .conversationIndexMessageBuilder() + .type("test") + .sessionId("123") + .question("question") + .response("response") + .finalAnswer(false) + .build(); + } + + @Test + public void testToString() { + Assert.assertEquals("Human:question\nAI:response", message.toString()); + } + + @Test + public void toXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + message.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = BytesReference.bytes(builder).utf8ToString(); + + Assert.assertTrue(content.contains("\"session_id\":\"123\"")); + Assert.assertTrue(content.contains("\"question\":\"question\"")); + Assert.assertTrue(content.contains("\"response\":\"response\"")); + Assert.assertTrue(content.contains("\"final_answer\":false")); + Assert.assertTrue(content.contains("\"created_time\":")); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTest.java new file mode 100644 index 0000000000..a21a3ed60d --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTest.java @@ -0,0 +1,124 @@ +package org.opensearch.ml.engine.memory; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.memory.index.ConversationMetaIndex; +import org.opensearch.threadpool.ThreadPool; + +public class MLMemoryManagerTest { + + @Mock + Client client; + + @Mock + AdminClient adminClient; + + @Mock + IndicesAdminClient indicesAdminClient; + + @Mock + ClusterService clusterService; + + @Mock + ClusterState clusterState; + + @Mock + Metadata metadata; + + @Mock + ConversationMetaIndex conversationMetaIndex; + + @Mock + private ThreadPool threadPool; + + MLMemoryManager memoryManager; + Settings settings; + ThreadContext threadContext; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + memoryManager = new MLMemoryManager(client); + doNothing().when(client).execute(any(), any(), any()); + doNothing().when(client).update(any(), any()); + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + doNothing().when(indicesAdminClient).refresh(any(), any()); + doNothing().when(conversationMetaIndex).checkAccess(any(), any()); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex(anyString())).thenReturn(true); + settings = Settings.builder().put("test_key", 10).build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + } + + @Test + public void createConversation() { + ActionListener actionListener = mock(ActionListener.class); + memoryManager.createConversation("test", "test", actionListener); + } + + @Test + public void createInteraction() { + ActionListener actionListener = mock(ActionListener.class); + memoryManager.createInteraction("test", "test", "test", "test", "test", Map.of("feedback", "1"), "test", 0, actionListener); + } + + @Test + public void createInteractionNullAdditionalInfo() { + ActionListener actionListener = mock(ActionListener.class); + memoryManager.createInteraction("test", "test", "test", "test", "test", null, "test", 0, actionListener); + } + + @Test + public void getFinalInteractions() { + ActionListener> actionListener = mock(ActionListener.class); + memoryManager.getFinalInteractions("test", 1, actionListener); + } + + @Test + public void getTracesIndex() { + ActionListener> actionListener = mock(ActionListener.class); + memoryManager.getTraces("test", actionListener); + } + + @Test + public void getTracesNoIndex() { + ActionListener> actionListener = mock(ActionListener.class); + when(metadata.hasIndex(anyString())).thenReturn(false); + memoryManager.getTraces("test", actionListener); + } + + @Test + public void updateInteraction() { + ActionListener actionListener = mock(ActionListener.class); + memoryManager.updateInteraction("test", Map.of("feedback", "1"), actionListener); + } +} diff --git a/plugin/build.gradle b/plugin/build.gradle index 0934a592b0..f8e01391d2 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -47,7 +47,7 @@ opensearchplugin { dependencies { implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow') - implementation project(':opensearch-ml-common') + implementation project(path: ":${rootProject.name}-common", configuration: 'shadow') implementation project(':opensearch-ml-algorithms') implementation project(':opensearch-ml-search-processors') implementation project(':opensearch-ml-memory') diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java index e40bacc207..4cadcc936a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -37,8 +37,8 @@ import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.exceptions.MetaDataException; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.tasks.Task; diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java index 94d4b5a8a7..4e29db680d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java @@ -17,8 +17,8 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupRequest; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 4aea9bcc23..a81e349238 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -50,9 +50,9 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLStats; diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java index 1227703a21..683536e21e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java +++ b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java @@ -34,8 +34,8 @@ import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkInput; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkResponse; import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.utils.RestActionUtils; import lombok.extern.log4j.Log4j2; diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java index c4cdd9f899..0cf4215a23 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java @@ -19,7 +19,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.ml.autoredeploy.MLModelAutoReDeployer; import org.opensearch.ml.engine.encryptor.Encryptor; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.threadpool.Scheduler; import org.opensearch.threadpool.ThreadPool; diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index 3a5ea83347..12e37b7b4d 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -42,7 +42,7 @@ import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest; import org.opensearch.ml.engine.encryptor.Encryptor; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index 83523729e4..2187a4577e 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -34,8 +34,8 @@ import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 2402374a99..20286fd3c5 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -111,8 +111,8 @@ import org.opensearch.ml.engine.MLExecutable; import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.Predictable; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.engine.utils.FileUtils; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.profile.MLModelProfile; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index c738a941ae..45a9f8631b 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -128,10 +128,10 @@ import org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; -import org.opensearch.ml.indices.MLInputDatasetHandler; import org.opensearch.ml.memory.ConversationalMemoryHandler; import org.opensearch.ml.memory.action.conversation.CreateConversationAction; import org.opensearch.ml.memory.action.conversation.CreateConversationTransportAction; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java index 1498b39d07..3e82e7a20e 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java @@ -21,7 +21,7 @@ import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; import org.opensearch.ml.engine.MLEngine; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.stats.MLNodeLevelStat; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 348618773a..92e05a5ba9 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -48,7 +48,7 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.Predictable; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java index f1feb4ec32..9e9dea5d22 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java @@ -42,7 +42,7 @@ import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableMap; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java index 9461c3adaf..fe78e88d1e 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java @@ -29,7 +29,7 @@ import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction; import org.opensearch.ml.engine.MLEngine; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.stats.MLNodeLevelStat; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java index 711a94171f..88366b17f2 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java @@ -37,8 +37,8 @@ import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; import org.opensearch.ml.engine.MLEngine; -import org.opensearch.ml.indices.MLIndicesHandler; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.stats.MLNodeLevelStat; diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java index 9fcc89d701..e16400bc56 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java @@ -41,8 +41,8 @@ import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.engine.MLEngine; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelManager; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java index 269ac30d95..f54f034b64 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java @@ -29,8 +29,8 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupRequest; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index 0222b4efe1..b40a278289 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -54,9 +54,9 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLNodeLevelStat; diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java index 6fba3efe59..292183f8ed 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java @@ -39,8 +39,8 @@ import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkInput; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkResponse; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java index 6c189422c0..3299d1a7c9 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java @@ -70,7 +70,7 @@ import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.utils.TestHelper; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; diff --git a/plugin/src/test/java/org/opensearch/ml/indices/MLIndicesHandlerTests.java b/plugin/src/test/java/org/opensearch/ml/indices/MLIndicesHandlerTests.java deleted file mode 100644 index 9acb84633a..0000000000 --- a/plugin/src/test/java/org/opensearch/ml/indices/MLIndicesHandlerTests.java +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.indices; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; -import static org.opensearch.ml.common.CommonValue.META; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX_SCHEMA_VERSION; -import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; -import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX_MAPPING; -import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX_SCHEMA_VERSION; -import static org.opensearch.ml.common.CommonValue.SCHEMA_VERSION_FIELD; - -import java.io.IOException; -import java.util.Map; -import java.util.concurrent.ExecutionException; - -import org.junit.Before; -import org.opensearch.action.admin.indices.create.CreateIndexRequest; -import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.client.AdminClient; -import org.opensearch.client.Client; -import org.opensearch.client.IndicesAdminClient; -import org.opensearch.cluster.metadata.IndexMetadata; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.core.action.ActionListener; -import org.opensearch.test.OpenSearchIntegTestCase; - -@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2) -public class MLIndicesHandlerTests extends OpenSearchIntegTestCase { - ClusterService clusterService; - Client client; - MLIndicesHandler mlIndicesHandler; - - String OLD_ML_MODEL_INDEX_MAPPING_V0 = "{\n" - + " \"properties\": {\n" - + " \"task_id\": { \"type\": \"keyword\" },\n" - + " \"algorithm\": {\"type\": \"keyword\"},\n" - + " \"model_name\" : { \"type\": \"keyword\"},\n" - + " \"model_version\" : { \"type\": \"keyword\"},\n" - + " \"model_content\" : { \"type\": \"binary\"}\n" - + " }\n" - + "}"; - - String OLD_ML_TASK_INDEX_MAPPING_V0 = "{\n" - + " \"properties\": {\n" - + " \"model_id\": {\"type\": \"keyword\"},\n" - + " \"task_type\": {\"type\": \"keyword\"},\n" - + " \"function_name\": {\"type\": \"keyword\"},\n" - + " \"state\": {\"type\": \"keyword\"},\n" - + " \"input_type\": {\"type\": \"keyword\"},\n" - + " \"progress\": {\"type\": \"float\"},\n" - + " \"output_index\": {\"type\": \"keyword\"},\n" - + " \"worker_node\": {\"type\": \"keyword\"},\n" - + " \"create_time\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"last_update_time\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"error\": {\"type\": \"text\"},\n" - + " \"user\": {\n" - + " \"type\": \"nested\",\n" - + " \"properties\": {\n" - + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" - + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" - + " }\n" - + " }\n" - + " }\n" - + "}";; - - @Before - public void setup() { - clusterService = clusterService(); - client = client(); - mlIndicesHandler = new MLIndicesHandler(clusterService, client); - } - - public void testInitMLTaskIndex() { - ActionListener listener = ActionListener.wrap(r -> { assertTrue(r); }, e -> { throw new RuntimeException(e); }); - mlIndicesHandler.initMLTaskIndex(listener); - } - - public void testInitMLTaskIndexWithExistingIndex() throws ExecutionException, InterruptedException { - CreateIndexRequest request = new CreateIndexRequest(ML_TASK_INDEX).mapping(ML_TASK_INDEX_MAPPING); - client.admin().indices().create(request).get(); - testInitMLTaskIndex(); - } - - public void testInitMLModelIndexIfAbsentWithExistingIndex() throws ExecutionException, InterruptedException, IOException { - testInitMLIndexIfAbsentWithExistingIndex(ML_MODEL_INDEX, OLD_ML_MODEL_INDEX_MAPPING_V0, ML_MODEL_INDEX_SCHEMA_VERSION); - } - - public void testInitMLTaskIndexIfAbsentWithExistingIndex() throws ExecutionException, InterruptedException, IOException { - testInitMLIndexIfAbsentWithExistingIndex(ML_TASK_INDEX, OLD_ML_TASK_INDEX_MAPPING_V0, ML_TASK_INDEX_SCHEMA_VERSION); - } - - private void testInitMLIndexIfAbsentWithExistingIndex(String indexName, String oldIndexMapping, int schemaVersion) - throws ExecutionException, - InterruptedException, - IOException { - mlIndicesHandler.shouldUpdateIndex(indexName, 1, ActionListener.wrap(shouldUpdate -> { assertFalse(shouldUpdate); }, e -> { - throw new RuntimeException(e); - })); - CreateIndexRequest request = new CreateIndexRequest(indexName).mapping(oldIndexMapping); - client.admin().indices().create(request).get(); - mlIndicesHandler.shouldUpdateIndex(indexName, 1, ActionListener.wrap(shouldUpdate -> { assertTrue(shouldUpdate); }, e -> { - throw new RuntimeException(e); - })); - assertNull(getIndexSchemaVersion(indexName)); - ActionListener listener = ActionListener.wrap(r -> { - assertTrue(r); - Integer indexSchemaVersion = getIndexSchemaVersion(indexName); - if (indexSchemaVersion != null) { - assertEquals(schemaVersion, indexSchemaVersion.intValue()); - mlIndicesHandler.shouldUpdateIndex(indexName, 1, ActionListener.wrap(shouldUpdate -> { assertFalse(shouldUpdate); }, e -> { - throw new RuntimeException(e); - })); - } - }, e -> { throw new RuntimeException(e); }); - mlIndicesHandler.initModelIndexIfAbsent(listener); - } - - public void testInitMLModelIndexIfAbsentWithNonExistingIndex() { - ActionListener listener = ActionListener.wrap(r -> { assertTrue(r); }, e -> { throw new RuntimeException(e); }); - mlIndicesHandler.initModelIndexIfAbsent(listener); - } - - public void testInitMLModelIndexIfAbsentWithNonExistingIndex_Exception() { - Client mockClient = mock(Client.class); - Object[] objects = setUpMockClient(mockClient); - IndicesAdminClient adminClient = (IndicesAdminClient) objects[0]; - MLIndicesHandler mlIndicesHandler = (MLIndicesHandler) objects[1]; - String errorMessage = "test exception"; - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(new RuntimeException(errorMessage)); - return null; - }).when(adminClient).create(any(), any()); - ActionListener listener = ActionListener.wrap(r -> { throw new RuntimeException("unexpected result"); }, e -> { - assertEquals(errorMessage, e.getMessage()); - }); - mlIndicesHandler.initModelIndexIfAbsent(listener); - - when(mockClient.threadPool()).thenThrow(new RuntimeException(errorMessage)); - mlIndicesHandler.initModelIndexIfAbsent(listener); - } - - public void testInitMLModelIndexIfAbsentWithNonExistingIndex_FalseAcknowledge() { - Client mockClient = mock(Client.class); - Object[] objects = setUpMockClient(mockClient); - IndicesAdminClient adminClient = (IndicesAdminClient) objects[0]; - MLIndicesHandler mlIndicesHandler = (MLIndicesHandler) objects[1]; - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - CreateIndexResponse response = new CreateIndexResponse(false, false, ML_MODEL_INDEX); - actionListener.onResponse(response); - return null; - }).when(adminClient).create(any(), any()); - ActionListener listener = ActionListener.wrap(r -> { assertFalse(r); }, e -> { throw new RuntimeException(e); }); - mlIndicesHandler.initModelIndexIfAbsent(listener); - } - - private Object[] setUpMockClient(Client mockClient) { - AdminClient admin = spy(client.admin()); - when(mockClient.admin()).thenReturn(admin); - IndicesAdminClient adminClient = spy(client.admin().indices()); - - MLIndicesHandler mlIndicesHandler = new MLIndicesHandler(clusterService, mockClient); - when(admin.indices()).thenReturn(adminClient); - - when(mockClient.threadPool()).thenReturn(client.threadPool()); - - return new Object[] { adminClient, mlIndicesHandler }; - } - - private Integer getIndexSchemaVersion(String indexName) { - IndexMetadata indexMetaData = clusterService.state().getMetadata().indices().get(indexName); - if (indexMetaData == null) { - return null; - } - Integer oldVersion = null; - Map indexMapping = indexMetaData.mapping().getSourceAsMap(); - Object meta = indexMapping.get(META); - if (meta != null && meta instanceof Map) { - Map metaMapping = (Map) meta; - Object schemaVersion = metaMapping.get(SCHEMA_VERSION_FIELD); - if (schemaVersion instanceof Integer) { - oldVersion = (Integer) schemaVersion; - } - } - return oldVersion; - } -} diff --git a/plugin/src/test/java/org/opensearch/ml/indices/MLInputDatasetHandlerTests.java b/plugin/src/test/java/org/opensearch/ml/indices/MLInputDatasetHandlerTests.java deleted file mode 100644 index 5ec2ab686c..0000000000 --- a/plugin/src/test/java/org/opensearch/ml/indices/MLInputDatasetHandlerTests.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.indices; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import org.apache.lucene.search.TotalHits; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.ExpectedException; -import org.mockito.ArgumentCaptor; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.client.Client; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.bytes.BytesArray; -import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.ml.common.dataframe.DataFrame; -import org.opensearch.ml.common.dataframe.DataFrameBuilder; -import org.opensearch.ml.common.dataset.DataFrameInputDataset; -import org.opensearch.ml.common.dataset.MLInputDataset; -import org.opensearch.ml.common.dataset.SearchQueryInputDataset; -import org.opensearch.search.SearchHit; -import org.opensearch.search.SearchHits; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.test.OpenSearchTestCase; - -public class MLInputDatasetHandlerTests extends OpenSearchTestCase { - Client client; - MLInputDatasetHandler mlInputDatasetHandler; - ActionListener listener; - DataFrame dataFrame; - SearchResponse searchResponse; - - @Rule - public ExpectedException expectedEx = ExpectedException.none(); - - @Before - public void setup() { - Map source = new HashMap<>(); - source.put("taskId", "111"); - List> mapList = new ArrayList<>(); - mapList.add(source); - dataFrame = DataFrameBuilder.load(mapList); - client = mock(Client.class); - mlInputDatasetHandler = new MLInputDatasetHandler(client); - listener = spy(new ActionListener() { - @Override - public void onResponse(MLInputDataset inputDataset) {} - - @Override - public void onFailure(Exception e) {} - }); - - } - - @SuppressWarnings("unchecked") - public void testSearchQueryInputDatasetWithHits() { - searchResponse = mock(SearchResponse.class); - BytesReference bytesArray = new BytesArray("{\"taskId\":\"111\"}"); - SearchHit hit = new SearchHit(1); - hit.sourceRef(bytesArray); - SearchHits hits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f); - when(searchResponse.getHits()).thenReturn(hits); - doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments()[1]; - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), any()); - - SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset - .builder() - .indices(Collections.singletonList("index1")) - .searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())) - .build(); - mlInputDatasetHandler.parseSearchQueryInput(searchQueryInputDataset, listener); - ArgumentCaptor captor = ArgumentCaptor.forClass(MLInputDataset.class); - verify(listener, times(1)).onResponse(captor.capture()); - Assert.assertEquals(captor.getAllValues().size(), 1); - } - - @SuppressWarnings("unchecked") - public void testSearchQueryInputDatasetWithoutHits() { - searchResponse = mock(SearchResponse.class); - SearchHits hits = new SearchHits(new SearchHit[0], new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f); - when(searchResponse.getHits()).thenReturn(hits); - doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments()[1]; - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), any()); - - SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset - .builder() - .indices(Collections.singletonList("index1")) - .searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())) - .build(); - mlInputDatasetHandler.parseSearchQueryInput(searchQueryInputDataset, listener); - verify(listener, times(1)).onFailure(any()); - } - - public void testSearchQueryInputDatasetWithNullHits() { - searchResponse = mock(SearchResponse.class); - when(searchResponse.getHits()).thenReturn(null); - doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments()[1]; - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), any()); - - SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset - .builder() - .indices(Collections.singletonList("index1")) - .searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())) - .build(); - mlInputDatasetHandler.parseSearchQueryInput(searchQueryInputDataset, listener); - verify(listener, times(1)).onFailure(any()); - } - - public void testSearchQueryInputDatasetWithNullResponse() { - doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments()[1]; - listener.onResponse(null); - return null; - }).when(client).search(any(), any()); - - SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset - .builder() - .indices(Collections.singletonList("index1")) - .searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())) - .build(); - mlInputDatasetHandler.parseSearchQueryInput(searchQueryInputDataset, listener); - verify(listener, times(1)).onFailure(any()); - } - - public void testSearchQueryInputDatasetWrongType() { - expectedEx.expect(IllegalArgumentException.class); - expectedEx.expectMessage("Input dataset is not SEARCH_QUERY type."); - DataFrame testDataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() { - { - put("key1", 2.0D); - } - })); - DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset.builder().dataFrame(testDataFrame).build(); - mlInputDatasetHandler.parseSearchQueryInput(dataFrameInputDataset, listener); - } - -} diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index ccedef9bc1..ce01b44026 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -42,8 +42,8 @@ import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.utils.TestHelper; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 598d2db5a7..2f8ef74f66 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -104,7 +104,7 @@ import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.stats.MLNodeLevelStat; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index 11e9bc3441..9011746797 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -41,7 +41,7 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStats; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 0d0c594458..13de526978 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -55,7 +55,7 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; @@ -214,7 +214,6 @@ public void testExecuteTask_OnLocalNode() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); - // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); verify(client).get(any(), any()); verify(mlTaskManager).remove(anyString()); @@ -237,7 +236,6 @@ public void testExecuteTask_OnLocalNode_QueryInput() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithQuery, transportService, listener); verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); - // verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); verify(client).get(any(), any()); verify(mlTaskManager).remove(anyString()); @@ -248,7 +246,6 @@ public void testExecuteTask_OnLocalNode_QueryInput_Failure() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithQuery, transportService, listener); verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); - // verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager, never()).add(any(MLTask.class)); verify(client, never()).get(any(), any()); } @@ -277,7 +274,6 @@ public void testExecuteTask_OnLocalNode_GetModelFail() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); - // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); verify(client).get(any(), any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -291,7 +287,6 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); - // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); verify(client, never()).get(any(), any()); verify(mlTaskManager).remove(anyString()); @@ -305,7 +300,6 @@ public void testExecuteTask_OnLocalNode_NullGetResponse() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); - // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); verify(client).get(any(), any()); verify(mlTaskManager).remove(anyString()); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java index 69b5f613b8..ab5f82734e 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java @@ -39,7 +39,7 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java index a40c5c87cf..ff7c963e8a 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java @@ -49,7 +49,7 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStats; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java index ae397067bc..943bd5740d 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java @@ -52,8 +52,8 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.opensearch.ml.indices.MLIndicesHandler; -import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLInputDatasetHandler; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStats; diff --git a/plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java index 497a7c4229..78141b1781 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java @@ -28,7 +28,7 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.threadpool.ThreadPool; public class MockHelper { diff --git a/search-processors/build.gradle b/search-processors/build.gradle index ff52eeeddb..394b45c9f2 100644 --- a/search-processors/build.gradle +++ b/search-processors/build.gradle @@ -28,12 +28,10 @@ repositories { } dependencies { - + implementation project(path: ":${rootProject.name}-common", configuration: 'shadow') compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' implementation 'org.apache.commons:commons-lang3:3.12.0' - //implementation project(':opensearch-ml-client') - implementation project(':opensearch-ml-common') implementation project(':opensearch-ml-memory') implementation group: 'org.opensearch', name: 'common-utils', version: "${common_utils_version}" // https://mvnrepository.com/artifact/org.apache.httpcomponents.core5/httpcore5