Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Conversation API in MLClient #3475

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ plugins {
dependencies {
implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow')
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
implementation project(path: ":${rootProject.name}-memory")
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;

/**
* A client to provide interfaces for machine learning jobs. This will be used by other plugins.
Expand Down Expand Up @@ -553,4 +554,22 @@ default void getConfig(String configId, ActionListener<MLConfig> listener) {
* @param listener a listener to be notified of the result
*/
void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener);

/**
* Create conversational memory for conversation
* @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
* @return the result future
*/
default ActionFuture<CreateConversationResponse> createConversation(String name) {
PlainActionFuture<CreateConversationResponse> actionFuture = PlainActionFuture.newFuture();
createConversation(name, actionFuture);
return actionFuture;
}

/**
* Create conversational memory for conversation
* @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/
* @param listener action listener
*/
void createConversation(String name, ActionListener<CreateConversationResponse> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;

import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -318,6 +321,11 @@ public void getConfig(String configId, String tenantId, ActionListener<MLConfig>
client.execute(MLConfigGetAction.INSTANCE, mlConfigGetRequest, getMlGetConfigResponseActionListener(listener));
}

public void createConversation(String name, ActionListener<CreateConversationResponse> listener) {
CreateConversationRequest createConversationRequest = new CreateConversationRequest(name);
client.execute(CreateConversationAction.INSTANCE, createConversationRequest, getCreateConversationResponseActionListener(listener));
}

private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener(ActionListener<List<ToolMetadata>> listener) {
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
listener.onResponse(mlModelListResponse.getToolMetadataList());
Expand Down Expand Up @@ -386,6 +394,16 @@ private ActionListener<MLRegisterModelResponse> getMLRegisterModelResponseAction
return wrapActionListener(listener, MLRegisterModelResponse::fromActionResponse);
}

private ActionListener<CreateConversationResponse> getCreateConversationResponseActionListener(
ActionListener<CreateConversationResponse> listener
) {
ActionListener<CreateConversationResponse> actionListener = wrapActionListener(listener, response -> {
CreateConversationResponse conversationResponse = CreateConversationResponse.fromActionResponse(response);
return conversationResponse;
});
return actionListener;
}

private <T extends ActionResponse> ActionListener<T> wrapActionListener(
final ActionListener<T> listener,
final Function<ActionResponse, T> recreate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.config.MLConfigGetResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
Expand All @@ -59,6 +58,7 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;

public class MachineLearningClientTest {

Expand Down Expand Up @@ -107,7 +107,7 @@ public class MachineLearningClientTest {
MLRegisterAgentResponse registerAgentResponse;

@Mock
MLConfigGetResponse configGetResponse;
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
CreateConversationResponse createConversationResponse;

private final String modekId = "test_model_id";
private MLModel mlModel;
Expand Down Expand Up @@ -256,6 +256,11 @@ public void deleteAgent(String agentId, String tenantId, ActionListener<DeleteRe
public void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener) {
listener.onResponse(mlConfig);
}

@Override
public void createConversation(String name, ActionListener<CreateConversationResponse> listener) {
listener.onResponse(createConversationResponse);
}
};
}

Expand Down Expand Up @@ -554,4 +559,9 @@ public void listTools() {
public void getConfig() {
assertEquals(mlConfig, machineLearningClient.getConfig("configId").actionGet());
}

@Test
public void createConversation() {
assertEquals(createConversationResponse, machineLearningClient.createConversation("Conversation for a RAG pipeline").actionGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
Expand Down Expand Up @@ -219,6 +222,9 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<MLConfig> getMlConfigListener;

@Mock
ActionListener<CreateConversationResponse> createConversationResponseActionListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -1455,6 +1461,25 @@ public void onFailure(Exception e) {
verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), isA(MLTaskDeleteRequest.class), any());
}

public void createConversation() {
String name = "Conversation for a RAG pipeline";
String conversationId = "conversationId";

doAnswer(invocation -> {
ActionListener<CreateConversationResponse> actionListener = invocation.getArgument(2);
CreateConversationResponse output = new CreateConversationResponse(conversationId);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(CreateConversationAction.INSTANCE), any(), any());

ArgumentCaptor<CreateConversationResponse> argumentCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class);
machineLearningNodeClient.createConversation(name, createConversationResponseActionListener);

verify(client).execute(eq(CreateConversationAction.INSTANCE), isA(CreateConversationRequest.class), any());
verify(createConversationResponseActionListener).onResponse(argumentCaptor.capture());
assertEquals(conversationId, argumentCaptor.getValue().getId());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,21 @@
*/
package org.opensearch.ml.memory.action.conversation;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;

import lombok.AllArgsConstructor;

Expand Down Expand Up @@ -67,4 +73,20 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
return builder;
}

public static CreateConversationResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLCreateConnectorResponse) {
return (CreateConversationResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new CreateConversationResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into CreateConversationResponse", e);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@
*/
package org.opensearch.ml.memory.action.conversation;

import static org.junit.Assert.assertEquals;

import java.io.IOException;
import java.io.UncheckedIOException;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.BytesStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
Expand All @@ -32,6 +38,13 @@

public class CreateConversationResponseTests extends OpenSearchTestCase {

CreateConversationResponse response;

@Before
public void setup() {
response = new CreateConversationResponse("test-id");
}

public void testCreateConversationResponseStreaming() throws IOException {
CreateConversationResponse response = new CreateConversationResponse("test-id");
assert (response.getId().equals("test-id"));
Expand All @@ -51,4 +64,34 @@ public void testToXContent() throws IOException {
String result = BytesReference.bytes(builder).utf8ToString();
assert (result.equals(expected));
}

@Test
public void fromActionResponseWithCreateConversationResponseSuccess() {
CreateConversationResponse responseFromActionResponse = CreateConversationResponse.fromActionResponse(response);
assertEquals(response.getId(), responseFromActionResponse.getId());
}

@Test
public void fromActionResponseSuccess() {
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
response.writeTo(out);
}
};
CreateConversationResponse responseFromActionResponse = CreateConversationResponse.fromActionResponse(actionResponse);
assertNotSame(response, responseFromActionResponse);
assertEquals(response.getId(), responseFromActionResponse.getId());
}

@Test(expected = UncheckedIOException.class)
public void fromActionResponseIOException() {
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException();
}
};
CreateConversationResponse.fromActionResponse(actionResponse);
}
}
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
Loading