Skip to content

Commit

Permalink
adding tenantID to the request + undeploy request (opensearch-project…
Browse files Browse the repository at this point in the history
…#3425) (opensearch-project#3429)

Signed-off-by: Dhrubo Saha <[email protected]>
(cherry picked from commit af96fe0)

Co-authored-by: Dhrubo Saha <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and dhrubo-os authored Jan 24, 2025
1 parent 880b674 commit 76f0f3b
Show file tree
Hide file tree
Showing 20 changed files with 517 additions and 152 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,19 @@ default ActionFuture<MLOutput> predict(String modelId, MLInput mlInput) {
* @param mlInput ML input
* @param listener a listener to be notified of the result
*/
void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener);
default void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
predict(modelId, null, mlInput, listener);
}

/**
* Do prediction machine learning job
* For additional info on Predict, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#predict
* @param modelId the trained model id
* @param tenantId tenant id
* @param mlInput ML input
* @param listener a listener to be notified of the result
*/
void predict(String modelId, String tenantId, MLInput mlInput, ActionListener<MLOutput> listener);

/**
* Train model then predict with the same data set.
Expand Down Expand Up @@ -352,7 +364,19 @@ default ActionFuture<MLUndeployModelsResponse> undeploy(String[] modelIds, @Null
* @param modelIds the node ids. May be null for all nodes.
* @param listener a listener to be notified of the result
*/
void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener);
default void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
undeploy(modelIds, nodeIds, null, listener);
}

/**
* Undeploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
* @param modelIds the model ids
* @param modelIds the node ids. May be null for all nodes.
* @param tenantId the tenant id. This is necessary for multi-tenancy.
* @param listener a listener to be notified of the result
*/
void undeploy(String[] modelIds, String[] nodeIds, String tenantId, ActionListener<MLUndeployModelsResponse> listener);

/**
* Create connector for remote model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,15 @@ public class MachineLearningNodeClient implements MachineLearningClient {
Client client;

@Override
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
public void predict(String modelId, String tenantId, MLInput mlInput, ActionListener<MLOutput> listener) {
validateMLInput(mlInput, true);

MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest
.builder()
.mlInput(mlInput)
.modelId(modelId)
.dispatchTask(true)
.tenantId(tenantId)
.build();
client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, getMlPredictionTaskResponseActionListener(listener));
}
Expand Down Expand Up @@ -262,8 +263,8 @@ public void deploy(String modelId, String tenantId, ActionListener<MLDeployModel
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
MLUndeployModelsRequest undeployModelRequest = new MLUndeployModelsRequest(modelIds, nodeIds);
public void undeploy(String[] modelIds, String[] nodeIds, String tenantId, ActionListener<MLUndeployModelsResponse> listener) {
MLUndeployModelsRequest undeployModelRequest = new MLUndeployModelsRequest(modelIds, nodeIds, tenantId);
client.execute(MLUndeployModelsAction.INSTANCE, undeployModelRequest, getMlUndeployModelsResponseActionListener(listener));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> li
listener.onResponse(output);
}

@Override
public void predict(String modelId, String tenantId, MLInput mlInput, ActionListener<MLOutput> listener) {
listener.onResponse(output);
}

@Override
public void trainAndPredict(MLInput mlInput, ActionListener<MLOutput> listener) {
listener.onResponse(output);
Expand Down Expand Up @@ -234,6 +239,11 @@ public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndep
listener.onResponse(undeployModelsResponse);
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, String tenantId, ActionListener<MLUndeployModelsResponse> listener) {
listener.onResponse(undeployModelsResponse);
}

@Override
public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener) {
listener.onResponse(createConnectorResponse);
Expand Down Expand Up @@ -320,7 +330,7 @@ public void predict_WithAlgoAndParametersAndInputDataAndModelId() {
public void predict_WithAlgoAndInputDataAndListener() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(new DataFrameInputDataset(input)).build();
ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
machineLearningClient.predict(null, mlInput, dataFrameActionListener);
machineLearningClient.predict(null, null, mlInput, dataFrameActionListener);
verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture());
assertEquals(output, dataFrameArgumentCaptor.getValue());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,32 @@

package org.opensearch.ml.common.transport.undeploy;

import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;

import java.io.IOException;

import org.opensearch.Version;
import org.opensearch.action.support.nodes.BaseNodesRequest;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

import lombok.Getter;
import lombok.Setter;

public class MLUndeployModelNodesRequest extends BaseNodesRequest<MLUndeployModelNodesRequest> {

@Getter
private String[] modelIds;
@Getter
@Setter
private String tenantId;

public MLUndeployModelNodesRequest(StreamInput in) throws IOException {
super(in);
Version streamInputVersion = in.getVersion();
this.modelIds = in.readOptionalStringArray();
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null;
}

public MLUndeployModelNodesRequest(String[] nodeIds, String[] modelIds) {
Expand All @@ -36,7 +45,11 @@ public MLUndeployModelNodesRequest(DiscoveryNode... nodes) {
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
Version streamOutputVersion = out.getVersion();
out.writeOptionalStringArray(modelIds);
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
out.writeOptionalString(tenantId);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.ml.common.transport.undeploy;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
Expand All @@ -14,6 +16,7 @@
import java.util.ArrayList;
import java.util.List;

import org.opensearch.Version;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
Expand All @@ -39,24 +42,28 @@ public class MLUndeployModelsRequest extends MLTaskRequest {
private String[] modelIds;
private String[] nodeIds;
boolean async;
private String tenantId;

@Builder
public MLUndeployModelsRequest(String[] modelIds, String[] nodeIds, boolean async, boolean dispatchTask) {
public MLUndeployModelsRequest(String[] modelIds, String[] nodeIds, boolean async, boolean dispatchTask, String tenantId) {
super(dispatchTask);
this.modelIds = modelIds;
this.nodeIds = nodeIds;
this.async = async;
this.tenantId = tenantId;
}

public MLUndeployModelsRequest(String[] modelIds, String[] nodeIds) {
this(modelIds, nodeIds, false, false);
public MLUndeployModelsRequest(String[] modelIds, String[] nodeIds, String tenantId) {
this(modelIds, nodeIds, false, false, tenantId);
}

public MLUndeployModelsRequest(StreamInput in) throws IOException {
super(in);
Version streamInputVersion = in.getVersion();
this.modelIds = in.readOptionalStringArray();
this.nodeIds = in.readOptionalStringArray();
this.async = in.readBoolean();
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null;
}

@Override
Expand All @@ -68,15 +75,20 @@ public ActionRequestValidationException validate() {
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
Version streamOutputVersion = out.getVersion();
out.writeOptionalStringArray(modelIds);
out.writeOptionalStringArray(nodeIds);
out.writeBoolean(async);
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
out.writeOptionalString(tenantId);
}
}

public static MLUndeployModelsRequest parse(XContentParser parser, String modelId) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
List<String> modelIdList = new ArrayList<>();
List<String> nodeIdList = new ArrayList<>();
String tenantId = null;
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
Expand All @@ -94,14 +106,17 @@ public static MLUndeployModelsRequest parse(XContentParser parser, String modelI
nodeIdList.add(parser.text());
}
break;
case TENANT_ID_FIELD:
tenantId = parser.textOrNull();
break;
default:
parser.skipChildren();
break;
}
}
String[] modelIds = modelIdList == null ? null : modelIdList.toArray(new String[0]);
String[] nodeIds = nodeIdList == null ? null : nodeIdList.toArray(new String[0]);
return new MLUndeployModelsRequest(modelIds, nodeIds, false, true);
String[] modelIds = modelIdList.toArray(new String[0]);
String[] nodeIds = nodeIdList.toArray(new String[0]);
return new MLUndeployModelsRequest(modelIds, nodeIds, false, true, tenantId);
}

public static MLUndeployModelsRequest fromActionRequest(ActionRequest actionRequest) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package org.opensearch.ml.common.transport.undeploy;

import static org.junit.Assert.*;
import static org.opensearch.ml.common.CommonValue.VERSION_2_18_0;
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Collections;
import java.util.function.Consumer;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.search.SearchModule;

public class MLUndeployModelsRequestTest {

private MLUndeployModelsRequest mlUndeployModelsRequest;

@Before
public void setUp() {
mlUndeployModelsRequest = MLUndeployModelsRequest
.builder()
.modelIds(new String[] { "model1", "model2" })
.nodeIds(new String[] { "node1", "node2" })
.async(true)
.dispatchTask(true)
.tenantId("tenant1")
.build();
}

@Test
public void testValidate() {
MLUndeployModelsRequest request = MLUndeployModelsRequest.builder().modelIds(new String[] { "model1" }).build();
assertNull(request.validate());
}

@Test
public void testStreamInputVersionBefore_2_19_0() throws IOException {
BytesStreamOutput out = new BytesStreamOutput();
out.setVersion(VERSION_2_18_0);
mlUndeployModelsRequest.writeTo(out);

StreamInput in = out.bytes().streamInput();
in.setVersion(VERSION_2_18_0);
MLUndeployModelsRequest request = new MLUndeployModelsRequest(in);

assertArrayEquals(mlUndeployModelsRequest.getModelIds(), request.getModelIds());
assertArrayEquals(mlUndeployModelsRequest.getNodeIds(), request.getNodeIds());
assertEquals(mlUndeployModelsRequest.isAsync(), request.isAsync());
assertEquals(mlUndeployModelsRequest.isDispatchTask(), request.isDispatchTask());
assertNull(request.getTenantId());
}

@Test
public void testStreamInputVersionAfter_2_19_0() throws IOException {
BytesStreamOutput out = new BytesStreamOutput();
out.setVersion(VERSION_2_19_0);
mlUndeployModelsRequest.writeTo(out);

StreamInput in = out.bytes().streamInput();
in.setVersion(VERSION_2_19_0);
MLUndeployModelsRequest request = new MLUndeployModelsRequest(in);

assertArrayEquals(mlUndeployModelsRequest.getModelIds(), request.getModelIds());
assertArrayEquals(mlUndeployModelsRequest.getNodeIds(), request.getNodeIds());
assertEquals(mlUndeployModelsRequest.isAsync(), request.isAsync());
assertEquals(mlUndeployModelsRequest.isDispatchTask(), request.isDispatchTask());
assertEquals(mlUndeployModelsRequest.getTenantId(), request.getTenantId());
}

@Test
public void testWriteToWithNullFields() throws IOException {
MLUndeployModelsRequest request = MLUndeployModelsRequest
.builder()
.modelIds(null)
.nodeIds(null)
.async(true)
.dispatchTask(true)
.build();

BytesStreamOutput out = new BytesStreamOutput();
out.setVersion(VERSION_2_19_0);
request.writeTo(out);

StreamInput in = out.bytes().streamInput();
in.setVersion(VERSION_2_19_0);
MLUndeployModelsRequest result = new MLUndeployModelsRequest(in);

assertNull(result.getModelIds());
assertNull(result.getNodeIds());
assertEquals(request.isAsync(), result.isAsync());
assertEquals(request.isDispatchTask(), result.isDispatchTask());
}

@Test(expected = UncheckedIOException.class)
public void fromActionRequest_IOException() {
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException("test");
}
};
MLUndeployModelsRequest.fromActionRequest(actionRequest);
}

@Test
public void fromActionRequest_Success_WithMLUndeployModelsRequest() {
MLUndeployModelsRequest request = MLUndeployModelsRequest.builder().modelIds(new String[] { "model1" }).build();
assertSame(MLUndeployModelsRequest.fromActionRequest(request), request);
}

@Test
public void testParse() throws Exception {
String expectedInputStr = "{\"model_ids\":[\"model1\"],\"node_ids\":[\"node1\"]}";
parseFromJsonString(expectedInputStr, parsedInput -> {
assertArrayEquals(new String[] { "model1" }, parsedInput.getModelIds());
assertArrayEquals(new String[] { "node1" }, parsedInput.getNodeIds());
assertFalse(parsedInput.isAsync());
assertTrue(parsedInput.isDispatchTask());
});
}

@Test
public void testParseWithInvalidField() throws Exception {
String withInvalidFieldInputStr = "{\"invalid_field\":\"void\",\"model_ids\":[\"model1\"],\"node_ids\":[\"node1\"]}";
parseFromJsonString(withInvalidFieldInputStr, parsedInput -> {
assertArrayEquals(new String[] { "model1" }, parsedInput.getModelIds());
assertArrayEquals(new String[] { "node1" }, parsedInput.getNodeIds());
});
}

private void parseFromJsonString(String expectedInputStr, Consumer<MLUndeployModelsRequest> verify) throws Exception {
XContentParser parser = XContentType.JSON
.xContent()
.createParser(
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
LoggingDeprecationHandler.INSTANCE,
expectedInputStr
);
parser.nextToken();
MLUndeployModelsRequest parsedInput = MLUndeployModelsRequest.parse(parser, null);
verify.accept(parsedInput);
}
}
Loading

0 comments on commit 76f0f3b

Please sign in to comment.