Skip to content

Commit

Permalink
merge main and fix conflicts
Browse files Browse the repository at this point in the history
Signed-off-by: Pavan Yekbote <[email protected]>
  • Loading branch information
pyek-bot committed Jan 28, 2025
2 parents 48a0428 + 06a6021 commit 835cbf1
Show file tree
Hide file tree
Showing 100 changed files with 4,808 additions and 934 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,19 @@ default ActionFuture<MLOutput> predict(String modelId, MLInput mlInput) {
* @param mlInput ML input
* @param listener a listener to be notified of the result
*/
void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener);
default void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
predict(modelId, null, mlInput, listener);
}

/**
* Do prediction machine learning job
* For additional info on Predict, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#predict
* @param modelId the trained model id
* @param tenantId tenant id
* @param mlInput ML input
* @param listener a listener to be notified of the result
*/
void predict(String modelId, String tenantId, MLInput mlInput, ActionListener<MLOutput> listener);

/**
* Train model then predict with the same data set.
Expand Down Expand Up @@ -352,7 +364,19 @@ default ActionFuture<MLUndeployModelsResponse> undeploy(String[] modelIds, @Null
* @param modelIds the node ids. May be null for all nodes.
* @param listener a listener to be notified of the result
*/
void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener);
default void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
undeploy(modelIds, nodeIds, null, listener);
}

/**
* Undeploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
* @param modelIds the model ids
* @param modelIds the node ids. May be null for all nodes.
* @param tenantId the tenant id. This is necessary for multi-tenancy.
* @param listener a listener to be notified of the result
*/
void undeploy(String[] modelIds, String[] nodeIds, String tenantId, ActionListener<MLUndeployModelsResponse> listener);

/**
* Create connector for remote model
Expand Down Expand Up @@ -450,7 +474,23 @@ default ActionFuture<DeleteResponse> deleteAgent(String agentId) {
return actionFuture;
}

void deleteAgent(String agentId, ActionListener<DeleteResponse> listener);
/**
* Delete agent
* @param agentId The id of the agent to delete
* @param listener a listener to be notified of the result
*/
default void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteAgent(agentId, null, actionFuture);
}

/**
* Delete agent
* @param agentId The id of the agent to delete
* @param tenantId the tenant id. This is necessary for multi-tenancy.
* @param listener a listener to be notified of the result
*/
void deleteAgent(String agentId, String tenantId, ActionListener<DeleteResponse> listener);

/**
* Get a list of ToolMetadata and return ActionFuture.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,15 @@ public class MachineLearningNodeClient implements MachineLearningClient {
Client client;

@Override
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
public void predict(String modelId, String tenantId, MLInput mlInput, ActionListener<MLOutput> listener) {
validateMLInput(mlInput, true);

MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest
.builder()
.mlInput(mlInput)
.modelId(modelId)
.dispatchTask(true)
.tenantId(tenantId)
.build();
client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, getMlPredictionTaskResponseActionListener(listener));
}
Expand Down Expand Up @@ -262,8 +263,8 @@ public void deploy(String modelId, String tenantId, ActionListener<MLDeployModel
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
MLUndeployModelsRequest undeployModelRequest = new MLUndeployModelsRequest(modelIds, nodeIds);
public void undeploy(String[] modelIds, String[] nodeIds, String tenantId, ActionListener<MLUndeployModelsResponse> listener) {
MLUndeployModelsRequest undeployModelRequest = new MLUndeployModelsRequest(modelIds, nodeIds, tenantId);
client.execute(MLUndeployModelsAction.INSTANCE, undeployModelRequest, getMlUndeployModelsResponseActionListener(listener));
}

Expand Down Expand Up @@ -291,8 +292,8 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
}

@Override
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId);
public void deleteAgent(String agentId, String tenantId, ActionListener<DeleteResponse> listener) {
MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId, tenantId);
client.execute(MLAgentDeleteAction.INSTANCE, agentDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> li
listener.onResponse(output);
}

@Override
public void predict(String modelId, String tenantId, MLInput mlInput, ActionListener<MLOutput> listener) {
listener.onResponse(output);
}

@Override
public void trainAndPredict(MLInput mlInput, ActionListener<MLOutput> listener) {
listener.onResponse(output);
Expand Down Expand Up @@ -234,6 +239,11 @@ public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndep
listener.onResponse(undeployModelsResponse);
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, String tenantId, ActionListener<MLUndeployModelsResponse> listener) {
listener.onResponse(undeployModelsResponse);
}

@Override
public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener) {
listener.onResponse(createConnectorResponse);
Expand Down Expand Up @@ -281,6 +291,11 @@ public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener)
listener.onResponse(deleteResponse);
}

@Override
public void deleteAgent(String agentId, String tenantId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void getConfig(String configId, ActionListener<MLConfig> listener) {
listener.onResponse(mlConfig);
Expand Down Expand Up @@ -320,7 +335,7 @@ public void predict_WithAlgoAndParametersAndInputDataAndModelId() {
public void predict_WithAlgoAndInputDataAndListener() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(new DataFrameInputDataset(input)).build();
ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
machineLearningClient.predict(null, mlInput, dataFrameActionListener);
machineLearningClient.predict(null, null, mlInput, dataFrameActionListener);
verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture());
assertEquals(output, dataFrameArgumentCaptor.getValue());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,7 @@ public void deleteAgent() {

ArgumentCaptor<DeleteResponse> argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class);

machineLearningNodeClient.deleteAgent(agentId, deleteAgentActionListener);
machineLearningNodeClient.deleteAgent(agentId, null, deleteAgentActionListener);

verify(client).execute(eq(MLAgentDeleteAction.INSTANCE), isA(MLAgentDeleteRequest.class), any());
verify(deleteAgentActionListener).onResponse(argumentCaptor.capture());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public class CommonValue {
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words";
public static final Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
public static final String TOOL_PARAMETERS_PREFIX = "tools.parameters.";

// Index mapping paths
public static final String ML_MODEL_GROUP_INDEX_MAPPING_PATH = "index-mappings/ml_model_group.json";
Expand Down
4 changes: 1 addition & 3 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,7 @@ public MLModel(StreamInput input) throws IOException {
if (input.readBoolean()) {
modelInterface = input.readMap(StreamInput::readString, StreamInput::readString);
}
if (streamInputVersion.onOrAfter(VERSION_2_19_0)) {
tenantId = input.readOptionalString();
}
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
}
}

Expand Down
26 changes: 21 additions & 5 deletions common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.ml.common.agent;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;

import java.io.IOException;
Expand Down Expand Up @@ -63,6 +65,7 @@ public class MLAgent implements ToXContentObject, Writeable {
private Instant lastUpdateTime;
private String appType;
private Boolean isHidden;
private final String tenantId;

@Builder(toBuilder = true)
public MLAgent(
Expand All @@ -76,7 +79,8 @@ public MLAgent(
Instant createdTime,
Instant lastUpdateTime,
String appType,
Boolean isHidden
Boolean isHidden,
String tenantId
) {
this.name = name;
this.type = type;
Expand All @@ -90,6 +94,7 @@ public MLAgent(
this.appType = appType;
// is_hidden field isn't going to be set by user. It will be set by the code.
this.isHidden = isHidden;
this.tenantId = tenantId;
validate();
}

Expand Down Expand Up @@ -155,6 +160,7 @@ public MLAgent(StreamInput input) throws IOException {
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) {
isHidden = input.readOptionalBoolean();
}
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
validate();
}

Expand All @@ -169,7 +175,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
if (tools != null && tools.size() > 0) {
if (tools != null && !tools.isEmpty()) {
out.writeBoolean(true);
out.writeInt(tools.size());
for (MLToolSpec tool : tools) {
Expand All @@ -178,7 +184,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
if (parameters != null && parameters.size() > 0) {
if (parameters != null && !parameters.isEmpty()) {
out.writeBoolean(true);
out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeOptionalString);
} else {
Expand All @@ -197,6 +203,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) {
out.writeOptionalBoolean(isHidden);
}
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
out.writeOptionalString(tenantId);
}
}

@Override
Expand Down Expand Up @@ -236,6 +245,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (isHidden != null) {
builder.field(MLModel.IS_HIDDEN_FIELD, isHidden);
}
if (tenantId != null) {
builder.field(TENANT_ID_FIELD, tenantId);
}
builder.endObject();
return builder;
}
Expand All @@ -260,6 +272,7 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid
Instant lastUpdateTime = null;
String appType = null;
boolean isHidden = false;
String tenantId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -305,6 +318,9 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid
if (parseHidden)
isHidden = parser.booleanValue();
break;
case TENANT_ID_FIELD:
tenantId = parser.textOrNull();
break;
default:
parser.skipChildren();
break;
Expand All @@ -324,11 +340,11 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid
.lastUpdateTime(lastUpdateTime)
.appType(appType)
.isHidden(isHidden)
.tenantId(tenantId)
.build();
}

public static MLAgent fromStream(StreamInput in) throws IOException {
MLAgent agent = new MLAgent(in);
return agent;
return new MLAgent(in);
}
}
Loading

0 comments on commit 835cbf1

Please sign in to comment.