From 48ad895e0f6beea5f2843da1d403c333f4b5ffda Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Wed, 13 Dec 2023 15:37:35 -0800 Subject: [PATCH 1/2] Add CatIndexTool (#1746) * Add CatIndexTool Signed-off-by: Daniel Widdis * Add test coverage Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis --- .gitignore | 3 + .../ml/engine/tools/CatIndexTool.java | 435 ++++++++++++++++++ .../ml/engine/tools/CatIndexToolTests.java | 245 ++++++++++ 3 files changed, 683 insertions(+) create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java diff --git a/.gitignore b/.gitignore index 154f424daf..2fc377955d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ .gradle/ build/ .idea/ +.project +.classpath +.settings client/build/ common/build/ ml-algorithms/build/ diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java new file mode 100644 index 0000000000..16cec3870d --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java @@ -0,0 +1,435 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Spliterators; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import org.apache.logging.log4j.util.Strings; +import org.opensearch.action.admin.cluster.health.ClusterHealthRequest; +import org.opensearch.action.admin.cluster.health.ClusterHealthResponse; +import org.opensearch.action.admin.cluster.state.ClusterStateRequest; +import org.opensearch.action.admin.cluster.state.ClusterStateResponse; +import org.opensearch.action.admin.indices.settings.get.GetSettingsRequest; +import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; +import org.opensearch.action.admin.indices.stats.CommonStats; +import org.opensearch.action.admin.indices.stats.IndexStats; +import org.opensearch.action.admin.indices.stats.IndicesStatsRequest; +import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; +import org.opensearch.action.support.GroupedActionListener; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.client.Client; +import org.opensearch.cluster.health.ClusterIndexHealth; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Table; +import org.opensearch.common.Table.Cell; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.index.IndexSettings; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; + +import lombok.Getter; +import lombok.Setter; + +@ToolAnnotation(CatIndexTool.TYPE) +public class CatIndexTool implements Tool { + public static final String TYPE = "CatIndexTool"; + private static final String DEFAULT_DESCRIPTION = "Use this tool to get index information."; + + @Setter + @Getter + private String name = CatIndexTool.TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + @Getter + private String version; + + private Client client; + @Setter + private Parser inputParser; + @Setter + private Parser outputParser; + @SuppressWarnings("unused") + private ClusterService clusterService; + + public CatIndexTool(Client client, ClusterService clusterService) { + this.client = client; + this.clusterService = clusterService; + + outputParser = new Parser<>() { + @Override + public Object parse(Object o) { + @SuppressWarnings("unchecked") + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + } + }; + } + + @Override + public void run(Map parameters, ActionListener listener) { + // TODO: This logic exactly matches the OpenSearch _cat/indices REST action. If code at + // o.o.rest/action/cat/RestIndicesAction.java changes those changes need to be reflected here + // https://github.com/opensearch-project/ml-commons/pull/1582#issuecomment-1796962876 + @SuppressWarnings("unchecked") + List indexList = parameters.containsKey("indices") + ? gson.fromJson(parameters.get("indices"), List.class) + : Collections.emptyList(); + final String[] indices = indexList.toArray(Strings.EMPTY_ARRAY); + + final IndicesOptions indicesOptions = IndicesOptions.strictExpand(); + final boolean local = parameters.containsKey("local") ? Boolean.parseBoolean("local") : false; + final TimeValue clusterManagerNodeTimeout = DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT; + final boolean includeUnloadedSegments = parameters.containsKey("include_unloaded_segments") + ? Boolean.parseBoolean(parameters.get("include_unloaded_segments")) + : false; + + final ActionListener internalListener = ActionListener.notifyOnce(ActionListener.wrap(table -> { + // Handle empty table + if (table.getRows().isEmpty()) { + @SuppressWarnings("unchecked") + T empty = (T) ("There were no results searching the indices parameter [" + parameters.get("indices") + "]."); + listener.onResponse(empty); + return; + } + StringBuilder sb = new StringBuilder( + // Currently using c.value which is short header matching _cat/indices + // May prefer to use c.attr.get("desc") for full description + table.getHeaders().stream().map(c -> c.value.toString()).collect(Collectors.joining("\t", "", "\n")) + ); + for (List row : table.getRows()) { + sb.append(row.stream().map(c -> c.value == null ? null : c.value.toString()).collect(Collectors.joining("\t", "", "\n"))); + } + @SuppressWarnings("unchecked") + T response = (T) sb.toString(); + listener.onResponse(response); + }, listener::onFailure)); + + sendGetSettingsRequest( + indices, + indicesOptions, + local, + clusterManagerNodeTimeout, + client, + new ActionListener() { + @Override + public void onResponse(final GetSettingsResponse getSettingsResponse) { + final GroupedActionListener groupedListener = createGroupedListener(4, internalListener); + groupedListener.onResponse(getSettingsResponse); + + // The list of indices that will be returned is determined by the indices returned from the Get Settings call. + // All the other requests just provide additional detail, and wildcards may be resolved differently depending on the + // type of request in the presence of security plugins (looking at you, ClusterHealthRequest), so + // force the IndicesOptions for all the sub-requests to be as inclusive as possible. + final IndicesOptions subRequestIndicesOptions = IndicesOptions.lenientExpandHidden(); + + sendIndicesStatsRequest( + indices, + subRequestIndicesOptions, + includeUnloadedSegments, + client, + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) + ); + sendClusterStateRequest( + indices, + subRequestIndicesOptions, + local, + clusterManagerNodeTimeout, + client, + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) + ); + sendClusterHealthRequest( + indices, + subRequestIndicesOptions, + local, + clusterManagerNodeTimeout, + client, + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) + ); + } + + @Override + public void onFailure(final Exception e) { + internalListener.onFailure(e); + } + } + ); + } + + @Override + public String getType() { + return TYPE; + } + + /** + * We're using the Get Settings API here to resolve the authorized indices for the user. + * This is because the Cluster State and Cluster Health APIs do not filter output based + * on index privileges, so they can't be used to determine which indices are authorized + * or not. On top of this, the Indices Stats API cannot be used either to resolve indices + * as it does not provide information for all existing indices (for example recovering + * indices or non replicated closed indices are not reported in indices stats response). + */ + private void sendGetSettingsRequest( + final String[] indices, + final IndicesOptions indicesOptions, + final boolean local, + final TimeValue clusterManagerNodeTimeout, + final Client client, + final ActionListener listener + ) { + final GetSettingsRequest request = new GetSettingsRequest(); + request.indices(indices); + request.indicesOptions(indicesOptions); + request.local(local); + request.clusterManagerNodeTimeout(clusterManagerNodeTimeout); + request.names(IndexSettings.INDEX_SEARCH_THROTTLED.getKey()); + + client.admin().indices().getSettings(request, listener); + } + + private void sendClusterStateRequest( + final String[] indices, + final IndicesOptions indicesOptions, + final boolean local, + final TimeValue clusterManagerNodeTimeout, + final Client client, + final ActionListener listener + ) { + + final ClusterStateRequest request = new ClusterStateRequest(); + request.indices(indices); + request.indicesOptions(indicesOptions); + request.local(local); + request.clusterManagerNodeTimeout(clusterManagerNodeTimeout); + + client.admin().cluster().state(request, listener); + } + + private void sendClusterHealthRequest( + final String[] indices, + final IndicesOptions indicesOptions, + final boolean local, + final TimeValue clusterManagerNodeTimeout, + final Client client, + final ActionListener listener + ) { + + final ClusterHealthRequest request = new ClusterHealthRequest(); + request.indices(indices); + request.indicesOptions(indicesOptions); + request.local(local); + request.clusterManagerNodeTimeout(clusterManagerNodeTimeout); + + client.admin().cluster().health(request, listener); + } + + private void sendIndicesStatsRequest( + final String[] indices, + final IndicesOptions indicesOptions, + final boolean includeUnloadedSegments, + final Client client, + final ActionListener listener + ) { + + final IndicesStatsRequest request = new IndicesStatsRequest(); + request.indices(indices); + request.indicesOptions(indicesOptions); + request.all(); + request.includeUnloadedSegments(includeUnloadedSegments); + + client.admin().indices().stats(request, listener); + } + + private GroupedActionListener createGroupedListener(final int size, final ActionListener
listener) { + return new GroupedActionListener<>(new ActionListener>() { + @Override + public void onResponse(final Collection responses) { + try { + GetSettingsResponse settingsResponse = extractResponse(responses, GetSettingsResponse.class); + Map indicesSettings = StreamSupport + .stream(Spliterators.spliterator(settingsResponse.getIndexToSettings().entrySet(), 0), false) + .collect(Collectors.toMap(cursor -> cursor.getKey(), cursor -> cursor.getValue())); + + ClusterStateResponse stateResponse = extractResponse(responses, ClusterStateResponse.class); + Map indicesStates = StreamSupport + .stream(stateResponse.getState().getMetadata().spliterator(), false) + .collect(Collectors.toMap(indexMetadata -> indexMetadata.getIndex().getName(), Function.identity())); + + ClusterHealthResponse healthResponse = extractResponse(responses, ClusterHealthResponse.class); + Map indicesHealths = healthResponse.getIndices(); + + IndicesStatsResponse statsResponse = extractResponse(responses, IndicesStatsResponse.class); + Map indicesStats = statsResponse.getIndices(); + + Table responseTable = buildTable(indicesSettings, indicesHealths, indicesStats, indicesStates); + listener.onResponse(responseTable); + } catch (Exception e) { + onFailure(e); + } + } + + @Override + public void onFailure(final Exception e) { + listener.onFailure(e); + } + }, size); + } + + @Override + public boolean validate(Map parameters) { + if (parameters == null || parameters.size() == 0) { + return false; + } + return true; + } + + /** + * Factory for the {@link CatIndexTool} + */ + public static class Factory implements Tool.Factory { + private Client client; + private ClusterService clusterService; + + private static Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (CatIndexTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * @param client The OpenSearch client + * @param clusterService The OpenSearch cluster service + */ + public void init(Client client, ClusterService clusterService) { + this.client = client; + this.clusterService = clusterService; + } + + @Override + public CatIndexTool create(Map map) { + return new CatIndexTool(client, clusterService); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } + + private Table getTableWithHeader() { + Table table = new Table(); + table.startHeaders(); + // First param is cell.value which is currently returned + // Second param is cell.attr we may want to use attr.desc in the future + table.addCell("health", "alias:h;desc:current health status"); + table.addCell("status", "alias:s;desc:open/close status"); + table.addCell("index", "alias:i,idx;desc:index name"); + table.addCell("uuid", "alias:id,uuid;desc:index uuid"); + table.addCell("pri", "alias:p,shards.primary,shardsPrimary;text-align:right;desc:number of primary shards"); + table.addCell("rep", "alias:r,shards.replica,shardsReplica;text-align:right;desc:number of replica shards"); + table.addCell("docs.count", "alias:dc,docsCount;text-align:right;desc:available docs"); + table.addCell("docs.deleted", "alias:dd,docsDeleted;text-align:right;desc:deleted docs"); + table.addCell("store.size", "sibling:pri;alias:ss,storeSize;text-align:right;desc:store size of primaries & replicas"); + table.addCell("pri.store.size", "text-align:right;desc:store size of primaries"); + // Above includes all the default fields for cat indices. See RestIndicesAction for a lot more that could be included. + table.endHeaders(); + return table; + } + + private Table buildTable( + final Map indicesSettings, + final Map indicesHealths, + final Map indicesStats, + final Map indicesMetadatas + ) { + final Table table = getTableWithHeader(); + + indicesSettings.forEach((indexName, settings) -> { + if (indicesMetadatas.containsKey(indexName) == false) { + // the index exists in the Get Indices response but is not present in the cluster state: + // it is likely that the index was deleted in the meanwhile, so we ignore it. + return; + } + + final IndexMetadata indexMetadata = indicesMetadatas.get(indexName); + final IndexMetadata.State indexState = indexMetadata.getState(); + final IndexStats indexStats = indicesStats.get(indexName); + + final String health; + final ClusterIndexHealth indexHealth = indicesHealths.get(indexName); + if (indexHealth != null) { + health = indexHealth.getStatus().toString().toLowerCase(Locale.ROOT); + } else if (indexStats != null) { + health = "red*"; + } else { + health = ""; + } + + final CommonStats primaryStats; + final CommonStats totalStats; + + if (indexStats == null || indexState == IndexMetadata.State.CLOSE) { + primaryStats = new CommonStats(); + totalStats = new CommonStats(); + } else { + primaryStats = indexStats.getPrimaries(); + totalStats = indexStats.getTotal(); + } + table.startRow(); + table.addCell(health); + table.addCell(indexState.toString().toLowerCase(Locale.ROOT)); + table.addCell(indexName); + table.addCell(indexMetadata.getIndexUUID()); + table.addCell(indexHealth == null ? null : indexHealth.getNumberOfShards()); + table.addCell(indexHealth == null ? null : indexHealth.getNumberOfReplicas()); + + table.addCell(primaryStats.getDocs() == null ? null : primaryStats.getDocs().getCount()); + table.addCell(primaryStats.getDocs() == null ? null : primaryStats.getDocs().getDeleted()); + + table.addCell(totalStats.getStore() == null ? null : totalStats.getStore().size()); + table.addCell(primaryStats.getStore() == null ? null : primaryStats.getStore().size()); + + table.endRow(); + }); + + return table; + } + + @SuppressWarnings("unchecked") + private static A extractResponse(final Collection responses, Class c) { + return (A) responses.stream().filter(c::isInstance).findFirst().get(); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java new file mode 100644 index 0000000000..cffb0ff338 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java @@ -0,0 +1,245 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.action.admin.cluster.health.ClusterHealthResponse; +import org.opensearch.action.admin.cluster.state.ClusterStateResponse; +import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; +import org.opensearch.action.admin.indices.stats.CommonStats; +import org.opensearch.action.admin.indices.stats.CommonStatsFlags; +import org.opensearch.action.admin.indices.stats.IndexStats; +import org.opensearch.action.admin.indices.stats.IndexStats.IndexStatsBuilder; +import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; +import org.opensearch.action.admin.indices.stats.ShardStats; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.health.ClusterIndexHealth; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.IndexMetadata.State; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.routing.IndexRoutingTable; +import org.opensearch.cluster.routing.IndexShardRoutingTable; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.routing.ShardRoutingState; +import org.opensearch.cluster.routing.TestShardRouting; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.UUIDs; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.shard.ShardPath; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.engine.tools.CatIndexTool.Factory; + +public class CatIndexToolTests { + + @Mock + private Client client; + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private ClusterAdminClient clusterAdminClient; + @Mock + private ClusterService clusterService; + @Mock + private ClusterState clusterState; + @Mock + private Metadata metadata; + @Mock + private GetSettingsResponse getSettingsResponse; + @Mock + private IndicesStatsResponse indicesStatsResponse; + @Mock + private ClusterStateResponse clusterStateResponse; + @Mock + private ClusterHealthResponse clusterHealthResponse; + @Mock + private IndexMetadata indexMetadata; + @Mock + private IndexRoutingTable indexRoutingTable; + + private Map indicesParams; + private Map otherParams; + private Map emptyParams; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + when(adminClient.indices()).thenReturn(indicesAdminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + when(client.admin()).thenReturn(adminClient); + + when(indexMetadata.getState()).thenReturn(State.OPEN); + when(indexMetadata.getCreationVersion()).thenReturn(Version.CURRENT); + + when(metadata.index(any(String.class))).thenReturn(indexMetadata); + when(clusterState.metadata()).thenReturn(metadata); + when(clusterService.state()).thenReturn(clusterState); + + CatIndexTool.Factory.getInstance().init(client, clusterService); + + indicesParams = Map.of("index", "[\"foo\"]"); + otherParams = Map.of("other", "[\"bar\"]"); + emptyParams = Collections.emptyMap(); + } + + @Test + public void testRunAsyncNoIndices() throws Exception { + @SuppressWarnings("unchecked") + ArgumentCaptor> settingsActionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doNothing().when(indicesAdminClient).getSettings(any(), settingsActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> statsActionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doNothing().when(indicesAdminClient).stats(any(), statsActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> clusterStateActionListenerCaptor = ArgumentCaptor + .forClass(ActionListener.class); + doNothing().when(clusterAdminClient).state(any(), clusterStateActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> clusterHealthActionListenerCaptor = ArgumentCaptor + .forClass(ActionListener.class); + doNothing().when(clusterAdminClient).health(any(), clusterHealthActionListenerCaptor.capture()); + + when(getSettingsResponse.getIndexToSettings()).thenReturn(Collections.emptyMap()); + when(indicesStatsResponse.getIndices()).thenReturn(Collections.emptyMap()); + when(clusterStateResponse.getState()).thenReturn(clusterState); + when(clusterState.getMetadata()).thenReturn(metadata); + when(metadata.spliterator()).thenReturn(Arrays.spliterator(new IndexMetadata[0])); + + when(clusterHealthResponse.getIndices()).thenReturn(Collections.emptyMap()); + + Tool tool = CatIndexTool.Factory.getInstance().create(Collections.emptyMap()); + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + + tool.run(otherParams, listener); + settingsActionListenerCaptor.getValue().onResponse(getSettingsResponse); + statsActionListenerCaptor.getValue().onResponse(indicesStatsResponse); + clusterStateActionListenerCaptor.getValue().onResponse(clusterStateResponse); + clusterHealthActionListenerCaptor.getValue().onResponse(clusterHealthResponse); + + future.join(); + assertEquals("There were no results searching the indices parameter [null].", future.get()); + } + + @Test + public void testRunAsyncIndexStats() throws Exception { + String indexName = "foo"; + Index index = new Index(indexName, UUIDs.base64UUID()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> settingsActionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doNothing().when(indicesAdminClient).getSettings(any(), settingsActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> statsActionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doNothing().when(indicesAdminClient).stats(any(), statsActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> clusterStateActionListenerCaptor = ArgumentCaptor + .forClass(ActionListener.class); + doNothing().when(clusterAdminClient).state(any(), clusterStateActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> clusterHealthActionListenerCaptor = ArgumentCaptor + .forClass(ActionListener.class); + doNothing().when(clusterAdminClient).health(any(), clusterHealthActionListenerCaptor.capture()); + + when(getSettingsResponse.getIndexToSettings()).thenReturn(Map.of("foo", Settings.EMPTY)); + + int shardId = 0; + ShardId shId = new ShardId(index, shardId); + Path path = Files.createTempDirectory("temp").resolve("indices").resolve(index.getUUID()).resolve(String.valueOf(shardId)); + ShardPath shardPath = new ShardPath(false, path, path, shId); + ShardRouting routing = TestShardRouting.newShardRouting(shId, "node", true, ShardRoutingState.STARTED); + CommonStats commonStats = new CommonStats(CommonStatsFlags.ALL); + IndexStats fooStats = new IndexStatsBuilder(index.getName(), index.getUUID()) + .add(new ShardStats(routing, shardPath, commonStats, null, null, null)) + .build(); + when(indicesStatsResponse.getIndices()).thenReturn(Map.of(indexName, fooStats)); + + when(indexMetadata.getIndex()).thenReturn(index); + when(indexMetadata.getNumberOfShards()).thenReturn(5); + when(indexMetadata.getNumberOfReplicas()).thenReturn(1); + when(clusterStateResponse.getState()).thenReturn(clusterState); + when(clusterState.getMetadata()).thenReturn(metadata); + when(metadata.spliterator()).thenReturn(Arrays.spliterator(new IndexMetadata[] { indexMetadata })); + @SuppressWarnings("unchecked") + Iterator iterator = (Iterator) mock(Iterator.class); + when(iterator.hasNext()).thenReturn(false); + when(indexRoutingTable.iterator()).thenReturn(iterator); + ClusterIndexHealth fooHealth = new ClusterIndexHealth(indexMetadata, indexRoutingTable); + when(clusterHealthResponse.getIndices()).thenReturn(Map.of(indexName, fooHealth)); + + // Now make the call + Tool tool = CatIndexTool.Factory.getInstance().create(Collections.emptyMap()); + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + + tool.run(otherParams, listener); + settingsActionListenerCaptor.getValue().onResponse(getSettingsResponse); + statsActionListenerCaptor.getValue().onResponse(indicesStatsResponse); + clusterStateActionListenerCaptor.getValue().onResponse(clusterStateResponse); + clusterHealthActionListenerCaptor.getValue().onResponse(clusterHealthResponse); + + future.orTimeout(10, TimeUnit.SECONDS).join(); + String response = future.get(); + String[] responseRows = response.trim().split("\\n"); + + assertEquals(2, responseRows.length); + String header = responseRows[0]; + String fooRow = responseRows[1]; + assertEquals(header.split("\\t").length, fooRow.split("\\t").length); + assertEquals("health\tstatus\tindex\tuuid\tpri\trep\tdocs.count\tdocs.deleted\tstore.size\tpri.store.size", header); + assertEquals("red\topen\tfoo\tnull\t5\t1\t0\t0\t0b\t0b", fooRow); + } + + @Test + public void testTool() { + Factory instance = CatIndexTool.Factory.getInstance(); + assertEquals(instance, CatIndexTool.Factory.getInstance()); + assertTrue(instance.getDefaultDescription().contains("tool")); + + Tool tool = instance.create(Collections.emptyMap()); + assertEquals(CatIndexTool.TYPE, tool.getType()); + assertTrue(tool.validate(indicesParams)); + assertTrue(tool.validate(otherParams)); + assertFalse(tool.validate(emptyParams)); + } +} From 00b5feaf6605e37eca769e32b7dccbf3b2270573 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Fri, 15 Dec 2023 15:29:56 -0800 Subject: [PATCH 2/2] Memory Manager and Update Memory Actions/APIs (#1761) * More memory actions, APS and tests Signed-off-by: Xun Zhang * refactor memory manager and Get Trace actions Signed-off-by: Xun Zhang * updates for some comments Signed-off-by: Xun Zhang * comments updated Signed-off-by: Xun Zhang --------- Signed-off-by: Xun Zhang --- .../common/conversation/ActionConstants.java | 9 + .../action/conversation/GetTracesAction.java | 23 ++ .../action/conversation/GetTracesRequest.java | 124 ++++++++ .../conversation/GetTracesResponse.java | 77 +++++ .../GetTracesTransportAction.java | 64 ++++ .../UpdateConversationAction.java | 18 ++ .../UpdateConversationRequest.java | 105 +++++++ .../UpdateConversationTransportAction.java | 68 +++++ .../conversation/UpdateInteractionAction.java | 19 ++ .../UpdateInteractionRequest.java | 110 +++++++ .../UpdateInteractionTransportAction.java | 70 +++++ .../conversation/GetTracesRequestTests.java | 101 +++++++ .../conversation/GetTracesResponseTests.java | 109 +++++++ .../GetTracesTransportActionTests.java | 159 ++++++++++ .../UpdateConversationRequestTests.java | 155 ++++++++++ ...pdateConversationTransportActionTests.java | 135 +++++++++ .../UpdateInteractionRequestTests.java | 172 +++++++++++ ...UpdateInteractionTransportActionTests.java | 134 +++++++++ ml-algorithms/build.gradle | 1 + .../ml/engine/memory/MLMemoryManager.java | 164 +++++++++++ .../engine/memory/MLMemoryManagerTests.java | 277 ++++++++++++++++++ .../ml/plugin/MachineLearningPlugin.java | 22 +- .../ml/rest/RestMemoryGetTracesAction.java | 37 +++ .../RestMemoryUpdateConversationAction.java | 59 ++++ .../RestMemoryUpdateInteractionAction.java | 60 ++++ .../rest/RestMemoryGetTracesActionTests.java | 64 ++++ .../RestMemoryUpdateConversationTests.java | 165 +++++++++++ ...estMemoryUpdateInteractionActionTests.java | 164 +++++++++++ 28 files changed, 2663 insertions(+), 2 deletions(-) create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesRequest.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesResponse.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequest.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequest.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesRequestTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesResponseTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportActionTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequestTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequestTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportActionTests.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetTracesAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateConversationAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetTracesActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateConversationTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionActionTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java index 8776c618b0..119d5a6659 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java @@ -29,6 +29,8 @@ public class ActionConstants { public final static String RESPONSE_CONVERSATION_LIST_FIELD = "conversations"; /** name of list on interactions in all responses */ public final static String RESPONSE_INTERACTION_LIST_FIELD = "interactions"; + /** name of list on traces in all responses */ + public final static String RESPONSE_TRACES_LIST_FIELD = "traces"; /** name of interaction Id field in all responses */ public final static String RESPONSE_INTERACTION_ID_FIELD = "interaction_id"; @@ -56,20 +58,27 @@ public class ActionConstants { public final static String SUCCESS_FIELD = "success"; private final static String BASE_REST_PATH = "/_plugins/_ml/memory/conversation"; + private final static String BASE_REST_INTERACTION_PATH = "/_plugins/_ml/memory/interaction"; /** path for create conversation */ public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/_create"; /** path for get conversations */ public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_list"; + /** path for update conversations */ + public final static String UPDATE_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_update"; /** path for create interaction */ public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_create"; /** path for get interactions */ public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_list"; + /** path for get traces */ + public final static String GET_TRACES_REST_PATH = "/_plugins/_ml/memory/trace" + "/{interaction_id}/_list"; /** path for delete conversation */ public final static String DELETE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_delete"; /** path for search conversations */ public final static String SEARCH_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_search"; /** path for search interactions */ public final static String SEARCH_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_search"; + /** path for update interactions */ + public final static String UPDATE_INTERACTIONS_REST_PATH = BASE_REST_INTERACTION_PATH + "/{interaction_id}/_update"; /** path for get conversation */ public final static String GET_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}"; /** path for get interaction */ diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesAction.java new file mode 100644 index 0000000000..0117df94b5 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesAction.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; + +/** + * Action to return the traces associated with an interaction + */ +public class GetTracesAction extends ActionType { + /** Instance of this */ + public static final GetTracesAction INSTANCE = new GetTracesAction(); + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/trace/get"; + + private GetTracesAction() { + super(NAME, GetTracesResponse::new); + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesRequest.java new file mode 100644 index 0000000000..9b65f78148 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesRequest.java @@ -0,0 +1,124 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; + +import lombok.Getter; + +/** + * ActionRequest for get traces + */ +public class GetTracesRequest extends ActionRequest { + @Getter + private String interactionId; + @Getter + private int maxResults = ActionConstants.DEFAULT_MAX_RESULTS; + @Getter + private int from = 0; + + /** + * Constructor + * @param interactionId UID of the interaction to get traces from + */ + public GetTracesRequest(String interactionId) { + this.interactionId = interactionId; + } + + /** + * Constructor + * @param interactionId UID of the conversation to get interactions from + * @param maxResults number of interactions to retrieve + */ + public GetTracesRequest(String interactionId, int maxResults) { + this.interactionId = interactionId; + this.maxResults = maxResults; + } + + /** + * Constructor + * @param interactionId UID of the conversation to get interactions from + * @param maxResults number of interactions to retrieve + * @param from position of first interaction to retrieve + */ + public GetTracesRequest(String interactionId, int maxResults, int from) { + this.interactionId = interactionId; + this.maxResults = maxResults; + this.from = from; + } + + /** + * Constructor + * @param in streaminput to read this from. assumes there was a GetTracesRequest.writeTo + * @throws IOException if there wasn't a GIR in the stream + */ + public GetTracesRequest(StreamInput in) throws IOException { + super(in); + this.interactionId = in.readString(); + this.maxResults = in.readInt(); + this.from = in.readInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(interactionId); + out.writeInt(maxResults); + out.writeInt(from); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (interactionId == null) { + exception = addValidationError("Traces must be retrieved from an interaction", exception); + } + if (maxResults <= 0) { + exception = addValidationError("The number of traces to retrieve must be positive", exception); + } + if (from < 0) { + exception = addValidationError("The starting position must be nonnegative", exception); + } + + return exception; + } + + /** + * Makes a GetTracesRequest out of a RestRequest + * @param request Rest Request representing a get traces request + * @return a new GetTracesRequest + * @throws IOException if something goes wrong + */ + public static GetTracesRequest fromRestRequest(RestRequest request) throws IOException { + String cid = request.param(ActionConstants.RESPONSE_INTERACTION_ID_FIELD); + if (request.hasParam(ActionConstants.NEXT_TOKEN_FIELD)) { + int from = Integer.parseInt(request.param(ActionConstants.NEXT_TOKEN_FIELD)); + if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) { + int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)); + return new GetTracesRequest(cid, maxResults, from); + } else { + return new GetTracesRequest(cid, ActionConstants.DEFAULT_MAX_RESULTS, from); + } + } else { + if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) { + int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)); + return new GetTracesRequest(cid, maxResults); + } else { + return new GetTracesRequest(cid); + } + } + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesResponse.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesResponse.java new file mode 100644 index 0000000000..38486f1c1b --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesResponse.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.conversation.Interaction; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NonNull; + +/** + * Action Response for get traces for an interaction + */ +@AllArgsConstructor +public class GetTracesResponse extends ActionResponse implements ToXContentObject { + @Getter + @NonNull + private List traces; + @Getter + private int nextToken; + private boolean hasMoreTokens; + + /** + * Constructor + * @param in stream input; assumes GetTracesResponse.writeTo was called + * @throws IOException if there's not a G.I.R. in the stream + */ + public GetTracesResponse(StreamInput in) throws IOException { + super(in); + traces = in.readList(Interaction::fromStream); + nextToken = in.readInt(); + hasMoreTokens = in.readBoolean(); + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeList(traces); + out.writeInt(nextToken); + out.writeBoolean(hasMoreTokens); + } + + /** + * Are there more pages in this search results + * @return whether there are more traces in this search + */ + public boolean hasMorePages() { + return hasMoreTokens; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.startArray(ActionConstants.RESPONSE_TRACES_LIST_FIELD); + for (Interaction trace : traces) { + trace.toXContent(builder, params); + } + builder.endArray(); + if (hasMoreTokens) { + builder.field(ActionConstants.NEXT_TOKEN_FIELD, nextToken); + } + builder.endObject(); + return builder; + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportAction.java new file mode 100644 index 0000000000..698136fd95 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportAction.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import java.util.List; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetTracesTransportAction extends HandledTransportAction { + private Client client; + private ConversationalMemoryHandler cmHandler; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + */ + @Inject + public GetTracesTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + Client client + ) { + super(GetTracesAction.NAME, transportService, actionFilters, GetTracesRequest::new); + this.client = client; + this.cmHandler = cmHandler; + } + + @Override + public void doExecute(Task task, GetTracesRequest request, ActionListener actionListener) { + int maxResults = request.getMaxResults(); + int from = request.getFrom(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + // TODO: check this newStoredContext() method and remove it if it's redundant + ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + ActionListener> al = ActionListener.wrap(tracesList -> { + internalListener.onResponse(new GetTracesResponse(tracesList, from + maxResults, tracesList.size() == maxResults)); + }, e -> { internalListener.onFailure(e); }); + cmHandler.getTraces(request.getInteractionId(), from, maxResults, al); + } catch (Exception e) { + log.error("Failed to get traces for conversation " + request.getInteractionId(), e); + actionListener.onFailure(e); + } + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationAction.java new file mode 100644 index 0000000000..6c8023171e --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationAction.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; +import org.opensearch.action.update.UpdateResponse; + +public class UpdateConversationAction extends ActionType { + public static final UpdateConversationAction INSTANCE = new UpdateConversationAction(); + public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/update"; + + private UpdateConversationAction() { + super(NAME, UpdateResponse::new); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequest.java new file mode 100644 index 0000000000..7afec5d0ab --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequest.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Builder; +import lombok.Getter; + +@Getter +public class UpdateConversationRequest extends ActionRequest { + private String conversationId; + private Map updateContent; + + private static final Set allowedList = new HashSet<>(Arrays.asList(META_NAME_FIELD)); + + @Builder + public UpdateConversationRequest(String conversationId, Map updateContent) { + this.conversationId = conversationId; + this.updateContent = filterUpdateContent(updateContent); + } + + public UpdateConversationRequest(StreamInput in) throws IOException { + super(in); + this.conversationId = in.readString(); + this.updateContent = filterUpdateContent(in.readMap()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.conversationId); + out.writeMap(this.getUpdateContent()); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.conversationId == null) { + exception = addValidationError("conversation id can't be null", exception); + } + if (this.updateContent == null) { + exception = addValidationError("Update conversation content can't be null", exception); + } + + return exception; + } + + public static UpdateConversationRequest parse(XContentParser parser, String conversationId) throws IOException { + Map dataAsMap = null; + dataAsMap = parser.map(); + + return UpdateConversationRequest.builder().conversationId(conversationId).updateContent(dataAsMap).build(); + } + + public static UpdateConversationRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof UpdateConversationRequest) { + return (UpdateConversationRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new UpdateConversationRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into UpdateConversationRequest", e); + } + } + + private Map filterUpdateContent(Map updateContent) { + if (updateContent == null) { + return new HashMap<>(); + } + return updateContent + .entrySet() + .stream() + .filter(map -> allowedList.contains(map.getKey())) + .collect(Collectors.toMap(map -> map.getKey(), map -> map.getValue())); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java new file mode 100644 index 0000000000..9f8c42f17b --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class UpdateConversationTransportAction extends HandledTransportAction { + Client client; + + @Inject + public UpdateConversationTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { + super(UpdateConversationAction.NAME, transportService, actionFilters, UpdateConversationRequest::new); + this.client = client; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request); + String conversationId = updateConversationRequest.getConversationId(); + UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId); + updateRequest.doc(updateConversationRequest.getUpdateContent()); + updateRequest.docAsUpsert(true); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.update(updateRequest, getUpdateResponseListener(conversationId, listener, context)); + } catch (Exception e) { + log.error("Failed to update Conversation for conversation id" + conversationId, e); + listener.onFailure(e); + } + } + + private ActionListener getUpdateResponseListener( + String conversationId, + ActionListener actionListener, + ThreadContext.StoredContext context + ) { + return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { + if (updateResponse != null && updateResponse.getResult() == DocWriteResponse.Result.UPDATED) { + log.info("Successfully updated the Conversation with ID: {}", conversationId); + actionListener.onResponse(updateResponse); + } else { + log.info("Failed to update the Conversation with ID: {}", conversationId); + actionListener.onResponse(updateResponse); + } + }, exception -> { + log.error("Failed to update ML Conversation with ID " + conversationId, exception); + actionListener.onFailure(exception); + }), context::restore); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionAction.java new file mode 100644 index 0000000000..64f7e56846 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionAction.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; +import org.opensearch.action.update.UpdateResponse; + +public class UpdateInteractionAction extends ActionType { + public static final UpdateInteractionAction INSTANCE = new UpdateInteractionAction(); + public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/update"; + + private UpdateInteractionAction() { + super(NAME, UpdateResponse::new); + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequest.java new file mode 100644 index 0000000000..96ef467590 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequest.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.opensearch.OpenSearchParseException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Builder; +import lombok.Getter; + +@Getter +public class UpdateInteractionRequest extends ActionRequest { + private String interactionId; + private Map updateContent; + + private static final Set allowedList = new HashSet<>(Arrays.asList(INTERACTIONS_ADDITIONAL_INFO_FIELD)); + + @Builder + public UpdateInteractionRequest(String interactionId, Map updateContent) { + this.interactionId = interactionId; + this.updateContent = filterUpdateContent(updateContent); + } + + public UpdateInteractionRequest(StreamInput in) throws IOException { + super(in); + this.interactionId = in.readString(); + this.updateContent = filterUpdateContent(in.readMap()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.interactionId); + out.writeMap(this.getUpdateContent()); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.interactionId == null) { + exception = addValidationError("interaction id can't be null", exception); + } + if (this.updateContent == null) { + exception = addValidationError("Update Interaction content can't be null", exception); + } + + return exception; + } + + public static UpdateInteractionRequest parse(XContentParser parser, String interactionId) throws IOException { + Map dataAsMap = null; + dataAsMap = parser.map(); + + if (dataAsMap == null) { + throw new OpenSearchParseException("Failed to parse UpdateInteractionRequest due to Null update content"); + } + + return UpdateInteractionRequest.builder().interactionId(interactionId).updateContent(dataAsMap).build(); + } + + public static UpdateInteractionRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof UpdateInteractionRequest) { + return (UpdateInteractionRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new UpdateInteractionRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into UpdateInteractionRequest", e); + } + } + + private Map filterUpdateContent(Map updateContent) { + if (updateContent == null) { + return new HashMap<>(); + } + return updateContent + .entrySet() + .stream() + .filter(map -> allowedList.contains(map.getKey())) + .collect(Collectors.toMap(map -> map.getKey(), map -> map.getValue())); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java new file mode 100644 index 0000000000..9abf8571c4 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class UpdateInteractionTransportAction extends HandledTransportAction { + Client client; + + @Inject + public UpdateInteractionTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { + super(UpdateInteractionAction.NAME, transportService, actionFilters, UpdateInteractionRequest::new); + this.client = client; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest.fromActionRequest(request); + String interactionId = updateInteractionRequest.getInteractionId(); + UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.INTERACTIONS_INDEX_NAME, interactionId); + updateRequest.doc(updateInteractionRequest.getUpdateContent()); + updateRequest.docAsUpsert(true); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.update(updateRequest, getUpdateResponseListener(interactionId, listener, context)); + } catch (Exception e) { + log.error("Failed to update Interaction for interaction id " + interactionId, e); + listener.onFailure(e); + } + } + + private ActionListener getUpdateResponseListener( + String interactionId, + ActionListener actionListener, + ThreadContext.StoredContext context + ) { + return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { + if (updateResponse != null && updateResponse.getResult() == DocWriteResponse.Result.UPDATED) { + log.info("Successfully updated the interaction with ID: {}", interactionId); + actionListener.onResponse(updateResponse); + } else { + log.info("Failed to update the interaction with ID: {}", interactionId); + actionListener.onResponse(updateResponse); + } + }, exception -> { + log.error("Failed to update ML interaction with ID " + interactionId, exception); + actionListener.onFailure(exception); + }), context::restore); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesRequestTests.java new file mode 100644 index 0000000000..0b88bd48c6 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesRequestTests.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class GetTracesRequestTests extends OpenSearchTestCase { + + public void testConstructorsAndStreaming() throws IOException { + GetTracesRequest request = new GetTracesRequest("test-iid"); + assert (request.validate() == null); + assert (request.getInteractionId().equals("test-iid")); + assert (request.getFrom() == 0); + assert (request.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS); + + GetTracesRequest req2 = new GetTracesRequest("test-iid2", 3); + assert (req2.validate() == null); + assert (req2.getInteractionId().equals("test-iid2")); + assert (req2.getFrom() == 0); + assert (req2.getMaxResults() == 3); + + GetTracesRequest req3 = new GetTracesRequest("test-iid3", 4, 5); + assert (req3.validate() == null); + assert (req3.getInteractionId().equals("test-iid3")); + assert (req3.getFrom() == 5); + assert (req3.getMaxResults() == 4); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + request.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetTracesRequest req4 = new GetTracesRequest(in); + assert (req4.validate() == null); + assert (req4.getInteractionId().equals("test-iid")); + assert (req4.getFrom() == 0); + assert (req4.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS); + } + + public void testBadValues_thenFail() { + String nullstr = null; + GetTracesRequest request = new GetTracesRequest(nullstr); + assert (request.validate().validationErrors().get(0).equals("Traces must be retrieved from an interaction")); + assert (request.validate().validationErrors().size() == 1); + + request = new GetTracesRequest("iid", -2); + assert (request.validate().validationErrors().size() == 1); + assert (request.validate().validationErrors().get(0).equals("The number of traces to retrieve must be positive")); + + request = new GetTracesRequest("iid", 2, -2); + assert (request.validate().validationErrors().size() == 1); + assert (request.validate().validationErrors().get(0).equals("The starting position must be nonnegative")); + } + + public void testFromRestRequest() throws IOException { + Map basic = Map.of(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid1"); + Map maxResOnly = Map + .of(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid2", ActionConstants.REQUEST_MAX_RESULTS_FIELD, "4"); + Map nextTokOnly = Map + .of(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid3", ActionConstants.NEXT_TOKEN_FIELD, "6"); + Map bothFields = Map + .of( + ActionConstants.RESPONSE_INTERACTION_ID_FIELD, + "iid4", + ActionConstants.REQUEST_MAX_RESULTS_FIELD, + "2", + ActionConstants.NEXT_TOKEN_FIELD, + "7" + ); + RestRequest req1 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(basic).build(); + RestRequest req2 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(maxResOnly).build(); + RestRequest req3 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(nextTokOnly).build(); + RestRequest req4 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(bothFields).build(); + GetTracesRequest gir1 = GetTracesRequest.fromRestRequest(req1); + GetTracesRequest gir2 = GetTracesRequest.fromRestRequest(req2); + GetTracesRequest gir3 = GetTracesRequest.fromRestRequest(req3); + GetTracesRequest gir4 = GetTracesRequest.fromRestRequest(req4); + + assert (gir1.validate() == null && gir2.validate() == null && gir3.validate() == null && gir4.validate() == null); + assert (gir1.getInteractionId().equals("iid1") && gir2.getInteractionId().equals("iid2")); + assert (gir3.getInteractionId().equals("iid3") && gir4.getInteractionId().equals("iid4")); + assert (gir1.getFrom() == 0 && gir2.getFrom() == 0 && gir3.getFrom() == 6 && gir4.getFrom() == 7); + assert (gir1.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS && gir2.getMaxResults() == 4); + assert (gir3.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS && gir4.getMaxResults() == 2); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesResponseTests.java new file mode 100644 index 0000000000..e013bcc518 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesResponseTests.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.List; + +import org.apache.lucene.search.spell.LevenshteinDistance; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.test.OpenSearchTestCase; + +public class GetTracesResponseTests extends OpenSearchTestCase { + List traces; + + @Before + public void setup() { + traces = List + .of( + new Interaction( + "id0", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta"), + "parent_id", + 1 + ), + new Interaction( + "id1", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta"), + "parent_id", + 2 + ), + new Interaction( + "id2", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta"), + "parent_id", + 3 + + ) + ); + } + + public void testGetInteractionsResponseStreaming() throws IOException { + GetTracesResponse response = new GetTracesResponse(traces, 4, true); + assert (response.getTraces().equals(traces)); + assert (response.getNextToken() == 4); + assert (response.hasMorePages()); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + response.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetTracesResponse newResp = new GetTracesResponse(in); + assert (newResp.getTraces().equals(traces)); + assert (newResp.getNextToken() == 4); + assert (newResp.hasMorePages()); + } + + public void testToXContent_MoreTokens() throws IOException { + GetTracesResponse response = new GetTracesResponse(traces.subList(0, 1), 2, true); + Interaction trace = response.getTraces().get(0); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + System.out.println(result); + String expected = "{\"traces\":[{\"conversation_id\":\"cid\",\"interaction_id\":\"id0\",\"create_time\":" + + trace.getCreateTime() + + ",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"additional_info\":{\"metadata\":\"some meta\"},\"parent_interaction_id\":\"parent_id\",\"trace_number\":1}],\"next_token\":2}"; + // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness + LevenshteinDistance ld = new LevenshteinDistance(); + assert (ld.getDistance(result, expected) > 0.95); + } + + @Test(expected = NullPointerException.class) + public void testConstructor_NullTraces() { + GetTracesResponse response = new GetTracesResponse(null, 0, false); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportActionTests.java new file mode 100644 index 0000000000..c6aef01097 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportActionTests.java @@ -0,0 +1,159 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.List; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetTracesTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + GetTracesRequest request; + GetTracesTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + + @SuppressWarnings("unchecked") + ActionListener al = (ActionListener) Mockito.mock(ActionListener.class); + this.actionListener = al; + this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); + + this.request = new GetTracesRequest("test-iid"); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + + this.action = spy(new GetTracesTransportAction(transportService, actionFilters, cmHandler, client)); + } + + public void testGetTraces_noMorePages() { + Interaction testTrace = new Interaction( + "test-trace", + Instant.now(), + "test-cid", + "test-input", + "pt", + "test-response", + "test-origin", + Collections.singletonMap("metadata", "some meta"), + "parent-id", + 1 + ); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onResponse(List.of(testTrace)); + return null; + }).when(cmHandler).getTraces(any(), anyInt(), anyInt(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetTracesResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + List traces = argCaptor.getValue().getTraces(); + assert (traces.size() == 1); + Interaction trace = traces.get(0); + assert (trace.equals(testTrace)); + assert (!argCaptor.getValue().hasMorePages()); + } + + public void testGetTraces_MorePages() { + Interaction testTrace = new Interaction( + "test-trace", + Instant.now(), + "test-cid", + "test-input", + "pt", + "test-response", + "test-origin", + Collections.singletonMap("metadata", "some meta"), + "parent_id", + 1 + ); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onResponse(List.of(testTrace)); + return null; + }).when(cmHandler).getTraces(any(), anyInt(), anyInt(), any()); + GetTracesRequest shortPageRequest = new GetTracesRequest("test-trace", 1); + action.doExecute(null, shortPageRequest, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetTracesResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + List traces = argCaptor.getValue().getTraces(); + assert (traces.size() == 1); + Interaction trace = traces.get(0); + assert (trace.equals(testTrace)); + assert (argCaptor.getValue().hasMorePages()); + } + + public void testGetTracesFails_thenFail() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onFailure(new Exception("Testing Failure")); + return null; + }).when(cmHandler).getTraces(any(), anyInt(), anyInt(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Testing Failure")); + } + + public void testDoExecuteFails_thenFail() { + doThrow(new RuntimeException("Failure in doExecute")).when(cmHandler).getTraces(any(), anyInt(), anyInt(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in doExecute")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequestTests.java new file mode 100644 index 0000000000..3b25f1b174 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequestTests.java @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_UPDATED_TIME_FIELD; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; + +public class UpdateConversationRequestTests { + Map updateContent = new HashMap<>(); + + @Before + public void setUp() { + updateContent.put(META_NAME_FIELD, "new name"); + } + + @Test + public void testConstructor() throws IOException { + UpdateConversationRequest updateConversationRequest = new UpdateConversationRequest("conversationId", updateContent); + assert (updateConversationRequest.validate() == null); + assert (updateConversationRequest.getConversationId().equals("conversationId")); + assert (updateConversationRequest.getUpdateContent().size() == 1); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + updateConversationRequest.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + UpdateConversationRequest newRequest = new UpdateConversationRequest(in); + assert updateConversationRequest.getConversationId().equals(newRequest.getConversationId()); + assert updateConversationRequest.getUpdateContent().equals(newRequest.getUpdateContent()); + } + + @Test + public void testConstructor_UpdateContentNotAllowed() throws IOException { + Map updateCont = new HashMap<>(); + updateCont.put(META_UPDATED_TIME_FIELD, Instant.ofEpochMilli(123)); + UpdateConversationRequest updateConversationRequest = new UpdateConversationRequest("conversationId", updateCont); + assert (updateConversationRequest.validate() == null); + assert (updateConversationRequest.getConversationId().equals("conversationId")); + assert (updateConversationRequest.getUpdateContent().size() == 0); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + updateConversationRequest.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + UpdateConversationRequest newRequest = new UpdateConversationRequest(in); + assert updateConversationRequest.getConversationId().equals(newRequest.getConversationId()); + assert updateConversationRequest.getUpdateContent().equals(newRequest.getUpdateContent()); + assert (newRequest.getUpdateContent().size() == 0); + } + + @Test + public void testConstructor_NullConversationId() throws IOException { + UpdateConversationRequest updateConversationRequest = new UpdateConversationRequest(null, updateContent); + assert updateConversationRequest.validate().getMessage().equals("Validation Failed: 1: conversation id can't be null;"); + } + + @Test + public void testConstructor_NullUpdateContent() throws IOException { + UpdateConversationRequest updateConversationRequest = new UpdateConversationRequest(null, null); + assert updateConversationRequest.validate().getMessage().equals("Validation Failed: 1: conversation id can't be null;"); + } + + @Test + public void testParse_Success() throws IOException { + String jsonStr = "{\"name\":\"new name\",\"application_type\":\"new type\"}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.parse(parser, "conversationId"); + assertEquals(updateConversationRequest.getConversationId(), "conversationId"); + assertEquals("new name", updateConversationRequest.getUpdateContent().get("name")); + } + + @Test + public void fromActionRequest_Success() { + UpdateConversationRequest updateConversationRequest = UpdateConversationRequest + .builder() + .conversationId("conversationId") + .updateContent(updateContent) + .build(); + assertSame(UpdateConversationRequest.fromActionRequest(updateConversationRequest), updateConversationRequest); + } + + @Test + public void fromActionRequest_Success_fromActionRequest() { + UpdateConversationRequest updateConversationRequest = UpdateConversationRequest + .builder() + .conversationId("conversationId") + .updateContent(updateContent) + .build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + updateConversationRequest.writeTo(out); + } + }; + UpdateConversationRequest request = UpdateConversationRequest.fromActionRequest(actionRequest); + assertNotSame(request, updateConversationRequest); + assertEquals(updateConversationRequest.getConversationId(), request.getConversationId()); + assertEquals(updateConversationRequest.getUpdateContent(), request.getUpdateContent()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + UpdateConversationRequest.fromActionRequest(actionRequest); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java new file mode 100644 index 0000000000..ea713d99bb --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java @@ -0,0 +1,135 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_UPDATED_TIME_FIELD; + +import java.io.IOException; +import java.time.Instant; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class UpdateConversationTransportActionTests extends OpenSearchTestCase { + private UpdateConversationTransportAction transportUpdateConversationAction; + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private Task task; + + @Mock + private UpdateConversationRequest updateRequest; + + @Mock + private UpdateResponse updateResponse; + + @Mock + ActionListener actionListener; + + ThreadContext threadContext; + + private Settings settings; + + private ShardId shardId; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + settings = Settings.builder().build(); + + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + String conversationId = "test_conversation_id"; + Map updateContent = Map.of(META_NAME_FIELD, "new name", META_UPDATED_TIME_FIELD, Instant.ofEpochMilli(123)); + when(updateRequest.getConversationId()).thenReturn(conversationId); + when(updateRequest.getUpdateContent()).thenReturn(updateContent); + shardId = new ShardId(new Index("indexName", "uuid"), 1); + updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + + transportUpdateConversationAction = new UpdateConversationTransportAction(transportService, actionFilters, client); + } + + public void test_execute_Success() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateConversationAction.doExecute(task, updateRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + public void test_execute_UpdateFailure() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Error in Update Request")); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateConversationAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Error in Update Request", argumentCaptor.getValue().getMessage()); + } + + public void test_execute_UpdateWrongStatus() { + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateConversationAction.doExecute(task, updateRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + public void test_execute_ThrowException() { + doThrow(new RuntimeException("Error in Update Request")).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateConversationAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Error in Update Request", argumentCaptor.getValue().getMessage()); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequestTests.java new file mode 100644 index 0000000000..4db4b768e8 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionRequestTests.java @@ -0,0 +1,172 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; + +public class UpdateInteractionRequestTests { + + Map updateContent = new HashMap<>(); + + @Before + public void setUp() { + updateContent.put(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!")); + } + + @Test + public void testConstructor() throws IOException { + UpdateInteractionRequest updateInteractionRequest = new UpdateInteractionRequest("interaction_id", updateContent); + assert updateInteractionRequest.validate() == null; + assert updateInteractionRequest.getInteractionId().equals("interaction_id"); + assert updateInteractionRequest.getUpdateContent().size() == 1; + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + updateInteractionRequest.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + UpdateInteractionRequest newRequest = new UpdateInteractionRequest(in); + assert updateInteractionRequest.getInteractionId().equals(newRequest.getInteractionId()); + assert updateInteractionRequest.getUpdateContent().equals(newRequest.getUpdateContent()); + } + + @Test + public void testConstructor_UpdateContentNotAllowed() throws IOException { + updateContent.put(INTERACTIONS_RESPONSE_FIELD, "response"); + UpdateInteractionRequest updateInteractionRequest = new UpdateInteractionRequest("interaction_id", updateContent); + assert (updateInteractionRequest.validate() == null); + assert (updateInteractionRequest.getInteractionId().equals("interaction_id")); + assert (updateInteractionRequest.getUpdateContent().size() == 1); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + updateInteractionRequest.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + UpdateInteractionRequest newRequest = new UpdateInteractionRequest(in); + assert updateInteractionRequest.getInteractionId().equals(newRequest.getInteractionId()); + assert updateInteractionRequest.getUpdateContent().equals(newRequest.getUpdateContent()); + assert (newRequest.getUpdateContent().size() == 1); + } + + @Test + public void testConstructor_NullInteractionId() throws IOException { + UpdateInteractionRequest updateInteractionRequest = new UpdateInteractionRequest(null, updateContent); + assert updateInteractionRequest.validate().getMessage().equals("Validation Failed: 1: interaction id can't be null;"); + } + + @Test + public void testConstructor_NullUpdateContent() throws IOException { + UpdateInteractionRequest updateInteractionRequest = new UpdateInteractionRequest(null, null); + assert updateInteractionRequest.validate().getMessage().equals("Validation Failed: 1: interaction id can't be null;"); + } + + @Test + public void testParse_Success() throws IOException { + String jsonStr = "{\"additional_info\": {\n" + " \"feedback\": \"thumbs up!\"\n" + " }}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest.parse(parser, "interaction_id"); + assertEquals(updateInteractionRequest.getInteractionId(), "interaction_id"); + assertEquals(Map.of("feedback", "thumbs up!"), updateInteractionRequest.getUpdateContent().get(INTERACTIONS_ADDITIONAL_INFO_FIELD)); + } + + @Test + public void testParse_UpdateContentNotAllowed() throws IOException { + String jsonStr = "{\"response\": \"new response!\"}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest.parse(parser, "interaction_id"); + assertEquals(updateInteractionRequest.getInteractionId(), "interaction_id"); + assertEquals(0, updateInteractionRequest.getUpdateContent().size()); + assertNotEquals(null, updateInteractionRequest.getUpdateContent()); + } + + @Test + public void fromActionRequest_Success() { + UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest + .builder() + .interactionId("interaction_id") + .updateContent(updateContent) + .build(); + assertSame(UpdateInteractionRequest.fromActionRequest(updateInteractionRequest), updateInteractionRequest); + } + + @Test + public void fromActionRequest_Success_fromActionRequest() { + UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest + .builder() + .interactionId("interaction_id") + .updateContent(updateContent) + .build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + updateInteractionRequest.writeTo(out); + } + }; + UpdateInteractionRequest request = UpdateInteractionRequest.fromActionRequest(actionRequest); + assertNotSame(request, updateInteractionRequest); + assertEquals(updateInteractionRequest.getInteractionId(), request.getInteractionId()); + assertEquals(updateInteractionRequest.getUpdateContent(), request.getUpdateContent()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + UpdateInteractionRequest.fromActionRequest(actionRequest); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportActionTests.java new file mode 100644 index 0000000000..3dbd16ca64 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportActionTests.java @@ -0,0 +1,134 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; + +import java.io.IOException; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class UpdateInteractionTransportActionTests extends OpenSearchTestCase { + private UpdateInteractionTransportAction updateInteractionTransportAction; + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private Task task; + + @Mock + private UpdateInteractionRequest updateRequest; + + @Mock + private UpdateResponse updateResponse; + + @Mock + ActionListener actionListener; + + ThreadContext threadContext; + + private Settings settings; + + private ShardId shardId; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + settings = Settings.builder().build(); + + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + String interactionId = "test_interaction_id"; + Map updateContent = Map + .of(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!"), INTERACTIONS_RESPONSE_FIELD, "response"); + when(updateRequest.getInteractionId()).thenReturn(interactionId); + when(updateRequest.getUpdateContent()).thenReturn(updateContent); + shardId = new ShardId(new Index("indexName", "uuid"), 1); + updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + + updateInteractionTransportAction = new UpdateInteractionTransportAction(transportService, actionFilters, client); + } + + public void test_execute_Success() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + updateInteractionTransportAction.doExecute(task, updateRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + public void test_execute_UpdateFailure() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Error in Update Request")); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + updateInteractionTransportAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Error in Update Request", argumentCaptor.getValue().getMessage()); + } + + public void test_execute_UpdateWrongStatus() { + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + updateInteractionTransportAction.doExecute(task, updateRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + public void test_execute_ThrowException() { + doThrow(new RuntimeException("Error in Update Request")).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + updateInteractionTransportAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Error in Update Request", argumentCaptor.getValue().getMessage()); + } +} diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 700e251676..e2e37c2212 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -20,6 +20,7 @@ dependencies { compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow') implementation project(':opensearch-ml-common') + implementation project(':opensearch-ml-memory') implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}" testImplementation "org.opensearch.test:framework:${opensearch_version}" implementation "org.opensearch:common-utils:${common_utils_version}" diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java new file mode 100644 index 0000000000..dc99ef4438 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java @@ -0,0 +1,164 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionAction; +import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; +import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; +import org.opensearch.ml.memory.action.conversation.GetTracesAction; +import org.opensearch.ml.memory.action.conversation.GetTracesRequest; +import org.opensearch.ml.memory.action.conversation.GetTracesResponse; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionAction; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionRequest; + +import com.google.common.base.Preconditions; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +/** + * Memory manager for Memories. It contains ML memory related operations like create, read interactions etc. + */ +@Log4j2 +@AllArgsConstructor +public class MLMemoryManager { + + private Client client; + + /** + * Create a new Conversation + * @param name the name of the conversation + * @param applicationType the application type that creates this conversation + * @param actionListener action listener to process the response + */ + public void createConversation(String name, String applicationType, ActionListener actionListener) { + try { + client.execute(CreateConversationAction.INSTANCE, new CreateConversationRequest(name, applicationType), actionListener); + } catch (Exception exception) { + actionListener.onFailure(exception); + } + } + + /** + * Adds an interaction to the conversation indicated, updating the conversational metadata + * @param conversationId the conversation to add the interaction to + * @param input the human input for the interaction + * @param promptTemplate the prompt template used for this interaction + * @param response the Gen AI response for this interaction + * @param origin the name of the GenAI agent in this interaction + * @param additionalInfo additional information used in constructing the LLM prompt + * @param parentIntId the parent interactionId of this interaction + * @param traceNum the trace number for a parent interaction + * @param actionListener gets the ID of the new interaction + */ + public void createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + Map additionalInfo, + String parentIntId, + Integer traceNum, + ActionListener actionListener + ) { + Preconditions.checkNotNull(conversationId); + Preconditions.checkNotNull(input); + Preconditions.checkNotNull(response); + // additionalInfo cannot be null as flat object + additionalInfo = (additionalInfo == null) ? new HashMap<>() : additionalInfo; + try { + client + .execute( + CreateInteractionAction.INSTANCE, + new CreateInteractionRequest( + conversationId, + input, + promptTemplate, + response, + origin, + additionalInfo, + parentIntId, + traceNum + ), + actionListener + ); + } catch (Exception exception) { + actionListener.onFailure(exception); + } + } + + /** + * Get the interactions associate with this conversation that are not traces, sorted by recency + * @param conversationId the conversation whose interactions to get + * @param lastNInteraction Return how many interactions + * @param actionListener get all the final interactions that are not traces + */ + public void getFinalInteractions(String conversationId, int lastNInteraction, ActionListener> actionListener) { + Preconditions.checkNotNull(conversationId); + Preconditions.checkArgument(lastNInteraction > 0, "lastN must be at least 1."); + log.debug("Getting Interactions, conversationId {}, lastN {}", conversationId, lastNInteraction); + + ActionListener al = ActionListener.wrap(getInteractionsResponse -> { + actionListener.onResponse(getInteractionsResponse.getInteractions()); + }, e -> { actionListener.onFailure(e); }); + + try { + client.execute(GetInteractionsAction.INSTANCE, new GetInteractionsRequest(conversationId, lastNInteraction), al); + } catch (Exception exception) { + actionListener.onFailure(exception); + } + } + + /** + * Get the interactions associate with this conversation, sorted by recency + * @param parentInteractionId the parent interaction id whose traces to get + * @param actionListener get all the trace interactions that are only traces + */ + public void getTraces(String parentInteractionId, ActionListener> actionListener) { + Preconditions.checkNotNull(parentInteractionId); + log.debug("Getting traces for conversationId {}", parentInteractionId); + + ActionListener al = ActionListener.wrap(getTracesResponse -> { + actionListener.onResponse(getTracesResponse.getTraces()); + }, e -> { actionListener.onFailure(e); }); + + try { + client.execute(GetTracesAction.INSTANCE, new GetTracesRequest(parentInteractionId), al); + } catch (Exception exception) { + actionListener.onFailure(exception); + } + } + + /** + * Get the interactions associate with this conversation, sorted by recency + * @param interactionId the parent interaction id whose traces to get + * @param actionListener listener for the update response + */ + public void updateInteraction(String interactionId, Map updateContent, ActionListener actionListener) { + Preconditions.checkNotNull(interactionId); + Preconditions.checkNotNull(updateContent); + try { + client.execute(UpdateInteractionAction.INSTANCE, new UpdateInteractionRequest(interactionId, updateContent), actionListener); + } catch (Exception exception) { + actionListener.onFailure(exception); + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java new file mode 100644 index 0000000000..b3a5f0da56 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java @@ -0,0 +1,277 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; + +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionAction; +import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; +import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; +import org.opensearch.ml.memory.action.conversation.GetTracesAction; +import org.opensearch.ml.memory.action.conversation.GetTracesRequest; +import org.opensearch.ml.memory.action.conversation.GetTracesResponse; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionAction; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionRequest; + +public class MLMemoryManagerTests { + + @Mock + Client client; + + @Mock + MLMemoryManager mlMemoryManager; + + @Mock + ActionListener createConversationResponseActionListener; + + @Mock + ActionListener createInteractionResponseActionListener; + + @Mock + ActionListener> interactionListActionListener; + + @Mock + ActionListener updateResponseActionListener; + + String conversationName; + String applicationType; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + mlMemoryManager = new MLMemoryManager(client); + conversationName = "new conversation"; + applicationType = "ml application"; + } + + @Test + public void testCreateConversation() { + ArgumentCaptor captor = ArgumentCaptor.forClass(CreateConversationRequest.class); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(2); + al.onResponse(new CreateConversationResponse("conversation-id")); + return null; + }).when(client).execute(any(), any(), any()); + + mlMemoryManager.createConversation(conversationName, applicationType, createConversationResponseActionListener); + + verify(client, times(1)) + .execute(eq(CreateConversationAction.INSTANCE), captor.capture(), eq(createConversationResponseActionListener)); + assertEquals(conversationName, captor.getValue().getName()); + assertEquals(applicationType, captor.getValue().getApplicationType()); + } + + @Test + public void testCreateConversationFails_thenFail() { + doThrow(new RuntimeException("Failure in runtime")).when(client).execute(any(), any(), any()); + mlMemoryManager.createConversation(conversationName, applicationType, createConversationResponseActionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createConversationResponseActionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in runtime")); + } + + @Test + public void testCreateInteraction() { + ArgumentCaptor captor = ArgumentCaptor.forClass(CreateInteractionRequest.class); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(2); + al.onResponse(new CreateInteractionResponse("interaction-id")); + return null; + }).when(client).execute(any(), any(), any()); + + mlMemoryManager + .createInteraction( + "conversationId", + "input", + "prompt", + "response", + "origin", + Collections.singletonMap("feedback", "thumbsup"), + "parent-id", + 1, + createInteractionResponseActionListener + ); + verify(client, times(1)) + .execute(eq(CreateInteractionAction.INSTANCE), captor.capture(), eq(createInteractionResponseActionListener)); + assertEquals("conversationId", captor.getValue().getConversationId()); + assertEquals("input", captor.getValue().getInput()); + assertEquals("prompt", captor.getValue().getPromptTemplate()); + assertEquals("response", captor.getValue().getResponse()); + assertEquals("origin", captor.getValue().getOrigin()); + assertEquals(Collections.singletonMap("feedback", "thumbsup"), captor.getValue().getAdditionalInfo()); + assertEquals("parent-id", captor.getValue().getParentIid()); + assertEquals("1", captor.getValue().getTraceNumber().toString()); + } + + @Test + public void testCreateInteractionFails_thenFail() { + doThrow(new RuntimeException("Failure in runtime")).when(client).execute(any(), any(), any()); + mlMemoryManager + .createInteraction( + "conversationId", + "input", + "prompt", + "response", + "origin", + Collections.singletonMap("feedback", "thumbsup"), + "parent-id", + 1, + createInteractionResponseActionListener + ); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createInteractionResponseActionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in runtime")); + } + + @Test + public void testGetInteractions() { + List interactions = List + .of( + new Interaction( + "id0", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta") + ) + ); + ArgumentCaptor captor = ArgumentCaptor.forClass(GetInteractionsRequest.class); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(2); + GetInteractionsResponse getInteractionsResponse = new GetInteractionsResponse(interactions, 4, false); + al.onResponse(getInteractionsResponse); + return null; + }).when(client).execute(any(), any(), any()); + + mlMemoryManager.getFinalInteractions("cid", 10, interactionListActionListener); + + verify(client, times(1)).execute(eq(GetInteractionsAction.INSTANCE), captor.capture(), any()); + assertEquals("cid", captor.getValue().getConversationId()); + assertEquals(0, captor.getValue().getFrom()); + assertEquals(10, captor.getValue().getMaxResults()); + } + + @Test + public void testGetInteractionFails_thenFail() { + doThrow(new RuntimeException("Failure in runtime")).when(client).execute(any(), any(), any()); + mlMemoryManager.getFinalInteractions("cid", 10, interactionListActionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(interactionListActionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in runtime")); + } + + @Test + public void testGetTraces() { + List traces = List + .of( + new Interaction( + "id0", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta"), + "parent_id", + 1 + ) + ); + ArgumentCaptor captor = ArgumentCaptor.forClass(GetTracesRequest.class); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(2); + GetTracesResponse getTracesResponse = new GetTracesResponse(traces, 4, false); + al.onResponse(getTracesResponse); + return null; + }).when(client).execute(any(), any(), any()); + + mlMemoryManager.getTraces("iid", interactionListActionListener); + + verify(client, times(1)).execute(eq(GetTracesAction.INSTANCE), captor.capture(), any()); + assertEquals("iid", captor.getValue().getInteractionId()); + assertEquals(0, captor.getValue().getFrom()); + assertEquals(10, captor.getValue().getMaxResults()); + } + + @Test + public void testGetTracesFails_thenFail() { + doThrow(new RuntimeException("Failure in runtime")).when(client).execute(any(), any(), any()); + mlMemoryManager.getTraces("cid", interactionListActionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(interactionListActionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in runtime")); + } + + @Test + public void testUpdateInteraction() { + Map updateContent = Map + .of(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!"), INTERACTIONS_RESPONSE_FIELD, "response"); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(2); + al.onResponse(updateResponse); + return null; + }).when(client).execute(any(), any(), any()); + + ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateInteractionRequest.class); + mlMemoryManager.updateInteraction("iid", updateContent, updateResponseActionListener); + verify(client, times(1)).execute(eq(UpdateInteractionAction.INSTANCE), captor.capture(), any()); + assertEquals("iid", captor.getValue().getInteractionId()); + assertEquals(1, captor.getValue().getUpdateContent().keySet().size()); + assertNotEquals(updateContent, captor.getValue().getUpdateContent()); + } + + @Test + public void testUpdateInteraction_thenFail() { + doThrow(new RuntimeException("Failure in runtime")).when(client).execute(any(), any(), any()); + mlMemoryManager + .updateInteraction( + "iid", + Map.of(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!")), + updateResponseActionListener + ); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(updateResponseActionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in runtime")); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 13b4834d59..e986d7e3c1 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -147,10 +147,16 @@ import org.opensearch.ml.memory.action.conversation.GetInteractionTransportAction; import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; import org.opensearch.ml.memory.action.conversation.GetInteractionsTransportAction; +import org.opensearch.ml.memory.action.conversation.GetTracesAction; +import org.opensearch.ml.memory.action.conversation.GetTracesTransportAction; import org.opensearch.ml.memory.action.conversation.SearchConversationsAction; import org.opensearch.ml.memory.action.conversation.SearchConversationsTransportAction; import org.opensearch.ml.memory.action.conversation.SearchInteractionsAction; import org.opensearch.ml.memory.action.conversation.SearchInteractionsTransportAction; +import org.opensearch.ml.memory.action.conversation.UpdateConversationAction; +import org.opensearch.ml.memory.action.conversation.UpdateConversationTransportAction; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionAction; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionTransportAction; import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; @@ -189,8 +195,11 @@ import org.opensearch.ml.rest.RestMemoryGetConversationsAction; import org.opensearch.ml.rest.RestMemoryGetInteractionAction; import org.opensearch.ml.rest.RestMemoryGetInteractionsAction; +import org.opensearch.ml.rest.RestMemoryGetTracesAction; import org.opensearch.ml.rest.RestMemorySearchConversationsAction; import org.opensearch.ml.rest.RestMemorySearchInteractionsAction; +import org.opensearch.ml.rest.RestMemoryUpdateConversationAction; +import org.opensearch.ml.rest.RestMemoryUpdateInteractionAction; import org.opensearch.ml.settings.MLCommonsSettings; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.MLClusterLevelStat; @@ -321,7 +330,10 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(SearchInteractionsAction.INSTANCE, SearchInteractionsTransportAction.class), new ActionHandler<>(SearchConversationsAction.INSTANCE, SearchConversationsTransportAction.class), new ActionHandler<>(GetConversationAction.INSTANCE, GetConversationTransportAction.class), - new ActionHandler<>(GetInteractionAction.INSTANCE, GetInteractionTransportAction.class) + new ActionHandler<>(GetInteractionAction.INSTANCE, GetInteractionTransportAction.class), + new ActionHandler<>(UpdateConversationAction.INSTANCE, UpdateConversationTransportAction.class), + new ActionHandler<>(UpdateInteractionAction.INSTANCE, UpdateInteractionTransportAction.class), + new ActionHandler<>(GetTracesAction.INSTANCE, GetTracesTransportAction.class) ); } @@ -577,6 +589,9 @@ public List getRestHandlers( RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction(); RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction(); RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction(); + RestMemoryUpdateConversationAction restMemoryUpdateConversationAction = new RestMemoryUpdateConversationAction(); + RestMemoryUpdateInteractionAction restMemoryUpdateInteractionAction = new RestMemoryUpdateInteractionAction(); + RestMemoryGetTracesAction restMemoryGetTracesAction = new RestMemoryGetTracesAction(); return ImmutableList .of( restMLStatsAction, @@ -615,7 +630,10 @@ public List getRestHandlers( restSearchConversationsAction, restSearchInteractionsAction, restGetConversationAction, - restGetInteractionAction + restGetInteractionAction, + restMemoryUpdateConversationAction, + restMemoryUpdateInteractionAction, + restMemoryGetTracesAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetTracesAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetTracesAction.java new file mode 100644 index 0000000000..12c0815cc3 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetTracesAction.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetTracesAction; +import org.opensearch.ml.memory.action.conversation.GetTracesRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +public class RestMemoryGetTracesAction extends BaseRestHandler { + private final static String GET_TRACES_NAME = "conversational_get_traces"; + + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.GET, ActionConstants.GET_TRACES_REST_PATH)); + } + + @Override + public String getName() { + return GET_TRACES_NAME; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + GetTracesRequest gtRequest = GetTracesRequest.fromRestRequest(request); + return channel -> client.execute(GetTracesAction.INSTANCE, gtRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateConversationAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateConversationAction.java new file mode 100644 index 0000000000..c0934056b6 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateConversationAction.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.OpenSearchParseException; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.UpdateConversationAction; +import org.opensearch.ml.memory.action.conversation.UpdateConversationRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; + +public class RestMemoryUpdateConversationAction extends BaseRestHandler { + private static final String ML_UPDATE_CONVERSATION_ACTION = "ml_update_conversation_action"; + + @Override + public String getName() { + return ML_UPDATE_CONVERSATION_ACTION; + } + + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.PUT, ActionConstants.UPDATE_CONVERSATIONS_REST_PATH)); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + UpdateConversationRequest updateConversationRequest = getRequest(request); + return restChannel -> client + .execute(UpdateConversationAction.INSTANCE, updateConversationRequest, new RestToXContentListener<>(restChannel)); + } + + @VisibleForTesting + private UpdateConversationRequest getRequest(RestRequest request) throws IOException { + if (!request.hasContent()) { + throw new OpenSearchParseException("Failed to update conversation: Request body is empty"); + } + + String conversationId = getParameterId(request, "conversation_id"); + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + return UpdateConversationRequest.parse(parser, conversationId); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionAction.java new file mode 100644 index 0000000000..dafc0352ec --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionAction.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.OpenSearchParseException; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionAction; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; + +public class RestMemoryUpdateInteractionAction extends BaseRestHandler { + private static final String ML_UPDATE_INTERACTION_ACTION = "ml_update_interaction_action"; + + @Override + public String getName() { + return ML_UPDATE_INTERACTION_ACTION; + } + + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.PUT, ActionConstants.UPDATE_INTERACTIONS_REST_PATH)); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + UpdateInteractionRequest updateInteractionRequest = getRequest(request); + return restChannel -> client + .execute(UpdateInteractionAction.INSTANCE, updateInteractionRequest, new RestToXContentListener<>(restChannel)); + } + + @VisibleForTesting + private UpdateInteractionRequest getRequest(RestRequest request) throws IOException { + if (!request.hasContent()) { + throw new OpenSearchParseException("Failed to update interaction: Request body is empty"); + } + + String interactionId = getParameterId(request, "interaction_id"); + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + return UpdateInteractionRequest.parse(parser, interactionId); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetTracesActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetTracesActionTests.java new file mode 100644 index 0000000000..67a91db6e8 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetTracesActionTests.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; +import java.util.Map; + +import org.mockito.ArgumentCaptor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetTracesAction; +import org.opensearch.ml.memory.action.conversation.GetTracesRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class RestMemoryGetTracesActionTests extends OpenSearchTestCase { + + public void testBasics() { + RestMemoryGetTracesAction action = new RestMemoryGetTracesAction(); + assert (action.getName().equals("conversational_get_traces")); + List routes = action.routes(); + assert (routes.size() == 1); + assert (routes.get(0).equals(new RestHandler.Route(RestRequest.Method.GET, ActionConstants.GET_TRACES_REST_PATH))); + } + + public void testPrepareRequest() throws Exception { + RestMemoryGetTracesAction action = new RestMemoryGetTracesAction(); + Map params = Map + .of( + ActionConstants.RESPONSE_INTERACTION_ID_FIELD, + "iid", + ActionConstants.REQUEST_MAX_RESULTS_FIELD, + "2", + ActionConstants.NEXT_TOKEN_FIELD, + "7" + ); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetTracesRequest.class); + verify(client, times(1)).execute(eq(GetTracesAction.INSTANCE), argCaptor.capture(), any()); + GetTracesRequest req = argCaptor.getValue(); + assert (req.getInteractionId().equals("iid")); + assert (req.getFrom() == 7); + assert (req.getMaxResults() == 2); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateConversationTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateConversationTests.java new file mode 100644 index 0000000000..539527bdf5 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateConversationTests.java @@ -0,0 +1,165 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.memory.action.conversation.UpdateConversationAction; +import org.opensearch.ml.memory.action.conversation.UpdateConversationRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import com.google.gson.Gson; + +public class RestMemoryUpdateConversationTests extends OpenSearchTestCase { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private RestMemoryUpdateConversationAction restMemoryUpdateConversationAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + restMemoryUpdateConversationAction = new RestMemoryUpdateConversationAction(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(UpdateConversationAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMemoryUpdateConversationAction restMemoryUpdateConversationAction = new RestMemoryUpdateConversationAction(); + assertNotNull(restMemoryUpdateConversationAction); + } + + public void testGetName() { + String actionName = restMemoryUpdateConversationAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_update_conversation_action", actionName); + } + + public void testRoutes() { + List routes = restMemoryUpdateConversationAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.PUT, route.getMethod()); + assertEquals("/_plugins/_ml/memory/conversation/{conversation_id}/_update", route.getPath()); + } + + public void testUpdateConversationRequest() throws Exception { + RestRequest request = getRestRequest(); + restMemoryUpdateConversationAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateConversationRequest.class); + verify(client, times(1)).execute(eq(UpdateConversationAction.INSTANCE), argumentCaptor.capture(), any()); + UpdateConversationRequest updateConversationRequest = argumentCaptor.getValue(); + assertEquals("test_conversationId", updateConversationRequest.getConversationId()); + assertEquals("new name", updateConversationRequest.getUpdateContent().get(META_NAME_FIELD)); + } + + public void testUpdateConnectorRequestWithEmptyContent() throws Exception { + exceptionRule.expect(OpenSearchParseException.class); + exceptionRule.expectMessage("Failed to update conversation: Request body is empty"); + RestRequest request = getRestRequestWithEmptyContent(); + restMemoryUpdateConversationAction.handleRequest(request, channel, client); + } + + public void testUpdateConnectorRequestWithNullConversationId() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Request should contain conversation_id"); + RestRequest request = getRestRequestWithNullConversationId(); + restMemoryUpdateConversationAction.handleRequest(request, channel, client); + } + + private RestRequest getRestRequest() { + RestRequest.Method method = RestRequest.Method.POST; + final Map updateContent = Map.of(META_NAME_FIELD, "new name"); + String requestContent = new Gson().toJson(updateContent); + Map params = new HashMap<>(); + params.put("conversation_id", "test_conversationId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/memory/conversation/{conversation_id}/_update") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithEmptyContent() { + RestRequest.Method method = RestRequest.Method.POST; + Map params = new HashMap<>(); + params.put("conversation_id", "test_conversationId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/memory/conversation/{conversation_id}/_update") + .withParams(params) + .withContent(new BytesArray(""), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithNullConversationId() { + RestRequest.Method method = RestRequest.Method.POST; + final Map updateContent = Map.of(META_NAME_FIELD, "new name"); + String requestContent = new Gson().toJson(updateContent); + Map params = new HashMap<>(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/memory/conversation/{conversation_id}/_update") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionActionTests.java new file mode 100644 index 0000000000..cdfdaa2b3c --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionActionTests.java @@ -0,0 +1,164 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionAction; +import org.opensearch.ml.memory.action.conversation.UpdateInteractionRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import com.google.gson.Gson; + +public class RestMemoryUpdateInteractionActionTests extends OpenSearchTestCase { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private RestMemoryUpdateInteractionAction restMemoryUpdateInteractionAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + restMemoryUpdateInteractionAction = new RestMemoryUpdateInteractionAction(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(UpdateInteractionAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMemoryUpdateInteractionAction restMemoryUpdateInteractionAction = new RestMemoryUpdateInteractionAction(); + assertNotNull(restMemoryUpdateInteractionAction); + } + + public void testGetName() { + String actionName = restMemoryUpdateInteractionAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_update_interaction_action", actionName); + } + + public void testRoutes() { + List routes = restMemoryUpdateInteractionAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.PUT, route.getMethod()); + assertEquals("/_plugins/_ml/memory/interaction/{interaction_id}/_update", route.getPath()); + } + + public void testUpdateInteractionRequest() throws Exception { + RestRequest request = getRestRequest(); + restMemoryUpdateInteractionAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateInteractionRequest.class); + verify(client, times(1)).execute(eq(UpdateInteractionAction.INSTANCE), argumentCaptor.capture(), any()); + UpdateInteractionRequest updateInteractionRequest = argumentCaptor.getValue(); + assertEquals("test_interactionId", updateInteractionRequest.getInteractionId()); + assertEquals(Map.of("feedback", "thumbs up!"), updateInteractionRequest.getUpdateContent().get(INTERACTIONS_ADDITIONAL_INFO_FIELD)); + } + + public void testUpdateInteractionRequestWithEmptyContent() throws Exception { + exceptionRule.expect(OpenSearchParseException.class); + exceptionRule.expectMessage("Failed to update interaction: Request body is empty"); + RestRequest request = getRestRequestWithEmptyContent(); + restMemoryUpdateInteractionAction.handleRequest(request, channel, client); + } + + public void testUpdateInteractionRequestWithNullInteractionId() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Request should contain interaction_id"); + RestRequest request = getRestRequestWithNullInteractionId(); + restMemoryUpdateInteractionAction.handleRequest(request, channel, client); + } + + private RestRequest getRestRequest() { + RestRequest.Method method = RestRequest.Method.POST; + final Map updateContent = Map.of(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!")); + String requestContent = new Gson().toJson(updateContent); + Map params = new HashMap<>(); + params.put("interaction_id", "test_interactionId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/memory/interaction/{interaction_id}/_update") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithEmptyContent() { + RestRequest.Method method = RestRequest.Method.POST; + Map params = new HashMap<>(); + params.put("interaction_id", "test_interactionId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/memory/interaction/{interaction_id}/_update") + .withParams(params) + .withContent(new BytesArray(""), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithNullInteractionId() { + RestRequest.Method method = RestRequest.Method.POST; + final Map updateContent = Map.of(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!")); + String requestContent = new Gson().toJson(updateContent); + Map params = new HashMap<>(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/memory/interaction/{interaction_id}/_update") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } +}