Skip to content

Commit

Permalink
applying multi-tenancy to task apis, deploy, predict apis (opensearch…
Browse files Browse the repository at this point in the history
…-project#3416) (opensearch-project#3420)

* applying multi-tenancy to task, deploy, predict



* addressed comments



---------

Signed-off-by: Dhrubo Saha <[email protected]>
  • Loading branch information
dhrubo-os authored Jan 23, 2025
1 parent 89dee61 commit ed36d6e
Show file tree
Hide file tree
Showing 55 changed files with 1,175 additions and 461 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,18 @@ default ActionFuture<MLTask> getTask(String taskId) {
* @param taskId id of the model
* @param listener action listener
*/
void getTask(String taskId, ActionListener<MLTask> listener);
default void getTask(String taskId, ActionListener<MLTask> listener) {
getTask(taskId, null, listener);
}

/**
* Get MLTask and return task in listener
* For more info on get task, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#get-task-information
* @param taskId id of the model
* @param tenantId the tenant id. This is necessary for multi-tenancy.
* @param listener action listener
*/
void getTask(String taskId, String tenantId, ActionListener<MLTask> listener);

/**
* Delete the model with modelId.
Expand Down Expand Up @@ -224,7 +235,18 @@ default ActionFuture<DeleteResponse> deleteTask(String taskId) {
* @param taskId id of the task
* @param listener action listener
*/
void deleteTask(String taskId, ActionListener<DeleteResponse> listener);
default void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
deleteTask(taskId, null, listener);
}

/**
* Delete MLTask
* For more info on delete task, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#delete-task
* @param taskId id of the task
* @param tenantId the tenant id. This is necessary for multi-tenancy.
* @param listener action listener
*/
void deleteTask(String taskId, String tenantId, ActionListener<DeleteResponse> listener);

/**
* For more info on search model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#search-model
Expand Down Expand Up @@ -298,7 +320,18 @@ default ActionFuture<MLDeployModelResponse> deploy(String modelId) {
* @param modelId the model id
* @param listener a listener to be notified of the result
*/
void deploy(String modelId, ActionListener<MLDeployModelResponse> listener);
default void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
deploy(modelId, null, listener);
}

/**
* Deploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/deploy-model/
* @param modelId the model id
* @param tenantId the tenant id. This is necessary for multi-tenancy.
* @param listener a listener to be notified of the result
*/
void deploy(String modelId, String tenantId, ActionListener<MLDeployModelResponse> listener);

/**
* Undeploy models
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,27 @@ public void getTask(String taskId, ActionListener<MLTask> listener) {
client.execute(MLTaskGetAction.INSTANCE, mlTaskGetRequest, getMLTaskResponseActionListener(listener));
}

@Override
public void getTask(String taskId, String tenantId, ActionListener<MLTask> listener) {
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).tenantId(tenantId).build();

client.execute(MLTaskGetAction.INSTANCE, mlTaskGetRequest, getMLTaskResponseActionListener(listener));
}

@Override
public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.builder().taskId(taskId).build();

client.execute(MLTaskDeleteAction.INSTANCE, mlTaskDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
}

@Override
public void deleteTask(String taskId, String tenantId, ActionListener<DeleteResponse> listener) {
MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.builder().taskId(taskId).tenantId(tenantId).build();

client.execute(MLTaskDeleteAction.INSTANCE, mlTaskDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
}

@Override
public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
client.execute(MLTaskSearchAction.INSTANCE, searchRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
Expand All @@ -242,8 +256,8 @@ public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterMode
}

@Override
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, false);
public void deploy(String modelId, String tenantId, ActionListener<MLDeployModelResponse> listener) {
MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, tenantId, false);
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, getMlDeployModelResponseActionListener(listener));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ public class MachineLearningClientTest {
@Mock
MLConfigGetResponse configGetResponse;

private String modekId = "test_model_id";
private final String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
private MLConfig mlConfig;
private ToolMetadata toolMetadata;
private List<ToolMetadata> toolsList = new ArrayList<>();
private final List<ToolMetadata> toolsList = new ArrayList<>();

@Before
public void setUp() {
Expand Down Expand Up @@ -194,11 +194,21 @@ public void getTask(String taskId, ActionListener<MLTask> listener) {
listener.onResponse(mlTask);
}

@Override
public void getTask(String taskId, String tenantId, ActionListener<MLTask> listener) {
listener.onResponse(mlTask);
}

@Override
public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void deleteTask(String taskId, String tenantId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
listener.onResponse(searchResponse);
Expand All @@ -214,6 +224,11 @@ public void deploy(String modelId, ActionListener<MLDeployModelResponse> listene
listener.onResponse(deployModelResponse);
}

@Override
public void deploy(String modelId, String tenantId, ActionListener<MLDeployModelResponse> listener) {
listener.onResponse(deployModelResponse);
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
listener.onResponse(undeployModelsResponse);
Expand Down Expand Up @@ -487,8 +502,8 @@ public void createConnector() {
@Test
public void executeMetricsCorrelation() {
List<float[]> inputData = new ArrayList<>(
Arrays
.asList(
List
.of(
new float[] {
0.89451003f,
4.2006273f,
Expand Down
3 changes: 1 addition & 2 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
}

public static MLModel fromStream(StreamInput in) throws IOException {
MLModel mlModel = new MLModel(in);
return mlModel;
return new MLModel(in);
}

}
24 changes: 20 additions & 4 deletions common/src/main/java/org/opensearch/ml/common/MLTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package org.opensearch.ml.common;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.USER;
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;

import java.io.IOException;
import java.time.Instant;
Expand Down Expand Up @@ -72,6 +74,7 @@ public class MLTask implements ToXContentObject, Writeable {
private boolean async;
@Setter
private Map<String, Object> remoteJob;
private String tenantId;

@Builder(toBuilder = true)
public MLTask(
Expand All @@ -89,7 +92,8 @@ public MLTask(
String error,
User user,
boolean async,
Map<String, Object> remoteJob
Map<String, Object> remoteJob,
String tenantId
) {
this.taskId = taskId;
this.modelId = modelId;
Expand All @@ -106,6 +110,7 @@ public MLTask(
this.user = user;
this.async = async;
this.remoteJob = remoteJob;
this.tenantId = tenantId;
}

public MLTask(StreamInput input) throws IOException {
Expand Down Expand Up @@ -134,9 +139,10 @@ public MLTask(StreamInput input) throws IOException {
this.async = input.readBoolean();
if (streamInputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_PREDICTION_JOB)) {
if (input.readBoolean()) {
this.remoteJob = input.readMap(s -> s.readString(), s -> s.readGenericValue());
this.remoteJob = input.readMap(StreamInput::readString, StreamInput::readGenericValue);
}
}
tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
}

@Override
Expand Down Expand Up @@ -173,6 +179,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
}
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
out.writeOptionalString(tenantId);
}
}

@Override
Expand Down Expand Up @@ -221,12 +230,14 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
if (remoteJob != null) {
builder.field(REMOTE_JOB_FIELD, remoteJob);
}
if (tenantId != null) {
builder.field(TENANT_ID_FIELD, tenantId);
}
return builder.endObject();
}

public static MLTask fromStream(StreamInput in) throws IOException {
MLTask mlTask = new MLTask(in);
return mlTask;
return new MLTask(in);
}

public static MLTask parse(XContentParser parser) throws IOException {
Expand All @@ -245,6 +256,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
User user = null;
boolean async = false;
Map<String, Object> remoteJob = null;
String tenantId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -305,6 +317,9 @@ public static MLTask parse(XContentParser parser) throws IOException {
case REMOTE_JOB_FIELD:
remoteJob = parser.map();
break;
case TENANT_ID_FIELD:
tenantId = parser.textOrNull();
break;
default:
parser.skipChildren();
break;
Expand All @@ -327,6 +342,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
.user(user)
.async(async)
.remoteJob(remoteJob)
.tenantId(tenantId)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

package org.opensearch.ml.common.transport.deploy;

import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;

import java.io.IOException;

import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand All @@ -18,6 +21,7 @@
@Data
public class MLDeployModelInput implements Writeable {
private String modelId;
private String tenantId;
private String taskId;
private String modelContentHash;
private Integer nodeCount;
Expand All @@ -26,13 +30,15 @@ public class MLDeployModelInput implements Writeable {
private MLTask mlTask;

public MLDeployModelInput(StreamInput in) throws IOException {
Version streamInputVersion = in.getVersion();
this.modelId = in.readString();
this.taskId = in.readString();
this.modelContentHash = in.readOptionalString();
this.nodeCount = in.readInt();
this.coordinatingNodeId = in.readString();
this.isDeployToAllNodes = in.readOptionalBoolean();
this.mlTask = new MLTask(in);
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null;
}

@Builder
Expand All @@ -43,7 +49,8 @@ public MLDeployModelInput(
Integer nodeCount,
String coordinatingNodeId,
Boolean isDeployToAllNodes,
MLTask mlTask
MLTask mlTask,
String tenantId
) {
this.modelId = modelId;
this.taskId = taskId;
Expand All @@ -52,19 +59,24 @@ public MLDeployModelInput(
this.coordinatingNodeId = coordinatingNodeId;
this.isDeployToAllNodes = isDeployToAllNodes;
this.mlTask = mlTask;
this.tenantId = tenantId;
}

public MLDeployModelInput() {}

@Override
public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeString(modelId);
out.writeString(taskId);
out.writeOptionalString(modelContentHash);
out.writeInt(nodeCount);
out.writeString(coordinatingNodeId);
out.writeOptionalBoolean(isDeployToAllNodes);
mlTask.writeTo(out);
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
out.writeOptionalString(tenantId);
}
}

}
Loading

0 comments on commit ed36d6e

Please sign in to comment.