Skip to content

Commit

Permalink
Add Undeploy Models to MLClient
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Nov 30, 2023
1 parent 06c0380 commit f3a5741
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.Nullable;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.MLModel;
Expand All @@ -28,6 +29,7 @@
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;

/**
* A client to provide interfaces for machine learning jobs. This will be used by other plugins.
Expand Down Expand Up @@ -255,7 +257,7 @@ default ActionFuture<MLRegisterModelResponse> register(MLRegisterModelInput mlIn

/**
* Deploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/deploy-model/
* @param modelId the model id
*/
default ActionFuture<MLDeployModelResponse> deploy(String modelId) {
Expand All @@ -266,12 +268,33 @@ default ActionFuture<MLDeployModelResponse> deploy(String modelId) {

/**
* Deploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/deploy-model/
* @param modelId the model id
* @param listener a listener to be notified of the result
*/
void deploy(String modelId, ActionListener<MLDeployModelResponse> listener);

/**
* Undeploy models
* For additional info on undeploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
* @param modelIds the model ids
* @param nodeIds the node ids. May be null for all nodes.
*/
default ActionFuture<MLUndeployModelsResponse> undeploy(String[] modelIds, @Nullable String[] nodeIds) {
PlainActionFuture<MLUndeployModelsResponse> actionFuture = PlainActionFuture.newFuture();
undeploy(modelIds, nodeIds, actionFuture);
return actionFuture;
}

/**
* Undeploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
* @param modelIds the model ids
* @param modelIds the node ids. May be null for all nodes.
* @param listener a listener to be notified of the result
*/
void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener);

/**
* Create connector for remote model
* @param mlCreateConnectorInput Create Connector Input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/extensibility/connectors/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;

import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -239,6 +242,12 @@ public void deploy(String modelId, ActionListener<MLDeployModelResponse> listene
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, getMlDeployModelResponseActionListener(listener));
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
MLUndeployModelsRequest undeployModelRequest = new MLUndeployModelsRequest(modelIds, nodeIds);
client.execute(MLUndeployModelsAction.INSTANCE, undeployModelRequest, getMlUndeployModelsResponseActionListener(listener));
}

@Override
public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener) {
MLCreateConnectorRequest createConnectorRequest = new MLCreateConnectorRequest(mlCreateConnectorInput);
Expand Down Expand Up @@ -323,6 +332,16 @@ private ActionListener<MLDeployModelResponse> getMlDeployModelResponseActionList
return actionListener;
}

private ActionListener<MLUndeployModelsResponse> getMlUndeployModelsResponseActionListener(
ActionListener<MLUndeployModelsResponse> listener
) {
ActionListener<MLUndeployModelsResponse> actionListener = wrapActionListener(listener, response -> {
MLUndeployModelsResponse deployModelResponse = MLUndeployModelsResponse.fromActionResponse(response);
return deployModelResponse;
});
return actionListener;
}

private ActionListener<MLCreateConnectorResponse> getMlCreateConnectorResponseActionListener(
ActionListener<MLCreateConnectorResponse> listener
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;

public class MachineLearningClientTest {

Expand Down Expand Up @@ -78,6 +79,9 @@ public class MachineLearningClientTest {
@Mock
MLDeployModelResponse deployModelResponse;

@Mock
MLUndeployModelsResponse undeployModelsResponse;

@Mock
MLCreateConnectorResponse createConnectorResponse;

Expand Down Expand Up @@ -162,6 +166,11 @@ public void deploy(String modelId, ActionListener<MLDeployModelResponse> listene
listener.onResponse(deployModelResponse);
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
listener.onResponse(undeployModelsResponse);
}

@Override
public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener) {
listener.onResponse(createConnectorResponse);
Expand Down Expand Up @@ -358,6 +367,11 @@ public void deploy() {
assertEquals(deployModelResponse, machineLearningClient.deploy("modelId").actionGet());
}

@Test
public void undeploy() {
assertEquals(undeployModelsResponse, machineLearningClient.undeploy(new String[] { "modelId" }, null).actionGet());
}

@Test
public void createConnector() {
Map<String, String> params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.ml.client;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Answers.RETURNS_DEEP_STUBS;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
Expand Down Expand Up @@ -42,6 +44,7 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.ClusterName;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
Expand Down Expand Up @@ -104,6 +107,10 @@
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
Expand Down Expand Up @@ -152,6 +159,9 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<MLDeployModelResponse> deployModelActionListener;

@Mock
ActionListener<MLUndeployModelsResponse> undeployModelsActionListener;

@Mock
ActionListener<MLCreateConnectorResponse> createConnectorActionListener;

Expand Down Expand Up @@ -647,6 +657,33 @@ public void deploy() {
assertEquals(status, (argumentCaptor.getValue()).getStatus());
}

@Test
public void undeploy() {
ClusterName clusterName = new ClusterName("clusterName");
String[] modelIds = new String[] { "modelId" };
String[] nodeIds = new String[] { "nodeId" };
doAnswer(invocation -> {
ActionListener<MLUndeployModelsResponse> actionListener = invocation.getArgument(2);
MLUndeployModelNodesResponse mlUndeployModelNodesResponse = new MLUndeployModelNodesResponse(
clusterName,
Collections.emptyList(),
Collections.emptyList()
);
MLUndeployModelsResponse output = new MLUndeployModelsResponse(mlUndeployModelNodesResponse);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLUndeployModelsAction.INSTANCE), any(), any());

ArgumentCaptor<MLUndeployModelsResponse> argumentCaptor = ArgumentCaptor.forClass(MLUndeployModelsResponse.class);
machineLearningNodeClient.undeploy(modelIds, nodeIds, undeployModelsActionListener);

verify(client).execute(eq(MLUndeployModelsAction.INSTANCE), isA(MLUndeployModelsRequest.class), any());
verify(undeployModelsActionListener).onResponse(argumentCaptor.capture());
assertEquals(clusterName, (argumentCaptor.getValue()).getResponse().getClusterName());
assertTrue((argumentCaptor.getValue()).getResponse().getNodes().isEmpty());
assertFalse((argumentCaptor.getValue()).getResponse().hasFailures());
}

@Test
public void createConnector() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@

import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLUndeployModelsResponse extends ActionResponse implements ToXContentObject {
Expand Down Expand Up @@ -49,4 +53,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
return builder;
}

public static MLUndeployModelsResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLUndeployModelsResponse) {
return (MLUndeployModelsResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLUndeployModelsResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLUndeployModelsResponse", e);
}
}
}

0 comments on commit f3a5741

Please sign in to comment.