diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 53a12a4224..bf07ad5040 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -7,6 +7,7 @@ import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.connector.AbstractConnector; +import org.opensearch.ml.common.controller.MLModelController; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; @@ -54,12 +55,14 @@ public class CommonValue { public static final String ML_MODEL_INDEX = ".plugins-ml-model"; public static final String ML_TASK_INDEX = ".plugins-ml-task"; public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2; - public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 8; + public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 9; public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2; public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2; public static final String ML_CONFIG_INDEX = ".plugins-ml-config"; public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2; + public static final String ML_MODEL_CONTROLLER_INDEX = ".plugins-ml-controller"; + public static final Integer ML_MODEL_CONTROLLER_INDEX_SCHEMA_VERSION = 1; public static final String ML_MAP_RESPONSE_KEY = "response"; public static final String ML_AGENT_INDEX = ".plugins-ml-agent"; public static final Integer ML_AGENT_INDEX_SCHEMA_VERSION = 1; @@ -222,6 +225,15 @@ public class CommonValue { + MODEL_MAX_LENGTH_FIELD + "\":{\"type\":\"integer\"},\"" + ALL_CONFIG_FIELD + "\":{\"type\":\"text\"}}},\n" + " \"" + + MLModel.IS_ENABLED_FIELD + + "\" : {\"type\": \"boolean\"},\n" + + " \"" + + MLModel.IS_MODEL_CONTROLLER_ENABLED_FIELD + + "\" : {\"type\": \"boolean\"},\n" + + " \"" + + MLModel.MODEL_RATE_LIMITER_CONFIG_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + MLModel.MODEL_CONTENT_HASH_VALUE_FIELD + "\" : {\"type\": \"keyword\"},\n" + " \"" @@ -350,6 +362,17 @@ public class CommonValue { + " }\n" + "}"; + public static final String ML_MODEL_CONTROLLER_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_MODEL_CONTROLLER_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLModelController.USER_RATE_LIMITER_CONFIG + + "\" : {\"type\": \"flat_object\"}\n" + + " }\n" + + "}"; + public static final String ML_AGENT_INDEX_MAPPING = "{\n" + " \"_meta\": {\"schema_version\": " + ML_AGENT_INDEX_SCHEMA_VERSION diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index 80b488d418..8a2b50d07f 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -17,6 +17,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; @@ -50,9 +51,13 @@ public class MLModel implements ToXContentObject { public static final String MODEL_FORMAT_FIELD = "model_format"; public static final String MODEL_STATE_FIELD = "model_state"; public static final String MODEL_CONTENT_SIZE_IN_BYTES_FIELD = "model_content_size_in_bytes"; - //SHA256 hash value of model content. + // SHA256 hash value of model content. public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value"; + // Model level quota and throttling control + public static final String IS_ENABLED_FIELD = "is_enabled"; + public static final String MODEL_RATE_LIMITER_CONFIG_FIELD = "model_rate_limiter_config"; + public static final String IS_MODEL_CONTROLLER_ENABLED_FIELD = "is_model_controller_enabled"; public static final String MODEL_CONFIG_FIELD = "model_config"; public static final String CREATED_TIME_FIELD = "created_time"; public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; @@ -94,6 +99,9 @@ public class MLModel implements ToXContentObject { private Long modelContentSizeInBytes; private String modelContentHash; private MLModelConfig modelConfig; + private Boolean isEnabled; + private Boolean isModelControllerEnabled; + private MLRateLimiter modelRateLimiterConfig; private Instant createdTime; private Instant lastUpdateTime; private Instant lastRegisteredTime; @@ -131,6 +139,9 @@ public MLModel(String name, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHash, + Boolean isEnabled, + Boolean isModelControllerEnabled, + MLRateLimiter modelRateLimiterConfig, MLModelConfig modelConfig, Instant createdTime, Instant lastUpdateTime, @@ -158,6 +169,9 @@ public MLModel(String name, this.modelState = modelState; this.modelContentSizeInBytes = modelContentSizeInBytes; this.modelContentHash = modelContentHash; + this.isEnabled = isEnabled; + this.isModelControllerEnabled = isModelControllerEnabled; + this.modelRateLimiterConfig = modelRateLimiterConfig; this.modelConfig = modelConfig; this.createdTime = createdTime; this.lastUpdateTime = lastUpdateTime; @@ -204,6 +218,11 @@ public MLModel(StreamInput input) throws IOException{ modelConfig = new TextEmbeddingModelConfig(input); } } + isEnabled = input.readOptionalBoolean(); + isModelControllerEnabled = input.readOptionalBoolean(); + if (input.readBoolean()) { + modelRateLimiterConfig = new MLRateLimiter(input); + } createdTime = input.readOptionalInstant(); lastUpdateTime = input.readOptionalInstant(); lastRegisteredTime = input.readOptionalInstant(); @@ -258,6 +277,14 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + out.writeOptionalBoolean(isEnabled); + out.writeOptionalBoolean(isModelControllerEnabled); + if (modelRateLimiterConfig != null) { + out.writeBoolean(true); + modelRateLimiterConfig.writeTo(out); + } else { + out.writeBoolean(false); + } out.writeOptionalInstant(createdTime); out.writeOptionalInstant(lastUpdateTime); out.writeOptionalInstant(lastRegisteredTime); @@ -321,6 +348,15 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (modelConfig != null) { builder.field(MODEL_CONFIG_FIELD, modelConfig); } + if (isEnabled != null) { + builder.field(IS_ENABLED_FIELD, isEnabled); + } + if (isModelControllerEnabled != null) { + builder.field(IS_MODEL_CONTROLLER_ENABLED_FIELD, isModelControllerEnabled); + } + if (modelRateLimiterConfig != null) { + builder.field(MODEL_RATE_LIMITER_CONFIG_FIELD, modelRateLimiterConfig); + } if (createdTime != null) { builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli()); } @@ -389,6 +425,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws Long modelContentSizeInBytes = null; String modelContentHash = null; MLModelConfig modelConfig = null; + Boolean isEnabled = null; + Boolean isModelControllerEnabled = null; + MLRateLimiter modelRateLimiterConfig = null; Instant createdTime = null; Instant lastUpdateTime = null; Instant lastUploadedTime = null; @@ -474,6 +513,15 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws modelConfig = TextEmbeddingModelConfig.parse(parser); } break; + case IS_ENABLED_FIELD: + isEnabled = parser.booleanValue(); + break; + case IS_MODEL_CONTROLLER_ENABLED_FIELD: + isModelControllerEnabled = parser.booleanValue(); + break; + case MODEL_RATE_LIMITER_CONFIG_FIELD: + modelRateLimiterConfig = MLRateLimiter.parse(parser); + break; case PLANNING_WORKER_NODE_COUNT_FIELD: planningWorkerNodeCount = parser.intValue(); break; @@ -540,6 +588,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws .modelContentSizeInBytes(modelContentSizeInBytes) .modelContentHash(modelContentHash) .modelConfig(modelConfig) + .isEnabled(isEnabled) + .isModelControllerEnabled(isModelControllerEnabled) + .modelRateLimiterConfig(modelRateLimiterConfig) .createdTime(createdTime) .lastUpdateTime(lastUpdateTime) .lastRegisteredTime(lastRegisteredTime == null? lastUploadedTime : lastRegisteredTime) diff --git a/common/src/main/java/org/opensearch/ml/common/controller/MLModelController.java b/common/src/main/java/org/opensearch/ml/common/controller/MLModelController.java new file mode 100644 index 0000000000..269a875f57 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/controller/MLModelController.java @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.controller; + +import lombok.Builder; +import lombok.Data; +import lombok.Getter; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Objects; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; + +@Data +public class MLModelController implements ToXContentObject, Writeable { + + public static final String MODEL_ID_FIELD = "model_id"; // mandatory + public static final String USER_RATE_LIMITER_CONFIG = "user_rate_limiter_config"; + + @Getter + private String modelId; + // The String is the username field where the MLRateLimiter is its corresponding rate limiter config. + private Map userRateLimiterConfig; + + @Builder(toBuilder = true) + public MLModelController(String modelId, Map userRateLimiterConfig) { + this.modelId = modelId; + this.userRateLimiterConfig = userRateLimiterConfig; + } + + public static MLModelController parse(XContentParser parser) throws IOException { + String modelId = null; + Map userRateLimiterConfig = new HashMap<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MODEL_ID_FIELD: + modelId = parser.text(); + break; + case USER_RATE_LIMITER_CONFIG: + Map userRateLimiterConfigStringMap = getParameterMap(parser.map()); + userRateLimiterConfigStringMap.forEach((user, rateLimiterString) -> { + try { + XContentParser rateLimiterParser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, rateLimiterString); + rateLimiterParser.nextToken(); + MLRateLimiter rateLimiter = MLRateLimiter.parse(rateLimiterParser); + if (!rateLimiter.isEmpty()) { + userRateLimiterConfig.put(user, rateLimiter); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + break; + default: + parser.skipChildren(); + break; + } + } + // Model ID can only be set through RestRequest. + return new MLModelController(modelId, userRateLimiterConfig); + } + + public MLModelController(StreamInput in) throws IOException{ + modelId = in.readString(); + if (in.readBoolean()) { + userRateLimiterConfig = in.readMap(StreamInput::readString, MLRateLimiter::new); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + if (userRateLimiterConfig != null) { + out.writeBoolean(true); + out.writeMap(userRateLimiterConfig, StreamOutput::writeString, (streamOutput, rateLimiter) -> rateLimiter.writeTo(streamOutput)); + } else { + out.writeBoolean(false); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID_FIELD, modelId); + if (userRateLimiterConfig != null) { + builder.field(USER_RATE_LIMITER_CONFIG, userRateLimiterConfig); + } + builder.endObject(); + return builder; + } + + + /** + * Checks if a deployment is required after updating the MLModelController. + * + * @param updateContent The updated MLModelController object. + * @return True if a deployment is required, false otherwise. + */ + public boolean isDeployRequiredAfterUpdate(MLModelController updateContent) { + if (updateContent != null && updateContent.getUserRateLimiterConfig() != null && !updateContent.getUserRateLimiterConfig().isEmpty()) { + Map updateUserRateLimiterConfig = updateContent.getUserRateLimiterConfig(); + for (Map.Entry entry : updateUserRateLimiterConfig.entrySet()) { + String newUser = entry.getKey(); + MLRateLimiter newRateLimiter = entry.getValue(); + if (this.userRateLimiterConfig.containsKey(newUser)) { + MLRateLimiter oldRateLimiter = this.userRateLimiterConfig.get(newUser); + if (MLRateLimiter.isDeployRequiredAfterUpdate(oldRateLimiter, newRateLimiter)) { + return true; + } + } else { + if (newRateLimiter.isValid()) { + return true; + } + } + } + } + return false; + } + + public void update(MLModelController updateContent) { + Map updateUserRateLimiterConfig = updateContent.getUserRateLimiterConfig(); + if (updateUserRateLimiterConfig != null && !updateUserRateLimiterConfig.isEmpty()) { + updateUserRateLimiterConfig.forEach((user, rateLimiter) -> { + // rateLimiter can't be null due to parsing exception + if (this.userRateLimiterConfig.containsKey(user)) { + this.userRateLimiterConfig.get(user).update(rateLimiter); + } else { + this.userRateLimiterConfig.put(user, rateLimiter); + } + }); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/controller/MLRateLimiter.java b/common/src/main/java/org/opensearch/ml/common/controller/MLRateLimiter.java new file mode 100644 index 0000000000..c132392708 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/controller/MLRateLimiter.java @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.controller; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +@Setter +@Getter +public class MLRateLimiter implements ToXContentObject, Writeable { + public static final String RATE_LIMIT_NUMBER_FIELD = "rate_limit_number"; + public static final String RATE_LIMIT_UNIT_FIELD = "rate_limit_unit"; + + private String rateLimitNumber; + private TimeUnit rateLimitUnit; + + @Builder(toBuilder = true) + public MLRateLimiter(String rateLimitNumber, TimeUnit rateLimitUnit) { + this.rateLimitNumber = rateLimitNumber; + this.rateLimitUnit = rateLimitUnit; + } + + public static MLRateLimiter parse(XContentParser parser) throws IOException { + String rateLimitNumber = null; + TimeUnit rateLimitUnit = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case RATE_LIMIT_NUMBER_FIELD: + rateLimitNumber = parser.text(); + break; + case RATE_LIMIT_UNIT_FIELD: + rateLimitUnit = TimeUnit.valueOf(parser.text()); + break; + default: + parser.skipChildren(); + break; + } + } + return new MLRateLimiter(rateLimitNumber, rateLimitUnit); + } + + public MLRateLimiter(StreamInput in) throws IOException{ + this.rateLimitNumber = in.readOptionalString(); + if (in.readBoolean()) { + this.rateLimitUnit = in.readEnum(TimeUnit.class); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(rateLimitNumber); + if (rateLimitUnit != null) { + out.writeBoolean(true); + out.writeEnum(rateLimitUnit); + } else { + out.writeBoolean(false); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (rateLimitNumber != null) { + builder.field(RATE_LIMIT_NUMBER_FIELD, rateLimitNumber); + } + if (rateLimitUnit != null) { + builder.field(RATE_LIMIT_UNIT_FIELD, rateLimitUnit); + } + builder.endObject(); + return builder; + } + + public void update(MLRateLimiter updateContent) { + if (updateContent.getRateLimitNumber() != null) { + this.rateLimitNumber = updateContent.getRateLimitNumber(); + } + if (updateContent.getRateLimitUnit() != null) { + this.rateLimitUnit = updateContent.getRateLimitUnit(); + } + } + + public static MLRateLimiter update(MLRateLimiter rateLimiter, MLRateLimiter updateContent) { + if (rateLimiter == null) { + return updateContent; + } else { + rateLimiter.update(updateContent); + return rateLimiter; + } + } + + /** + * Checks the validity of this incoming update before performing an update operation. + * A valid update indicates the corresponding index will be updated with the current MLRateLimiter config and the update content + * + * @param rateLimiter The existing rate limiter. + * @param updateContent The update content. + * @return true if the update is valid, false otherwise. + */ + public static boolean updateValidityPreCheck(MLRateLimiter rateLimiter, MLRateLimiter updateContent) { + if (updateContent == null) { + return false; + } else if (rateLimiter == null) { + return true; + } else if (updateContent.isEmpty()) { + return false; + } else return (!Objects.equals(updateContent.getRateLimitNumber(), rateLimiter.getRateLimitNumber()) && updateContent.getRateLimitNumber() != null) + || (!Objects.equals(updateContent.getRateLimitUnit(), rateLimiter.getRateLimitUnit()) && updateContent.getRateLimitUnit() != null); + } + + /** + * Checks if we need to deploy this update into ML Cache (if model is deployed) after performing this update operation. + * + * @param rateLimiter The existing rate limiter. + * @param updateContent The update content. + * @return true if the update is valid, false otherwise. + */ + public static boolean isDeployRequiredAfterUpdate(MLRateLimiter rateLimiter, MLRateLimiter updateContent) { + if (!updateValidityPreCheck(rateLimiter, updateContent)) { + return false; + } else { + return updateContent.isValid() + || (rateLimiter.getRateLimitUnit() != null && updateContent.getRateLimitNumber() != null) + || (rateLimiter.getRateLimitNumber() != null && updateContent.getRateLimitUnit() != null); + } + } + + public boolean isValid() { + return (this.rateLimitUnit != null && this.rateLimitNumber != null); + } + + public boolean isEmpty() { + return (this.rateLimitUnit == null && this.rateLimitNumber == null); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerAction.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerAction.java new file mode 100644 index 0000000000..4e99704771 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.controller; + +import org.opensearch.action.ActionType; + +public class MLCreateModelControllerAction extends ActionType{ + public static final MLCreateModelControllerAction INSTANCE = new MLCreateModelControllerAction(); + public static final String NAME = "cluster:admin/opensearch/ml/controllers/create"; + + private MLCreateModelControllerAction() { + super(NAME, MLCreateModelControllerResponse::new); + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerRequest.java new file mode 100644 index 0000000000..136a3bd373 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerRequest.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.controller; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.controller.MLModelController; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLCreateModelControllerRequest extends ActionRequest { + private MLModelController modelControllerInput; + + @Builder + public MLCreateModelControllerRequest(MLModelController modelControllerInput) { + this.modelControllerInput = modelControllerInput; + } + + public MLCreateModelControllerRequest(StreamInput in) throws IOException { + super(in); + this.modelControllerInput = new MLModelController(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + modelControllerInput.writeTo(out); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (modelControllerInput == null) { + exception = addValidationError("Model controller input can't be null", exception); + } + return exception; + } + + public static MLCreateModelControllerRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLCreateModelControllerRequest) { + return (MLCreateModelControllerRequest) actionRequest; + } + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) + { + return new MLCreateModelControllerRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLCreateModelControllerRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerResponse.java new file mode 100644 index 0000000000..531aa4daf7 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerResponse.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.controller; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +@Getter +public class MLCreateModelControllerResponse extends ActionResponse implements ToXContentObject { + + public static final String MODEL_ID_FIELD = "model_id"; + public static final String STATUS_FIELD = "status"; + + @Getter + String modelId; + String status; + + public MLCreateModelControllerResponse(StreamInput in) throws IOException { + super(in); + this.modelId = in.readString(); + this.status = in.readString(); + } + + @Builder + public MLCreateModelControllerResponse(String modelId, String status) { + this.modelId = modelId; + this.status = status; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeString(status); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID_FIELD, modelId); + builder.field(STATUS_FIELD, status); + builder.endObject(); + return builder; + } + + public static MLCreateModelControllerResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLCreateModelControllerResponse) { + return (MLCreateModelControllerResponse) actionResponse; + } + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLCreateModelControllerResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionResponse into MLCreateModelControllerResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerAction.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerAction.java new file mode 100644 index 0000000000..68b1094e0e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerAction.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import org.opensearch.action.ActionType; + +// This action will only be passively called when creating or updating a model controller when the model is deployed. +public class MLDeployModelControllerAction extends ActionType { + public static final MLDeployModelControllerAction INSTANCE = new MLDeployModelControllerAction(); + public static final String NAME = "cluster:admin/opensearch/ml/controllers/deploy"; + + private MLDeployModelControllerAction() { super(NAME, MLDeployModelControllerNodesResponse::new);} +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeRequest.java new file mode 100644 index 0000000000..d11e488641 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeRequest.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import java.io.IOException; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +public class MLDeployModelControllerNodeRequest extends TransportRequest { + @Getter + private MLDeployModelControllerNodesRequest deployModelControllerNodesRequest; + + public MLDeployModelControllerNodeRequest(StreamInput in) throws IOException { + super(in); + this.deployModelControllerNodesRequest = new MLDeployModelControllerNodesRequest(in); + } + + public MLDeployModelControllerNodeRequest(MLDeployModelControllerNodesRequest request) { + this.deployModelControllerNodesRequest = request; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + deployModelControllerNodesRequest.writeTo(out); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeResponse.java new file mode 100644 index 0000000000..9587a4e40b --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeResponse.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import lombok.Getter; +import lombok.extern.log4j.Log4j2; +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; + +@Getter +@Log4j2 +public class MLDeployModelControllerNodeResponse extends BaseNodeResponse implements ToXContentFragment { + private Map modelControllerDeployStatus; + + public MLDeployModelControllerNodeResponse(DiscoveryNode node, Map modelControllerDeployStatus) { + super(node); + this.modelControllerDeployStatus = modelControllerDeployStatus; + } + + public MLDeployModelControllerNodeResponse(StreamInput in) throws IOException { + super(in); + if (in.readBoolean()) { + this.modelControllerDeployStatus = in.readMap(StreamInput::readString, StreamInput::readString); + } + } + + public static MLDeployModelControllerNodeResponse readStats(StreamInput in) throws IOException { + return new MLDeployModelControllerNodeResponse(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + + if (!isModelControllerDeployStatusEmpty()) { + out.writeBoolean(true); + out.writeMap(modelControllerDeployStatus, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + } + + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject("stats"); + if (!isModelControllerDeployStatusEmpty()) { + for (Map.Entry stat : modelControllerDeployStatus.entrySet()) { + builder.field(stat.getKey(), stat.getValue()); + } + } + builder.endObject(); + return builder; + } + + public boolean isModelControllerDeployStatusEmpty() { + return modelControllerDeployStatus == null || modelControllerDeployStatus.isEmpty(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesRequest.java new file mode 100644 index 0000000000..ac399828f1 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesRequest.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + package org.opensearch.ml.common.transport.controller; + +import lombok.Getter; +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import java.io.IOException; + +public class MLDeployModelControllerNodesRequest extends BaseNodesRequest { + + @Getter + private String modelId; + + public MLDeployModelControllerNodesRequest(StreamInput in) throws IOException { + super(in); + this.modelId = in.readString(); + } + + public MLDeployModelControllerNodesRequest(String[] nodeIds, String modelId) { + super(nodeIds); + this.modelId = modelId; + } + + public MLDeployModelControllerNodesRequest(DiscoveryNode[] nodeIds, String modelId) { + super(nodeIds); + this.modelId = modelId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelId); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesResponse.java new file mode 100644 index 0000000000..bbd27c7cca --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesResponse.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + package org.opensearch.ml.common.transport.controller; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; + +public class MLDeployModelControllerNodesResponse extends BaseNodesResponse implements ToXContentObject { + + public MLDeployModelControllerNodesResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(MLDeployModelControllerNodeResponse::readStats), in.readList(FailedNodeException::new)); + } + + public MLDeployModelControllerNodesResponse(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(MLDeployModelControllerNodeResponse::readStats); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + String nodeId; + DiscoveryNode node; + builder.startObject(); + for (MLDeployModelControllerNodeResponse deployStats : getNodes()) { + if (!deployStats.isModelControllerDeployStatusEmpty()) { + node = deployStats.getNode(); + nodeId = node.getId(); + builder.startObject(nodeId); + deployStats.toXContent(builder, params); + builder.endObject(); + } + } + builder.endObject(); + return builder; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteAction.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteAction.java new file mode 100644 index 0000000000..2e44fffa5c --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteAction.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import org.opensearch.action.ActionType; +import org.opensearch.action.delete.DeleteResponse; + +public class MLModelControllerDeleteAction extends ActionType { + public static final MLModelControllerDeleteAction INSTANCE = new MLModelControllerDeleteAction(); + public static final String NAME = "cluster:admin/opensearch/ml/controllers/delete"; + + private MLModelControllerDeleteAction() { super(NAME, DeleteResponse::new);} +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteRequest.java new file mode 100644 index 0000000000..d7709d808d --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteRequest.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +public class MLModelControllerDeleteRequest extends ActionRequest { + @Getter + String modelId; + + @Builder + public MLModelControllerDeleteRequest(String modelId) { + this.modelId = modelId; + } + + public MLModelControllerDeleteRequest(StreamInput input) throws IOException { + super(input); + this.modelId = input.readString(); + } + + @Override + public void writeTo(StreamOutput output) throws IOException { + super.writeTo(output); + output.writeString(modelId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.modelId == null) { + exception = addValidationError("ML model id can't be null", exception); + } + + return exception; + } + + public static MLModelControllerDeleteRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLModelControllerDeleteRequest) { + return (MLModelControllerDeleteRequest)actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLModelControllerDeleteRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLModelControllerDeleteRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetAction.java new file mode 100644 index 0000000000..bbae2ac7de --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetAction.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import org.opensearch.action.ActionType; + +public class MLModelControllerGetAction extends ActionType { + public static final MLModelControllerGetAction INSTANCE = new MLModelControllerGetAction(); + public static final String NAME = "cluster:admin/opensearch/ml/controllers/get"; + + private MLModelControllerGetAction() { super(NAME, MLModelControllerGetResponse::new);} +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetRequest.java new file mode 100644 index 0000000000..d46afd93b3 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetRequest.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLModelControllerGetRequest extends ActionRequest { + + String modelId; + boolean returnContent; + + @Builder + public MLModelControllerGetRequest(String modelId, boolean returnContent) { + this.modelId = modelId; + this.returnContent = returnContent; + } + + public MLModelControllerGetRequest(StreamInput in) throws IOException { + super(in); + this.modelId = in.readString(); + this.returnContent = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.modelId); + out.writeBoolean(returnContent); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.modelId == null) { + exception = addValidationError("ML model id can't be null", exception); + } + + return exception; + } + + public static MLModelControllerGetRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLModelControllerGetRequest) { + return (MLModelControllerGetRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLModelControllerGetRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLModelControllerGetRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetResponse.java new file mode 100644 index 0000000000..6c5fe8db09 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetResponse.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.controller.MLModelController; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +public class MLModelControllerGetResponse extends ActionResponse implements ToXContentObject { + + @Getter + MLModelController modelController; + + @Builder + public MLModelControllerGetResponse(MLModelController modelController) { + this.modelController = modelController; + } + + public MLModelControllerGetResponse(StreamInput in) throws IOException { + super(in); + modelController = new MLModelController(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException{ + modelController.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException { + return modelController.toXContent(xContentBuilder, params); + } + + public static MLModelControllerGetResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLModelControllerGetResponse) { + return (MLModelControllerGetResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLModelControllerGetResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLModelControllerGetResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerAction.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerAction.java new file mode 100644 index 0000000000..3be1af7306 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerAction.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import org.opensearch.action.ActionType; + +// This action will only be passively called when deleting a model controller when the model is deployed. +public class MLUndeployModelControllerAction extends ActionType { + public static final MLUndeployModelControllerAction INSTANCE = new MLUndeployModelControllerAction(); + public static final String NAME = "cluster:admin/opensearch/ml/controllers/undeploy"; + + private MLUndeployModelControllerAction() { super(NAME, MLUndeployModelControllerNodesResponse::new);} +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeRequest.java new file mode 100644 index 0000000000..0cbd67891b --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeRequest.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import java.io.IOException; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + + +public class MLUndeployModelControllerNodeRequest extends TransportRequest { + @Getter + private MLUndeployModelControllerNodesRequest undeployModelControllerNodesRequest; + + public MLUndeployModelControllerNodeRequest(StreamInput in) throws IOException { + super(in); + this.undeployModelControllerNodesRequest = new MLUndeployModelControllerNodesRequest(in); + } + + public MLUndeployModelControllerNodeRequest(MLUndeployModelControllerNodesRequest request) { + this.undeployModelControllerNodesRequest = request; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + undeployModelControllerNodesRequest.writeTo(out); + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeResponse.java new file mode 100644 index 0000000000..bf4a1cb8a0 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeResponse.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import lombok.Getter; +import lombok.extern.log4j.Log4j2; +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; + +@Getter +@Log4j2 +public class MLUndeployModelControllerNodeResponse extends BaseNodeResponse implements ToXContentFragment { + private Map modelControllerUndeployStatus; + + public MLUndeployModelControllerNodeResponse(DiscoveryNode node, Map modelControllerUndeployStatus) { + super(node); + this.modelControllerUndeployStatus = modelControllerUndeployStatus; + } + + public MLUndeployModelControllerNodeResponse(StreamInput in) throws IOException { + super(in); + if (in.readBoolean()) { + this.modelControllerUndeployStatus = in.readMap(StreamInput::readString, StreamInput::readString); + } + } + + public static MLUndeployModelControllerNodeResponse readStats(StreamInput in) throws IOException { + return new MLUndeployModelControllerNodeResponse(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + + if (!isModelControllerUndeployStatusEmpty()) { + out.writeBoolean(true); + out.writeMap(modelControllerUndeployStatus, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + } + + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject("stats"); + if (!isModelControllerUndeployStatusEmpty()) { + for (Map.Entry stat : modelControllerUndeployStatus.entrySet()) { + builder.field(stat.getKey(), stat.getValue()); + } + } + builder.endObject(); + return builder; + } + + public boolean isModelControllerUndeployStatusEmpty() { + return modelControllerUndeployStatus == null || modelControllerUndeployStatus.isEmpty(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesRequest.java new file mode 100644 index 0000000000..19c638f872 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesRequest.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + package org.opensearch.ml.common.transport.controller; + +import lombok.Getter; +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import java.io.IOException; + +public class MLUndeployModelControllerNodesRequest extends BaseNodesRequest { + + @Getter + private String modelId; + + public MLUndeployModelControllerNodesRequest(StreamInput in) throws IOException { + super(in); + this.modelId = in.readString(); + } + + public MLUndeployModelControllerNodesRequest(String[] nodeIds, String modelId) { + super(nodeIds); + this.modelId = modelId; + } + + public MLUndeployModelControllerNodesRequest(DiscoveryNode[] nodeIds, String modelId) { + super(nodeIds); + this.modelId = modelId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelId); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesResponse.java new file mode 100644 index 0000000000..36f046d81f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesResponse.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; + +public class MLUndeployModelControllerNodesResponse extends BaseNodesResponse implements ToXContentObject { + + public MLUndeployModelControllerNodesResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(MLUndeployModelControllerNodeResponse::readStats), in.readList(FailedNodeException::new)); + } + + public MLUndeployModelControllerNodesResponse(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(MLUndeployModelControllerNodeResponse::readStats); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + String nodeId; + DiscoveryNode node; + builder.startObject(); + for (MLUndeployModelControllerNodeResponse deployStats : getNodes()) { + if (!deployStats.isModelControllerUndeployStatusEmpty()) { + node = deployStats.getNode(); + nodeId = node.getId(); + builder.startObject(nodeId); + deployStats.toXContent(builder, params); + builder.endObject(); + } + } + builder.endObject(); + return builder; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerAction.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerAction.java new file mode 100644 index 0000000000..7c48429765 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.controller; + +import org.opensearch.action.ActionType; +import org.opensearch.action.update.UpdateResponse; + +public class MLUpdateModelControllerAction extends ActionType { + public static final MLUpdateModelControllerAction INSTANCE = new MLUpdateModelControllerAction(); + public static final String NAME = "cluster:admin/opensearch/ml/controllers/update"; + + private MLUpdateModelControllerAction() { + super(NAME, UpdateResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerRequest.java new file mode 100644 index 0000000000..7e8abedda5 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerRequest.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.controller; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.controller.MLModelController; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString + +public class MLUpdateModelControllerRequest extends ActionRequest { + private MLModelController updateModelControllerInput; + + @Builder + public MLUpdateModelControllerRequest(MLModelController updateModelControllerInput) { + this.updateModelControllerInput = updateModelControllerInput; + } + + public MLUpdateModelControllerRequest(StreamInput in) throws IOException { + super(in); + this.updateModelControllerInput = new MLModelController(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + updateModelControllerInput.writeTo(out); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (updateModelControllerInput == null) { + exception = addValidationError("Update model controller input can't be null", exception); + } + return exception; + } + + public static MLUpdateModelControllerRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLUpdateModelControllerRequest) { + return (MLUpdateModelControllerRequest) actionRequest; + } + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLUpdateModelControllerRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse action request to MLCreateModelControllerRequest", e); + } + + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeRequest.java index 1d99f47d15..e595c58fec 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeRequest.java @@ -6,9 +6,10 @@ package org.opensearch.ml.common.transport.deploy; import lombok.Getter; -import org.opensearch.transport.TransportRequest; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; import java.io.IOException; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java index adfdd4f307..4cada116a6 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java @@ -16,6 +16,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; @@ -32,11 +33,12 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable { public static final String MODEL_VERSION_FIELD = "model_version"; // passively set when register model to a new model group public static final String MODEL_NAME_FIELD = "name"; // optional public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // optional + public static final String IS_ENABLED_FIELD = "is_enabled"; // optional + public static final String MODEL_RATE_LIMITER_CONFIG_FIELD = "model_rate_limiter_config"; // optional public static final String MODEL_CONFIG_FIELD = "model_config"; // optional - public static final String CONNECTOR_FIELD = "connector"; // optional + public static final String UPDATED_CONNECTOR_FIELD = "updated_connector"; // passively set when updating the internal connector public static final String CONNECTOR_ID_FIELD = "connector_id"; // optional - // The field CONNECTOR_UPDATE_CONTENT_FIELD need to be declared because the update of Connector class relies on the MLCreateConnectorInput class - public static final String CONNECTOR_UPDATE_CONTENT_FIELD = "connector_update_content"; + public static final String CONNECTOR_FIELD = "connector"; // optional public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; // passively set when sending update request @Getter @@ -45,25 +47,29 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable { private String version; private String name; private String modelGroupId; + private Boolean isEnabled; + private MLRateLimiter modelRateLimiterConfig; private MLModelConfig modelConfig; - private Connector connector; + private Connector updatedConnector; private String connectorId; - private MLCreateConnectorInput connectorUpdateContent; + private MLCreateConnectorInput connector; private Instant lastUpdateTime; @Builder(toBuilder = true) public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, - MLModelConfig modelConfig, Connector connector, String connectorId, - MLCreateConnectorInput connectorUpdateContent, Instant lastUpdateTime) { + Boolean isEnabled, MLRateLimiter modelRateLimiterConfig, MLModelConfig modelConfig, + Connector updatedConnector, String connectorId, MLCreateConnectorInput connector, Instant lastUpdateTime) { this.modelId = modelId; this.description = description; this.version = version; this.name = name; this.modelGroupId = modelGroupId; + this.isEnabled = isEnabled; + this.modelRateLimiterConfig = modelRateLimiterConfig; this.modelConfig = modelConfig; - this.connector = connector; + this.updatedConnector = updatedConnector; this.connectorId = connectorId; - this.connectorUpdateContent = connectorUpdateContent; + this.connector = connector; this.lastUpdateTime = lastUpdateTime; } @@ -73,15 +79,19 @@ public MLUpdateModelInput(StreamInput in) throws IOException { version = in.readOptionalString(); name = in.readOptionalString(); modelGroupId = in.readOptionalString(); + isEnabled = in.readOptionalBoolean(); + if (in.readBoolean()) { + modelRateLimiterConfig = new MLRateLimiter(in); + } if (in.readBoolean()) { modelConfig = new TextEmbeddingModelConfig(in); } if (in.readBoolean()) { - connector = Connector.fromStream(in); + updatedConnector = Connector.fromStream(in); } connectorId = in.readOptionalString(); if (in.readBoolean()) { - connectorUpdateContent = new MLCreateConnectorInput(in); + connector = new MLCreateConnectorInput(in); } lastUpdateTime = in.readOptionalInstant(); } @@ -102,17 +112,23 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (modelGroupId != null) { builder.field(MODEL_GROUP_ID_FIELD, modelGroupId); } + if (isEnabled != null) { + builder.field(IS_ENABLED_FIELD, isEnabled); + } + if (modelRateLimiterConfig != null) { + builder.field(MODEL_RATE_LIMITER_CONFIG_FIELD, modelRateLimiterConfig); + } if (modelConfig != null) { builder.field(MODEL_CONFIG_FIELD, modelConfig); } - if (connector != null) { - builder.field(CONNECTOR_FIELD, connector); + if (updatedConnector != null) { + builder.field(UPDATED_CONNECTOR_FIELD, updatedConnector); } if (connectorId != null) { builder.field(CONNECTOR_ID_FIELD, connectorId); } - if (connectorUpdateContent != null) { - builder.field(CONNECTOR_UPDATE_CONTENT_FIELD, connectorUpdateContent); + if (connector != null) { + builder.field(CONNECTOR_FIELD, connector); } if (lastUpdateTime != null) { builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli()); @@ -128,22 +144,29 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(version); out.writeOptionalString(name); out.writeOptionalString(modelGroupId); + out.writeOptionalBoolean(isEnabled); + if (modelRateLimiterConfig != null) { + out.writeBoolean(true); + modelRateLimiterConfig.writeTo(out); + } else { + out.writeBoolean(false); + } if (modelConfig != null) { out.writeBoolean(true); modelConfig.writeTo(out); } else { out.writeBoolean(false); } - if (connector != null) { + if (updatedConnector != null) { out.writeBoolean(true); - connector.writeTo(out); + updatedConnector.writeTo(out); } else { out.writeBoolean(false); } out.writeOptionalString(connectorId); - if (connectorUpdateContent != null) { + if (connector != null) { out.writeBoolean(true); - connectorUpdateContent.writeTo(out); + connector.writeTo(out); } else { out.writeBoolean(false); } @@ -156,10 +179,12 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException String version = null; String name = null; String modelGroupId = null; + Boolean isEnabled = null; + MLRateLimiter modelRateLimiterConfig = null; MLModelConfig modelConfig = null; - Connector connector = null; + Connector updatedConnector = null; String connectorId = null; - MLCreateConnectorInput connectorUpdateContent = null; + MLCreateConnectorInput connector = null; Instant lastUpdateTime = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); @@ -167,35 +192,29 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException String fieldName = parser.currentName(); parser.nextToken(); switch (fieldName) { - case MODEL_ID_FIELD: - modelId = parser.text(); - break; case DESCRIPTION_FIELD: description = parser.text(); break; case MODEL_NAME_FIELD: name = parser.text(); break; - case MODEL_VERSION_FIELD: - version = parser.text(); - break; case MODEL_GROUP_ID_FIELD: modelGroupId = parser.text(); break; + case IS_ENABLED_FIELD: + isEnabled = parser.booleanValue(); + break; + case MODEL_RATE_LIMITER_CONFIG_FIELD: + modelRateLimiterConfig = MLRateLimiter.parse(parser); + break; case MODEL_CONFIG_FIELD: modelConfig = TextEmbeddingModelConfig.parse(parser); break; - case CONNECTOR_FIELD: - connector = Connector.createConnector(parser); - break; case CONNECTOR_ID_FIELD: connectorId = parser.text(); break; - case CONNECTOR_UPDATE_CONTENT_FIELD: - connectorUpdateContent = MLCreateConnectorInput.parse(parser, true); - break; - case LAST_UPDATED_TIME_FIELD: - lastUpdateTime = Instant.ofEpochMilli(parser.longValue()); + case CONNECTOR_FIELD: + connector = MLCreateConnectorInput.parse(parser, true); break; default: parser.skipChildren(); @@ -203,6 +222,6 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException } } // Model ID can only be set through RestRequest. Model version can only be set automatically. - return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, modelConfig, connector, connectorId, connectorUpdateContent, lastUpdateTime); + return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, isEnabled, modelRateLimiterConfig, modelConfig, updatedConnector, connectorId, connector, lastUpdateTime); } } \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java index 3bf3dabd03..d3394191e0 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java @@ -69,7 +69,7 @@ public static MLRegisterModelGroupRequest fromActionRequest(ActionRequest action return new MLRegisterModelGroupRequest(input); } } catch (IOException e) { - throw new UncheckedIOException("Failed to parse ActionRequest into MLCreateModelMetaRequest", e); + throw new UncheckedIOException("Failed to parse ActionRequest into MLRegisterModelGroupRequest", e); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index e8866bc8e4..ee5c89f1da 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -18,6 +18,7 @@ import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; @@ -29,7 +30,6 @@ import java.util.Objects; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - import static org.opensearch.ml.common.connector.Connector.createConnector; /** @@ -43,6 +43,8 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; public static final String DESCRIPTION_FIELD = "description"; public static final String VERSION_FIELD = "version"; + public static final String IS_ENABLED_FIELD = "is_enabled"; + public static final String MODEL_RATE_LIMITER_CONFIG_FIELD = "model_rate_limiter_config"; public static final String URL_FIELD = "url"; public static final String MODEL_FORMAT_FIELD = "model_format"; public static final String MODEL_CONFIG_FIELD = "model_config"; @@ -60,6 +62,8 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { private String modelGroupId; private String version; private String description; + private Boolean isEnabled; + private MLRateLimiter modelRateLimiterConfig; private String url; private String hashValue; private MLModelFormat modelFormat; @@ -84,6 +88,8 @@ public MLRegisterModelInput(FunctionName functionName, String modelGroupId, String version, String description, + Boolean isEnabled, + MLRateLimiter modelRateLimiterConfig, String url, String hashValue, MLModelFormat modelFormat, @@ -114,6 +120,8 @@ public MLRegisterModelInput(FunctionName functionName, this.modelGroupId = modelGroupId; this.version = version; this.description = description; + this.isEnabled = isEnabled; + this.modelRateLimiterConfig = modelRateLimiterConfig; this.url = url; this.hashValue = hashValue; this.modelFormat = modelFormat; @@ -136,6 +144,10 @@ public MLRegisterModelInput(StreamInput in) throws IOException { this.modelGroupId = in.readOptionalString(); this.version = in.readOptionalString(); this.description = in.readOptionalString(); + this.isEnabled = in.readOptionalBoolean(); + if (in.readBoolean()) { + this.modelRateLimiterConfig = new MLRateLimiter(in); + } this.url = in.readOptionalString(); this.hashValue = in.readOptionalString(); if (in.readBoolean()) { @@ -172,6 +184,13 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(modelGroupId); out.writeOptionalString(version); out.writeOptionalString(description); + out.writeOptionalBoolean(isEnabled); + if (modelRateLimiterConfig != null) { + out.writeBoolean(true); + modelRateLimiterConfig.writeTo(out); + } else { + out.writeBoolean(false); + } out.writeOptionalString(url); out.writeOptionalString(hashValue); if (modelFormat != null) { @@ -226,6 +245,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (description != null) { builder.field(DESCRIPTION_FIELD, description); } + if (isEnabled != null) { + builder.field(IS_ENABLED_FIELD, isEnabled); + } + if (modelRateLimiterConfig != null) { + builder.field(MODEL_RATE_LIMITER_CONFIG_FIELD, modelRateLimiterConfig); + } if (url != null) { builder.field(URL_FIELD, url); } @@ -270,6 +295,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public static MLRegisterModelInput parse(XContentParser parser, String modelName, String version, boolean deployModel) throws IOException { FunctionName functionName = null; String modelGroupId = null; + Boolean isEnabled = null; + MLRateLimiter modelRateLimiterConfig = null; String url = null; String hashValue = null; String description = null; @@ -295,6 +322,12 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName case MODEL_GROUP_ID_FIELD: modelGroupId = parser.text(); break; + case IS_ENABLED_FIELD: + isEnabled = parser.booleanValue(); + break; + case MODEL_RATE_LIMITER_CONFIG_FIELD: + modelRateLimiterConfig = MLRateLimiter.parse(parser); + break; case URL_FIELD: url = parser.text(); break; @@ -345,8 +378,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName break; } } - return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, isHidden); - + return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, isEnabled, modelRateLimiterConfig, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, isHidden); } public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException { @@ -354,6 +386,8 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo String name = null; String modelGroupId = null; String version = null; + Boolean isEnabled = null; + MLRateLimiter modelRateLimiterConfig = null; String url = null; String hashValue = null; String description = null; @@ -389,6 +423,12 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo case DESCRIPTION_FIELD: description = parser.text(); break; + case IS_ENABLED_FIELD: + isEnabled = parser.booleanValue(); + break; + case MODEL_RATE_LIMITER_CONFIG_FIELD: + modelRateLimiterConfig = MLRateLimiter.parse(parser); + break; case URL_FIELD: url = parser.text(); break; @@ -436,6 +476,6 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo break; } } - return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, isHidden); + return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, isEnabled, modelRateLimiterConfig, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, isHidden); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequest.java index 1158dcc843..4d19e14ad3 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequest.java @@ -6,9 +6,9 @@ package org.opensearch.ml.common.transport.sync; import lombok.Getter; -import org.opensearch.transport.TransportRequest; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; import java.io.IOException; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheAction.java b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheAction.java index ef9d6ec063..8cccab63d6 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheAction.java @@ -9,7 +9,7 @@ public class MLUpdateModelCacheAction extends ActionType { public static final MLUpdateModelCacheAction INSTANCE = new MLUpdateModelCacheAction(); - public static final String NAME = "cluster:admin/opensearch/ml/models/update_model_cache"; + public static final String NAME = "cluster:admin/opensearch/ml/models/update_cache"; private MLUpdateModelCacheAction() { super(NAME, MLUpdateModelCacheNodesResponse::new);} } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponse.java index 35a642c33c..2c9abd5a5c 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponse.java @@ -63,6 +63,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } public boolean isModelUpdateStatusEmpty() { - return modelUpdateStatus == null || modelUpdateStatus.size() == 0; + return modelUpdateStatus == null || modelUpdateStatus.isEmpty(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequest.java index 566b838632..a6d77abe3a 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequest.java @@ -16,31 +16,25 @@ public class MLUpdateModelCacheNodesRequest extends BaseNodesRequest backendRoles, + public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, String description, Boolean isEnabled, MLRateLimiter modelRateLimiterConfig, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks, List backendRoles, AccessMode accessMode, Boolean isAddAllBackendRoles, Boolean doesVersionCreateModelGroup, Boolean isHidden) { @@ -80,7 +83,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m throw new IllegalArgumentException("model name is null"); } if (functionName == null) { - this.functionName = functionName.TEXT_EMBEDDING; + this.functionName = FunctionName.TEXT_EMBEDDING; } else { this.functionName = functionName; } @@ -100,6 +103,8 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m this.modelGroupId = modelGroupId; this.version = version; this.description = description; + this.isEnabled = isEnabled; + this.modelRateLimiterConfig = modelRateLimiterConfig; this.modelFormat = modelFormat; this.modelState = modelState; this.modelContentSizeInBytes = modelContentSizeInBytes; @@ -119,6 +124,10 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException{ this.modelGroupId = in.readOptionalString(); this.version = in.readOptionalString(); this.description = in.readOptionalString(); + this.isEnabled = in.readOptionalBoolean(); + if (in.readBoolean()) { + modelRateLimiterConfig = new MLRateLimiter(in); + } if (in.readBoolean()) { modelFormat = in.readEnum(MLModelFormat.class); } @@ -147,6 +156,13 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(modelGroupId); out.writeOptionalString(version); out.writeOptionalString(description); + out.writeOptionalBoolean(isEnabled); + if (modelRateLimiterConfig != null) { + out.writeBoolean(true); + modelRateLimiterConfig.writeTo(out); + } else { + out.writeBoolean(false); + } if (modelFormat != null) { out.writeBoolean(true); out.writeEnum(modelFormat); @@ -199,6 +215,12 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (description != null) { builder.field(MLModel.DESCRIPTION_FIELD, description); } + if (isEnabled != null) { + builder.field(IS_ENABLED_FIELD, isEnabled); + } + if (modelRateLimiterConfig != null) { + builder.field(MODEL_RATE_LIMITER_CONFIG_FIELD, modelRateLimiterConfig); + } builder.field(MODEL_FORMAT_FIELD, modelFormat); if (modelState != null) { builder.field(MODEL_STATE_FIELD, modelState); @@ -234,6 +256,8 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc String modelGroupId = null; String version = null; String description = null; + Boolean isEnabled = null; + MLRateLimiter modelRateLimiterConfig = null; MLModelFormat modelFormat = null; MLModelState modelState = null; Long modelContentSizeInBytes = null; @@ -266,6 +290,12 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc case DESCRIPTION_FIELD: description = parser.text(); break; + case IS_ENABLED_FIELD: + isEnabled = parser.booleanValue(); + break; + case MODEL_RATE_LIMITER_CONFIG_FIELD: + modelRateLimiterConfig = MLRateLimiter.parse(parser); + break; case MODEL_FORMAT_FIELD: modelFormat = MLModelFormat.from(parser.text()); break; @@ -308,7 +338,7 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc break; } } - return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup, isHidden); + return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, isEnabled, modelRateLimiterConfig, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup, isHidden); } } diff --git a/common/src/test/java/org/opensearch/ml/common/controller/MLModelControllerTest.java b/common/src/test/java/org/opensearch/ml/common/controller/MLModelControllerTest.java new file mode 100644 index 0000000000..e9283c4958 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/controller/MLModelControllerTest.java @@ -0,0 +1,329 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.controller; + + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; + +public class MLModelControllerTest { + private MLRateLimiter rateLimiter; + + private MLModelController modelController; + + private MLModelController modelControllerNull; + + private final String expectedInputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":" + + "{\"testUser\":{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"}}}"; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() throws Exception { + rateLimiter = MLRateLimiter.builder() + .rateLimitNumber("1") + .rateLimitUnit(TimeUnit.MILLISECONDS) + .build(); + + modelControllerNull = MLModelController.builder() + .modelId("testModelId").build(); + + modelController = MLModelControllerGenerator("testUser", rateLimiter); + + } + + @Test + public void readInputStreamSuccess() throws IOException { + readInputStream(modelController, parsedInput -> { + assertEquals("testModelId", parsedInput.getModelId()); + assertEquals(modelController.getUserRateLimiterConfig().get("testUser").getRateLimitNumber(), + parsedInput.getUserRateLimiterConfig().get("testUser").getRateLimitNumber()); + }); + } + + @Test + public void readInputStreamSuccessWithNullFields() throws IOException { + modelController.setUserRateLimiterConfig(null); + readInputStream(modelController, parsedInput -> { + assertNull(parsedInput.getUserRateLimiterConfig()); + }); + } + + @Test + public void testToXContent() throws Exception { + String jsonStr = serializationWithToXContent(modelController); + assertEquals(expectedInputStr, jsonStr); + } + + + @Test + public void testToXContentIncomplete() throws Exception { + final String expectedIncompleteInputStr = + "{\"model_id\":\"testModelId\"}"; + String jsonStr = serializationWithToXContent(modelControllerNull); + assertEquals(expectedIncompleteInputStr, jsonStr); + } + + @Test + public void testToXContentWithNullMLRateLimiterInUserRateLimiterConfig() throws Exception { + // Notice that MLModelController will throw an exception if it parses this output string, check parseWithNullMLRateLimiterInUserRateLimiterConfigFieldWithException test below. + final String expectedOutputStrWithNullField = + "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":{\"testUser\":null}}"; + MLModelController modelControllerWithTestUserAndEmptyRateLimiter = MLModelController.builder() + .modelId("testModelId") + .userRateLimiterConfig(new HashMap<>(){{put("testUser", null);}}) + .build(); + String jsonStr = serializationWithToXContent(modelControllerWithTestUserAndEmptyRateLimiter); + assertEquals(expectedOutputStrWithNullField, jsonStr); + } + + @Test + public void parseSuccess() throws Exception { + testParseFromJsonString(expectedInputStr, parsedInput -> assertEquals("testModelId", parsedInput.getModelId())); + } + + @Test + // Notice that this won't throw an IllegalStateException, which is pretty different from usual + public void parseWithoutUserRateLimiterConfigFieldWithNoException() throws Exception { + final String expectedIncompleteInputStr = "{\"model_id\":\"testModelId\"}"; + final String expectedOutputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":{}}"; + + testParseFromJsonString(expectedIncompleteInputStr, parsedInput -> { + try { + assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + // Notice that this won't throw an IllegalStateException, which is pretty different from usual + public void parseWithNullUserRateLimiterConfigFieldWithNoException() throws Exception { + final String expectedInputStrWithNullField = "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":null}"; + final String expectedOutputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":{}}"; + + testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { + try { + assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + // Notice that this won't throw an IllegalStateException, which is pretty different from usual + public void parseWithTestUserAndEmptyRateLimiterFieldWithNoException() throws Exception { + final String expectedInputStrWithEmptyField = "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":" + + "{\"testUser\":{}}}"; + final String expectedOutputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":" + + "{}}"; + testParseFromJsonString(expectedInputStrWithEmptyField, parsedInput -> { + try { + assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void parseWithNullField() throws Exception { + exceptionRule.expect(IllegalStateException.class); + final String expectedInputStrWithNullField = "{\"model_id\":null,\"user_rate_limiter_config\":" + + "{\"testUser\":{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"}}}"; + + testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { + try { + assertEquals(expectedInputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void parseWithIllegalField() throws Exception { + final String expectedInputStrWithIllegalField = "{\"model_id\":\"testModelId\",\"illegal_field\":\"This field need to be skipped.\",\"user_rate_limiter_config\":" + + "{\"testUser\":{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"}}}"; + + testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { + try { + assertEquals(expectedInputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + // This will throw a ParsingException because MLRateLimiter parser cannot parse null field. + public void parseWithNullMLRateLimiterInUserRateLimiterConfigFieldWithException() throws Exception { + exceptionRule.expect(RuntimeException.class); + final String expectedInputStrWithNullField = "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":{\"testUser\":null}}"; + final String expectedOutputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":{\"testUser\":null}}"; + + testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { + try { + assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void parseWithIllegalRateLimiterFieldWithException() throws Exception { + exceptionRule.expect(RuntimeException.class); + final String expectedInputStrWithIllegalField = "{\"model_id\":\"testModelId\",\"illegal_field\":\"This field need to be skipped.\",\"user_rate_limiter_config\":" + + "{\"testUser\":\"Some illegal content that MLRateLimiter parser cannot parse.\"}}"; + + testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { + try { + assertEquals(expectedInputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void testUserRateLimiterConfigUpdate() { + MLRateLimiter rateLimiterWithNumber = MLRateLimiter.builder().rateLimitNumber("1").build(); + + MLModelController modelControllerWithEmptyUserRateLimiterConfig = MLModelControllerGenerator(); + MLModelController modelControllerWithTestUserAndRateLimiterWithNumber = MLModelControllerGenerator("testUser", rateLimiterWithNumber); + MLModelController modelControllerWithNewUserAndEmptyRateLimiter = MLModelControllerGenerator("newUser"); + + modelControllerWithEmptyUserRateLimiterConfig.update(modelControllerNull); + assertTrue(modelControllerWithEmptyUserRateLimiterConfig.getUserRateLimiterConfig().isEmpty()); + + modelControllerWithEmptyUserRateLimiterConfig.update(modelControllerWithEmptyUserRateLimiterConfig); + assertTrue(modelControllerWithEmptyUserRateLimiterConfig.getUserRateLimiterConfig().isEmpty()); + + modelControllerWithEmptyUserRateLimiterConfig.update(modelControllerWithTestUserAndRateLimiterWithNumber); + assertEquals("1", modelControllerWithEmptyUserRateLimiterConfig.getUserRateLimiterConfig().get("testUser").getRateLimitNumber()); + assertNull(modelControllerWithEmptyUserRateLimiterConfig.getUserRateLimiterConfig().get("testUser").getRateLimitUnit()); + + modelControllerWithEmptyUserRateLimiterConfig.update(modelController); + assertEquals("1", modelControllerWithEmptyUserRateLimiterConfig.getUserRateLimiterConfig().get("testUser").getRateLimitNumber()); + assertEquals(TimeUnit.MILLISECONDS, modelControllerWithEmptyUserRateLimiterConfig.getUserRateLimiterConfig().get("testUser").getRateLimitUnit()); + + modelControllerWithEmptyUserRateLimiterConfig.update(modelControllerWithNewUserAndEmptyRateLimiter); + assertTrue(modelControllerWithEmptyUserRateLimiterConfig.getUserRateLimiterConfig().get("newUser").isEmpty()); + } + + @Test + public void testUserRateLimiterConfigIsUpdatable() { + MLRateLimiter rateLimiterWithNumber = MLRateLimiter.builder().rateLimitNumber("1").build(); + + MLModelController modelControllerWithEmptyUserRateLimiterConfig = MLModelControllerGenerator(); + MLModelController modelControllerWithTestUserAndRateLimiterWithNumber = MLModelControllerGenerator("testUser", rateLimiterWithNumber); + MLModelController modelControllerWithNewUserAndRateLimiterWithNumber = MLModelControllerGenerator("newUser", rateLimiterWithNumber); + MLModelController modelControllerWithNewUserAndEmptyRateLimiter = MLModelControllerGenerator("newUser"); + MLModelController modelControllerWithNewUserAndRateLimiter = MLModelControllerGenerator("newUser", rateLimiter); + + assertFalse(modelControllerWithEmptyUserRateLimiterConfig.isDeployRequiredAfterUpdate(null)); + assertFalse(modelControllerWithEmptyUserRateLimiterConfig.isDeployRequiredAfterUpdate(modelControllerNull)); + assertFalse(modelControllerWithEmptyUserRateLimiterConfig.isDeployRequiredAfterUpdate(modelControllerWithEmptyUserRateLimiterConfig)); + assertFalse(modelControllerWithEmptyUserRateLimiterConfig.isDeployRequiredAfterUpdate(modelControllerWithNewUserAndEmptyRateLimiter)); + + assertFalse(modelControllerWithEmptyUserRateLimiterConfig.isDeployRequiredAfterUpdate(modelControllerWithTestUserAndRateLimiterWithNumber)); + assertFalse(modelControllerWithTestUserAndRateLimiterWithNumber.isDeployRequiredAfterUpdate(modelControllerWithTestUserAndRateLimiterWithNumber)); + assertTrue(modelControllerWithEmptyUserRateLimiterConfig.isDeployRequiredAfterUpdate(modelController)); + assertTrue(modelControllerWithTestUserAndRateLimiterWithNumber.isDeployRequiredAfterUpdate(modelController)); + + assertFalse(modelControllerWithTestUserAndRateLimiterWithNumber.isDeployRequiredAfterUpdate(modelControllerWithNewUserAndRateLimiterWithNumber)); + assertTrue(modelControllerWithTestUserAndRateLimiterWithNumber.isDeployRequiredAfterUpdate(modelControllerWithNewUserAndRateLimiter)); + } + + private void testParseFromJsonString(String expectedInputStr, Consumer verify) throws Exception { + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); + parser.nextToken(); + MLModelController parsedInput = MLModelController.parse(parser); + verify.accept(parsedInput); + } + + private void readInputStream(MLModelController input, Consumer verify) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + input.writeTo(bytesStreamOutput); + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLModelController parsedInput = new MLModelController(streamInput); + verify.accept(parsedInput); + } + + private String serializationWithToXContent(MLModelController input) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + return builder.toString(); + } + + private MLModelController MLModelControllerGenerator(String user, MLRateLimiter rateLimiter) { + return MLModelController.builder() + .modelId("testModelId") + .userRateLimiterConfig(new HashMap<>(){{put(user, rateLimiter);}}) + .build(); + + } + + private MLModelController MLModelControllerGenerator(String user) { + return MLModelController.builder() + .modelId("testModelId") + .userRateLimiterConfig(new HashMap<>(){{put(user, MLRateLimiter.builder().build());}}) + .build(); + + } + + private MLModelController MLModelControllerGenerator() { + return MLModelController.builder() + .modelId("testModelId") + .userRateLimiterConfig(new HashMap<>()) + .build(); + + } + + @Ignore + @Test + public void testRateLimiterRemove() { + MLModelController modelControllerWithTestUserAndEmptyRateLimiter = MLModelController.builder() + .modelId("testModelId") + .userRateLimiterConfig(new HashMap<>(){{put("testUser", MLRateLimiter.builder().build());}}) + .build(); + + modelController.update(modelControllerWithTestUserAndEmptyRateLimiter); + assertNull(modelController.getUserRateLimiterConfig().get("testUser")); + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/controller/MLRateLimiterTest.java b/common/src/test/java/org/opensearch/ml/common/controller/MLRateLimiterTest.java new file mode 100644 index 0000000000..7cc85974ab --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/controller/MLRateLimiterTest.java @@ -0,0 +1,241 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.controller; + + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; + +public class MLRateLimiterTest { + + private MLRateLimiter rateLimiter; + + private MLRateLimiter rateLimiterWithNumber; + + private MLRateLimiter rateLimiterWithUnit; + + private MLRateLimiter rateLimiterNull; + + private final String expectedInputStr = "{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"}"; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() throws Exception { + rateLimiter = MLRateLimiter.builder() + .rateLimitNumber("1") + .rateLimitUnit(TimeUnit.MILLISECONDS) + .build(); + rateLimiterWithNumber = MLRateLimiter.builder() + .rateLimitNumber("1") + .build(); + + rateLimiterWithUnit = MLRateLimiter.builder() + .rateLimitUnit(TimeUnit.MILLISECONDS) + .build(); + + rateLimiterNull = MLRateLimiter.builder().build(); + + } + + @Test + public void readInputStreamSuccess() throws IOException { + readInputStream(rateLimiter, parsedInput -> { + assertEquals("1", parsedInput.getRateLimitNumber()); + assertEquals(TimeUnit.MILLISECONDS, parsedInput.getRateLimitUnit()); + }); + } + + @Test + public void readInputStreamSuccessWithNullFields() throws IOException { + readInputStream(rateLimiterWithNumber, parsedInput -> { + assertNull(parsedInput.getRateLimitUnit()); + }); + } + + @Test + public void testToXContent() throws Exception { + String jsonStr = serializationWithToXContent(rateLimiter); + assertEquals(expectedInputStr, jsonStr); + } + + @Test + public void testToXContentIncomplete() throws Exception { + final String expectedIncompleteInputStr = "{}"; + + String jsonStr = serializationWithToXContent(rateLimiterNull); + assertEquals(expectedIncompleteInputStr, jsonStr); + } + + @Test + public void parseSuccess() throws Exception { + testParseFromJsonString(expectedInputStr, parsedInput -> { + assertEquals("1", parsedInput.getRateLimitNumber()); + assertEquals(TimeUnit.MILLISECONDS, parsedInput.getRateLimitUnit()); + }); + } + + @Test + public void parseWithNullField() throws Exception { + exceptionRule.expect(IllegalStateException.class); + final String expectedInputStrWithNullField = "{\"rate_limit_number\":\"1\",\"rate_limit_unit\":null}"; + + testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { + try { + assertEquals(expectedInputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void parseWithIllegalField() throws Exception { + final String expectedInputStrWithIllegalField = "{\"rate_limit_number\":\"1\",\"rate_limit_unit\":" + + "\"MILLISECONDS\",\"illegal_field\":\"This field need to be skipped.\"}"; + + testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { + try { + assertEquals(expectedInputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void testisValid() { + assertTrue(rateLimiter.isValid()); + assertFalse(rateLimiterWithNumber.isValid()); + assertFalse(rateLimiterWithUnit.isValid()); + assertFalse(rateLimiterNull.isValid()); + } + + @Test + public void testIsRateLimiterRemovable() { + assertFalse(rateLimiter.isEmpty()); + assertFalse(rateLimiterWithNumber.isEmpty()); + assertFalse(rateLimiterWithUnit.isEmpty()); + assertTrue(rateLimiterNull.isEmpty()); + } + + @Test + public void testRateLimiterUpdate() { + MLRateLimiter updatedRateLimiter = MLRateLimiter.update(rateLimiterNull, rateLimiter); + assertEquals("1", updatedRateLimiter.getRateLimitNumber()); + assertEquals(TimeUnit.MILLISECONDS, updatedRateLimiter.getRateLimitUnit()); + } + + @Test + public void testRateLimiterPartiallyUpdate() { + rateLimiterNull.update(rateLimiterWithNumber); + assertEquals("1", rateLimiterNull.getRateLimitNumber()); + assertNull(rateLimiterNull.getRateLimitUnit()); + rateLimiterNull.update(rateLimiterWithUnit); + assertEquals("1", rateLimiterNull.getRateLimitNumber()); + assertEquals(TimeUnit.MILLISECONDS, rateLimiterNull.getRateLimitUnit()); + } + + @Test + public void testRateLimiterUpdateNull() { + MLRateLimiter updatedRateLimiter = MLRateLimiter.update(null, rateLimiter); + assertEquals("1", updatedRateLimiter.getRateLimitNumber()); + assertEquals(TimeUnit.MILLISECONDS, updatedRateLimiter.getRateLimitUnit()); + } + + @Test + public void testRateLimiterIsUpdatable() { + assertFalse(MLRateLimiter.updateValidityPreCheck(rateLimiter, null)); + assertFalse(MLRateLimiter.updateValidityPreCheck(rateLimiter, rateLimiterNull)); + + assertTrue(MLRateLimiter.updateValidityPreCheck(null, rateLimiter)); + assertTrue(MLRateLimiter.updateValidityPreCheck(rateLimiterNull, rateLimiter)); + + assertTrue(MLRateLimiter.updateValidityPreCheck(rateLimiterWithUnit, rateLimiterWithNumber)); + assertTrue(MLRateLimiter.updateValidityPreCheck(rateLimiterWithUnit, rateLimiter)); + assertTrue(MLRateLimiter.updateValidityPreCheck(rateLimiterWithNumber, rateLimiter)); + assertFalse(MLRateLimiter.updateValidityPreCheck(rateLimiter, rateLimiter)); + + assertFalse(MLRateLimiter.updateValidityPreCheck(rateLimiter, rateLimiterWithUnit)); + assertFalse(MLRateLimiter.updateValidityPreCheck(rateLimiter, rateLimiterWithNumber)); + } + + @Test + public void testRateLimiterIsDeployRequiredAfterUpdate() { + MLRateLimiter rateLimiterWithNumber2 = MLRateLimiter.builder() + .rateLimitNumber("2") + .build(); + + MLRateLimiter rateLimiterWithUnit2 = MLRateLimiter.builder() + .rateLimitUnit(TimeUnit.NANOSECONDS) + .build(); + + assertTrue(MLRateLimiter.isDeployRequiredAfterUpdate(rateLimiter, rateLimiterWithNumber2)); + + assertTrue(MLRateLimiter.isDeployRequiredAfterUpdate(rateLimiterNull, rateLimiter)); + + assertTrue(MLRateLimiter.isDeployRequiredAfterUpdate(rateLimiterWithUnit, rateLimiterWithNumber)); + assertTrue(MLRateLimiter.isDeployRequiredAfterUpdate(rateLimiterWithNumber, rateLimiterWithUnit)); + assertFalse(MLRateLimiter.isDeployRequiredAfterUpdate(rateLimiterWithUnit, rateLimiterWithUnit2)); + assertFalse(MLRateLimiter.isDeployRequiredAfterUpdate(rateLimiterWithNumber, rateLimiterWithNumber2)); + } + + private void testParseFromJsonString(String expectedInputStr, Consumer verify) throws Exception { + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); + parser.nextToken(); + MLRateLimiter parsedInput = MLRateLimiter.parse(parser); + verify.accept(parsedInput); + } + + private void readInputStream(MLRateLimiter input, Consumer verify) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + input.writeTo(bytesStreamOutput); + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLRateLimiter parsedInput = new MLRateLimiter(streamInput); + verify.accept(parsedInput); + } + + private String serializationWithToXContent(MLRateLimiter input) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + return builder.toString(); + } + + @Ignore + @Test + public void testRateLimiterRemove() { + MLRateLimiter updatedRateLimiter = MLRateLimiter.update(rateLimiter, rateLimiterNull); + assertNull(updatedRateLimiter.getRateLimitUnit()); + assertNull(updatedRateLimiter.getRateLimitNumber()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java index d404f49ab4..1cff34b131 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java @@ -25,28 +25,28 @@ public class MLConnectorDeleteRequestTests { @Before public void setUp() { - connectorId = "test-connector-id"; + connectorId = "testConnectorId"; } @Test - public void writeTo_Success() throws IOException { + public void writeToSuccess() throws IOException { MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder() .connectorId(connectorId).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); mlConnectorDeleteRequest.writeTo(bytesStreamOutput); - MLConnectorDeleteRequest parsedConnector = new MLConnectorDeleteRequest(bytesStreamOutput.bytes().streamInput()); - assertEquals(parsedConnector.getConnectorId(), connectorId); + MLConnectorDeleteRequest parsedRequest = new MLConnectorDeleteRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(parsedRequest.getConnectorId(), connectorId); } @Test - public void valid_Exception_NullConnectorId() { + public void validWithNullConnectorIdException() { MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder().build(); ActionRequestValidationException exception = mlConnectorDeleteRequest.validate(); assertEquals("Validation Failed: 1: ML connector id can't be null;", exception.getMessage()); } @Test - public void validate_Success() { + public void validateSuccess() { MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder() .connectorId(connectorId).build(); ActionRequestValidationException actionRequestValidationException = mlConnectorDeleteRequest.validate(); @@ -54,7 +54,7 @@ public void validate_Success() { } @Test - public void fromActionRequest_Success() { + public void fromActionRequestSuccess() { MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder() .connectorId(connectorId).build(); ActionRequest actionRequest = new ActionRequest() { @@ -68,13 +68,13 @@ public void writeTo(StreamOutput out) throws IOException { mlConnectorDeleteRequest.writeTo(out); } }; - MLConnectorDeleteRequest parsedConnector = MLConnectorDeleteRequest.fromActionRequest(actionRequest); - assertNotSame(parsedConnector, mlConnectorDeleteRequest); - assertEquals(parsedConnector.getConnectorId(), connectorId); + MLConnectorDeleteRequest parsedRequest = MLConnectorDeleteRequest.fromActionRequest(actionRequest); + assertNotSame(parsedRequest, mlConnectorDeleteRequest); + assertEquals(parsedRequest.getConnectorId(), connectorId); } @Test(expected = UncheckedIOException.class) - public void fromActionRequest_IOException() { + public void fromActionRequestIOException() { ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -90,7 +90,7 @@ public void writeTo(StreamOutput out) throws IOException { } @Test - public void fromActionRequestWithConnectorDeleteRequest_Success() { + public void fromActionRequestWithConnectorDeleteRequestSuccess() { MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder() .connectorId(connectorId).build(); MLConnectorDeleteRequest mlConnectorDeleteRequestFromActionRequest = MLConnectorDeleteRequest.fromActionRequest(mlConnectorDeleteRequest); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequestTests.java index 53fcce560b..0b663f4cc9 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequestTests.java @@ -26,20 +26,20 @@ public class MLConnectorGetRequestTests { @Before public void setUp() { - connectorId = "test-connector-id"; + connectorId = "testConnectorId"; } @Test - public void writeTo_Success() throws IOException { + public void writeToSuccess() throws IOException { MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(connectorId).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); mlConnectorGetRequest.writeTo(bytesStreamOutput); - MLConnectorGetRequest parsedConnector = new MLConnectorGetRequest(bytesStreamOutput.bytes().streamInput()); - assertEquals(connectorId, parsedConnector.getConnectorId()); + MLConnectorGetRequest parsedRequest = new MLConnectorGetRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(connectorId, parsedRequest.getConnectorId()); } @Test - public void fromActionRequest_Success() { + public void fromActionRequestSuccess() { MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(connectorId).build(); ActionRequest actionRequest = new ActionRequest() { @Override @@ -58,7 +58,7 @@ public void writeTo(StreamOutput out) throws IOException { } @Test(expected = UncheckedIOException.class) - public void fromActionRequest_IOException() { + public void fromActionRequestIOException() { ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -74,7 +74,7 @@ public void writeTo(StreamOutput out) throws IOException { } @Test - public void fromActionRequestWithMLConnectorGetRequest_Success() { + public void fromActionRequestWithMLConnectorGetRequestSuccess() { MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(connectorId).build(); MLConnectorGetRequest mlConnectorGetRequestFromActionRequest = MLConnectorGetRequest.fromActionRequest(mlConnectorGetRequest); assertSame(mlConnectorGetRequest, mlConnectorGetRequestFromActionRequest); @@ -82,14 +82,14 @@ public void fromActionRequestWithMLConnectorGetRequest_Success() { } @Test - public void validate_Exception_NullConnctorId() { + public void validateWithNullConnectorIdException() { MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.builder().build(); ActionRequestValidationException actionRequestValidationException = mlConnectorGetRequest.validate(); assertEquals("Validation Failed: 1: ML connector id can't be null;", actionRequestValidationException.getMessage()); } @Test - public void validate_Success() { + public void validateSuccess() { MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(connectorId).build(); ActionRequestValidationException actionRequestValidationException = mlConnectorGetRequest.validate(); assertNull(actionRequestValidationException); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java index 417f77506c..55d13afeaf 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java @@ -35,7 +35,7 @@ public void setUp() { } @Test - public void writeTo_Success() throws IOException { + public void writeToSuccess() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); MLConnectorGetResponse response = MLConnectorGetResponse.builder().mlConnector(connector).build(); response.writeTo(bytesStreamOutput); @@ -72,7 +72,7 @@ public void toXContentTest() throws IOException { } @Test - public void fromActionResponseWithMLConnectorGetResponse_Success() { + public void fromActionResponseWithMLConnectorGetResponseSuccess() { MLConnectorGetResponse mlConnectorGetResponse = MLConnectorGetResponse.builder().mlConnector(connector).build(); MLConnectorGetResponse mlConnectorGetResponseFromActionResponse = MLConnectorGetResponse.fromActionResponse(mlConnectorGetResponse); assertSame(mlConnectorGetResponse, mlConnectorGetResponseFromActionResponse); @@ -80,7 +80,7 @@ public void fromActionResponseWithMLConnectorGetResponse_Success() { } @Test - public void fromActionResponse_Success() { + public void fromActionResponseSuccess() { MLConnectorGetResponse mlConnectorGetResponse = MLConnectorGetResponse.builder().mlConnector(connector).build(); ActionResponse actionResponse = new ActionResponse() { @Override @@ -94,7 +94,7 @@ public void writeTo(StreamOutput out) throws IOException { } @Test(expected = UncheckedIOException.class) - public void fromActionResponse_IOException() { + public void fromActionResponseIOException() { ActionResponse actionResponse = new ActionResponse() { @Override public void writeTo(StreamOutput out) throws IOException { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java index 5310be6582..ed0a1dd439 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java @@ -31,6 +31,8 @@ public class MLCreateConnectorRequestTests { private MLCreateConnectorInput mlCreateConnectorInput; + private MLCreateConnectorRequest mlCreateConnectorRequest; + @Before public void setUp(){ ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; @@ -55,11 +57,11 @@ public void setUp(){ .backendRoles(Arrays.asList("role1", "role2")) .addAllBackendRoles(false) .build(); + mlCreateConnectorRequest = MLCreateConnectorRequest.builder().mlCreateConnectorInput(mlCreateConnectorInput).build(); } @Test - public void writeTo_Success() throws IOException { - MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder().mlCreateConnectorInput(mlCreateConnectorInput).build(); + public void writeToSuccess() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); mlCreateConnectorRequest.writeTo(output); MLCreateConnectorRequest parsedRequest = new MLCreateConnectorRequest(output.bytes().streamInput()); @@ -72,16 +74,12 @@ public void writeTo_Success() throws IOException { } @Test - public void validate_Success() { - MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder() - .mlCreateConnectorInput(mlCreateConnectorInput) - .build(); - + public void validateSuccess() { assertNull(mlCreateConnectorRequest.validate()); } @Test - public void validate_Exception_NullMLRegisterModelGroupInput() { + public void validateWithNullMLCreateConnectorInputException() { MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder() .build(); ActionRequestValidationException exception = mlCreateConnectorRequest.validate(); @@ -89,18 +87,12 @@ public void validate_Exception_NullMLRegisterModelGroupInput() { } @Test - public void fromActionRequest_Success_WithMLRegisterModelRequest() { - MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder() - .mlCreateConnectorInput(mlCreateConnectorInput) - .build(); + public void fromActionRequestWithMLCreateConnectorRequestSuccess() { assertSame(MLCreateConnectorRequest.fromActionRequest(mlCreateConnectorRequest), mlCreateConnectorRequest); } @Test - public void fromActionRequest_Success_WithNonMLRegisterModelRequest() { - MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder() - .mlCreateConnectorInput(mlCreateConnectorInput) - .build(); + public void fromActionRequestWithNonMLCreateConnectorRequestSuccess() { ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -118,7 +110,7 @@ public void writeTo(StreamOutput out) throws IOException { } @Test(expected = UncheckedIOException.class) - public void fromActionRequest_IOException() { + public void fromActionRequestIOException() { ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTests.java index 8d58047980..7995e47f8f 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTests.java @@ -9,30 +9,71 @@ import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.TestHelper; import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; public class MLCreateConnectorResponseTests { @Test public void toXContent() throws IOException { - MLCreateConnectorResponse response = new MLCreateConnectorResponse("test_id"); + MLCreateConnectorResponse response = new MLCreateConnectorResponse("testConnectorId"); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"connector_id\":\"test_id\"}", content); + assertEquals("{\"connector_id\":\"testConnectorId\"}", content); } @Test public void readFromStream() throws IOException { - MLCreateConnectorResponse response = new MLCreateConnectorResponse("test_id"); + MLCreateConnectorResponse response = new MLCreateConnectorResponse("testConnectorId"); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLCreateConnectorResponse response2 = new MLCreateConnectorResponse(output.bytes().streamInput()); - Assert.assertEquals("test_id", response2.getConnectorId()); + assertEquals("testConnectorId", response2.getConnectorId()); + } + + + @Test + public void fromActionResponseWithMLCreateConnectorResponseSuccess() { + MLCreateConnectorResponse response = new MLCreateConnectorResponse("testConnectorId"); + MLCreateConnectorResponse responseFromActionResponse = MLCreateConnectorResponse.fromActionResponse(response); + assertSame(response, responseFromActionResponse); + assertEquals(response.getConnectorId(), responseFromActionResponse.getConnectorId()); + } + + @Test + public void fromActionResponseSuccess() { + MLCreateConnectorResponse response = new MLCreateConnectorResponse("testConnectorId"); + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + response.writeTo(out); + } + }; + MLCreateConnectorResponse responseFromActionResponse = MLCreateConnectorResponse.fromActionResponse(actionResponse); + assertNotSame(response, responseFromActionResponse); + assertEquals(response.getConnectorId(), responseFromActionResponse.getConnectorId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionResponseIOException() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLCreateConnectorResponse.fromActionResponse(actionResponse); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerRequestTest.java new file mode 100644 index 0000000000..b95eb49bd9 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerRequestTest.java @@ -0,0 +1,124 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.concurrent.TimeUnit; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLRateLimiter; + + +public class MLCreateModelControllerRequestTest { + private MLModelController modelControllerInput; + + private MLCreateModelControllerRequest request; + + @Before + public void setUp() throws Exception { + + MLRateLimiter rateLimiter = MLRateLimiter.builder() + .rateLimitNumber("1") + .rateLimitUnit(TimeUnit.MILLISECONDS) + .build(); + modelControllerInput = MLModelController.builder() + .modelId("testModelId") + .userRateLimiterConfig(new HashMap<>() {{ + put("testUser", rateLimiter); + }}) + .build(); + request = MLCreateModelControllerRequest.builder() + .modelControllerInput(modelControllerInput) + .build(); + } + + @Test + public void writeToSuccess() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + request.writeTo(bytesStreamOutput); + MLCreateModelControllerRequest parsedRequest = new MLCreateModelControllerRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals("testModelId", parsedRequest.getModelControllerInput().getModelId()); + assertTrue(parsedRequest.getModelControllerInput().getUserRateLimiterConfig().containsKey("testUser")); + assertEquals("1", parsedRequest.getModelControllerInput().getUserRateLimiterConfig().get("testUser").getRateLimitNumber()); + assertEquals(TimeUnit.MILLISECONDS, parsedRequest.getModelControllerInput().getUserRateLimiterConfig().get("testUser").getRateLimitUnit()); + } + + @Test + public void validateSuccess() { + assertNull(request.validate()); + } + + @Test + public void validateWithNullMLModelControllerInputException() { + MLCreateModelControllerRequest request = MLCreateModelControllerRequest.builder().build(); + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Model controller input can't be null;", exception.getMessage()); + } + + @Test + public void validateWithNullMLModelID() { + modelControllerInput.setModelId(null); + MLCreateModelControllerRequest request = MLCreateModelControllerRequest.builder() + .modelControllerInput(modelControllerInput) + .build(); + + assertNull(request.validate()); + assertNull(request.getModelControllerInput().getModelId()); + } + + @Test + public void fromActionRequestWithMLCreateModelControllerRequestSuccess() { + assertSame(MLCreateModelControllerRequest.fromActionRequest(request), request); + } + + @Test + public void fromActionRequestWithNonMLCreateModelControllerRequestSuccess() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + request.writeTo(out); + } + }; + MLCreateModelControllerRequest result = MLCreateModelControllerRequest.fromActionRequest(actionRequest); + assertNotSame(result, request); + assertEquals(request.getModelControllerInput().getModelId(), result.getModelControllerInput().getModelId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequestIOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLCreateModelControllerRequest.fromActionRequest(actionRequest); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerResponseTest.java new file mode 100644 index 0000000000..1c3f4160f2 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerResponseTest.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.TestHelper; + +public class MLCreateModelControllerResponseTest { + + private MLCreateModelControllerResponse response; + + @Before + public void setup() { + response = new MLCreateModelControllerResponse("testModelId", "Status"); + } + + + @Test + public void writeToSuccess() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + response.writeTo(bytesStreamOutput); + MLCreateModelControllerResponse newResponse = new MLCreateModelControllerResponse(bytesStreamOutput.bytes().streamInput()); + assertEquals(response.getModelId(), newResponse.getModelId()); + assertEquals(response.getStatus(), newResponse.getStatus()); + } + + @Test + public void testToXContent() throws IOException { + MLCreateModelControllerResponse response = new MLCreateModelControllerResponse("testModelId", "Status"); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = TestHelper.xContentBuilderToString(builder); + final String expected = "{\"model_id\":\"testModelId\",\"status\":\"Status\"}"; + assertEquals(expected, jsonStr); + } + + @Test + public void fromActionResponseWithMLCreateModelControllerResponseSuccess() { + MLCreateModelControllerResponse responseFromActionResponse = MLCreateModelControllerResponse.fromActionResponse(response); + assertSame(response, responseFromActionResponse); + assertEquals(response.getModelId(), responseFromActionResponse.getModelId()); + } + + @Test + public void fromActionResponseSuccess() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + response.writeTo(out); + } + }; + MLCreateModelControllerResponse responseFromActionResponse = MLCreateModelControllerResponse.fromActionResponse(actionResponse); + assertNotSame(response, responseFromActionResponse); + assertEquals(response.getModelId(), responseFromActionResponse.getModelId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionResponseIOException() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLCreateModelControllerResponse.fromActionResponse(actionResponse); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeResponseTest.java new file mode 100644 index 0000000000..3a2a3104a7 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeResponseTest.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.Version; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.transport.TransportAddress; + +@RunWith(MockitoJUnitRunner.class) +public class MLDeployModelControllerNodeResponseTest { + + @Mock + private DiscoveryNode localNode; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() throws Exception { + localNode = new DiscoveryNode( + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + } + + @Test + public void testSerializationDeserialization() throws IOException { + Map deployModelControllerStatus = Map.of("modelName:version", "response"); + MLDeployModelControllerNodeResponse response = new MLDeployModelControllerNodeResponse(localNode, deployModelControllerStatus); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + MLDeployModelControllerNodeResponse newResponse = new MLDeployModelControllerNodeResponse(output.bytes().streamInput()); + assertEquals(newResponse.getNode().getId(), response.getNode().getId()); + } + + @Test + public void testSerializationDeserializationNullModelUpdateModelCacheStatus() throws IOException { + MLDeployModelControllerNodeResponse response = new MLDeployModelControllerNodeResponse(localNode, null); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + MLDeployModelControllerNodeResponse newResponse = new MLDeployModelControllerNodeResponse(output.bytes().streamInput()); + assertEquals(newResponse.getNode().getId(), response.getNode().getId()); + } + + @Test + public void testReadProfile() throws IOException { + MLDeployModelControllerNodeResponse response = new MLDeployModelControllerNodeResponse(localNode, new HashMap<>()); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + MLDeployModelControllerNodeResponse newResponse = MLDeployModelControllerNodeResponse.readStats(output.bytes().streamInput()); + assertNotEquals(newResponse, response); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesRequestTest.java new file mode 100644 index 0000000000..7f30734f96 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesRequestTest.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; + +// This test combined MLDeployModelControllerNodesRequestTest and MLDeployModelControllerNodeRequestTest together. +@RunWith(MockitoJUnitRunner.class) +public class MLDeployModelControllerNodesRequestTest { + + @Mock + private DiscoveryNode localNode1; + + @Mock + private DiscoveryNode localNode2; + + private MLDeployModelControllerNodeRequest deployModelControllerNodeRequestWithStringNodeIds; + + private MLDeployModelControllerNodeRequest deployModelControllerNodeRequestWithDiscoveryNodeIds; + + @Before + public void setUp() throws Exception { + + String modelId = "testModelId"; + String[] stringNodeIds = {"nodeId1", "nodeId2", "nodeId3"}; + DiscoveryNode[] discoveryNodeIds = {localNode1, localNode2}; + + deployModelControllerNodeRequestWithStringNodeIds = new MLDeployModelControllerNodeRequest( + new MLDeployModelControllerNodesRequest(stringNodeIds, modelId) + ); + deployModelControllerNodeRequestWithDiscoveryNodeIds = new MLDeployModelControllerNodeRequest( + new MLDeployModelControllerNodesRequest(discoveryNodeIds, modelId) + ); + + } + + @Test + public void testConstructorSerialization1() throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + deployModelControllerNodeRequestWithStringNodeIds.writeTo(output); + assertEquals("testModelId", deployModelControllerNodeRequestWithStringNodeIds.getDeployModelControllerNodesRequest().getModelId()); + + } + + @Test + public void testConstructorSerialization2() { + assertEquals(2, deployModelControllerNodeRequestWithDiscoveryNodeIds.getDeployModelControllerNodesRequest().concreteNodes().length); + + } + + @Test + public void testConstructorFromInputStream() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + deployModelControllerNodeRequestWithStringNodeIds.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLDeployModelControllerNodeRequest parsedNodeRequest = new MLDeployModelControllerNodeRequest(streamInput); + + assertEquals(deployModelControllerNodeRequestWithStringNodeIds.getDeployModelControllerNodesRequest().getModelId(), + parsedNodeRequest.getDeployModelControllerNodesRequest().getModelId()); + + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesResponseTest.java new file mode 100644 index 0000000000..47620d015b --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesResponseTest.java @@ -0,0 +1,107 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +@RunWith(MockitoJUnitRunner.class) +public class MLDeployModelControllerNodesResponseTest { + @Mock + private ClusterName clusterName; + @Mock + private DiscoveryNode node1; + @Mock + private DiscoveryNode node2; + + @Before + public void setUp() throws Exception { + clusterName = new ClusterName("clusterName"); + node1 = new DiscoveryNode( + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + node2 = new DiscoveryNode( + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + } + + @Test + public void testSerializationDeserialization1() throws IOException { + List responseList = new ArrayList<>(); + List failuresList = new ArrayList<>(); + MLDeployModelControllerNodesResponse response = new MLDeployModelControllerNodesResponse(clusterName, responseList, failuresList); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + MLDeployModelControllerNodesResponse newResponse = new MLDeployModelControllerNodesResponse(output.bytes().streamInput()); + assertEquals(newResponse.getNodes().size(), response.getNodes().size()); + } + + @Test + public void testToXContent() throws IOException { + List nodes = new ArrayList<>(); + + Map deployModelControllerStatus1 = Map.of("modelId1", "response"); + nodes.add(new MLDeployModelControllerNodeResponse(node1, deployModelControllerStatus1)); + + Map deployModelControllerStatus2 = Map.of("modelId2", "response"); + nodes.add(new MLDeployModelControllerNodeResponse(node2, deployModelControllerStatus2)); + + List failures = new ArrayList<>(); + MLDeployModelControllerNodesResponse response = new MLDeployModelControllerNodesResponse(clusterName, nodes, failures); + XContentBuilder builder = XContentFactory.jsonBuilder(); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = builder.toString(); + assertEquals( + "{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}", + jsonStr + ); + } + + @Test + public void testNullUpdateModelCacheStatusToXContent() throws IOException { + List nodes = new ArrayList<>(); + nodes.add(new MLDeployModelControllerNodeResponse(node1, null)); + List failures = new ArrayList<>(); + MLDeployModelControllerNodesResponse response = new MLDeployModelControllerNodesResponse(clusterName, nodes, failures); + XContentBuilder builder = XContentFactory.jsonBuilder(); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = builder.toString(); + assertEquals("{}",jsonStr); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteRequestTest.java new file mode 100644 index 0000000000..ae5ef8dd49 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteRequestTest.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLModelControllerDeleteRequestTest { + + private String modelId; + + private MLModelControllerDeleteRequest request; + + @Before + public void setUp() { + + modelId = "testModelId"; + + request = MLModelControllerDeleteRequest.builder() + .modelId(modelId).build(); + } + + + @Test + public void writeToSuccess() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + request.writeTo(bytesStreamOutput); + MLModelControllerDeleteRequest parsedRequest = new MLModelControllerDeleteRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(parsedRequest.getModelId(), modelId); + } + + @Test + public void validateSuccess() { + assertNull(request.validate()); + } + + @Test + public void validateWithNullModelIdException() { + MLModelControllerDeleteRequest request = MLModelControllerDeleteRequest.builder().build(); + + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: ML model id can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequestWithMLUpdateModelControllerRequestSuccess() { + assertSame(MLModelControllerDeleteRequest.fromActionRequest(request), request); + } + + @Test + public void fromActionRequestSuccess() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + request.writeTo(out); + } + }; + MLModelControllerDeleteRequest result = MLModelControllerDeleteRequest.fromActionRequest(actionRequest); + assertNotSame(result, request); + assertEquals(result.getModelId(), request.getModelId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequestIOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLModelControllerDeleteRequest.fromActionRequest(actionRequest); + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetRequestTest.java new file mode 100644 index 0000000000..f45b790250 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetRequestTest.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLModelControllerGetRequestTest { + + private String modelId; + + private MLModelControllerGetRequest request; + + @Before + public void setUp() { + + modelId = "testModelId"; + + request = MLModelControllerGetRequest.builder().modelId(modelId).build(); + } + + @Test + public void writeToSuccess() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + request.writeTo(bytesStreamOutput); + MLModelControllerGetRequest parsedRequest = new MLModelControllerGetRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(modelId, parsedRequest.getModelId()); + } + + @Test + public void fromActionRequestSuccess() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + request.writeTo(out); + } + }; + MLModelControllerGetRequest requestFromActionRequest = MLModelControllerGetRequest.fromActionRequest(actionRequest); + assertNotSame(request, requestFromActionRequest); + assertEquals(request.getModelId(), requestFromActionRequest.getModelId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequestIOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLModelControllerGetRequest.fromActionRequest(actionRequest); + } + + @Test + public void fromActionRequestWithMLModelControllerGetRequestSuccess() { + MLModelControllerGetRequest requestFromActionRequest = MLModelControllerGetRequest.fromActionRequest(request); + assertSame(request, requestFromActionRequest); + assertEquals(request.getModelId(), requestFromActionRequest.getModelId()); + } + + @Test + public void validateNullModelIdException() { + MLModelControllerGetRequest request = MLModelControllerGetRequest.builder().build(); + ActionRequestValidationException actionRequestValidationException = request.validate(); + assertEquals("Validation Failed: 1: ML model id can't be null;", actionRequestValidationException.getMessage()); + } + + @Test + public void validateSuccess() { + ActionRequestValidationException actionRequestValidationException = request.validate(); + assertNull(actionRequestValidationException); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetResponseTest.java new file mode 100644 index 0000000000..af6526638a --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetResponseTest.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.concurrent.TimeUnit; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLRateLimiter; + +public class MLModelControllerGetResponseTest { + + private MLModelController modelController; + + private MLModelControllerGetResponse response; + + @Before + public void setUp() { + MLRateLimiter rateLimiter = MLRateLimiter.builder() + .rateLimitNumber("1") + .rateLimitUnit(TimeUnit.MILLISECONDS) + .build(); + modelController = MLModelController.builder() + .modelId("testModelId") + .userRateLimiterConfig(new HashMap<>() {{ + put("testUser", rateLimiter); + }}) + .build(); + response = MLModelControllerGetResponse.builder().modelController(modelController).build(); + } + + @Test + public void writeToSuccess() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + response.writeTo(bytesStreamOutput); + MLModelControllerGetResponse parsedResponse = new MLModelControllerGetResponse(bytesStreamOutput.bytes().streamInput()); + assertNotEquals(response.getModelController(), parsedResponse.getModelController()); + assertEquals(response.getModelController().getModelId(), parsedResponse.getModelController().getModelId()); + assertEquals(response.getModelController().getUserRateLimiterConfig().get("testUser").getRateLimitNumber(), parsedResponse.getModelController().getUserRateLimiterConfig().get("testUser").getRateLimitNumber()); + assertEquals(response.getModelController().getUserRateLimiterConfig().get("testUser").getRateLimitUnit(), parsedResponse.getModelController().getUserRateLimiterConfig().get("testUser").getRateLimitUnit()); + } + + @Test + public void toXContentTest() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = builder.toString(); + assertEquals("{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":{\"testUser\":{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"}}}",jsonStr); + } + + @Test + public void fromActionResponseWithMLModelControllerGetResponseSuccess() { + MLModelControllerGetResponse responseFromActionResponse = MLModelControllerGetResponse.fromActionResponse(response); + assertSame(response, responseFromActionResponse); + assertEquals(response.getModelController(), responseFromActionResponse.getModelController()); + } + + @Test + public void fromActionResponseSuccess() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + response.writeTo(out); + } + }; + MLModelControllerGetResponse responseFromActionResponse = MLModelControllerGetResponse.fromActionResponse(actionResponse); + assertNotSame(response, responseFromActionResponse); + assertNotEquals(response.getModelController(), responseFromActionResponse.getModelController()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionResponseIOException() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLModelControllerGetResponse.fromActionResponse(actionResponse); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeResponseTest.java new file mode 100644 index 0000000000..5f0e045418 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeResponseTest.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.Version; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +@RunWith(MockitoJUnitRunner.class) +public class MLUndeployModelControllerNodeResponseTest { + + @Mock + private DiscoveryNode localNode; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() throws Exception { + localNode = new DiscoveryNode( + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + } + + @Test + public void testSerializationDeserialization() throws IOException { + Map undeployModelControllerStatus = Map.of("modelName:version", "response"); + MLUndeployModelControllerNodeResponse response = new MLUndeployModelControllerNodeResponse(localNode, undeployModelControllerStatus); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + MLUndeployModelControllerNodeResponse newResponse = new MLUndeployModelControllerNodeResponse(output.bytes().streamInput()); + assertEquals(newResponse.getNode().getId(), response.getNode().getId()); + } + + @Test + public void testSerializationDeserializationNullModelUpdateModelCacheStatus() throws IOException { + MLUndeployModelControllerNodeResponse response = new MLUndeployModelControllerNodeResponse(localNode, null); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + MLUndeployModelControllerNodeResponse newResponse = new MLUndeployModelControllerNodeResponse(output.bytes().streamInput()); + assertEquals(newResponse.getNode().getId(), response.getNode().getId()); + } + + @Test + public void testReadProfile() throws IOException { + MLUndeployModelControllerNodeResponse response = new MLUndeployModelControllerNodeResponse(localNode, new HashMap<>()); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + MLUndeployModelControllerNodeResponse newResponse = MLUndeployModelControllerNodeResponse.readStats(output.bytes().streamInput()); + assertNotEquals(newResponse, response); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesRequestTest.java new file mode 100644 index 0000000000..ad08f76408 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesRequestTest.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import static org.junit.Assert.assertEquals; +import java.io.IOException; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; + +// This test combined MLUndeployModelControllerNodesRequestTest and MLUndeployModelControllerNodeRequestTest together. +@RunWith(MockitoJUnitRunner.class) +public class MLUndeployModelControllerNodesRequestTest { + + @Mock + private DiscoveryNode localNode1; + + @Mock + private DiscoveryNode localNode2; + + private MLUndeployModelControllerNodeRequest undeployModelControllerNodeRequestWithStringNodeIds; + + private MLUndeployModelControllerNodeRequest undeployModelControllerNodeRequestWithDiscoveryNodeIds; + + @Before + public void setUp() throws Exception { + + String modelId = "testModelId"; + String[] stringNodeIds = {"nodeId1", "nodeId2", "nodeId3"}; + DiscoveryNode[] discoveryNodeIds = {localNode1, localNode2}; + + undeployModelControllerNodeRequestWithStringNodeIds = new MLUndeployModelControllerNodeRequest( + new MLUndeployModelControllerNodesRequest(stringNodeIds, modelId) + ); + undeployModelControllerNodeRequestWithDiscoveryNodeIds = new MLUndeployModelControllerNodeRequest( + new MLUndeployModelControllerNodesRequest(discoveryNodeIds, modelId) + ); + + } + + @Test + public void testConstructorSerialization1() throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + undeployModelControllerNodeRequestWithStringNodeIds.writeTo(output); + assertEquals("testModelId", undeployModelControllerNodeRequestWithStringNodeIds.getUndeployModelControllerNodesRequest().getModelId()); + + } + + @Test + public void testConstructorSerialization2() { + assertEquals(2, undeployModelControllerNodeRequestWithDiscoveryNodeIds.getUndeployModelControllerNodesRequest().concreteNodes().length); + + } + + @Test + public void testConstructorFromInputStream() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + undeployModelControllerNodeRequestWithStringNodeIds.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLUndeployModelControllerNodeRequest parsedNodeRequest = new MLUndeployModelControllerNodeRequest(streamInput); + + assertEquals(undeployModelControllerNodeRequestWithStringNodeIds.getUndeployModelControllerNodesRequest().getModelId(), + parsedNodeRequest.getUndeployModelControllerNodesRequest().getModelId()); + + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesResponseTest.java new file mode 100644 index 0000000000..77aa947fa6 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesResponseTest.java @@ -0,0 +1,107 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +@RunWith(MockitoJUnitRunner.class) +public class MLUndeployModelControllerNodesResponseTest { + @Mock + private ClusterName clusterName; + @Mock + private DiscoveryNode node1; + @Mock + private DiscoveryNode node2; + + @Before + public void setUp() throws Exception { + clusterName = new ClusterName("clusterName"); + node1 = new DiscoveryNode( + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + node2 = new DiscoveryNode( + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + } + + @Test + public void testSerializationDeserialization1() throws IOException { + List responseList = new ArrayList<>(); + List failuresList = new ArrayList<>(); + MLUndeployModelControllerNodesResponse response = new MLUndeployModelControllerNodesResponse(clusterName, responseList, failuresList); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + MLUndeployModelControllerNodesResponse newResponse = new MLUndeployModelControllerNodesResponse(output.bytes().streamInput()); + assertEquals(newResponse.getNodes().size(), response.getNodes().size()); + } + + @Test + public void testToXContent() throws IOException { + List nodes = new ArrayList<>(); + + Map undeployModelControllerStatus1 = Map.of("modelId1", "response"); + nodes.add(new MLUndeployModelControllerNodeResponse(node1, undeployModelControllerStatus1)); + + Map undeployModelControllerStatus2 = Map.of("modelId2", "response"); + nodes.add(new MLUndeployModelControllerNodeResponse(node2, undeployModelControllerStatus2)); + + List failures = new ArrayList<>(); + MLUndeployModelControllerNodesResponse response = new MLUndeployModelControllerNodesResponse(clusterName, nodes, failures); + XContentBuilder builder = XContentFactory.jsonBuilder(); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = builder.toString(); + assertEquals( + "{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}", + jsonStr + ); + } + + @Test + public void testNullUpdateModelCacheStatusToXContent() throws IOException { + List nodes = new ArrayList<>(); + nodes.add(new MLUndeployModelControllerNodeResponse(node1, null)); + List failures = new ArrayList<>(); + MLUndeployModelControllerNodesResponse response = new MLUndeployModelControllerNodesResponse(clusterName, nodes, failures); + XContentBuilder builder = XContentFactory.jsonBuilder(); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = builder.toString(); + assertEquals("{}",jsonStr); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerRequestTest.java new file mode 100644 index 0000000000..e452cfccc9 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerRequestTest.java @@ -0,0 +1,124 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.ml.common.transport.controller; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.concurrent.TimeUnit; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLRateLimiter; + +public class MLUpdateModelControllerRequestTest { + private MLModelController updateModelControllerInput; + + private MLUpdateModelControllerRequest request; + + @Before + public void setUp() throws Exception { + + MLRateLimiter rateLimiter = MLRateLimiter.builder() + .rateLimitNumber("1") + .rateLimitUnit(TimeUnit.MILLISECONDS) + .build(); + updateModelControllerInput = MLModelController.builder() + .modelId("testModelId") + .userRateLimiterConfig(new HashMap<>() {{ + put("testUser", rateLimiter); + }}) + .build(); + request = MLUpdateModelControllerRequest.builder() + .updateModelControllerInput(updateModelControllerInput) + .build(); + } + + @Test + public void writeToSuccess() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + request.writeTo(bytesStreamOutput); + MLUpdateModelControllerRequest parsedRequest = new MLUpdateModelControllerRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals("testModelId", parsedRequest.getUpdateModelControllerInput().getModelId()); + assertTrue(parsedRequest.getUpdateModelControllerInput().getUserRateLimiterConfig().containsKey("testUser")); + assertEquals("1", parsedRequest.getUpdateModelControllerInput().getUserRateLimiterConfig().get("testUser").getRateLimitNumber()); + assertEquals(TimeUnit.MILLISECONDS, parsedRequest.getUpdateModelControllerInput().getUserRateLimiterConfig().get("testUser").getRateLimitUnit()); + } + + @Test + public void validateSuccess() { + assertNull(request.validate()); + } + + @Test + public void validateWithNullMLModelControllerInputException() { + MLUpdateModelControllerRequest request = MLUpdateModelControllerRequest.builder().build(); + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Update model controller input can't be null;", exception.getMessage()); + } + + @Test + public void validateWithNullMLModelID() { + updateModelControllerInput.setModelId(null); + MLUpdateModelControllerRequest request = MLUpdateModelControllerRequest.builder() + .updateModelControllerInput(updateModelControllerInput) + .build(); + + assertNull(request.validate()); + assertNull(request.getUpdateModelControllerInput().getModelId()); + } + + @Test + public void fromActionRequestWithMLUpdateModelControllerRequestSuccess() { + assertSame(MLUpdateModelControllerRequest.fromActionRequest(request), request); + } + + @Test + public void fromActionRequestWithNonMLUpdateModelControllerRequestSuccess() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + request.writeTo(out); + } + }; + MLUpdateModelControllerRequest result = MLUpdateModelControllerRequest.fromActionRequest(actionRequest); + assertNotSame(result, request); + assertEquals(request.getUpdateModelControllerInput().getModelId(), result.getUpdateModelControllerInput().getModelId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequestIOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLUpdateModelControllerRequest.fromActionRequest(actionRequest); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java index ed1f5568eb..e283647f38 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java @@ -6,14 +6,15 @@ package org.opensearch.ml.common.transport.model; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import java.io.IOException; +import java.time.Instant; import java.util.Arrays; import java.util.Collections; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import org.junit.Before; @@ -33,6 +34,7 @@ import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.model.MLUpdateModelInput; import org.opensearch.search.SearchModule; @@ -42,33 +44,24 @@ public class MLUpdateModelInputTest { private MLUpdateModelInput updateModelInput; - private final String expectedInputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + private final String expectedInputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":" + + "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"model_rate_limiter_config\":" + + "{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"},\"model_config\":" + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + - "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector\":" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"updated_connector\":" + "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" + "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" + - "\"test-connector_id\",\"connector_update_content\":{\"description\":\"updated description\",\"version\":\"1\"}}"; - private final String expectedInputStrWithNullField = "{\"model_id\":\"test-model_id\",\"name\":null,\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + - "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + - "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; - private final String expectedOutputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + - "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + - "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector\":" + - "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + - "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + - "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" + - "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" + - "\"test-connector_id\",\"connector_update_content\":{\"description\":\"updated description\",\"version\":\"1\",\"parameters\":{},\"credential\":{}}}"; - private final String expectedInputStrWithIllegalField = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\"},\"last_updated_time\":1}"; + + private final String expectedOutputStr = "{\"model_id\":null,\"name\":\"name\",\"description\":\"description\",\"model_group_id\":" + + "\"modelGroupId\",\"is_enabled\":false,\"model_rate_limiter_config\":" + + "{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"},\"model_config\":" + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + - "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector\":" + - "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + - "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + - "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" + - "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" + - "\"test-connector_id\",\"connector_update_content\":{\"description\":\"updated description\",\"version\":\"1\"},\"illegal_field\":\"This field need to be skipped.\"}"; + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":" + + "\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\",\"parameters\":{},\"credential\":{}}}"; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -81,7 +74,7 @@ public void setUp() throws Exception { .embeddingDimension(100) .build(); - Connector connector = HttpConnector + Connector updatedConnector = HttpConnector .builder() .name("test") .protocol("http") @@ -110,16 +103,24 @@ public void setUp() throws Exception { .description("updated description") .build(); + MLRateLimiter rateLimiter = MLRateLimiter.builder() + .rateLimitNumber("1") + .rateLimitUnit(TimeUnit.MILLISECONDS) + .build(); + updateModelInput = MLUpdateModelInput.builder() .modelId("test-model_id") .modelGroupId("modelGroupId") .version("2") .name("name") .description("description") + .isEnabled(false) + .modelRateLimiterConfig(rateLimiter) .modelConfig(config) - .connector(connector) + .updatedConnector(updatedConnector) .connectorId("test-connector_id") - .connectorUpdateContent(updateContent) + .connector(updateContent) + .lastUpdateTime(Instant.ofEpochMilli(1)) .build(); } @@ -149,14 +150,8 @@ public void testToXContent() throws Exception { public void testToXContentIncomplete() throws Exception { String expectedIncompleteInputStr = "{\"model_id\":\"test-model_id\"}"; - updateModelInput.setDescription(null); - updateModelInput.setVersion(null); - updateModelInput.setName(null); - updateModelInput.setModelGroupId(null); - updateModelInput.setModelConfig(null); - updateModelInput.setConnector(null); - updateModelInput.setConnectorId(null); - updateModelInput.setConnectorUpdateContent(null); + updateModelInput = MLUpdateModelInput.builder() + .modelId("test-model_id").build(); String jsonStr = serializationWithToXContent(updateModelInput); assertEquals(expectedIncompleteInputStr, jsonStr); } @@ -171,6 +166,11 @@ public void parseSuccess() throws Exception { @Test public void parseWithNullFieldWithoutModel() throws Exception { exceptionRule.expect(IllegalStateException.class); + String expectedInputStrWithNullField = "{\"model_id\":\"test-model_id\",\"name\":null,\"description\":\"description\",\"model_version\":" + + "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"model_rate_limiter_config\":" + + "{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"},\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { try { assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); @@ -182,6 +182,16 @@ public void parseWithNullFieldWithoutModel() throws Exception { @Test public void parseWithIllegalFieldWithoutModel() throws Exception { + String expectedInputStrWithIllegalField = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":" + + "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"model_rate_limiter_config\":" + + "{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"},\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"updated_connector\":" + + "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + + "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + + "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" + + "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" + + "\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\"},\"last_updated_time\":1,\"illegal_field\":\"This field need to be skipped.\"}"; testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { try { assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java index 68e4491674..6c25f36dfe 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java @@ -12,38 +12,51 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; public class MLModelGroupDeleteRequestTest { private String modelGroupId; + private MLModelGroupDeleteRequest request; + @Before public void setUp() { - modelGroupId = "test_group_id"; + modelGroupId = "testGroupId"; + + request = MLModelGroupDeleteRequest.builder() + .modelGroupId(modelGroupId).build(); } @Test - public void writeTo_Success() throws IOException { - MLModelGroupDeleteRequest mlModelGroupDeleteRequest = MLModelGroupDeleteRequest.builder() - .modelGroupId(modelGroupId).build(); + public void writeToSuccess() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - mlModelGroupDeleteRequest.writeTo(bytesStreamOutput); - MLModelGroupDeleteRequest parsedModel = new MLModelGroupDeleteRequest(bytesStreamOutput.bytes().streamInput()); - assertEquals(parsedModel.getModelGroupId(), modelGroupId); + request.writeTo(bytesStreamOutput); + MLModelGroupDeleteRequest parsedRequest = new MLModelGroupDeleteRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(parsedRequest.getModelGroupId(), modelGroupId); } @Test - public void validate_Exception_NullModelId() { - MLModelGroupDeleteRequest mlModelGroupDeleteRequest = MLModelGroupDeleteRequest.builder().build(); + public void validateSuccess() { + assertNull(request.validate()); + } - ActionRequestValidationException exception = mlModelGroupDeleteRequest.validate(); + @Test + public void validateWithNullModelIdException() { + MLModelGroupDeleteRequest request = MLModelGroupDeleteRequest.builder().build(); + + ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: ML model group id can't be null;", exception.getMessage()); } @Test - public void fromActionRequest_Success() { - MLModelGroupDeleteRequest mlModelDeleteRequest = MLModelGroupDeleteRequest.builder() - .modelGroupId(modelGroupId).build(); + public void fromActionRequestWithMLUpdateModelControllerRequestSuccess() { + assertSame(MLModelGroupDeleteRequest.fromActionRequest(request), request); + } + + @Test + public void fromActionRequestSuccess() { ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -52,16 +65,16 @@ public ActionRequestValidationException validate() { @Override public void writeTo(StreamOutput out) throws IOException { - mlModelDeleteRequest.writeTo(out); + request.writeTo(out); } }; MLModelGroupDeleteRequest result = MLModelGroupDeleteRequest.fromActionRequest(actionRequest); - assertNotSame(result, mlModelDeleteRequest); - assertEquals(result.getModelGroupId(), mlModelDeleteRequest.getModelGroupId()); + assertNotSame(result, request); + assertEquals(result.getModelGroupId(), request.getModelGroupId()); } @Test(expected = UncheckedIOException.class) - public void fromActionRequest_IOException() { + public void fromActionRequestIOException() { ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequestTest.java index b1814b44c0..5b8000bdb9 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequestTest.java @@ -34,8 +34,8 @@ public void writeTo_Success() throws IOException { .modelGroupId(modelGroupId).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); mlModelGroupGetRequest.writeTo(bytesStreamOutput); - MLModelGroupGetRequest parsedModel = new MLModelGroupGetRequest(bytesStreamOutput.bytes().streamInput()); - assertEquals(parsedModel.getModelGroupId(), modelGroupId); + MLModelGroupGetRequest parsedRequest = new MLModelGroupGetRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(parsedRequest.getModelGroupId(), modelGroupId); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponseTest.java index 0143c1851e..da10789062 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponseTest.java @@ -40,7 +40,7 @@ public void setUp() { } @Test - public void writeTo_Success() throws IOException { + public void writeToSuccess() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); MLModelGroupGetResponse response = MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build(); response.writeTo(bytesStreamOutput); @@ -66,7 +66,7 @@ public void toXContentTest() throws IOException { } @Test - public void fromActionResponseWithMLModelGroupGetResponse_Success() { + public void fromActionResponseWithMLModelGroupGetResponseSuccess() { MLModelGroupGetResponse mlModelGroupGetResponse = MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build(); MLModelGroupGetResponse mlModelGroupGetResponseFromActionResponse = MLModelGroupGetResponse.fromActionResponse(mlModelGroupGetResponse); assertSame(mlModelGroupGetResponse, mlModelGroupGetResponseFromActionResponse); @@ -74,7 +74,7 @@ public void fromActionResponseWithMLModelGroupGetResponse_Success() { } @Test - public void fromActionResponse_Success() { + public void fromActionResponseSuccess() { MLModelGroupGetResponse mlModelGroupGetResponse = MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build(); ActionResponse actionResponse = new ActionResponse() { @Override @@ -88,7 +88,7 @@ public void writeTo(StreamOutput out) throws IOException { } @Test(expected = UncheckedIOException.class) - public void fromActionResponse_IOException() { + public void fromActionResponseIOException() { ActionResponse actionResponse = new ActionResponse() { @Override public void writeTo(StreamOutput out) throws IOException { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java index 8e27325e47..34be768c63 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java @@ -10,7 +10,7 @@ import java.io.IOException; import java.io.UncheckedIOException; -import java.util.Arrays; +import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotSame; @@ -21,23 +21,26 @@ public class MLRegisterModelGroupRequestTest { private MLRegisterModelGroupInput mlRegisterModelGroupInput; + private MLRegisterModelGroupRequest request; + @Before public void setUp(){ - mlRegisterModelGroupInput = mlRegisterModelGroupInput.builder() + mlRegisterModelGroupInput = MLRegisterModelGroupInput.builder() .name("name") .description("description") - .backendRoles(Arrays.asList("IT")) + .backendRoles(List.of("IT")) .modelAccessMode(AccessMode.RESTRICTED) .isAddAllBackendRoles(true) .build(); - } - @Test - public void writeTo_Success() throws IOException { - MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder() + request = MLRegisterModelGroupRequest.builder() .registerModelGroupInput(mlRegisterModelGroupInput) .build(); + } + + @Test + public void writeToSuccess() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); MLRegisterModelGroupRequest parsedRequest = new MLRegisterModelGroupRequest(bytesStreamOutput.bytes().streamInput()); @@ -49,25 +52,20 @@ public void writeTo_Success() throws IOException { } @Test - public void validate_Success() { - MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder() - .registerModelGroupInput(mlRegisterModelGroupInput) - .build(); - + public void validateSuccess() { assertNull(request.validate()); } @Test - public void validate_Exception_NullMLRegisterModelGroupInput() { - MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder() - .build(); + public void validateNullMLRegisterModelGroupInputException() { + MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder().build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: Model meta input can't be null;", exception.getMessage()); } @Test // MLRegisterModelGroupInput check its parameters when created, so exception is not thrown here - public void validate_Exception_NullMLModelName() { + public void validateNullMLModelNameException() { mlRegisterModelGroupInput.setName(null); MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder() .registerModelGroupInput(mlRegisterModelGroupInput) @@ -78,18 +76,12 @@ public void validate_Exception_NullMLModelName() { } @Test - public void fromActionRequest_Success_WithMLRegisterModelRequest() { - MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder() - .registerModelGroupInput(mlRegisterModelGroupInput) - .build(); + public void fromActionRequestWithMLRegisterModelGroupRequestSuccess() { assertSame(MLRegisterModelGroupRequest.fromActionRequest(request), request); } @Test - public void fromActionRequest_Success_WithNonMLRegisterModelRequest() { - MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder() - .registerModelGroupInput(mlRegisterModelGroupInput) - .build(); + public void fromActionRequestWithNonMLRegisterModelGroupRequestSuccess() { ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -107,7 +99,7 @@ public void writeTo(StreamOutput out) throws IOException { } @Test(expected = UncheckedIOException.class) - public void fromActionRequest_IOException() { + public void fromActionRequestIOException() { ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponseTest.java index 9299307539..528a0099b5 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponseTest.java @@ -9,42 +9,78 @@ import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.TestHelper; import java.io.IOException; +import java.io.UncheckedIOException; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; public class MLRegisterModelGroupResponseTest { - MLRegisterModelGroupResponse mlRegisterModelGroupResponse; + MLRegisterModelGroupResponse response; @Before public void setup() { - mlRegisterModelGroupResponse = new MLRegisterModelGroupResponse("ModelGroupId", "Status"); + response = new MLRegisterModelGroupResponse("testModelGroupId", "Status"); } @Test - public void writeTo_Success() throws IOException { + public void writeToSuccess() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - mlRegisterModelGroupResponse.writeTo(bytesStreamOutput); + response.writeTo(bytesStreamOutput); MLRegisterModelGroupResponse newResponse = new MLRegisterModelGroupResponse(bytesStreamOutput.bytes().streamInput()); - assertEquals(mlRegisterModelGroupResponse.getModelGroupId(), newResponse.getModelGroupId()); - assertEquals(mlRegisterModelGroupResponse.getStatus(), newResponse.getStatus()); + assertEquals(response.getModelGroupId(), newResponse.getModelGroupId()); + assertEquals(response.getStatus(), newResponse.getStatus()); } @Test public void testToXContent() throws IOException { - MLRegisterModelGroupResponse response = new MLRegisterModelGroupResponse("ModelGroupId", "Status"); + MLRegisterModelGroupResponse response = new MLRegisterModelGroupResponse("testModelGroupId", "Status"); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, EMPTY_PARAMS); assertNotNull(builder); String jsonStr = TestHelper.xContentBuilderToString(builder); - final String expected = "{\"model_group_id\":\"ModelGroupId\",\"status\":\"Status\"}"; + final String expected = "{\"model_group_id\":\"testModelGroupId\",\"status\":\"Status\"}"; assertEquals(expected, jsonStr); } + + @Test + public void fromActionResponseWithMLRegisterModelGroupResponseSuccess() { + MLRegisterModelGroupResponse responseFromActionResponse = MLRegisterModelGroupResponse.fromActionResponse(response); + assertSame(response, responseFromActionResponse); + assertEquals(response.getModelGroupId(), responseFromActionResponse.getModelGroupId()); + } + + @Test + public void fromActionResponseSuccess() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + response.writeTo(out); + } + }; + MLRegisterModelGroupResponse responseFromActionResponse = MLRegisterModelGroupResponse.fromActionResponse(actionResponse); + assertNotSame(response, responseFromActionResponse); + assertEquals(response.getModelGroupId(), responseFromActionResponse.getModelGroupId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionResponseIOException() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLRegisterModelGroupResponse.fromActionResponse(actionResponse); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java index 483d7c6c85..dd274c0f31 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java @@ -35,7 +35,7 @@ public void setUp(){ } @Test - public void writeTo_Success() throws IOException { + public void writeToSuccess() throws IOException { MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() .updateModelGroupInput(mlUpdateModelGroupInput) @@ -52,7 +52,7 @@ public void writeTo_Success() throws IOException { } @Test - public void validate_Success() { + public void validateSuccess() { MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() .updateModelGroupInput(mlUpdateModelGroupInput) .build(); @@ -61,7 +61,7 @@ public void validate_Success() { } @Test - public void validate_Exception_NullMLRegisterModelGroupInput() { + public void validateWithNullMLUpdateModelGroupInputException() { MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() .build(); ActionRequestValidationException exception = request.validate(); @@ -69,8 +69,8 @@ public void validate_Exception_NullMLRegisterModelGroupInput() { } @Test - // MLRegisterModelGroupInput check its parameters when created, so exception is not thrown here - public void validate_Exception_NullMLModelName() { + // MLUpdateModelGroupInput check its parameters when created, so exception is not thrown here + public void validateWithNullMLModelNameException() { mlUpdateModelGroupInput.setName(null); MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() .updateModelGroupInput(mlUpdateModelGroupInput) @@ -82,7 +82,7 @@ public void validate_Exception_NullMLModelName() { @Test - public void fromActionRequest_Success_WithMLUpdateModelRequest() { + public void fromActionRequestWithMLUpdateModelGroupRequestSuccess() { MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() .updateModelGroupInput(mlUpdateModelGroupInput) .build(); @@ -90,7 +90,7 @@ public void fromActionRequest_Success_WithMLUpdateModelRequest() { } @Test - public void fromActionRequest_Success_WithNonMLUpdateModelRequest() { + public void fromActionRequestWithNonMLUpdateModelGroupRequestSuccess() { MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() .updateModelGroupInput(mlUpdateModelGroupInput) .build(); @@ -111,7 +111,7 @@ public void writeTo(StreamOutput out) throws IOException { } @Test(expected = UncheckedIOException.class) - public void fromActionRequest_IOException() { + public void fromActionRequestIOException() { ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequestTest.java index dc16986f64..12d135a9b3 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequestTest.java @@ -25,8 +25,13 @@ @RunWith(MockitoJUnitRunner.class) public class MLSyncUpNodeRequestTest { + @Mock private DiscoveryNode localNode1; + + @Mock private DiscoveryNode localNode2; + + @Mock private DiscoveryNode localNode3; @Mock @@ -34,31 +39,6 @@ public class MLSyncUpNodeRequestTest { @Before public void setUp() throws Exception { - localNode1 = new DiscoveryNode( - "foo1", - "foo1", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT - ); - localNode2 = new DiscoveryNode( - "foo2", - "foo2", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT - ); - localNode3 = new DiscoveryNode( - "foo3", - "foo3", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT - ); - Map addedWorkerNodes = new HashMap<>(); Map removedWorkerNodes = new HashMap<>(); Map> modelRoutingTable = new HashMap<>(); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequestTest.java index 434ba2dbef..7323f059f3 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequestTest.java @@ -2,6 +2,7 @@ import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.opensearch.Version; import org.opensearch.cluster.node.DiscoveryNode; @@ -20,7 +21,10 @@ @RunWith(MockitoJUnitRunner.class) public class MLUndeployModelNodesRequestTest { + @Mock private DiscoveryNode localNode1; + + @Mock private DiscoveryNode localNode2; @Test @@ -42,23 +46,6 @@ public void testConstructorSerialization1() throws IOException { @Test public void testConstructorSerialization2() throws IOException { - localNode1 = new DiscoveryNode( - "foo1", - "foo1", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT - ); - localNode2 = new DiscoveryNode( - "foo2", - "foo2", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT - ); - MLUndeployModelNodeRequest undeployModelNodeRequest = new MLUndeployModelNodeRequest( new MLUndeployModelNodesRequest(localNode1,localNode2) ); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponseTest.java index e92faef338..28e61476cc 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponseTest.java @@ -16,6 +16,7 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodeResponse; import java.io.IOException; import java.net.InetAddress; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequestTest.java index b78cd6f263..a698c139dd 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequestTest.java @@ -13,6 +13,8 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodeRequest; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodesRequest; import java.io.IOException; import java.net.InetAddress; @@ -30,7 +32,7 @@ public void testConstructorSerialization1() throws IOException { String[] nodeIds = {"nodeId1", "nodeId2", "nodeId3"}; MLUpdateModelCacheNodeRequest updateModelCacheNodeRequest = new MLUpdateModelCacheNodeRequest( - new MLUpdateModelCacheNodesRequest(nodeIds, modelId, true) + new MLUpdateModelCacheNodesRequest(nodeIds, modelId) ); BytesStreamOutput output = new BytesStreamOutput(); @@ -59,7 +61,7 @@ public void testConstructorSerialization2() { ); DiscoveryNode[] nodes = {localNode1, localNode2}; MLUpdateModelCacheNodeRequest updateModelCacheNodeRequest = new MLUpdateModelCacheNodeRequest( - new MLUpdateModelCacheNodesRequest(nodes, modelId, true) + new MLUpdateModelCacheNodesRequest(nodes, modelId) ); assertEquals(2, updateModelCacheNodeRequest.getUpdateModelCacheNodesRequest().concreteNodes().length); } @@ -70,7 +72,7 @@ public void testConstructorFromInputStream() throws IOException { String[] nodeIds = {"nodeId1", "nodeId2", "nodeId3"}; MLUpdateModelCacheNodeRequest updateModelCacheNodeRequest = new MLUpdateModelCacheNodeRequest( - new MLUpdateModelCacheNodesRequest(nodeIds, modelId, true) + new MLUpdateModelCacheNodesRequest(nodeIds, modelId) ); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); updateModelCacheNodeRequest.writeTo(bytesStreamOutput); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponseTest.java index f3d1ac668e..e1fc242d43 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponseTest.java @@ -19,12 +19,13 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodeResponse; +import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodesResponse; import java.io.IOException; import java.net.InetAddress; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; @@ -38,7 +39,6 @@ public class MLUpdateModelCacheNodesResponseTest { private ClusterName clusterName; private DiscoveryNode node1; private DiscoveryNode node2; - private Map modelWorkerNodeCounts; @Before public void setUp() throws Exception { @@ -59,8 +59,6 @@ public void setUp() throws Exception { Collections.singleton(CLUSTER_MANAGER_ROLE), Version.CURRENT ); - modelWorkerNodeCounts = new HashMap<>(); - modelWorkerNodeCounts.put("modelId1", 1); } @Test @@ -78,16 +76,10 @@ public void testSerializationDeserialization1() throws IOException { public void testToXContent() throws IOException { List nodes = new ArrayList<>(); - Map updateModelCacheStatus1 = new HashMap<>(); - updateModelCacheStatus1.put("modelId1", "response"); - Map modelWorkerNodeCounts1 = new HashMap<>(); - modelWorkerNodeCounts1.put("modelId1", new String[]{"mockNode1"}); + Map updateModelCacheStatus1 = Map.of("modelId1", "response"); nodes.add(new MLUpdateModelCacheNodeResponse(node1, updateModelCacheStatus1)); - Map updateModelCacheStatus2 = new HashMap<>(); - updateModelCacheStatus2.put("modelId2", "response"); - Map modelWorkerNodeCounts2 = new HashMap<>(); - modelWorkerNodeCounts2.put("modelId2", new String[]{"mockNode2"}); + Map updateModelCacheStatus2 = Map.of("modelId2", "response"); nodes.add(new MLUpdateModelCacheNodeResponse(node2, updateModelCacheStatus2)); List failures = new ArrayList<>(); @@ -104,8 +96,6 @@ public void testToXContent() throws IOException { @Test public void testNullUpdateModelCacheStatusToXContent() throws IOException { List nodes = new ArrayList<>(); - Map modelWorkerNodeCounts1 = new HashMap<>(); - modelWorkerNodeCounts1.put("modelId1", new String[]{"mockNode1"}); nodes.add(new MLUpdateModelCacheNodeResponse(node1, null)); List failures = new ArrayList<>(); MLUpdateModelCacheNodesResponse response = new MLUpdateModelCacheNodesResponse(clusterName, nodes, failures); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java index 82aa8becde..3760b58a4a 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java @@ -43,7 +43,7 @@ public void setup() { config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "model_group_id", "1.0", - "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, + "Model Description", null, null, MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, false, false, false); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java index 5fdf55757c..39243a887a 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java @@ -33,7 +33,7 @@ public void setUp() { config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "Model Group Id", "1.0", - "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null, null); + "Model Description", null, null, MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null, null); } @Test diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 1472e9bbc9..178228992e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -20,6 +20,8 @@ import java.util.Map; import org.opensearch.OpenSearchStatusException; +import org.opensearch.client.Client; +import org.opensearch.common.util.TokenBucket; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.Connector; @@ -50,6 +52,15 @@ public class AwsConnectorExecutor implements RemoteConnectorExecutor { @Setter @Getter private ScriptService scriptService; + @Setter + @Getter + private TokenBucket modelRateLimiter; + @Setter + @Getter + private Map userRateLimiterMap; + @Setter + @Getter + private Client client; public AwsConnectorExecutor(Connector connector, SdkHttpClient httpClient) { this.connector = (AwsConnector) connector; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 89d8564cf5..d881707195 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -25,6 +25,8 @@ import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.util.EntityUtils; import org.opensearch.OpenSearchStatusException; +import org.opensearch.client.Client; +import org.opensearch.common.util.TokenBucket; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.HttpConnector; @@ -49,6 +51,16 @@ public class HttpJsonConnectorExecutor implements RemoteConnectorExecutor { @Getter private ScriptService scriptService; + @Setter + @Getter + private TokenBucket modelRateLimiter; + @Setter + @Getter + private Map userRateLimiterMap; + @Setter + @Getter + private Client client; + public HttpJsonConnectorExecutor(Connector connector) { this.connector = (HttpConnector) connector; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 92c6b263d1..34575b7ce8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -12,8 +12,13 @@ import java.util.List; import java.util.Map; +import org.opensearch.OpenSearchStatusException; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.TokenBucket; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.Connector; @@ -77,12 +82,22 @@ default void setScriptService(ScriptService scriptService) {} Connector getConnector(); + TokenBucket getModelRateLimiter(); + + Map getUserRateLimiterMap(); + + Client getClient(); + default void setClient(Client client) {} default void setXContentRegistry(NamedXContentRegistry xContentRegistry) {} default void setClusterService(ClusterService clusterService) {} + default void setModelRateLimiter(TokenBucket modelRateLimiter) {} + + default void setUserRateLimiterMap(Map userRateLimiterMap) {} + default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List tensorOutputs) { Connector connector = getConnector(); @@ -101,7 +116,24 @@ default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List parameters, String payload, List tensorOutputs); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index 21baad09db..aea6821334 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -9,6 +9,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.TokenBucket; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; @@ -34,6 +35,8 @@ public class RemoteModel implements Predictable { public static final String SCRIPT_SERVICE = "script_service"; public static final String CLIENT = "client"; public static final String XCONTENT_REGISTRY = "xcontent_registry"; + public static final String MODEL_RATE_LIMITER = "model_rate_limiter_config"; + public static final String USER_RATE_LIMITER_MAP = "user_rate_limiter_map"; private RemoteConnectorExecutor connectorExecutor; @@ -57,10 +60,10 @@ public MLOutput predict(MLInput mlInput) { try { return connectorExecutor.executePredict(mlInput); } catch (RuntimeException e) { - log.error("Failed to call remote model", e); + log.error("Failed to call remote model.", e); throw e; } catch (Throwable e) { - log.error("Failed to call remote model", e); + log.error("Failed to call remote model.", e); throw new MLException(e); } } @@ -85,11 +88,13 @@ public void initModel(MLModel model, Map params, Encryptor encry this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE)); this.connectorExecutor.setClient((Client) params.get(CLIENT)); this.connectorExecutor.setXContentRegistry((NamedXContentRegistry) params.get(XCONTENT_REGISTRY)); + this.connectorExecutor.setModelRateLimiter((TokenBucket) params.get(MODEL_RATE_LIMITER)); + this.connectorExecutor.setUserRateLimiterMap((Map) params.get(USER_RATE_LIMITER_MAP)); } catch (RuntimeException e) { - log.error("Failed to init remote model", e); + log.error("Failed to init remote model.", e); throw e; } catch (Throwable e) { - log.error("Failed to init remote model", e); + log.error("Failed to init remote model.", e); throw new MLException(e); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java index 671f4e548a..26fabba0c2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java @@ -20,6 +20,9 @@ import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX_SCHEMA_VERSION; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX_SCHEMA_VERSION; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX_SCHEMA_VERSION; @@ -36,6 +39,7 @@ public enum MLIndex { TASK(ML_TASK_INDEX, false, ML_TASK_INDEX_MAPPING, ML_TASK_INDEX_SCHEMA_VERSION), CONNECTOR(ML_CONNECTOR_INDEX, false, ML_CONNECTOR_INDEX_MAPPING, ML_CONNECTOR_SCHEMA_VERSION), CONFIG(ML_CONFIG_INDEX, false, ML_CONFIG_INDEX_MAPPING, ML_CONFIG_INDEX_SCHEMA_VERSION), + MODEL_CONTROLLER(ML_MODEL_CONTROLLER_INDEX, false, ML_MODEL_CONTROLLER_INDEX_MAPPING, ML_MODEL_CONTROLLER_INDEX_SCHEMA_VERSION), AGENT(ML_AGENT_INDEX, false, ML_AGENT_INDEX_MAPPING, ML_AGENT_INDEX_SCHEMA_VERSION), MEMORY_META(ML_MEMORY_META_INDEX, false, ML_MEMORY_META_INDEX_MAPPING, ML_MEMORY_META_INDEX_SCHEMA_VERSION), MEMORY_MESSAGE(ML_MEMORY_MESSAGE_INDEX, false, ML_MEMORY_MESSAGE_INDEX_MAPPING, ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java index ca5f88be78..d385ff8bbf 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java @@ -74,6 +74,10 @@ public void initMLConfigIndex(ActionListener listener) { initMLIndexIfAbsent(MLIndex.CONFIG, listener); } + public void initMLModelControllerIndex(ActionListener listener) { + initMLIndexIfAbsent(MLIndex.MODEL_CONTROLLER, listener); + } + public void initMLAgentIndex(ActionListener listener) { initMLIndexIfAbsent(MLIndex.AGENT, listener); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index b35f9b0eac..0119169e2a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -29,6 +29,9 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.Connector; @@ -42,6 +45,7 @@ import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.script.ScriptService; +import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -57,6 +61,16 @@ public class AwsConnectorExecutorTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + Settings settings; + + ThreadContext threadContext; + @Mock ScriptService scriptService; @@ -129,6 +143,11 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio .build(); connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); @@ -169,6 +188,11 @@ public void executePredict_RemoteInferenceInput_InvalidToken() throws IOExceptio .build(); connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); @@ -207,6 +231,11 @@ public void executePredict_RemoteInferenceInput() throws IOException { .build(); connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); ModelTensorOutput modelTensorOutput = executor @@ -252,6 +281,11 @@ public void executePredict_TextDocsInferenceInput() throws IOException { .build(); connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build(); ModelTensorOutput modelTensorOutput = executor diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index a4bc766aa2..ad42c1a3ca 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -28,6 +28,9 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.ingest.TestTemplateService; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.Connector; @@ -41,6 +44,7 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.script.ScriptService; +import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableMap; @@ -48,15 +52,25 @@ public class HttpJsonConnectorExecutorTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); + @Mock + ThreadPool threadPool; + @Mock ScriptService scriptService; @Mock CloseableHttpClient httpClient; + @Mock + Client client; + @Mock CloseableHttpResponse response; + Settings settings; + + ThreadContext threadContext; + @Before public void setUp() { MockitoAnnotations.openMocks(this); @@ -101,6 +115,11 @@ public void executePredict_RemoteInferenceInput() throws IOException { .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); when(response.getEntity()).thenReturn(entity); @@ -142,6 +161,11 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor @@ -180,6 +204,11 @@ public void executePredict_TextDocsInput_LimitExceed() throws IOException { .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); @@ -210,6 +239,11 @@ public void executePredict_TextDocsInput() throws IOException { .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); String modelResponse = "{\n" @@ -291,6 +325,11 @@ public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOE .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); // model takes 2 input docs, but only output 1 embedding @@ -361,6 +400,11 @@ public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs_InvalidStepS .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); // model takes 2 input docs, but only output 1 embedding diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateModelControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateModelControllerTransportAction.java new file mode 100644 index 0000000000..ece13bcd3b --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateModelControllerTransportAction.java @@ -0,0 +1,274 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.controller; + +import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX; +import static org.opensearch.ml.common.FunctionName.REMOTE; +import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.transport.controller.MLCreateModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLCreateModelControllerRequest; +import org.opensearch.ml.common.transport.controller.MLCreateModelControllerResponse; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesRequest; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesResponse; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelCacheHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class CreateModelControllerTransportAction extends HandledTransportAction { + MLIndicesHandler mlIndicesHandler; + Client client; + MLModelManager mlModelManager; + ClusterService clusterService; + MLModelCacheHelper mlModelCacheHelper; + ModelAccessControlHelper modelAccessControlHelper; + + @Inject + public CreateModelControllerTransportAction( + TransportService transportService, + ActionFilters actionFilters, + MLIndicesHandler mlIndicesHandler, + Client client, + ClusterService clusterService, + ModelAccessControlHelper modelAccessControlHelper, + MLModelCacheHelper mlModelCacheHelper, + MLModelManager mlModelManager + ) { + super(MLCreateModelControllerAction.NAME, transportService, actionFilters, MLCreateModelControllerRequest::new); + this.mlIndicesHandler = mlIndicesHandler; + this.client = client; + this.mlModelManager = mlModelManager; + this.clusterService = clusterService; + this.mlModelCacheHelper = mlModelCacheHelper; + this.modelAccessControlHelper = modelAccessControlHelper; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLCreateModelControllerRequest createModelControllerRequest = MLCreateModelControllerRequest.fromActionRequest(request); + MLModelController modelController = createModelControllerRequest.getModelControllerInput(); + String modelId = modelController.getModelId(); + User user = RestActionUtils.getUserContext(client); + String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); + mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { + FunctionName functionName = mlModel.getAlgorithm(); + if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { + modelAccessControlHelper + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { + if (hasPermission) { + if (mlModel.getModelState() != MLModelState.DEPLOYING) { + indexAndCreateModelController(mlModel, modelController, wrappedListener); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Creating a model controller during its corresponding model in DEPLOYING state is not allowed, " + + "please either create the model controller after it is deployed or before deploying it. Model ID: " + + modelId, + RestStatus.CONFLICT + ) + ); + log + .error( + "Failed to create a model controller during its corresponding model in DEPLOYING state. Model ID: " + + modelId + ); + } + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model controller, model ID: " + + modelId, + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log + .error( + "Permission denied: Unable to create the model controller for the model with ID {}. Details: {}", + modelId, + exception + ); + wrappedListener.onFailure(exception); + })); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Creating model controller on this operation on the function category " + + functionName.toString() + + " is not supported.", + RestStatus.FORBIDDEN + ) + ); + } + }, + e -> wrappedListener + .onFailure( + new OpenSearchStatusException( + "Failed to find model to create the corresponding model controller with the provided model ID: " + modelId, + RestStatus.NOT_FOUND + ) + ) + )); + } catch (Exception e) { + log.error("Failed to create model controller for " + modelId, e); + actionListener.onFailure(e); + } + } + + private void indexAndCreateModelController( + MLModel mlModel, + MLModelController modelController, + ActionListener actionListener + ) { + log.info("Indexing the model controller into system index"); + mlIndicesHandler.initMLModelControllerIndex(ActionListener.wrap(indexCreated -> { + if (!indexCreated) { + actionListener.onFailure(new RuntimeException("Failed to create model controller index.")); + return; + } + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener indexResponseListener = ActionListener.wrap(indexResponse -> { + String modelId = indexResponse.getId(); + MLCreateModelControllerResponse response = new MLCreateModelControllerResponse( + modelId, + indexResponse.getResult().name() + ); + log.info("Model controller for model id {} saved into index, result:{}", modelId, indexResponse.getResult()); + if (indexResponse.getResult() == DocWriteResponse.Result.CREATED) { + mlModelManager.updateModel(modelId, Map.of(MLModel.IS_MODEL_CONTROLLER_ENABLED_FIELD, true)); + } + if (mlModelCacheHelper.isModelDeployed(modelId)) { + log.info("Model {} is deployed. Start to deploy the model controller into cache.", modelId); + String[] targetNodeIds = mlModelManager.getWorkerNodes(modelId, mlModel.getAlgorithm()); + MLDeployModelControllerNodesRequest deployModelControllerNodesRequest = new MLDeployModelControllerNodesRequest( + targetNodeIds, + modelController.getModelId() + ); + client + .execute( + MLDeployModelControllerAction.INSTANCE, + deployModelControllerNodesRequest, + ActionListener.wrap(nodesResponse -> { + if (nodesResponse != null && isDeployModelControllerSuccessOnAllNodes(nodesResponse)) { + log.info("Successfully create model controller and deploy it into cache with model ID {}", modelId); + actionListener.onResponse(response); + } else { + String[] nodeIds = getDeployModelControllerFailedNodesList(nodesResponse); + log + .error( + "Successfully create model controller index with model ID {} but deploy model controller to cache was failed on following nodes {}, please retry.", + modelId, + Arrays.toString(nodeIds) + ); + actionListener + .onFailure( + new RuntimeException( + "Successfully create model controller index with model ID " + + modelId + + " but deploy model controller to cache was failed on following nodes " + + Arrays.toString(nodeIds) + + ", please retry." + ) + ); + } + }, e -> { + log.error("Failed to deploy model controller for model: {}" + modelId, e); + actionListener.onFailure(e); + }) + ); + } else { + actionListener.onResponse(response); + } + }, actionListener::onFailure); + + IndexRequest indexRequest = new IndexRequest(ML_MODEL_CONTROLLER_INDEX).id(modelController.getModelId()); + indexRequest + .source(modelController.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(indexRequest, ActionListener.runBefore(indexResponseListener, context::restore)); + } catch (Exception e) { + log.error("Failed to save model controller", e); + actionListener.onFailure(e); + } + }, e -> { + log.error("Failed to init model controller index", e); + actionListener.onFailure(e); + })); + } + + private boolean isDeployModelControllerSuccessOnAllNodes(MLDeployModelControllerNodesResponse deployModelControllerNodesResponse) { + return deployModelControllerNodesResponse.failures() == null || deployModelControllerNodesResponse.failures().isEmpty(); + } + + private String[] getDeployModelControllerFailedNodesList(MLDeployModelControllerNodesResponse deployModelControllerNodesResponse) { + if (deployModelControllerNodesResponse == null) { + return getAllNodes(); + } else { + List nodeIds = new ArrayList<>(); + for (FailedNodeException failedNodeException : deployModelControllerNodesResponse.failures()) { + nodeIds.add(failedNodeException.nodeId()); + } + return nodeIds.toArray(new String[0]); + } + } + + private String[] getAllNodes() { + Iterator iterator = clusterService.state().nodes().iterator(); + List nodeIds = new ArrayList<>(); + while (iterator.hasNext()) { + nodeIds.add(iterator.next().getId()); + } + return nodeIds.toArray(new String[0]); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteModelControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteModelControllerTransportAction.java new file mode 100644 index 0000000000..97c002e3f2 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteModelControllerTransportAction.java @@ -0,0 +1,250 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.controller; + +import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.transport.controller.MLModelControllerDeleteAction; +import org.opensearch.ml.common.transport.controller.MLModelControllerDeleteRequest; +import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodesRequest; +import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodesResponse; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelCacheHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(level = AccessLevel.PRIVATE) +public class DeleteModelControllerTransportAction extends HandledTransportAction { + Client client; + NamedXContentRegistry xContentRegistry; + ClusterService clusterService; + MLModelManager mlModelManager; + MLModelCacheHelper mlModelCacheHelper; + ModelAccessControlHelper modelAccessControlHelper; + + @Inject + public DeleteModelControllerTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry, + ClusterService clusterService, + MLModelManager mlModelManager, + MLModelCacheHelper mlModelCacheHelper, + ModelAccessControlHelper modelAccessControlHelper + ) { + super(MLModelControllerDeleteAction.NAME, transportService, actionFilters, MLModelControllerDeleteRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + this.clusterService = clusterService; + this.mlModelManager = mlModelManager; + this.mlModelCacheHelper = mlModelCacheHelper; + this.modelAccessControlHelper = modelAccessControlHelper; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLModelControllerDeleteRequest modelControllerDeleteRequest = MLModelControllerDeleteRequest.fromActionRequest(request); + String modelId = modelControllerDeleteRequest.getModelId(); + User user = RestActionUtils.getUserContext(client); + String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); + mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { + modelAccessControlHelper + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { + if (hasPermission) { + mlModelManager + .getModelController( + modelId, + ActionListener + .wrap( + modelController -> deleteModelControllerWithDeployedModel(modelId, wrappedListener), + deleteException -> { + log.error(deleteException); + wrappedListener.onFailure(deleteException); + } + ) + ); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model controller, model ID: " + + modelId, + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log + .error( + "Permission denied: Unable to delete the model controller with the provided model id {}. Details: {}", + modelId, + exception + ); + wrappedListener.onFailure(exception); + })); + }, e -> { + log + .warn( + "Failed to find corresponding model during deleting the model controller. Now trying to delete the model controller alone. Model ID: " + + modelId + ); + mlModelManager + .getModelController( + modelId, + ActionListener + .wrap(modelController -> deleteModelControllerWithDeployedModel(modelId, wrappedListener), deleteException -> { + log.error(deleteException); + wrappedListener.onFailure(deleteException); + }) + ); + })); + } catch (Exception e) { + log.error("Failed to delete model controller for model" + modelId, e); + actionListener.onFailure(e); + } + } + + // This method is used to handle the condition if we need to undeploy the model controller before deleting it from the index or not. + private void deleteModelControllerWithDeployedModel(String modelId, ActionListener actionListener) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + if (mlModelCacheHelper.isModelDeployed(modelId)) { + log.info("Model has already been deployed in ML cache, need undeploy model controller before sending delete request."); + String[] targetNodeIds = getAllNodes(); + MLUndeployModelControllerNodesRequest undeployModelControllerNodesRequest = new MLUndeployModelControllerNodesRequest( + targetNodeIds, + modelId + ); + client + .execute( + MLUndeployModelControllerAction.INSTANCE, + undeployModelControllerNodesRequest, + ActionListener.runBefore(ActionListener.wrap(nodesResponse -> { + if (nodesResponse != null && isUndeployModelControllerSuccessOnAllNodes(nodesResponse)) { + log + .info( + "Successfully undeploy model controller from cache. Start to delete the model controller for model {}", + modelId + ); + deleteModelController(modelId, actionListener); + } else { + String[] nodeIds = getUndeployModelControllerFailedNodesList(nodesResponse); + log + .error( + "Failed to undeploy model controller with model ID {} on following nodes {}, deletion is aborted. Please retry or undeploy the model manually and then perform the deletion.", + modelId, + Arrays.toString(nodeIds) + ); + actionListener + .onFailure( + new RuntimeException( + "Failed to undeploy model controller with model ID " + + modelId + + " on following nodes " + + Arrays.toString(nodeIds) + + ", deletion is aborted. Please retry or undeploy the model manually and then perform the deletion." + ) + ); + } + }, e -> { + log + .error( + "Failed to undeploy model controller from cache and delete the model controller for model {}", + modelId, + e + ); + actionListener.onFailure(e); + }), context::restore) + ); + } else { + deleteModelController(modelId, actionListener); + } + } catch (Exception e) { + log.error("Failed to delete model controller", e); + actionListener.onFailure(e); + } + } + + private void deleteModelController(String modelId, ActionListener actionListener) { + DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_CONTROLLER_INDEX, modelId); + client.delete(deleteRequest, new ActionListener<>() { + @Override + public void onResponse(DeleteResponse deleteResponse) { + log.info("Model controller for model {} successfully deleted from index, result: {}", modelId, deleteResponse.getResult()); + mlModelManager.updateModel(modelId, Map.of(MLModel.IS_MODEL_CONTROLLER_ENABLED_FIELD, false)); + actionListener.onResponse(deleteResponse); + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to delete model controller for model: " + modelId, e); + actionListener.onFailure(e); + } + }); + } + + private boolean isUndeployModelControllerSuccessOnAllNodes( + MLUndeployModelControllerNodesResponse undeployModelControllerNodesResponse + ) { + return undeployModelControllerNodesResponse.failures() == null || undeployModelControllerNodesResponse.failures().isEmpty(); + } + + private String[] getUndeployModelControllerFailedNodesList( + MLUndeployModelControllerNodesResponse undeployModelControllerNodesResponse + ) { + if (undeployModelControllerNodesResponse == null) { + return getAllNodes(); + } else { + List nodeIds = new ArrayList<>(); + for (FailedNodeException failedNodeException : undeployModelControllerNodesResponse.failures()) { + nodeIds.add(failedNodeException.nodeId()); + } + return nodeIds.toArray(new String[0]); + } + } + + private String[] getAllNodes() { + Iterator iterator = clusterService.state().nodes().iterator(); + List nodeIds = new ArrayList<>(); + while (iterator.hasNext()) { + nodeIds.add(iterator.next().getId()); + } + return nodeIds.toArray(new String[0]); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/DeployModelControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/DeployModelControllerTransportAction.java new file mode 100644 index 0000000000..9744cdc917 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/DeployModelControllerTransportAction.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.controller; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodeRequest; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodeResponse; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesRequest; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesResponse; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.stats.MLStats; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class DeployModelControllerTransportAction extends + TransportNodesAction { + + private final MLModelManager mlModelManager; + private final ClusterService clusterService; + private final Client client; + private DiscoveryNodeHelper nodeFilter; + private final MLStats mlStats; + private NamedXContentRegistry xContentRegistry; + + private ModelAccessControlHelper modelAccessControlHelper; + + @Inject + public DeployModelControllerTransportAction( + TransportService transportService, + ActionFilters actionFilters, + MLModelManager mlModelManager, + ClusterService clusterService, + ThreadPool threadPool, + Client client, + DiscoveryNodeHelper nodeFilter, + MLStats mlStats, + NamedXContentRegistry xContentRegistry, + ModelAccessControlHelper modelAccessControlHelper + ) { + super( + MLDeployModelControllerAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + MLDeployModelControllerNodesRequest::new, + MLDeployModelControllerNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + MLDeployModelControllerNodeResponse.class + ); + this.mlModelManager = mlModelManager; + this.clusterService = clusterService; + this.client = client; + this.nodeFilter = nodeFilter; + this.mlStats = mlStats; + this.xContentRegistry = xContentRegistry; + this.modelAccessControlHelper = modelAccessControlHelper; + } + + @Override + protected MLDeployModelControllerNodesResponse newResponse( + MLDeployModelControllerNodesRequest request, + List responses, + List failures + ) { + return new MLDeployModelControllerNodesResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected MLDeployModelControllerNodeRequest newNodeRequest(MLDeployModelControllerNodesRequest request) { + return new MLDeployModelControllerNodeRequest(request); + } + + @Override + protected MLDeployModelControllerNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new MLDeployModelControllerNodeResponse(in); + } + + @Override + protected MLDeployModelControllerNodeResponse nodeOperation(MLDeployModelControllerNodeRequest request) { + return createDeployModelControllerNodeResponse(request.getDeployModelControllerNodesRequest()); + } + + private MLDeployModelControllerNodeResponse createDeployModelControllerNodeResponse( + MLDeployModelControllerNodesRequest deployModelControllerNodesRequest + ) { + String modelId = deployModelControllerNodesRequest.getModelId(); + + Map modelControllerDeployStatus = new HashMap<>(); + modelControllerDeployStatus.put(modelId, "received"); + + String localNodeId = clusterService.localNode().getId(); + + mlModelManager.deployModelControllerWithDeployedModel(modelId, ActionListener.wrap(r -> { + log.info("Successfully deployed model controller for model {} on node {}", modelId, localNodeId); + }, e -> { log.error("Failed to deploy model controller for model {} on node {}", modelId, localNodeId, e); })); + return new MLDeployModelControllerNodeResponse(clusterService.localNode(), modelControllerDeployStatus); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/GetModelControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/GetModelControllerTransportAction.java new file mode 100644 index 0000000000..ddbf69b0f6 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/GetModelControllerTransportAction.java @@ -0,0 +1,150 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.controller; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX; +import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; +import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.transport.controller.MLModelControllerGetAction; +import org.opensearch.ml.common.transport.controller.MLModelControllerGetRequest; +import org.opensearch.ml.common.transport.controller.MLModelControllerGetResponse; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class GetModelControllerTransportAction extends HandledTransportAction { + Client client; + NamedXContentRegistry xContentRegistry; + ClusterService clusterService; + MLModelManager mlModelManager; + ModelAccessControlHelper modelAccessControlHelper; + + @Inject + public GetModelControllerTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry, + ClusterService clusterService, + MLModelManager mlModelManager, + ModelAccessControlHelper modelAccessControlHelper + ) { + super(MLModelControllerGetAction.NAME, transportService, actionFilters, MLModelControllerGetRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + this.clusterService = clusterService; + this.mlModelManager = mlModelManager; + this.modelAccessControlHelper = modelAccessControlHelper; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLModelControllerGetRequest modelControllerGetRequest = MLModelControllerGetRequest.fromActionRequest(request); + String modelId = modelControllerGetRequest.getModelId(); + FetchSourceContext fetchSourceContext = getFetchSourceContext(modelControllerGetRequest.isReturnContent()); + GetRequest getRequest = new GetRequest(ML_MODEL_CONTROLLER_INDEX).id(modelId).fetchSourceContext(fetchSourceContext); + User user = RestActionUtils.getUserContext(client); + String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); + client.get(getRequest, ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLModelController modelController = MLModelController.parse(parser); + mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { + modelAccessControlHelper + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { + if (hasPermission) { + wrappedListener + .onResponse(MLModelControllerGetResponse.builder().modelController(modelController).build()); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model controller, model ID: " + + modelId, + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log + .error( + "Permission denied: Unable to create the model controller for the model with ID {}. Details: {}", + modelId, + exception + ); + wrappedListener.onFailure(exception); + })); + }, + e -> wrappedListener + .onFailure( + new OpenSearchStatusException( + "Failed to find model to get the corresponding model controller with the provided model ID: " + + modelId, + RestStatus.NOT_FOUND + ) + ) + )); + + } catch (Exception e) { + log.error("Failed to parse model controller with model ID: " + r.getId(), e); + wrappedListener.onFailure(e); + } + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Failed to find model controller with the provided model ID: " + modelId, + RestStatus.NOT_FOUND + ) + ); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + log.error("Failed to get model controller index", e); + wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model controller", RestStatus.NOT_FOUND)); + } else { + log.error("Failed to get model controller " + modelId, e); + wrappedListener.onFailure(e); + } + })); + } catch (Exception e) { + log.error("Failed to get model controller " + modelId, e); + actionListener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/UndeployModelControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/UndeployModelControllerTransportAction.java new file mode 100644 index 0000000000..eb8bc04d8c --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/UndeployModelControllerTransportAction.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.controller; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodeRequest; +import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodeResponse; +import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodesRequest; +import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodesResponse; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.stats.MLStats; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class UndeployModelControllerTransportAction extends + TransportNodesAction { + + private final MLModelManager mlModelManager; + private final ClusterService clusterService; + private final Client client; + private DiscoveryNodeHelper nodeFilter; + private final MLStats mlStats; + private NamedXContentRegistry xContentRegistry; + + private ModelAccessControlHelper modelAccessControlHelper; + + @Inject + public UndeployModelControllerTransportAction( + TransportService transportService, + ActionFilters actionFilters, + MLModelManager mlModelManager, + ClusterService clusterService, + ThreadPool threadPool, + Client client, + DiscoveryNodeHelper nodeFilter, + MLStats mlStats, + NamedXContentRegistry xContentRegistry, + ModelAccessControlHelper modelAccessControlHelper + ) { + super( + MLUndeployModelControllerAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + MLUndeployModelControllerNodesRequest::new, + MLUndeployModelControllerNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + MLUndeployModelControllerNodeResponse.class + ); + this.mlModelManager = mlModelManager; + this.clusterService = clusterService; + this.client = client; + this.nodeFilter = nodeFilter; + this.mlStats = mlStats; + this.xContentRegistry = xContentRegistry; + this.modelAccessControlHelper = modelAccessControlHelper; + } + + @Override + protected MLUndeployModelControllerNodesResponse newResponse( + MLUndeployModelControllerNodesRequest request, + List responses, + List failures + ) { + return new MLUndeployModelControllerNodesResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected MLUndeployModelControllerNodeRequest newNodeRequest(MLUndeployModelControllerNodesRequest request) { + return new MLUndeployModelControllerNodeRequest(request); + } + + @Override + protected MLUndeployModelControllerNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new MLUndeployModelControllerNodeResponse(in); + } + + @Override + protected MLUndeployModelControllerNodeResponse nodeOperation(MLUndeployModelControllerNodeRequest request) { + return createUndeployModelControllerNodeResponse(request.getUndeployModelControllerNodesRequest()); + } + + private MLUndeployModelControllerNodeResponse createUndeployModelControllerNodeResponse( + MLUndeployModelControllerNodesRequest undeployModelControllerNodesRequest + ) { + String modelId = undeployModelControllerNodesRequest.getModelId(); + + Map modelControllerUndeployStatus = new HashMap<>(); + modelControllerUndeployStatus.put(modelId, "received"); + + String localNodeId = clusterService.localNode().getId(); + + mlModelManager.undeployModelController(modelId, ActionListener.wrap(r -> { + log.info("Successfully undeployed model controller for model {} on node {}", modelId, localNodeId); + }, e -> { log.error("Failed to undeploy model controller for model {} on node {}", modelId, localNodeId, e); })); + return new MLUndeployModelControllerNodeResponse(clusterService.localNode(), modelControllerUndeployStatus); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateModelControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateModelControllerTransportAction.java new file mode 100644 index 0000000000..f5ce76066f --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateModelControllerTransportAction.java @@ -0,0 +1,274 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.controller; + +import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX; +import static org.opensearch.ml.common.FunctionName.REMOTE; +import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesRequest; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesResponse; +import org.opensearch.ml.common.transport.controller.MLUpdateModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLUpdateModelControllerRequest; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelCacheHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class UpdateModelControllerTransportAction extends HandledTransportAction { + Client client; + MLModelManager mlModelManager; + MLModelCacheHelper mlModelCacheHelper; + ClusterService clusterService; + ModelAccessControlHelper modelAccessControlHelper; + + @Inject + public UpdateModelControllerTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + ModelAccessControlHelper modelAccessControlHelper, + MLModelCacheHelper mlModelCacheHelper, + MLModelManager mlModelManager + ) { + super(MLUpdateModelControllerAction.NAME, transportService, actionFilters, MLUpdateModelControllerRequest::new); + this.client = client; + this.mlModelManager = mlModelManager; + this.clusterService = clusterService; + this.mlModelCacheHelper = mlModelCacheHelper; + this.modelAccessControlHelper = modelAccessControlHelper; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLUpdateModelControllerRequest updateModelControllerRequest = MLUpdateModelControllerRequest.fromActionRequest(request); + MLModelController updateModelControllerInput = updateModelControllerRequest.getUpdateModelControllerInput(); + String modelId = updateModelControllerInput.getModelId(); + User user = RestActionUtils.getUserContext(client); + String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); + mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { + FunctionName functionName = mlModel.getAlgorithm(); + if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { + modelAccessControlHelper + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { + if (hasPermission) { + mlModelManager.getModelController(modelId, ActionListener.wrap(modelController -> { + boolean isDeployRequiredAfterUpdate = modelController + .isDeployRequiredAfterUpdate(updateModelControllerInput); + modelController.update(updateModelControllerInput); + updateModelController(mlModel, modelController, isDeployRequiredAfterUpdate, wrappedListener); + }, e -> { + if (mlModel.getIsModelControllerEnabled() == null || !mlModel.getIsModelControllerEnabled()) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Model controller haven't been created for the model. Consider calling create model controller api instead. Model ID: " + + modelId, + RestStatus.CONFLICT + ) + ); + log.error("Model controller haven't been created for the model: " + modelId, e); + } else { + log.error(e); + wrappedListener.onFailure(e); + } + })); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model controller, model ID: " + + modelId, + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log + .error( + "Permission denied: Unable to create the model controller for the model with ID {}. Details: {}", + modelId, + exception + ); + wrappedListener.onFailure(exception); + })); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Creating model controller on this operation on the function category " + + functionName.toString() + + " is not supported.", + RestStatus.FORBIDDEN + ) + ); + } + }, + e -> wrappedListener + .onFailure( + new OpenSearchStatusException( + "Failed to find model to create the corresponding model controller with the provided model ID: " + modelId, + RestStatus.NOT_FOUND + ) + ) + )); + } catch (Exception e) { + log.error("Failed to create model controller for " + modelId, e); + actionListener.onFailure(e); + } + } + + private void updateModelController( + MLModel mlModel, + MLModelController modelController, + boolean isDeployRequiredAfterUpdate, + ActionListener actionListener + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + String modelId = mlModel.getModelId(); + ActionListener updateResponseListener = ActionListener.wrap(updateResponse -> { + if (updateResponse != null && updateResponse.getResult() == DocWriteResponse.Result.UPDATED) { + log + .info( + "Model controller for model {} successfully updated to index, result: {}", + modelId, + updateResponse.getResult() + ); + if (mlModelCacheHelper.isModelDeployed(modelId) && isDeployRequiredAfterUpdate) { + log + .info( + "Model {} is deployed and the user rate limiter config is constructable. Start to deploy the model controller into cache.", + modelId + ); + String[] targetNodeIds = mlModelManager.getWorkerNodes(modelId, mlModel.getAlgorithm()); + MLDeployModelControllerNodesRequest deployModelControllerNodesRequest = new MLDeployModelControllerNodesRequest( + targetNodeIds, + modelId + ); + client + .execute( + MLDeployModelControllerAction.INSTANCE, + deployModelControllerNodesRequest, + ActionListener.wrap(nodesResponse -> { + if (nodesResponse != null && isDeployModelControllerSuccessOnAllNodes(nodesResponse)) { + log.info("Successfully update model controller and deploy it into cache with model ID {}", modelId); + actionListener.onResponse(updateResponse); + } else { + String[] nodeIds = getDeployModelControllerFailedNodesList(nodesResponse); + log + .error( + "Successfully update model controller index with model ID {} but deploy model controller to cache was failed on following nodes {}, please retry.", + modelId, + Arrays.toString(nodeIds) + ); + actionListener + .onFailure( + new RuntimeException( + "Successfully update model controller index with model ID " + + modelId + + " but deploy model controller to cache was failed on following nodes " + + Arrays.toString(nodeIds) + + ", please retry." + ) + ); + } + }, e -> { + log.error("Failed to deploy model controller for model: {}" + modelId, e); + actionListener.onFailure(e); + }) + ); + } else { + actionListener.onResponse(updateResponse); + } + } else if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { + // The update response returned an unexpected status may indicate a failed update + log + .warn( + "Update model controller for model {} got a result status other than update, result status: {}", + modelId, + updateResponse.getResult() + ); + actionListener.onResponse(updateResponse); + } else { + log.error("Failed to update model controller with model ID: " + modelId); + actionListener.onFailure(new RuntimeException("Failed to update model controller with model ID: " + modelId)); + } + }, actionListener::onFailure); + UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_CONTROLLER_INDEX, modelId); + updateRequest.doc(modelController.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.update(updateRequest, ActionListener.runBefore(updateResponseListener, context::restore)); + } catch (Exception e) { + log.error("Failed to update model controller.", e); + actionListener.onFailure(e); + } + } + + private boolean isDeployModelControllerSuccessOnAllNodes(MLDeployModelControllerNodesResponse deployModelControllerNodesResponse) { + return deployModelControllerNodesResponse.failures() == null || deployModelControllerNodesResponse.failures().isEmpty(); + } + + private String[] getDeployModelControllerFailedNodesList(MLDeployModelControllerNodesResponse deployModelControllerNodesResponse) { + if (deployModelControllerNodesResponse == null) { + return getAllNodes(); + } else { + List nodeIds = new ArrayList<>(); + for (FailedNodeException failedNodeException : deployModelControllerNodesResponse.failures()) { + nodeIds.add(failedNodeException.nodeId()); + } + return nodeIds.toArray(new String[0]); + } + } + + private String[] getAllNodes() { + Iterator iterator = clusterService.state().nodes().iterator(); + List nodeIds = new ArrayList<>(); + while (iterator.hasNext()) { + nodeIds.add(iterator.next().getId()); + } + return nodeIds.toArray(new String[0]); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index 984dbdd451..4c06a0391c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action.models; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; import static org.opensearch.ml.common.MLModel.IS_HIDDEN_FIELD; @@ -16,6 +17,7 @@ import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionRequest; +import org.opensearch.action.DocWriteResponse; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; @@ -215,6 +217,7 @@ private void deleteModel(String modelId, ActionListener actionLi @Override public void onResponse(DeleteResponse deleteResponse) { deleteModelChunks(modelId, deleteResponse, actionListener); + deleteModelController(modelId); } @Override @@ -222,12 +225,50 @@ public void onFailure(Exception e) { log.error("Failed to delete model meta data for model: " + modelId, e); if (e instanceof ResourceNotFoundException) { deleteModelChunks(modelId, null, actionListener); + deleteModelController(modelId); } actionListener.onFailure(e); } }); } + /** + * Delete the model controller for a model after the model is deleted from the ML index. + * + * @param modelId model ID + */ + private void deleteModelController(String modelId, ActionListener actionListener) { + DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_CONTROLLER_INDEX, modelId); + client.delete(deleteRequest, new ActionListener<>() { + @Override + public void onResponse(DeleteResponse deleteResponse) { + log.info("Model controller for model {} successfully deleted from index, result: {}", modelId, deleteResponse.getResult()); + actionListener.onResponse(deleteResponse); + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to delete model controller for model: " + modelId, e); + actionListener.onFailure(e); + } + }); + } + + /** + * Delete the model controller for a model after the model is deleted from the ML index with build-in listener. + * + * @param modelId model ID + */ + private void deleteModelController(String modelId) { + deleteModelController(modelId, ActionListener.wrap(deleteResponse -> { + if (deleteResponse.getResult() == DocWriteResponse.Result.DELETED) { + log.info("Model controller for model {} successfully deleted from index, result: {}", modelId, deleteResponse.getResult()); + } else { + log.warn("The deletion of model controller for model {} returned with result: {}", modelId, deleteResponse.getResult()); + } + }, e -> log.error("Failed to re-deploy the model controller for model: " + modelId, e))); + } + private Boolean isModelNotDeployed(MLModelState mlModelState) { return !mlModelState.equals(MLModelState.LOADED) && !mlModelState.equals(MLModelState.LOADING) diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java index aa71c59570..ccfb206aa9 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java @@ -23,6 +23,7 @@ import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.FailedNodeException; import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -45,12 +46,12 @@ import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.model.MLUpdateModelAction; import org.opensearch.ml.common.transport.model.MLUpdateModelInput; import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheAction; -import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodeResponse; import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodesRequest; import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodesResponse; import org.opensearch.ml.engine.MLEngine; @@ -124,20 +125,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { if (!isModelDeploying(mlModel.getModelState())) { - boolean isModelDeployed = isModelDeployed(mlModel.getModelState()); FunctionName functionName = mlModel.getAlgorithm(); // TODO: Support update as well as model/user level throttling in all other DLModel categories if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { if (mlModel.getIsHidden() != null && mlModel.getIsHidden()) { if (isSuperAdmin) { - updateRemoteOrTextEmbeddingModel( - modelId, - updateModelInput, - mlModel, - user, - wrappedListener, - isModelDeployed - ); + updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, wrappedListener); } else { wrappedListener .onFailure( @@ -151,14 +144,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (hasPermission) { - updateRemoteOrTextEmbeddingModel( - modelId, - updateModelInput, - mlModel, - user, - wrappedListener, - isModelDeployed - ); + updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, wrappedListener); } else { wrappedListener .onFailure( @@ -188,7 +174,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener, - boolean isModelDeployed + ActionListener wrappedListener ) { String newModelGroupId = (Strings.hasLength(updateModelInput.getModelGroupId()) && !Objects.equals(updateModelInput.getModelGroupId(), mlModel.getModelGroupId())) ? updateModelInput.getModelGroupId() : null; String newConnectorId = Strings.hasLength(updateModelInput.getConnectorId()) ? updateModelInput.getConnectorId() : null; - + boolean isModelDeployed = isModelDeployed(mlModel.getModelState()); + // This flag is used to decide if we need to re-deploy the predictor(model) when updating the model cache. + // If one of the internal connector, stand-alone connector id, model quota flag, as well as the model rate limiter needs update, we + // need to perform a re-deploy. + boolean isPredictorUpdate = (updateModelInput.getConnector() != null) + || (newConnectorId != null) + || !Objects.equals(updateModelInput.getIsEnabled(), mlModel.getIsEnabled()); + if (MLRateLimiter.updateValidityPreCheck(mlModel.getModelRateLimiterConfig(), updateModelInput.getModelRateLimiterConfig())) { + MLRateLimiter updatedRateLimiterConfig = MLRateLimiter + .update(mlModel.getModelRateLimiterConfig(), updateModelInput.getModelRateLimiterConfig()); + updateModelInput.setModelRateLimiterConfig(updatedRateLimiterConfig); + // An un-constructable updatedRateLimiterConfig does not require predictor to be re-deployed. + isPredictorUpdate = isPredictorUpdate || (updatedRateLimiterConfig.isValid()); + } + // This flag is used to decide if we need to update the model cache + boolean isUpdateModelCache = isPredictorUpdate && isModelDeployed; if (mlModel.getAlgorithm() == TEXT_EMBEDDING) { - if (newConnectorId == null && updateModelInput.getConnectorUpdateContent() == null) { + if (newConnectorId == null && updateModelInput.getConnector() == null) { updateModelWithRegisteringToAnotherModelGroup( modelId, newModelGroupId, user, updateModelInput, wrappedListener, - isModelDeployed + isUpdateModelCache ); } else { wrappedListener @@ -242,12 +242,12 @@ private void updateRemoteOrTextEmbeddingModel( } else { // mlModel.getAlgorithm() == REMOTE if (newConnectorId == null) { - if (updateModelInput.getConnectorUpdateContent() != null) { + if (updateModelInput.getConnector() != null) { Connector connector = mlModel.getConnector(); - connector.update(updateModelInput.getConnectorUpdateContent(), mlEngine::encrypt); + connector.update(updateModelInput.getConnector(), mlEngine::encrypt); connector.validateConnectorURL(trustedConnectorEndpointsRegex); - updateModelInput.setConnector(connector); - updateModelInput.setConnectorUpdateContent(null); + updateModelInput.setUpdatedConnector(connector); + updateModelInput.setConnector(null); } updateModelWithRegisteringToAnotherModelGroup( modelId, @@ -255,7 +255,7 @@ private void updateRemoteOrTextEmbeddingModel( user, updateModelInput, wrappedListener, - isModelDeployed + isUpdateModelCache ); } else { updateModelWithNewStandAloneConnector( @@ -266,7 +266,7 @@ private void updateRemoteOrTextEmbeddingModel( user, updateModelInput, wrappedListener, - isModelDeployed + isUpdateModelCache ); } } @@ -280,7 +280,7 @@ private void updateModelWithNewStandAloneConnector( User user, MLUpdateModelInput updateModelInput, ActionListener wrappedListener, - boolean isModelDeployed + boolean isUpdateModelCache ) { if (Strings.hasLength(mlModel.getConnectorId())) { connectorAccessControlHelper.validateConnectorAccess(client, newConnectorId, ActionListener.wrap(hasNewConnectorPermission -> { @@ -291,7 +291,7 @@ private void updateModelWithNewStandAloneConnector( user, updateModelInput, wrappedListener, - isModelDeployed + isUpdateModelCache ); } else { wrappedListener @@ -323,13 +323,9 @@ private void updateModelWithRegisteringToAnotherModelGroup( User user, MLUpdateModelInput updateModelInput, ActionListener wrappedListener, - boolean isModelDeployed + boolean isUpdateModelCache ) { UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_INDEX, modelId); - // This flag is used to decide if we need to re-deploy the predictor(model) when performing the in-place update - boolean isPredictorUpdate = (updateModelInput.getConnector() != null || updateModelInput.getConnectorId() != null); - // This flag is used to decide if we need to perform an in-place update - boolean isUpdateModelCache = isModelDeployed && isPredictorUpdate; if (newModelGroupId != null) { modelAccessControlHelper .validateModelGroupAccess(user, newModelGroupId, client, ActionListener.wrap(hasNewModelGroupPermission -> { @@ -342,8 +338,7 @@ private void updateModelWithRegisteringToAnotherModelGroup( updateModelInput, newModelGroupResponse, wrappedListener, - isUpdateModelCache, - isPredictorUpdate + isUpdateModelCache ); }, exception -> wrappedListener @@ -370,7 +365,7 @@ private void updateModelWithRegisteringToAnotherModelGroup( wrappedListener.onFailure(exception); })); } else { - updateRequestConstructor(modelId, updateRequest, updateModelInput, wrappedListener, isUpdateModelCache, isPredictorUpdate); + updateRequestConstructor(modelId, updateRequest, updateModelInput, wrappedListener, isUpdateModelCache); } } @@ -379,8 +374,7 @@ private void updateRequestConstructor( UpdateRequest updateRequest, MLUpdateModelInput updateModelInput, ActionListener wrappedListener, - boolean isUpdateModelCache, - boolean isPredictorUpdate + boolean isUpdateModelCache ) { try { updateModelInput.setLastUpdateTime(Instant.now()); @@ -389,11 +383,7 @@ private void updateRequestConstructor( updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); if (isUpdateModelCache) { String[] targetNodeIds = getAllNodes(); - MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest( - targetNodeIds, - modelId, - isPredictorUpdate - ); + MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest(targetNodeIds, modelId); client .update( updateRequest, @@ -415,8 +405,7 @@ private void updateRequestConstructor( MLUpdateModelInput updateModelInput, GetResponse newModelGroupResponse, ActionListener wrappedListener, - boolean isUpdateModelCache, - boolean isPredictorUpdate + boolean isUpdateModelCache ) { Map newModelGroupSourceMap = newModelGroupResponse.getSourceAsMap(); String updatedVersion = incrementLatestVersion(newModelGroupSourceMap); @@ -429,18 +418,13 @@ private void updateRequestConstructor( newModelGroupResponse.getPrimaryTerm(), Integer.parseInt(updatedVersion) ); - updateModelGroupRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); try { updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); updateRequest.docAsUpsert(true); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); if (isUpdateModelCache) { String[] targetNodeIds = getAllNodes(); - MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest( - targetNodeIds, - modelId, - isPredictorUpdate - ); + MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest(targetNodeIds, modelId); client.update(updateModelGroupRequest, ActionListener.wrap(r -> { client .update( @@ -485,25 +469,25 @@ private ActionListener getUpdateResponseListenerWithUpdateModelC return ActionListener.wrap(updateResponse -> { if (updateResponse != null && updateResponse.getResult() == DocWriteResponse.Result.UPDATED) { client.execute(MLUpdateModelCacheAction.INSTANCE, mlUpdateModelCacheNodesRequest, ActionListener.wrap(r -> { - if (isUpdateModelCacheSuccessOnAllNodes(modelId, r)) { + if (r != null && isUpdateModelCacheSuccessOnAllNodes(r)) { log.info("Successfully updated ML model cache with model ID {}", modelId); wrappedListener.onResponse(updateResponse); } else { - String[] nodeIds = getUpdateModelCacheFailedNodesList(modelId, r); + String[] nodeIds = getUpdateModelCacheFailedNodesList(r); log .error( - "Successfully update ML model index with model ID {} but update model cache was failed on following nodes {}, maybe retry?", + "Successfully update ML model index with model ID {} but update model cache was failed on following nodes {}, please retry or redeploy model manually.", modelId, Arrays.toString(nodeIds) ); wrappedListener .onFailure( new RuntimeException( - "Successfully update ML model index with model ID" + "Successfully update ML model index with model ID " + modelId - + "but update model cache was failed on following nodes " + + " but update model cache was failed on following nodes " + Arrays.toString(nodeIds) - + ", maybe retry?" + + ", please retry or redeploy model manually." ) ); } @@ -513,7 +497,12 @@ private ActionListener getUpdateResponseListenerWithUpdateModelC })); } else if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { // The update response returned an unexpected status may indicate a failed update - log.warn("Model id:{} failed update with result {}", modelId, updateResponse.getResult()); + log + .warn( + "Update model for model {} got a result status other than update, result status: {}", + modelId, + updateResponse.getResult() + ); wrappedListener.onResponse(updateResponse); } else { log.error("Failed to update ML model: " + modelId); @@ -531,7 +520,12 @@ private ActionListener getUpdateResponseListener(String modelId, log.info("Successfully update ML model with model ID {}", modelId); wrappedListener.onResponse(updateResponse); } else if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { - log.warn("Model id:{} failed update with result {}", modelId, updateResponse.getResult()); + log + .warn( + "Update model for model {} got a result status other than update, result status: {}", + modelId, + updateResponse.getResult() + ); wrappedListener.onResponse(updateResponse); } else { log.error("Failed to update ML model: " + modelId); @@ -589,30 +583,17 @@ private String[] getAllNodes() { return nodeIds.toArray(new String[0]); } - private boolean isUpdateModelCacheSuccessOnAllNodes(String modelId, MLUpdateModelCacheNodesResponse updateModelCacheNodesResponse) { - if (updateModelCacheNodesResponse == null) { - return false; - } else { - for (MLUpdateModelCacheNodeResponse mlUpdateModelCacheNodeResponse : updateModelCacheNodesResponse.getNodes()) { - if (mlUpdateModelCacheNodeResponse.isModelUpdateStatusEmpty() - || !Objects.equals(mlUpdateModelCacheNodeResponse.getModelUpdateStatus().get(modelId), "success")) { - return false; - } - } - return true; - } + private boolean isUpdateModelCacheSuccessOnAllNodes(MLUpdateModelCacheNodesResponse updateModelCacheNodesResponse) { + return updateModelCacheNodesResponse.failures() == null || updateModelCacheNodesResponse.failures().isEmpty(); } - private String[] getUpdateModelCacheFailedNodesList(String modelId, MLUpdateModelCacheNodesResponse updateModelCacheNodesResponse) { + private String[] getUpdateModelCacheFailedNodesList(MLUpdateModelCacheNodesResponse updateModelCacheNodesResponse) { if (updateModelCacheNodesResponse == null) { return getAllNodes(); } else { List nodeIds = new ArrayList<>(); - for (MLUpdateModelCacheNodeResponse mlUpdateModelCacheNodeResponse : updateModelCacheNodesResponse.getNodes()) { - if (mlUpdateModelCacheNodeResponse.isModelUpdateStatusEmpty() - || !Objects.equals(mlUpdateModelCacheNodeResponse.getModelUpdateStatus().get(modelId), "success")) { - nodeIds.add(mlUpdateModelCacheNodeResponse.getNode().getId()); - } + for (FailedNodeException failedNodeException : updateModelCacheNodesResponse.failures()) { + nodeIds.add(failedNodeException.nodeId()); } return nodeIds.toArray(new String[0]); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 63be5e2423..3c5321697c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -5,6 +5,7 @@ package org.opensearch.ml.action.prediction; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -14,6 +15,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; @@ -55,8 +57,8 @@ public class TransportPredictionTaskAction extends HandledTransportAction { log.error("Failed to Validate Access for ModelId " + modelId, e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/update_cache/UpdateModelCacheTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/update_cache/UpdateModelCacheTransportAction.java index fb0b6e5717..64bffeb5e2 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/update_cache/UpdateModelCacheTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/update_cache/UpdateModelCacheTransportAction.java @@ -106,20 +106,15 @@ private MLUpdateModelCacheNodeResponse createUpdateModelCacheNodeResponse( MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest ) { String modelId = mlUpdateModelCacheNodesRequest.getModelId(); - boolean isPredictorUpdate = mlUpdateModelCacheNodesRequest.isPredictorUpdate(); Map modelUpdateStatus = new HashMap<>(); modelUpdateStatus.put(modelId, "received"); String localNodeId = clusterService.localNode().getId(); - mlModelManager.updateModelCache(modelId, isPredictorUpdate, ActionListener.wrap(r -> { - modelUpdateStatus.replace(modelId, "success"); + mlModelManager.updateModelCache(modelId, ActionListener.wrap(r -> { log.info("Successfully performing in-place update model {} on node {}", modelId, localNodeId); - }, e -> { - modelUpdateStatus.replace(modelId, "failed"); - log.error("Failed to perform in-place update model for model {} on node {}", modelId, localNodeId); - })); + }, e -> { log.error("Failed to perform in-place update model for model {} on node {}", modelId, localNodeId); })); return new MLUpdateModelCacheNodeResponse(clusterService.localNode(), modelUpdateStatus); } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java index 5fd7d71ce0..a54247e359 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java @@ -7,12 +7,14 @@ import java.util.DoubleSummaryStatistics; import java.util.List; +import java.util.Map; import java.util.Queue; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.stream.DoubleStream; +import org.opensearch.common.util.TokenBucket; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelState; @@ -33,6 +35,9 @@ public class MLModelCache { private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) FunctionName functionName; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Predictable predictor; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) MLExecutable executor; + private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) TokenBucket modelRateLimiter; + private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Map userRateLimiterMap; + private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Boolean isModelEnabled; private final Set targetWorkerNodes; private final Set workerNodes; private MLModel modelInfo; @@ -157,6 +162,9 @@ public void clear() { if (executor != null) { executor.close(); } + isModelEnabled = null; + modelRateLimiter = null; + userRateLimiterMap = null; } public void addModelInferenceDuration(double duration, long maxRequestCount) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 553ffeb664..570db1bc42 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -15,8 +16,11 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import org.opensearch.OpenSearchStatusException; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.TokenBucket; +import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.exception.MLLimitExceededException; @@ -75,6 +79,165 @@ public synchronized void setModelState(String modelId, MLModelState state) { getExistingModelCache(modelId).setModelState(state); } + /** + * Set a rate limiter to enable model level throttling + * @param modelId model id + * @param rateLimiter rate limiter + */ + public synchronized void setModelRateLimiter(String modelId, TokenBucket rateLimiter) { + log.debug("Setting the rate limiter for Model {}", modelId); + getExistingModelCache(modelId).setModelRateLimiter(rateLimiter); + } + + /** + * Get the current rate limiter for the model. + * + * @param modelId model id + */ + public TokenBucket getModelRateLimiter(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache == null) { + return null; + } + return modelCache.getModelRateLimiter(); + } + + /** + * Remove the rate limiter from cache to disable model level throttling + * @param modelId model id + */ + public synchronized void removeModelRateLimiter(String modelId) { + log.debug("Removing the rate limiter for Model {}", modelId); + getExistingModelCache(modelId).setModelRateLimiter(null); + } + + /** + * Set the user rate limiter map for a single user to enable user level throttling. + * + * @param modelId model id + * @param user user + * @param rateLimiter rate limiter + */ + public synchronized void setUserRateLimiterMap(String modelId, String user, TokenBucket rateLimiter) { + log.debug("Setting the user level rate limiter for Model {}", modelId); + Map userRateLimiterMap = new HashMap<>() { + { + put(user, rateLimiter); + } + }; + getExistingModelCache(modelId).setUserRateLimiterMap(userRateLimiterMap); + } + + /** + * Set the user rate limiter map to enable user level throttling. + * + * @param modelId model id + * @param userRateLimiterMap a map with user's name and its corresponding rate limiter + */ + public synchronized void setUserRateLimiterMap(String modelId, Map userRateLimiterMap) { + log.debug("Setting the user level rate limiter for Model {}", modelId); + getExistingModelCache(modelId).setUserRateLimiterMap(userRateLimiterMap); + } + + /** + * Update the user rate limiter map with the user rate limiter map. + * If the user rate limiter map doesn't exist for the model, consider calling setUserRateLimiterMap instead. + * + * @param modelId model id + * @param updateUserRateLimiterMap a map with user's name and its corresponding rate limiter + */ + public synchronized void updateUserRateLimiterMap(String modelId, Map updateUserRateLimiterMap) { + log.debug("Updating the user level rate limiter for Model {}", modelId); + Map userRateLimiterMap = getExistingModelCache(modelId).getUserRateLimiterMap(); + if (userRateLimiterMap != null) { + userRateLimiterMap.putAll(updateUserRateLimiterMap); + } else { + throw new OpenSearchStatusException( + "Model controller doesn't exist for the model. Consider calling create model controller api instead. Model ID: " + modelId, + RestStatus.CONFLICT + ); + } + } + + /** + * Update the user rate limiter map for a single user. + * If the user rate limiter map doesn't exist for the model, consider calling setUserRateLimiterMap instead. + * + * @param modelId model id + * @param user user + * @param rateLimiter rate limiter + */ + public synchronized void updateUserRateLimiterMap(String modelId, String user, TokenBucket rateLimiter) { + log.debug("Updating the user level rate limiter for Model {}", modelId); + Map userRateLimiterMap = getExistingModelCache(modelId).getUserRateLimiterMap(); + if (userRateLimiterMap != null) { + userRateLimiterMap.put(user, rateLimiter); + } else { + throw new OpenSearchStatusException( + "Model controller doesn't exist for the model. Consider calling create model controller api instead. Model ID: " + modelId, + RestStatus.CONFLICT + ); + } + } + + /** + * Remove the user rate limiter map from cache to disable user level throttling. + * + * @param modelId model id + */ + public synchronized void removeUserRateLimiterMap(String modelId) { + log.debug("Removing the user level rate limiter for Model {}", modelId); + getExistingModelCache(modelId).setUserRateLimiterMap(null); + } + + /** + * Get the current user and its corresponding rate limiter map for the model + * + * @param modelId model id + */ + public Map getUserRateLimiterMap(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache == null) { + return null; + } + return modelCache.getUserRateLimiterMap(); + } + + /** + * Get the rate limiter for a specific user for the model + * + * @param modelId model id + */ + public TokenBucket getUserRateLimiter(String modelId, String user) { + Map userRateLimiterMap = getUserRateLimiterMap(modelId); + if (userRateLimiterMap == null) { + return null; + } + return userRateLimiterMap.get(user); + } + + /** + * Set a quota flag to control if the model can still receive request + * @param modelId model id + * @param isModelEnabled quota flag + */ + public synchronized void setIsModelEnabled(String modelId, Boolean isModelEnabled) { + log.debug("Setting the quota flag for Model {}", modelId); + getExistingModelCache(modelId).setIsModelEnabled(isModelEnabled); + } + + /** + * Get the current quota flag condition for the model + * @param modelId model id + */ + public Boolean getIsModelEnabled(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache == null) { + return null; + } + return modelCache.getIsModelEnabled(); + } + /** * Set memory size estimation CPU/GPU * @param modelId model id diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 04fa8cb75d..3214bf4417 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -8,6 +8,7 @@ import static org.opensearch.common.xcontent.XContentType.JSON; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.NOT_FOUND; @@ -25,7 +26,9 @@ import static org.opensearch.ml.engine.ModelHelper.MODEL_SIZE_IN_BYTES; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLIENT; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLUSTER_SERVICE; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.MODEL_RATE_LIMITER; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.SCRIPT_SERVICE; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.USER_RATE_LIMITER_MAP; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.XCONTENT_REGISTRY; import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.ML_ENGINE; import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.MODEL_HELPER; @@ -49,6 +52,7 @@ import java.time.Instant; import java.util.Arrays; import java.util.Base64; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -61,6 +65,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; +import org.apache.commons.lang3.BooleanUtils; import org.apache.logging.log4j.util.Strings; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.delete.DeleteRequest; @@ -75,6 +80,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.TokenBucket; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; @@ -94,6 +100,8 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; @@ -124,7 +132,6 @@ import org.opensearch.threadpool.ThreadPool; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.io.Files; @@ -279,6 +286,7 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput .version(version) .modelGroupId(mlRegisterModelMetaInput.getModelGroupId()) .description(mlRegisterModelMetaInput.getDescription()) + .modelRateLimiterConfig(mlRegisterModelMetaInput.getModelRateLimiterConfig()) .modelFormat(mlRegisterModelMetaInput.getModelFormat()) .modelState(MLModelState.REGISTERING) .modelConfig(mlRegisterModelMetaInput.getModelConfig()) @@ -508,6 +516,7 @@ private void indexRemoteModel( .modelGroupId(registerModelInput.getModelGroupId()) .version(version) .description(registerModelInput.getDescription()) + .modelRateLimiterConfig(registerModelInput.getModelRateLimiterConfig()) .modelFormat(registerModelInput.getModelFormat()) .modelState(MLModelState.REGISTERED) .connector(registerModelInput.getConnector()) @@ -530,7 +539,7 @@ private void indexRemoteModel( String modelId = modelMetaRes.getId(); mlTask.setModelId(modelId); log.info("create new model meta doc {} for upload task {}", modelId, taskId); - mlTaskManager.updateMLTask(taskId, ImmutableMap.of(MODEL_ID_FIELD, modelId, STATE_FIELD, COMPLETED), 5000, true); + mlTaskManager.updateMLTask(taskId, Map.of(MODEL_ID_FIELD, modelId, STATE_FIELD, COMPLETED), 5000, true); if (registerModelInput.isDeployModel()) { deployModelAfterRegistering(registerModelInput, modelId); } @@ -571,6 +580,7 @@ void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, St .modelGroupId(registerModelInput.getModelGroupId()) .version(version) .description(registerModelInput.getDescription()) + .modelRateLimiterConfig(registerModelInput.getModelRateLimiterConfig()) .modelFormat(registerModelInput.getModelFormat()) .modelState(MLModelState.REGISTERED) .connector(registerModelInput.getConnector()) @@ -591,7 +601,7 @@ void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, St String modelId = modelMetaRes.getId(); mlTask.setModelId(modelId); log.info("create new model meta doc {} for upload task {}", modelId, taskId); - mlTaskManager.updateMLTask(taskId, ImmutableMap.of(MODEL_ID_FIELD, modelId, STATE_FIELD, COMPLETED), 5000, true); + mlTaskManager.updateMLTask(taskId, Map.of(MODEL_ID_FIELD, modelId, STATE_FIELD, COMPLETED), 5000, true); if (registerModelInput.isDeployModel()) { deployModelAfterRegistering(registerModelInput, modelId); } @@ -636,6 +646,7 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas .algorithm(functionName) .version(version) .description(registerModelInput.getDescription()) + .modelRateLimiterConfig(registerModelInput.getModelRateLimiterConfig()) .modelFormat(registerModelInput.getModelFormat()) .modelState(MLModelState.REGISTERING) .modelConfig(registerModelInput.getModelConfig()) @@ -718,6 +729,7 @@ private void registerModel( .algorithm(functionName) .version(version) .modelFormat(registerModelInput.getModelFormat()) + .modelRateLimiterConfig(registerModelInput.getModelRateLimiterConfig()) .chunkNumber(chunkNum) .totalChunks(chunkFiles.size()) .content(Base64.getEncoder().encodeToString(bytes)) @@ -779,12 +791,7 @@ private void registerPrebuiltModel(MLRegisterModelInput registerModelInput, MLTa modelHelper.downloadPrebuiltModelConfig(taskId, registerModelInput, ActionListener.wrap(mlRegisterModelInput -> { mlTask.setFunctionName(mlRegisterModelInput.getFunctionName()); mlTaskManager - .updateMLTask( - taskId, - ImmutableMap.of(FUNCTION_NAME_FIELD, mlRegisterModelInput.getFunctionName()), - TIMEOUT_IN_MILLIS, - false - ); + .updateMLTask(taskId, Map.of(FUNCTION_NAME_FIELD, mlRegisterModelInput.getFunctionName()), TIMEOUT_IN_MILLIS, false); registerModelFromUrl(mlRegisterModelInput, mlTask, modelVersion); }, e -> { log.error("Failed to register prebuilt model", e); @@ -817,7 +824,7 @@ private void updateModelRegisterStateAsDone( ) { FunctionName functionName = registerModelInput.getFunctionName(); deleteFileQuietly(mlEngine.getRegisterModelPath(modelId)); - Map updatedFields = ImmutableMap + Map updatedFields = Map .of( MLModel.MODEL_STATE_FIELD, MLModelState.REGISTERED, @@ -832,7 +839,7 @@ private void updateModelRegisterStateAsDone( ); log.info("Model registered successfully, model id: {}, task id: {}", modelId, taskId); updateModel(modelId, updatedFields, ActionListener.wrap(updateResponse -> { - mlTaskManager.updateMLTask(taskId, ImmutableMap.of(STATE_FIELD, COMPLETED, MODEL_ID_FIELD, modelId), TIMEOUT_IN_MILLIS, true); + mlTaskManager.updateMLTask(taskId, Map.of(STATE_FIELD, COMPLETED, MODEL_ID_FIELD, modelId), TIMEOUT_IN_MILLIS, true); if (registerModelInput.isDeployModel()) { deployModelAfterRegistering(registerModelInput, modelId); } @@ -903,51 +910,10 @@ private void handleException(FunctionName functionName, String taskId, Exception mlStats.createCounterStatIfAbsent(functionName, REGISTER, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment(); } - Map updated = ImmutableMap.of(ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED); + Map updated = Map.of(ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED); mlTaskManager.updateMLTask(taskId, updated, TIMEOUT_IN_MILLIS, true); } - public synchronized void updateModelCache(String modelId, boolean isPredictorUpdate, ActionListener listener) { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); - getModel(modelId, ActionListener.wrap(mlModel -> { - if (isPredictorUpdate) { - assert FunctionName.REMOTE == mlModel.getAlgorithm() - : "In-place update is only supported on REMOTE models at this time."; - Map params = ImmutableMap - .of( - ML_ENGINE, - mlEngine, - SCRIPT_SERVICE, - scriptService, - CLIENT, - client, - XCONTENT_REGISTRY, - xContentRegistry, - CLUSTER_SERVICE, - clusterService - ); - if (mlModel.getConnector() != null) { - Predictable predictable = mlEngine.deploy(mlModel, params); - modelCacheHelper.setPredictor(modelId, predictable); - wrappedListener.onResponse("successfully performed in-place update for the model " + modelId); - log.info("Completed in-place update internal connector for the model {}", modelId); - } else { - getConnector(client, mlModel.getConnectorId(), ActionListener.wrap(connector -> { - mlModel.setConnector(connector); - Predictable predictable = mlEngine.deploy(mlModel, params); - modelCacheHelper.setPredictor(modelId, predictable); - wrappedListener.onResponse("successfully performed in-place update for the model " + modelId); - log.info("Completed in-place update stand-alone connector for the model {}", modelId); - }, wrappedListener::onFailure)); - } - wrappedListener.onResponse("successfully performed in-place update for the model " + modelId); - log.info("Completed in-place update for the model {}", modelId); - } - }, wrappedListener::onFailure)); - } - } - /** * Read model chunks from model index. Concat chunks into a whole model file, then load * into memory. @@ -984,42 +950,42 @@ public void deployModel( listener.onFailure(new IllegalArgumentException("Exceed max local model per node limit")); return; } + int eligibleNodeCount = workerNodes.size(); modelCacheHelper.initModelState(modelId, MLModelState.DEPLOYING, functionName, workerNodes, deployToAllNodes); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); checkAndAddRunningTask(mlTask, maxDeployTasksPerNode); this.getModel(modelId, threadedActionListener(DEPLOY_THREAD_POOL, ActionListener.wrap(mlModel -> { + modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled()); if (FunctionName.REMOTE == mlModel.getAlgorithm() || (!FunctionName.isDLModel(mlModel.getAlgorithm()) && mlModel.getAlgorithm() != FunctionName.METRICS_CORRELATION)) { // deploy remote model or model trained by built-in algorithm like kmeans - Map params = ImmutableMap - .of( - ML_ENGINE, - mlEngine, - SCRIPT_SERVICE, - scriptService, - CLIENT, - client, - XCONTENT_REGISTRY, - xContentRegistry, - CLUSTER_SERVICE, - clusterService - ); // deploy remote model with internal connector or model trained by built-in algorithm like kmeans - if (mlModel.getConnector() != null || FunctionName.REMOTE != mlModel.getAlgorithm()) { - setupPredictable(modelId, mlModel, params); - wrappedListener.onResponse("successful"); + if (BooleanUtils.isTrue(mlModel.getIsModelControllerEnabled())) { + getModelController(modelId, ActionListener.wrap(modelController -> { + setupUserRateLimiterMap(modelId, eligibleNodeCount, modelController.getUserRateLimiterConfig()); + log.info("Successfully redeployed model controller for model " + modelId); + log.info("Trying to deploy remote model with model controller configured."); + deployRemoteOrBuiltInModel(mlModel, eligibleNodeCount, wrappedListener); + }, e -> { + log + .error( + "Trying to deploy remote model with exceptions in re-deploying its model controller. Model ID: " + + modelId, + e + ); + deployRemoteOrBuiltInModel(mlModel, eligibleNodeCount, wrappedListener); + })); return; + } else { + log.info("Trying to deploy remote or built-in model without model controller configured."); + deployRemoteOrBuiltInModel(mlModel, eligibleNodeCount, wrappedListener); } - log.info("Set connector {} for the model: {}", mlModel.getConnectorId(), modelId); - getConnector(client, mlModel.getConnectorId(), ActionListener.wrap(connector -> { - mlModel.setConnector(connector); - setupPredictable(modelId, mlModel, params); - wrappedListener.onResponse("successful"); - log.info("Completed setting connector {} in the model {}", mlModel.getConnectorId(), modelId); - }, wrappedListener::onFailure)); return; } + + setupModelRateLimiter(modelId, eligibleNodeCount, mlModel.getModelRateLimiterConfig()); + deployModelControllerWithDeployingModel(mlModel, eligibleNodeCount); // check circuit breaker before deploying custom model chunks checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats); retrieveModelChunks(mlModel, ActionListener.wrap(modelZipFile -> {// read model chunks @@ -1031,8 +997,7 @@ public void deployModel( return; } log.debug("Model content matches original hash value, continue deploying"); - Map params = ImmutableMap - .of(MODEL_ZIP_FILE, modelZipFile, MODEL_HELPER, modelHelper, ML_ENGINE, mlEngine); + Map params = Map.of(MODEL_ZIP_FILE, modelZipFile, MODEL_HELPER, modelHelper, ML_ENGINE, mlEngine); if (FunctionName.METRICS_CORRELATION.equals(mlModel.getAlgorithm())) { MLExecutable mlExecutable = mlEngine.deployExecute(mlModel, params); try { @@ -1045,7 +1010,6 @@ public void deployModel( mlExecutable.close(); wrappedListener.onFailure(e); } - } else { Predictable predictable = mlEngine.deploy(mlModel, params); try { @@ -1079,6 +1043,63 @@ public void deployModel( } } + private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCount, ActionListener wrappedListener) { + String modelId = mlModel.getModelId(); + setupModelRateLimiter(modelId, eligibleNodeCount, mlModel.getModelRateLimiterConfig()); + if (mlModel.getConnector() != null || FunctionName.REMOTE != mlModel.getAlgorithm()) { + setupParamsAndPredictable(modelId, mlModel); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); + modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); + wrappedListener.onResponse("successful"); + return; + } + log.info("Set connector {} for the model: {}", mlModel.getConnectorId(), modelId); + getConnector(mlModel.getConnectorId(), ActionListener.wrap(connector -> { + mlModel.setConnector(connector); + setupParamsAndPredictable(modelId, mlModel); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); + modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); + wrappedListener.onResponse("successful"); + log.info("Completed setting connector {} in the model {}", mlModel.getConnectorId(), modelId); + }, wrappedListener::onFailure)); + } + + private void setupParamsAndPredictable(String modelId, MLModel mlModel) { + Map params = setUpParameterMap(modelId); + Predictable predictable = mlEngine.deploy(mlModel, params); + modelCacheHelper.setPredictor(modelId, predictable); + } + + private Map setUpParameterMap(String modelId) { + TokenBucket modelRateLimiter = getModelRateLimiter(modelId); + Map userRateLimiterMap = getUserRateLimiterMap(modelId); + + Map params = new HashMap<>(); + params.put(ML_ENGINE, mlEngine); + params.put(SCRIPT_SERVICE, scriptService); + params.put(CLIENT, client); + params.put(XCONTENT_REGISTRY, xContentRegistry); + params.put(CLUSTER_SERVICE, clusterService); + + if (modelRateLimiter == null && userRateLimiterMap == null) { + log.info("Setting up basic ML predictor parameters."); + return Collections.unmodifiableMap(params); + } else if (modelRateLimiter != null && userRateLimiterMap == null) { + params.put(MODEL_RATE_LIMITER, modelRateLimiter); + log.info("Setting up basic ML predictor parameters with model level throttling."); + return Collections.unmodifiableMap(params); + } else if (modelRateLimiter == null) { + params.put(USER_RATE_LIMITER_MAP, userRateLimiterMap); + log.info("Setting up basic ML predictor parameters with user level throttling."); + return Collections.unmodifiableMap(params); + } else { + params.put(MODEL_RATE_LIMITER, modelRateLimiter); + params.put(USER_RATE_LIMITER_MAP, userRateLimiterMap); + log.info("Setting up basic ML predictor parameters with both model and user level throttling."); + return Collections.unmodifiableMap(params); + } + } + private void handleDeployModelException(String modelId, FunctionName functionName, ActionListener listener, Exception e) { if (!(e instanceof MLLimitExceededException) @@ -1091,16 +1112,276 @@ private void handleDeployModelException(String modelId, FunctionName functionNam listener.onFailure(e); } - private void setupPredictable(String modelId, MLModel mlModel, Map params) { - Predictable predictable = mlEngine.deploy(mlModel, params); - modelCacheHelper.setPredictor(modelId, predictable); - mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); - modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); + public synchronized void updateModelCache(String modelId, ActionListener listener) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + getModel(modelId, ActionListener.wrap(mlModel -> { + int eligibleNodeCount = getWorkerNodes(modelId, mlModel.getAlgorithm()).length; + modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled()); + setupModelRateLimiter(modelId, eligibleNodeCount, mlModel.getModelRateLimiterConfig()); + if (mlModel.getAlgorithm() == FunctionName.REMOTE) { + if (mlModel.getConnector() != null) { + setupParamsAndPredictable(modelId, mlModel); + wrappedListener.onResponse("Successfully updated model cache for the remote model " + modelId); + log.info("Completed the model cache update for the remote model {}", modelId); + } else { + getConnector(mlModel.getConnectorId(), ActionListener.wrap(connector -> { + mlModel.setConnector(connector); + setupParamsAndPredictable(modelId, mlModel); + wrappedListener.onResponse("Successfully updated model cache for the remote model " + modelId); + log.info("Completed the model cache update for the remote model {}", modelId); + }, wrappedListener::onFailure)); + } + } + wrappedListener.onResponse("Successfully updated model cache for the model " + modelId); + log.info("Completed the model cache update for the model {}", modelId); + }, wrappedListener::onFailure)); + } catch (Exception e) { + log.error("Failed to updated model cache for the model " + modelId, e); + listener.onFailure(e); + } } /** - * Get model from model index. + * Deploy the model controller with a model id. This method should be called AFTER a model is deployed. + * If you want to implement similar behavior during model deploy, deployModelControllerWithDeployingModel is the one supposed be called. + * + * @param modelId ml model ID + * @param listener action listener + */ + public synchronized void deployModelControllerWithDeployedModel(String modelId, ActionListener listener) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + if (!modelCacheHelper.isModelDeployed(modelId)) { + throw new OpenSearchStatusException( + "The model of this model controller has not deployed yet, please deploy the model first.", + RestStatus.CONFLICT + ); + } + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + getModel(modelId, ActionListener.wrap(mlModel -> { + getModelController(modelId, ActionListener.wrap(modelController -> { + int eligibleNodeCount = getWorkerNodes(modelId, mlModel.getAlgorithm()).length; + setupUserRateLimiterMap(modelId, eligibleNodeCount, modelController.getUserRateLimiterConfig()); + if (mlModel.getAlgorithm() == FunctionName.REMOTE) { + if (mlModel.getConnector() != null) { + setupParamsAndPredictable(modelId, mlModel); + wrappedListener.onResponse("Successfully deployed model controller for the remote model " + modelId); + log.info("Deployed model controller for the remote model {}", modelId); + } else { + getConnector(mlModel.getConnectorId(), ActionListener.wrap(connector -> { + mlModel.setConnector(connector); + setupParamsAndPredictable(modelId, mlModel); + wrappedListener.onResponse("Successfully deployed model controller for the remote model " + modelId); + log.info("Deployed model controller for the remote model {}", modelId); + }, wrappedListener::onFailure)); + } + return; + } + wrappedListener.onResponse("Successfully deployed model controller for the model " + modelId); + log.info("Deployed model controller for the model {}", modelId); + }, wrappedListener::onFailure)); + }, wrappedListener::onFailure)); + } catch (Exception e) { + log.error("Failed to deploy model controller for the model " + modelId, e); + listener.onFailure(e); + } + } + + /** + * Undploy the model controller for a model. + * Usually this method is called during deleting the model controller. + * + * @param modelId ml model ID + * @param listener action listener + */ + public synchronized void undeployModelController(String modelId, ActionListener listener) { + if (modelCacheHelper.isModelDeployed(modelId)) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + getModel(modelId, ActionListener.wrap(mlModel -> { + removeUserRateLimiterMap(modelId); + if (mlModel.getAlgorithm() == FunctionName.REMOTE) { + if (mlModel.getConnector() != null) { + setupParamsAndPredictable(modelId, mlModel); + wrappedListener.onResponse("Successfully undeployed model controller for the remote model " + modelId); + log.info("Undeployed model controller for the remote model {}", modelId); + } else { + getConnector(mlModel.getConnectorId(), ActionListener.wrap(connector -> { + mlModel.setConnector(connector); + setupParamsAndPredictable(modelId, mlModel); + wrappedListener.onResponse("Successfully undeployed model controller for the remote model " + modelId); + log.info("Undeployed model controller for the remote model {}", modelId); + }, wrappedListener::onFailure)); + } + return; + } + wrappedListener.onResponse("Successfully undeployed model controller for the model " + modelId); + log.info("Undeployed model controller for the model {}", modelId); + }, wrappedListener::onFailure)); + } catch (Exception e) { + log.error("Failed to undeploy model controller for the model " + modelId, e); + listener.onFailure(e); + } + } else if (isModelRunningOnNode(modelId)) { + log + .error( + "Failed to undeploy model controller due to model is in ML cache but with a state other than deployed. Please check model: " + + modelId, + new RuntimeException() + ); + listener + .onFailure( + new RuntimeException( + "Failed to undeploy model controller due to model is in ML cache but with a state other than deployed. Please check model: " + + modelId + ) + ); + } else { + log.info("Successfully deployed model controller from cache due to model not exist in cache. Model ID: " + modelId); + listener.onResponse("Successfully deployed model controller from cache due to model not exist in cache. Model ID: " + modelId); + } + } + + /** + * Deploy the model controller for a model during model is deploying. + * + * @param mlModel ml model + * @param listener action listener + */ + private synchronized void deployModelControllerWithDeployingModel( + MLModel mlModel, + Integer eligibleNodeCount, + ActionListener listener + ) { + String modelId = mlModel.getModelId(); + FetchSourceContext fetchContext = new FetchSourceContext(true); + GetRequest getRequest = new GetRequest(ML_MODEL_CONTROLLER_INDEX).id(modelId).fetchSourceContext(fetchContext); + client.get(getRequest, ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLModelController modelController = MLModelController.parse(parser); + setupUserRateLimiterMap(modelId, eligibleNodeCount, modelController.getUserRateLimiterConfig()); + log.info("Successfully redeployed model controller for model " + modelId); + listener.onResponse("Successfully redeployed model controller for model " + modelId); + } catch (Exception e) { + log.error("Failed to parse ml task" + r.getId(), e); + listener.onFailure(e); + } + } else if (mlModel.getIsModelControllerEnabled() == null || !mlModel.getIsModelControllerEnabled()) { + // Not going to respond the failure here due to the model deploy can still work well + listener + .onResponse( + "The model " + + modelId + + " is expected not having a model controller. Please use the create model controller api to create one if this is unexpected." + ); + log + .warn( + "The model " + + modelId + + " is expected not having a model controller. Please use the create model controller api to create one if this is unexpected." + ); + } else { + listener.onFailure(new OpenSearchStatusException("Failed to find model controller", RestStatus.NOT_FOUND)); + } + }, listener::onFailure)); + } + + /** + * Deploy the model controller for a model during model is deploying with build-in listener. + * Usually this method is called when re-deploying a previous un-deployed model with the model controller. * + * @param mlModel ml model + */ + public void deployModelControllerWithDeployingModel(MLModel mlModel, Integer eligibleNodeCount) { + if (mlModel.getModelState() != MLModelState.DEPLOYING) { + throw new OpenSearchStatusException( + "This method should only be called when model is in DEPLOYING state, but the model is in state: " + mlModel.getModelState(), + RestStatus.CONFLICT + ); + } + deployModelControllerWithDeployingModel(mlModel, eligibleNodeCount, ActionListener.wrap(response -> { + if (response.startsWith("Successfully")) { + log.debug(response, mlModel.getModelId()); + } else if (response.startsWith("Failed")) { + log.error(response); + } else { + log.info(response); + } + }, e -> log.error("Failed to re-deploy the model controller for model: " + mlModel.getModelId(), e))); + } + + private void setupModelRateLimiter(String modelId, Integer eligibleNodeCount, MLRateLimiter modelRateLimiter) { + if (modelRateLimiter != null) { + modelCacheHelper.setModelRateLimiter(modelId, rateLimiterConstructor(eligibleNodeCount, modelRateLimiter)); + } else { + modelCacheHelper.removeModelRateLimiter(modelId); + } + } + + private void setupUserRateLimiterMap(String modelId, Integer eligibleNodeCount, Map userRateLimiterConfig) { + if (userRateLimiterConfig != null && !userRateLimiterConfig.isEmpty()) { + Map userRateLimiterMap = new HashMap<>(); + userRateLimiterConfig + .forEach((user, rateLimiter) -> userRateLimiterMap.put(user, rateLimiterConstructor(eligibleNodeCount, rateLimiter))); + modelCacheHelper.setUserRateLimiterMap(modelId, userRateLimiterMap); + } else { + modelCacheHelper.removeUserRateLimiterMap(modelId); + } + } + + private void removeUserRateLimiterMap(String modelId) { + modelCacheHelper.removeUserRateLimiterMap(modelId); + } + + /** + * Construct a TokenBucket object from its rate limiter config. + * + * @param eligibleNodeCount eligible node count + * @param modelRateLimiter model rate limiter config + * @return a TokenBucket object to enable throttling + */ + private TokenBucket rateLimiterConstructor(Integer eligibleNodeCount, MLRateLimiter modelRateLimiter) { + if (modelRateLimiter.isValid()) { + double rateLimitNumber = Double.parseDouble(modelRateLimiter.getRateLimitNumber()); + TimeUnit rateLimitUnit = modelRateLimiter.getRateLimitUnit(); + log + .info( + "Initializing the rate limiter with setting {} per {} (TPS limit {}), evenly distributed on {} nodes", + rateLimitNumber, + rateLimitUnit, + rateLimitNumber / rateLimitUnit.toSeconds(1), + eligibleNodeCount + ); + return new TokenBucket(System::nanoTime, rateLimitNumber / rateLimitUnit.toNanos(1) / eligibleNodeCount, rateLimitNumber); + } + return null; + } + + /** + * Get model-level rate limiter with model id. + * + * @param modelId model id + * @return a TokenBucket object to enable model-level throttling + */ + public TokenBucket getModelRateLimiter(String modelId) { + return modelCacheHelper.getModelRateLimiter(modelId); + } + + /** + * Get model-level rate limiter with model id. + * + * @param modelId model id + * @return a map with user's name and its corresponding rate limiter object to track user-level throttling + */ + public Map getUserRateLimiterMap(String modelId) { + return modelCacheHelper.getUserRateLimiterMap(modelId); + } + + /** + * Get model from model index. + * * @param modelId model id * @param listener action listener */ @@ -1110,7 +1391,7 @@ public void getModel(String modelId, ActionListener listener) { /** * Get model from model index with includes/excludes filter. - * + * * @param modelId model id * @param includes fields included * @param excludes fields excluded @@ -1139,7 +1420,38 @@ public void getModel(String modelId, String[] includes, String[] excludes, Actio }, listener::onFailure)); } - private void getConnector(Client client, String connectorId, ActionListener listener) { + /** + * Get model controller from model controller index. + * + * @param modelId model id + * @param listener action listener + */ + public void getModelController(String modelId, ActionListener listener) { + FetchSourceContext fetchContext = new FetchSourceContext(true); + GetRequest getRequest = new GetRequest(ML_MODEL_CONTROLLER_INDEX).id(modelId).fetchSourceContext(fetchContext); + client.get(getRequest, ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLModelController modelController = MLModelController.parse(parser); + listener.onResponse(modelController); + } catch (Exception e) { + log.error("Failed to parse ml task" + r.getId(), e); + listener.onFailure(e); + } + } else { + listener.onFailure(new OpenSearchStatusException("Failed to find model controller", RestStatus.NOT_FOUND)); + } + }, listener::onFailure)); + } + + /** + * Get connector from connector index. + * + * @param connectorId connector id + * @param listener action listener + */ + private void getConnector(String connectorId, ActionListener listener) { GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId); client.get(getRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { @@ -1163,6 +1475,12 @@ private void getConnector(Client client, String connectorId, ActionListener listener) throws InterruptedException { String modelId = mlModelMeta.getModelId(); String modelName = mlModelMeta.getName(); @@ -1206,7 +1524,7 @@ private void retrieveModelChunks(MLModel mlModelMeta, ActionListener liste /** * Update model with build-in listener. - * + * * @param modelId model id * @param updatedFields updated fields */ @@ -1222,7 +1540,7 @@ public void updateModel(String modelId, Map updatedFields) { /** * Update model. - * + * * @param modelId model id * @param updatedFields updated fields * @param listener action listener @@ -1250,7 +1568,8 @@ public void updateModel(String modelId, Map updatedFields, Actio } /** - * Get model chunk id + * Get model chunk id. + * * @param modelId model id * @param chunkNumber model chunk number * @return model chunk id @@ -1261,7 +1580,7 @@ public String getModelChunkId(String modelId, Integer chunkNumber) { /** * Add model worker node to cache. - * + * * @param modelId model id * @param nodeIds node ids */ diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 410fdbae26..c9c9a65231 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; @@ -44,6 +45,12 @@ import org.opensearch.ml.action.connector.SearchConnectorTransportAction; import org.opensearch.ml.action.connector.TransportCreateConnectorAction; import org.opensearch.ml.action.connector.UpdateConnectorTransportAction; +import org.opensearch.ml.action.controller.CreateModelControllerTransportAction; +import org.opensearch.ml.action.controller.DeleteModelControllerTransportAction; +import org.opensearch.ml.action.controller.DeployModelControllerTransportAction; +import org.opensearch.ml.action.controller.GetModelControllerTransportAction; +import org.opensearch.ml.action.controller.UndeployModelControllerTransportAction; +import org.opensearch.ml.action.controller.UpdateModelControllerTransportAction; import org.opensearch.ml.action.deploy.TransportDeployModelAction; import org.opensearch.ml.action.deploy.TransportDeployModelOnNodeAction; import org.opensearch.ml.action.execute.TransportExecuteTaskAction; @@ -106,6 +113,12 @@ import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; +import org.opensearch.ml.common.transport.controller.MLCreateModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLModelControllerDeleteAction; +import org.opensearch.ml.common.transport.controller.MLModelControllerGetAction; +import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLUpdateModelControllerAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelOnNodeAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; @@ -181,9 +194,11 @@ import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.rest.RestMLCreateConnectorAction; +import org.opensearch.ml.rest.RestMLCreateModelControllerAction; import org.opensearch.ml.rest.RestMLDeleteAgentAction; import org.opensearch.ml.rest.RestMLDeleteConnectorAction; import org.opensearch.ml.rest.RestMLDeleteModelAction; +import org.opensearch.ml.rest.RestMLDeleteModelControllerAction; import org.opensearch.ml.rest.RestMLDeleteModelGroupAction; import org.opensearch.ml.rest.RestMLDeleteTaskAction; import org.opensearch.ml.rest.RestMLDeployModelAction; @@ -191,6 +206,7 @@ import org.opensearch.ml.rest.RestMLGetAgentAction; import org.opensearch.ml.rest.RestMLGetConnectorAction; import org.opensearch.ml.rest.RestMLGetModelAction; +import org.opensearch.ml.rest.RestMLGetModelControllerAction; import org.opensearch.ml.rest.RestMLGetModelGroupAction; import org.opensearch.ml.rest.RestMLGetTaskAction; import org.opensearch.ml.rest.RestMLPredictionAction; @@ -209,6 +225,7 @@ import org.opensearch.ml.rest.RestMLUndeployModelAction; import org.opensearch.ml.rest.RestMLUpdateConnectorAction; import org.opensearch.ml.rest.RestMLUpdateModelAction; +import org.opensearch.ml.rest.RestMLUpdateModelControllerAction; import org.opensearch.ml.rest.RestMLUpdateModelGroupAction; import org.opensearch.ml.rest.RestMLUploadModelChunkAction; import org.opensearch.ml.rest.RestMemoryCreateConversationAction; @@ -363,6 +380,12 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(SearchConversationsAction.INSTANCE, SearchConversationsTransportAction.class), new ActionHandler<>(GetConversationAction.INSTANCE, GetConversationTransportAction.class), new ActionHandler<>(GetInteractionAction.INSTANCE, GetInteractionTransportAction.class), + new ActionHandler<>(MLCreateModelControllerAction.INSTANCE, CreateModelControllerTransportAction.class), + new ActionHandler<>(MLModelControllerGetAction.INSTANCE, GetModelControllerTransportAction.class), + new ActionHandler<>(MLDeployModelControllerAction.INSTANCE, DeployModelControllerTransportAction.class), + new ActionHandler<>(MLUpdateModelControllerAction.INSTANCE, UpdateModelControllerTransportAction.class), + new ActionHandler<>(MLModelControllerDeleteAction.INSTANCE, DeleteModelControllerTransportAction.class), + new ActionHandler<>(MLUndeployModelControllerAction.INSTANCE, UndeployModelControllerTransportAction.class), new ActionHandler<>(MLAgentGetAction.INSTANCE, GetAgentTransportAction.class), new ActionHandler<>(MLAgentDeleteAction.INSTANCE, DeleteAgentTransportAction.class), new ActionHandler<>(UpdateConversationAction.INSTANCE, UpdateConversationTransportAction.class), @@ -414,6 +437,11 @@ public Collection createComponents( .put(MLClusterLevelStat.ML_CONNECTOR_INDEX_STATUS, new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_CONNECTOR_INDEX))); stats.put(MLClusterLevelStat.ML_CONFIG_INDEX_STATUS, new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_CONFIG_INDEX))); stats.put(MLClusterLevelStat.ML_TASK_INDEX_STATUS, new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_TASK_INDEX))); + stats + .put( + MLClusterLevelStat.ML_MODEL_CONTROLLER_INDEX_STATUS, + new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_MODEL_CONTROLLER_INDEX)) + ); stats.put(MLClusterLevelStat.ML_MODEL_COUNT, new MLStat<>(true, new CounterSupplier())); stats.put(MLClusterLevelStat.ML_CONNECTOR_COUNT, new MLStat<>(true, new CounterSupplier())); // node level stats @@ -655,6 +683,10 @@ public List getRestHandlers( RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction(); RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction(); RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction(); + RestMLCreateModelControllerAction restMLCreateModelControllerAction = new RestMLCreateModelControllerAction(); + RestMLGetModelControllerAction restMLGetModelControllerAction = new RestMLGetModelControllerAction(); + RestMLUpdateModelControllerAction restMLUpdateModelControllerAction = new RestMLUpdateModelControllerAction(); + RestMLDeleteModelControllerAction restMLDeleteModelControllerAction = new RestMLDeleteModelControllerAction(); RestMLGetAgentAction restMLGetAgentAction = new RestMLGetAgentAction(); RestMLDeleteAgentAction restMLDeleteAgentAction = new RestMLDeleteAgentAction(); RestMemoryUpdateConversationAction restMemoryUpdateConversationAction = new RestMemoryUpdateConversationAction(); @@ -700,6 +732,10 @@ public List getRestHandlers( restSearchInteractionsAction, restGetConversationAction, restGetInteractionAction, + restMLCreateModelControllerAction, + restMLGetModelControllerAction, + restMLUpdateModelControllerAction, + restMLDeleteModelControllerAction, restMLGetAgentAction, restMLDeleteAgentAction, restMemoryUpdateConversationAction, diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateModelControllerAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateModelControllerAction.java new file mode 100644 index 0000000000..b6f180e49a --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateModelControllerAction.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.OpenSearchParseException; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.transport.controller.MLCreateModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLCreateModelControllerRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.collect.ImmutableList; + +public class RestMLCreateModelControllerAction extends BaseRestHandler { + + public final static String ML_CREATE_MODEL_CONTROLLER_ACTION = "ml_create_model_controller_action"; + + /** + * Constructor + */ + public RestMLCreateModelControllerAction() {} + + @Override + public String getName() { + return ML_CREATE_MODEL_CONTROLLER_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/model_controllers/{%s}", ML_BASE_URI, PARAMETER_MODEL_ID)) + ); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLCreateModelControllerRequest createModelControllerRequest = getRequest(request); + return channel -> { + client.execute(MLCreateModelControllerAction.INSTANCE, createModelControllerRequest, new RestToXContentListener<>(channel)); + }; + } + + /** + * Creates a MLCreateModelControllerRequest from a RestRequest + * + * @param request RestRequest + * @return MLCreateModelControllerRequest + */ + private MLCreateModelControllerRequest getRequest(RestRequest request) throws IOException { + if (!request.hasContent()) { + throw new OpenSearchParseException("Create model controller request has empty body"); + } + // Model ID can only be set here. + String modelId = getParameterId(request, PARAMETER_MODEL_ID); + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLModelController modelControllerInput = MLModelController.parse(parser); + modelControllerInput.setModelId(modelId); + return new MLCreateModelControllerRequest(modelControllerInput); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelControllerAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelControllerAction.java new file mode 100644 index 0000000000..51208a82f0 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelControllerAction.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.transport.controller.MLModelControllerDeleteAction; +import org.opensearch.ml.common.transport.controller.MLModelControllerDeleteRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to delete ML Model. + */ +public class RestMLDeleteModelControllerAction extends BaseRestHandler { + private static final String ML_DELETE_MODEL_CONTROLLER_ACTION = "ml_delete_model_controller_action"; + + public void RestMLDeleteModelControllerAction() {} + + @Override + public String getName() { + return ML_DELETE_MODEL_CONTROLLER_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.DELETE, + String.format(Locale.ROOT, "%s/model_controllers/{%s}", ML_BASE_URI, PARAMETER_MODEL_ID) + ) + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + String modelId = request.param(PARAMETER_MODEL_ID); + + MLModelControllerDeleteRequest mlModelControllerDeleteRequest = new MLModelControllerDeleteRequest(modelId); + return channel -> client + .execute(MLModelControllerDeleteAction.INSTANCE, mlModelControllerDeleteRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelControllerAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelControllerAction.java new file mode 100644 index 0000000000..d4946ec2e1 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelControllerAction.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; +import static org.opensearch.ml.utils.RestActionUtils.returnContent; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.transport.controller.MLModelControllerGetAction; +import org.opensearch.ml.common.transport.controller.MLModelControllerGetRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLGetModelControllerAction extends BaseRestHandler { + private static final String ML_GET_MODEL_CONTROLLER_ACTION = "ml_get_model_controller_action"; + + /** + * Constructor + */ + public RestMLGetModelControllerAction() {} + + @Override + public String getName() { + return ML_GET_MODEL_CONTROLLER_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/model_controllers/{%s}", ML_BASE_URI, PARAMETER_MODEL_ID)) + ); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLModelControllerGetRequest modelControllerGetRequest = getRequest(request); + return channel -> client + .execute(MLModelControllerGetAction.INSTANCE, modelControllerGetRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLModelControllerGetRequest from a RestRequest + * + * @param request RestRequest + * @return MLModelControllerGetRequest + */ + @VisibleForTesting + MLModelControllerGetRequest getRequest(RestRequest request) throws IOException { + String modelId = getParameterId(request, PARAMETER_MODEL_ID); + boolean returnContent = returnContent(request); + + return new MLModelControllerGetRequest(modelId, returnContent); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelGroupAction.java index 0e473fbb98..4c765732a2 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelGroupAction.java @@ -12,6 +12,7 @@ import java.util.List; import java.util.Locale; +import org.opensearch.OpenSearchParseException; import org.opensearch.client.node.NodeClient; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupAction; @@ -50,16 +51,16 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client } /** - * Creates a MLUploadModelMetaRequest from a RestRequest + * Creates a MLRegisterModelGroupRequest from a RestRequest * * @param request RestRequest - * @return MLUploadModelMetaRequest + * @return MLRegisterModelGroupRequest */ @VisibleForTesting MLRegisterModelGroupRequest getRequest(RestRequest request) throws IOException { boolean hasContent = request.hasContent(); if (!hasContent) { - throw new IOException("Model group request has empty body"); + throw new OpenSearchParseException("Model group request has empty body"); } XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java index fd5c828201..5a40ae8c47 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java @@ -57,7 +57,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client */ private MLUpdateModelRequest getRequest(RestRequest request) throws IOException { if (!request.hasContent()) { - throw new OpenSearchParseException("Model update request has empty body"); + throw new OpenSearchParseException("Update model request has empty body"); } String modelId = getParameterId(request, PARAMETER_MODEL_ID); @@ -66,7 +66,7 @@ private MLUpdateModelRequest getRequest(RestRequest request) throws IOException ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); try { MLUpdateModelInput input = MLUpdateModelInput.parse(parser); - if (input.getConnectorId() != null && input.getConnectorUpdateContent() != null) { + if (input.getConnectorId() != null && input.getConnector() != null) { throw new OpenSearchStatusException( "Model cannot have both stand-alone connector and internal connector. Please check your update input body.", RestStatus.BAD_REQUEST @@ -75,7 +75,7 @@ private MLUpdateModelRequest getRequest(RestRequest request) throws IOException // Model ID can only be set here. Model version as well as connector field can only be set automatically. input.setModelId(modelId); input.setVersion(null); - input.setConnector(null); + input.setUpdatedConnector(null); return new MLUpdateModelRequest(input); } catch (IllegalStateException e) { throw new OpenSearchParseException(e.getMessage()); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelControllerAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelControllerAction.java new file mode 100644 index 0000000000..b64d3e37e7 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelControllerAction.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.OpenSearchParseException; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.transport.controller.MLUpdateModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLUpdateModelControllerRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.collect.ImmutableList; + +public class RestMLUpdateModelControllerAction extends BaseRestHandler { + + public final static String ML_UPDATE_MODEL_CONTROLLER_ACTION = "ml_update_model_controller_action"; + + /** + * Constructor + */ + public RestMLUpdateModelControllerAction() {} + + @Override + public String getName() { + return ML_UPDATE_MODEL_CONTROLLER_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route(RestRequest.Method.PUT, String.format(Locale.ROOT, "%s/model_controllers/{%s}", ML_BASE_URI, PARAMETER_MODEL_ID)) + ); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLUpdateModelControllerRequest updateModelControllerRequest = getRequest(request); + return channel -> { + client.execute(MLUpdateModelControllerAction.INSTANCE, updateModelControllerRequest, new RestToXContentListener<>(channel)); + }; + } + + /** + * Creates a MLUpdateModelControllerRequest from a RestRequest + * + * @param request RestRequest to parse + * @return MLUpdateModelControllerRequest + * @throws IOException if an error occurs while parsing the request + */ + private MLUpdateModelControllerRequest getRequest(RestRequest request) throws IOException { + if (!request.hasContent()) { + throw new OpenSearchParseException("Update model controller request has empty body"); + } + // Model ID can only be set here. + String modelId = getParameterId(request, PARAMETER_MODEL_ID); + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLModelController modelControllerInput = MLModelController.parse(parser); + modelControllerInput.setModelId(modelId); + return new MLUpdateModelControllerRequest(modelControllerInput); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/stats/MLClusterLevelStat.java b/plugin/src/main/java/org/opensearch/ml/stats/MLClusterLevelStat.java index b918c3cd4c..b07b876825 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/MLClusterLevelStat.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/MLClusterLevelStat.java @@ -14,6 +14,7 @@ public enum MLClusterLevelStat { ML_CONNECTOR_INDEX_STATUS, ML_CONFIG_INDEX_STATUS, ML_TASK_INDEX_STATUS, + ML_MODEL_CONTROLLER_INDEX_STATUS, ML_MODEL_COUNT, ML_CONNECTOR_COUNT; diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateModelControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateModelControllerTransportActionTests.java new file mode 100644 index 0000000000..f19a7eda1f --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateModelControllerTransportActionTests.java @@ -0,0 +1,413 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.controller; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLRateLimiter; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.transport.controller.MLCreateModelControllerRequest; +import org.opensearch.ml.common.transport.controller.MLCreateModelControllerResponse; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesResponse; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelCacheHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class CreateModelControllerTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + IndexResponse indexResponse; + + @Mock + private MLModelManager mlModelManager; + + @Mock + MLModelCacheHelper mlModelCacheHelper; + + @Mock + ClusterService clusterService; + + @Mock + ClusterState clusterState; + + @Mock + private ModelAccessControlHelper modelAccessControlHelper; + + @Mock + private MLIndicesHandler mlIndicesHandler; + + @Mock + MLModel mlModel; + + @Mock + MLDeployModelControllerNodesResponse mlDeployModelControllerNodesResponse; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + CreateModelControllerTransportAction createModelControllerTransportAction; + MLCreateModelControllerRequest createModelControllerRequest; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + + InetAddress inetAddress1 = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 1 }); + InetAddress inetAddress2 = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 2 }); + + DiscoveryNode node1 = new DiscoveryNode( + "foo1", + "foo1", + new TransportAddress(inetAddress1, 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + + DiscoveryNode node2 = new DiscoveryNode( + "foo2", + "foo2", + new TransportAddress(inetAddress2, 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + + DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build(); + String[] targetNodeIds = new String[] { node1.getId(), node2.getId() }; + + createModelControllerTransportAction = spy( + new CreateModelControllerTransportAction( + transportService, + actionFilters, + mlIndicesHandler, + client, + clusterService, + modelAccessControlHelper, + mlModelCacheHelper, + mlModelManager + ) + ); + + MLRateLimiter rateLimiter = MLRateLimiter.builder().rateLimitNumber("1").rateLimitUnit(TimeUnit.MILLISECONDS).build(); + + MLModelController modelController = MLModelController.builder().modelId("testModelId").userRateLimiterConfig(new HashMap<>() { + { + put("testUser", rateLimiter); + } + }).build(); + + createModelControllerRequest = MLCreateModelControllerRequest.builder().modelControllerInput(modelController).build(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLModelControllerIndex(isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(client).index(isA(IndexRequest.class), isA(ActionListener.class)); + when(indexResponse.getId()).thenReturn("testModelId"); + when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(clusterService.getSettings()).thenReturn(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(false); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.nodes()).thenReturn(nodes); + when(mlModelManager.getWorkerNodes("testModelId", FunctionName.REMOTE)).thenReturn(targetNodeIds); + } + + @Test + public void testCreateModelControllerSuccess() { + createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + verify(actionListener).onResponse(any(MLCreateModelControllerResponse.class)); + } + + @Test + public void testCreateModelControllerWithTextEmbeddingModelSuccess() { + when(mlModel.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); + createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + verify(actionListener).onResponse(any(MLCreateModelControllerResponse.class)); + } + + @Test + public void testCreateModelControllerWithModelAccessControlNoPermission() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User doesn't have privilege to perform this operation on this model controller, model ID: testModelId", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testCreateModelControllerWithModelAccessControlOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testCreateModelControllerWithModelNotFound() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(null); + return null; + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + + createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Failed to find model to create the corresponding model controller with the provided model ID: testModelId", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testCreateModelControllerWithModelStateDeploying() { + when(mlModel.getModelState()).thenReturn(MLModelState.DEPLOYING); + + createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Creating a model controller during its corresponding model in DEPLOYING state is not allowed, please either create the model controller after it is deployed or before deploying it. Model ID: testModelId", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testCreateModelControllerWithModelFunctionUnsupported() { + when(mlModel.getAlgorithm()).thenReturn(FunctionName.METRICS_CORRELATION); + + createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Creating model controller on this operation on the function category METRICS_CORRELATION is not supported.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testCreateModelControllerWithIndexCreatedFailure() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(false); + return null; + }).when(mlIndicesHandler).initMLModelControllerIndex(isA(ActionListener.class)); + + createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to create model controller index.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testCreateModelControllerWithIndexCreatedOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); + return null; + }).when(mlIndicesHandler).initMLModelControllerIndex(isA(ActionListener.class)); + + createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testCreateModelControllerWithIndexResponseUpdated() { + when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.UPDATED); + + createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + verify(actionListener).onResponse(any(MLCreateModelControllerResponse.class)); + } + + @Test + public void testCreateModelControllerWithDeploySuccessNullFailures() { + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlDeployModelControllerNodesResponse); + return null; + }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); + when(mlDeployModelControllerNodesResponse.failures()).thenReturn(null); + + createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + verify(actionListener).onResponse(any(MLCreateModelControllerResponse.class)); + } + + @Test + public void testCreateModelControllerWithUndeploySuccessEmptyFailures() { + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlDeployModelControllerNodesResponse); + return null; + }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); + when(mlDeployModelControllerNodesResponse.failures()).thenReturn(new ArrayList<>()); + + createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + verify(actionListener).onResponse(any(MLCreateModelControllerResponse.class)); + } + + @Test + public void testCreateModelControllerWithUndeploySuccessPartiallyFailures() { + List failures = List + .of(new FailedNodeException("foo1", "Undeploy failed.", new RuntimeException("Exception occurred."))); + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlDeployModelControllerNodesResponse); + return null; + }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); + when(mlDeployModelControllerNodesResponse.failures()).thenReturn(failures); + + createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Successfully create model controller index with model ID testModelId " + + "but deploy model controller to cache was failed on following nodes [foo1], please retry.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testCreateModelControllerWithUndeployNullResponse() { + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); + + createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Successfully create model controller index with model ID testModelId " + + "but deploy model controller to cache was failed on following nodes [foo1, foo2], please retry.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testCreateModelControllerWithUndeployOtherException() { + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener + .onFailure( + new RuntimeException("Exception occurred. Please check log for more details.") + ); + return null; + }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); + + createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Exception occurred. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteModelControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteModelControllerTransportActionTests.java new file mode 100644 index 0000000000..5da46c6479 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteModelControllerTransportActionTests.java @@ -0,0 +1,374 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.controller; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.transport.controller.MLModelControllerDeleteRequest; +import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodesResponse; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelCacheHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class DeleteModelControllerTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + DeleteResponse deleteResponse; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + private MLModelManager mlModelManager; + + @Mock + MLModelCacheHelper mlModelCacheHelper; + + @Mock + ClusterService clusterService; + + @Mock + ClusterState clusterState; + + @Mock + private ModelAccessControlHelper modelAccessControlHelper; + + @Mock + MLModel mlModel; + + @Mock + MLModelController mlModelController; + + @Mock + MLUndeployModelControllerNodesResponse mlUndeployModelControllerNodesResponse; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + DeleteModelControllerTransportAction deleteModelControllerTransportAction; + MLModelControllerDeleteRequest mlModelControllerDeleteRequest; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + + InetAddress inetAddress1 = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 1 }); + InetAddress inetAddress2 = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 2 }); + + DiscoveryNode node1 = new DiscoveryNode( + "foo1", + "foo1", + new TransportAddress(inetAddress1, 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + + DiscoveryNode node2 = new DiscoveryNode( + "foo2", + "foo2", + new TransportAddress(inetAddress2, 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + + DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build(); + + deleteModelControllerTransportAction = spy( + new DeleteModelControllerTransportAction( + transportService, + actionFilters, + client, + xContentRegistry, + clusterService, + mlModelManager, + mlModelCacheHelper, + modelAccessControlHelper + ) + ); + + mlModelControllerDeleteRequest = MLModelControllerDeleteRequest.builder().modelId("testModelId").build(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mlModelController); + return null; + }).when(mlModelManager).getModelController(eq("testModelId"), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(clusterService.getSettings()).thenReturn(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(false); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.nodes()).thenReturn(nodes); + } + + @Test + public void testDeleteModelControllerSuccess() { + deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + verify(actionListener).onResponse(deleteResponse); + } + + @Test + public void testDeleteModelControllerWithModelAccessControlNoPermission() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User doesn't have privilege to perform this operation on this model controller, model ID: testModelId", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testDeleteModelControllerWithModelAccessControlOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testDeleteModelControllerWithGetModelNotFoundSuccess() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(null); + return null; + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + + deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + verify(actionListener).onResponse(deleteResponse); + } + + @Test + public void testDeleteModelControllerOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); + return null; + }).when(client).delete(any(), any()); + + deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testDeleteModelControllerWithGetModelControllerOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); + return null; + }).when(mlModelManager).getModelController(eq("testModelId"), isA(ActionListener.class)); + + deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testDeleteModelControllerWithGetModelNotFoundWithGetModelControllerOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(null); + return null; + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); + return null; + }).when(mlModelManager).getModelController(eq("testModelId"), isA(ActionListener.class)); + + deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testDeleteModelControllerWithUndeploySuccessNullFailures() { + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlUndeployModelControllerNodesResponse); + return null; + }).when(client).execute(eq(MLUndeployModelControllerAction.INSTANCE), any(), any()); + when(mlUndeployModelControllerNodesResponse.failures()).thenReturn(null); + + deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + verify(actionListener).onResponse(deleteResponse); + } + + @Test + public void testDeleteModelControllerWithUndeploySuccessEmptyFailures() { + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlUndeployModelControllerNodesResponse); + return null; + }).when(client).execute(eq(MLUndeployModelControllerAction.INSTANCE), any(), any()); + when(mlUndeployModelControllerNodesResponse.failures()).thenReturn(new ArrayList<>()); + + deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + verify(actionListener).onResponse(deleteResponse); + } + + @Test + public void testDeleteModelControllerWithUndeploySuccessPartiallyFailures() { + List failures = List + .of(new FailedNodeException("foo1", "Undeploy failed.", new RuntimeException("Exception occurred."))); + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlUndeployModelControllerNodesResponse); + return null; + }).when(client).execute(eq(MLUndeployModelControllerAction.INSTANCE), any(), any()); + when(mlUndeployModelControllerNodesResponse.failures()).thenReturn(failures); + + deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Failed to undeploy model controller with model ID testModelId on following nodes [foo1], deletion is aborted. Please retry or undeploy the model manually and then perform the deletion.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testDeleteModelControllerWithUndeployNullResponse() { + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(client).execute(eq(MLUndeployModelControllerAction.INSTANCE), any(), any()); + + deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Failed to undeploy model controller with model ID testModelId on following nodes [foo1, foo2], deletion is aborted. Please retry or undeploy the model manually and then perform the deletion.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testDeleteModelControllerWithUndeployOtherException() { + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener + .onFailure( + new RuntimeException("Exception occurred. Please check log for more details.") + ); + return null; + }).when(client).execute(eq(MLUndeployModelControllerAction.INSTANCE), any(), any()); + + deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Exception occurred. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/DeployModelControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/DeployModelControllerTransportActionTests.java new file mode 100644 index 0000000000..4f37e1e354 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/DeployModelControllerTransportActionTests.java @@ -0,0 +1,177 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.controller; + +import static org.junit.Assert.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodeRequest; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodeResponse; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesRequest; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesResponse; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.stats.MLStats; +import org.opensearch.transport.TransportService; + +@RunWith(MockitoJUnitRunner.class) +public class DeployModelControllerTransportActionTests { + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private MLModelManager mlModelManager; + + @Mock + private ClusterService clusterService; + + @Mock + private Client client; + + @Mock + private DiscoveryNodeHelper nodeFilter; + + @Mock + private MLStats mlStats; + + @Mock + NamedXContentRegistry xContentRegistry; + + private DeployModelControllerTransportAction action; + + private DiscoveryNode localNode; + + @Mock + private ModelAccessControlHelper modelAccessControlHelper; + + @Before + public void setUp() throws Exception { + action = new DeployModelControllerTransportAction( + transportService, + actionFilters, + mlModelManager, + clusterService, + null, + client, + nodeFilter, + mlStats, + xContentRegistry, + modelAccessControlHelper + ); + + localNode = new DiscoveryNode( + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + when(clusterService.localNode()).thenReturn(localNode); + when(clusterService.getClusterName()).thenReturn(new ClusterName("Local Cluster")); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse("successful"); + return null; + }).when(mlModelManager).deployModelControllerWithDeployedModel(any(), any()); + } + + @Test + public void testNewResponses() { + final MLDeployModelControllerNodesRequest nodesRequest = new MLDeployModelControllerNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + "testModelId" + ); + Map modelDeployModelControllerStatusMap = new HashMap<>(); + modelDeployModelControllerStatusMap.put("modelName:version", "response"); + MLDeployModelControllerNodeResponse response = new MLDeployModelControllerNodeResponse( + localNode, + modelDeployModelControllerStatusMap + ); + final List responses = List.of(response); + final List failures = new ArrayList<>(); + MLDeployModelControllerNodesResponse response1 = action.newResponse(nodesRequest, responses, failures); + assertNotNull(response1); + } + + @Test + public void testNewNodeRequest() { + final MLDeployModelControllerNodesRequest request = new MLDeployModelControllerNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + "testModelId" + ); + final MLDeployModelControllerNodeRequest deployModelControllerNodeRequest = action.newNodeRequest(request); + assertNotNull(deployModelControllerNodeRequest); + } + + @Test + public void testNewNodeStreamRequest() throws IOException { + Map deployModelControllerStatus = new HashMap<>(); + deployModelControllerStatus.put("modelId1", "response"); + MLDeployModelControllerNodeResponse response = new MLDeployModelControllerNodeResponse(localNode, deployModelControllerStatus); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + final MLDeployModelControllerNodeResponse deployModelControllerNodeResponse = action.newNodeResponse(output.bytes().streamInput()); + assertNotNull(deployModelControllerNodeResponse); + } + + @Test + public void testNodeOperation() { + final MLDeployModelControllerNodesRequest request = new MLDeployModelControllerNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + "testModelId" + ); + final MLDeployModelControllerNodeResponse response = action.nodeOperation(new MLDeployModelControllerNodeRequest(request)); + assertNotNull(response); + } + + @Test + public void testNodeOperationException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Test exception")); + return null; + }).when(mlModelManager).deployModelControllerWithDeployedModel(any(), any()); + final MLDeployModelControllerNodesRequest request = new MLDeployModelControllerNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + "testModelId" + ); + final MLDeployModelControllerNodeResponse response = action.nodeOperation(new MLDeployModelControllerNodeRequest(request)); + assertNotNull(response); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/GetModelControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/GetModelControllerTransportActionTests.java new file mode 100644 index 0000000000..9e8ff54f10 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/GetModelControllerTransportActionTests.java @@ -0,0 +1,270 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.controller; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.utils.MockHelper.mock_client_get_NotExist; + +import java.io.IOException; +import java.util.HashMap; +import java.util.concurrent.TimeUnit; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.get.GetResult; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLRateLimiter; +import org.opensearch.ml.common.transport.controller.MLModelControllerGetRequest; +import org.opensearch.ml.common.transport.controller.MLModelControllerGetResponse; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class GetModelControllerTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + ClusterService clusterService; + + @Mock + private ModelAccessControlHelper modelAccessControlHelper; + + @Mock + MLModelManager mlModelManager; + + @Mock + MLModel mlModel; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + GetModelControllerTransportAction getModelControllerTransportAction; + MLModelControllerGetRequest mlModelControllerGetRequest; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().build(); + getModelControllerTransportAction = spy( + new GetModelControllerTransportAction( + transportService, + actionFilters, + client, + xContentRegistry, + clusterService, + mlModelManager, + modelAccessControlHelper + ) + ); + mlModelControllerGetRequest = MLModelControllerGetRequest.builder().modelId("testModelId").build(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + GetResponse getResponse = prepareModelControllerGetResponse(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + } + + @Test + public void testGetModelControllerSuccess() { + getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + verify(actionListener).onResponse(any(MLModelControllerGetResponse.class)); + } + + @Test + public void testGetModelControllerWithModelAccessControlNoPermission() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User doesn't have privilege to perform this operation on this model controller, model ID: testModelId", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testGetModelControllerWithModelAccessControlOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testGetModelControllerWithGetModelNotFound() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(null); + return null; + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + + getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Failed to find model to get the corresponding model controller with the provided model ID: testModelId", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testGetModelControllerWithGetModelOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); + return null; + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + + getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Failed to find model to get the corresponding model controller with the provided model ID: testModelId", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testGetModelControllerOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); + return null; + }).when(client).get(any(), any()); + + getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testGetModelControllerNotFound() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(), any()); + + getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find model controller with the provided model ID: testModelId", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testGetModelControllerClientFailedToGetThreadPool() { + mock_client_get_NotExist(client); + getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find model controller with the provided model ID: testModelId", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testGetModelControllerIndexNotFoundException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new IndexNotFoundException("Failed to find model controller")); + return null; + }).when(client).get(any(), any()); + + getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find model controller", argumentCaptor.getValue().getMessage()); + } + + public GetResponse prepareModelControllerGetResponse() throws IOException { + + MLRateLimiter rateLimiter = MLRateLimiter.builder().rateLimitNumber("1").rateLimitUnit(TimeUnit.MILLISECONDS).build(); + + MLModelController modelController = MLModelController.builder().modelId("testModelId").userRateLimiterConfig(new HashMap<>() { + { + put("testUser", rateLimiter); + } + }).build(); + + XContentBuilder content = modelController.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + return new GetResponse(getResult); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/UndeployModelControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/UndeployModelControllerTransportActionTests.java new file mode 100644 index 0000000000..b18218cbd0 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/UndeployModelControllerTransportActionTests.java @@ -0,0 +1,181 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.controller; + +import static org.junit.Assert.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodeRequest; +import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodeResponse; +import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodesRequest; +import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodesResponse; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.stats.MLStats; +import org.opensearch.transport.TransportService; + +@RunWith(MockitoJUnitRunner.class) +public class UndeployModelControllerTransportActionTests { + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private MLModelManager mlModelManager; + + @Mock + private ClusterService clusterService; + + @Mock + private Client client; + + @Mock + private DiscoveryNodeHelper nodeFilter; + + @Mock + private MLStats mlStats; + + @Mock + NamedXContentRegistry xContentRegistry; + + private UndeployModelControllerTransportAction action; + + private DiscoveryNode localNode; + + @Mock + private ModelAccessControlHelper modelAccessControlHelper; + + @Before + public void setUp() throws Exception { + action = new UndeployModelControllerTransportAction( + transportService, + actionFilters, + mlModelManager, + clusterService, + null, + client, + nodeFilter, + mlStats, + xContentRegistry, + modelAccessControlHelper + ); + + localNode = new DiscoveryNode( + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + when(clusterService.localNode()).thenReturn(localNode); + when(clusterService.getClusterName()).thenReturn(new ClusterName("Local Cluster")); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse("successful"); + return null; + }).when(mlModelManager).undeployModelController(any(), any()); + } + + @Test + public void testNewResponses() { + final MLUndeployModelControllerNodesRequest nodesRequest = new MLUndeployModelControllerNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + "testModelId" + ); + Map modelUndeployModelControllerStatusMap = new HashMap<>(); + modelUndeployModelControllerStatusMap.put("modelName:version", "response"); + MLUndeployModelControllerNodeResponse response = new MLUndeployModelControllerNodeResponse( + localNode, + modelUndeployModelControllerStatusMap + ); + final List responses = List.of(response); + final List failures = new ArrayList<>(); + MLUndeployModelControllerNodesResponse response1 = action.newResponse(nodesRequest, responses, failures); + assertNotNull(response1); + } + + @Test + public void testNewNodeRequest() { + final MLUndeployModelControllerNodesRequest request = new MLUndeployModelControllerNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + "testModelId" + ); + final MLUndeployModelControllerNodeRequest undeployModelControllerNodeRequest = action.newNodeRequest(request); + assertNotNull(undeployModelControllerNodeRequest); + } + + @Test + public void testNewNodeStreamRequest() throws IOException { + Map undeployModelControllerStatus = new HashMap<>(); + undeployModelControllerStatus.put("modelId1", "response"); + MLUndeployModelControllerNodeResponse response = new MLUndeployModelControllerNodeResponse( + localNode, + undeployModelControllerStatus + ); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + final MLUndeployModelControllerNodeResponse undeployModelControllerNodeResponse = action + .newNodeResponse(output.bytes().streamInput()); + assertNotNull(undeployModelControllerNodeResponse); + } + + @Test + public void testNodeOperation() { + final MLUndeployModelControllerNodesRequest request = new MLUndeployModelControllerNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + "testModelId" + ); + final MLUndeployModelControllerNodeResponse response = action.nodeOperation(new MLUndeployModelControllerNodeRequest(request)); + assertNotNull(response); + } + + @Test + public void testNodeOperationException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Test exception")); + return null; + }).when(mlModelManager).undeployModelController(any(), any()); + final MLUndeployModelControllerNodesRequest request = new MLUndeployModelControllerNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + "testModelId" + ); + final MLUndeployModelControllerNodeResponse response = action.nodeOperation(new MLUndeployModelControllerNodeRequest(request)); + assertNotNull(response); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateModelControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateModelControllerTransportActionTests.java new file mode 100644 index 0000000000..64fd99ee4a --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateModelControllerTransportActionTests.java @@ -0,0 +1,463 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.controller; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLRateLimiter; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesResponse; +import org.opensearch.ml.common.transport.controller.MLUpdateModelControllerRequest; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelCacheHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class UpdateModelControllerTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + UpdateResponse updateResponse; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + private MLModelManager mlModelManager; + + @Mock + MLModelCacheHelper mlModelCacheHelper; + + @Mock + ClusterService clusterService; + + @Mock + ClusterState clusterState; + + @Mock + private ModelAccessControlHelper modelAccessControlHelper; + + @Mock + MLModel mlModel; + + @Mock + MLDeployModelControllerNodesResponse mlDeployModelControllerNodesResponse; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + MLModelController modelController; + MLModelController updatedModelController; + UpdateModelControllerTransportAction updateModelControllerTransportAction; + MLUpdateModelControllerRequest updateModelControllerRequest; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + + InetAddress inetAddress1 = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 1 }); + InetAddress inetAddress2 = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 2 }); + + DiscoveryNode node1 = new DiscoveryNode( + "foo1", + "foo1", + new TransportAddress(inetAddress1, 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + + DiscoveryNode node2 = new DiscoveryNode( + "foo2", + "foo2", + new TransportAddress(inetAddress2, 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + + DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build(); + String[] targetNodeIds = new String[] { node1.getId(), node2.getId() }; + + updateModelControllerTransportAction = spy( + new UpdateModelControllerTransportAction( + transportService, + actionFilters, + client, + clusterService, + modelAccessControlHelper, + mlModelCacheHelper, + mlModelManager + ) + ); + + MLRateLimiter rateLimiter = MLRateLimiter.builder().rateLimitNumber("1").rateLimitUnit(TimeUnit.MILLISECONDS).build(); + + modelController = MLModelController.builder().modelId("testModelId").userRateLimiterConfig(new HashMap<>() { + { + put("testUser", rateLimiter); + } + }).build(); + + MLRateLimiter updateRateLimiter = MLRateLimiter.builder().rateLimitNumber("2").rateLimitUnit(TimeUnit.NANOSECONDS).build(); + + updatedModelController = MLModelController.builder().modelId("testModelId").userRateLimiterConfig(new HashMap<>() { + { + put("newUser", updateRateLimiter); + } + }).build(); + + updateModelControllerRequest = MLUpdateModelControllerRequest.builder().updateModelControllerInput(updatedModelController).build(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); + when(mlModel.getModelId()).thenReturn("testModelId"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(modelController); + return null; + }).when(mlModelManager).getModelController(eq("testModelId"), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(), any()); + when(updateResponse.getResult()).thenReturn(DocWriteResponse.Result.UPDATED); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(clusterService.getSettings()).thenReturn(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(false); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.nodes()).thenReturn(nodes); + when(mlModelManager.getWorkerNodes("testModelId", FunctionName.REMOTE)).thenReturn(targetNodeIds); + } + + @Test + public void testUpdateModelControllerSuccess() { + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelControllerWithTextEmbeddingModelSuccess() { + when(mlModel.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelControllerWithModelAccessControlNoPermission() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User doesn't have privilege to perform this operation on this model controller, model ID: testModelId", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelControllerWithModelAccessControlOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testUpdateModelControllerWithModelControllerEnabledNull() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); + return null; + }).when(mlModelManager).getModelController(eq("testModelId"), isA(ActionListener.class)); + when(mlModel.getIsModelControllerEnabled()).thenReturn(null); + + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Model controller haven't been created for the model. Consider calling create model controller api instead. Model ID: testModelId", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelControllerWithModelControllerNotEnabled() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); + return null; + }).when(mlModelManager).getModelController(eq("testModelId"), isA(ActionListener.class)); + when(mlModel.getIsModelControllerEnabled()).thenReturn(false); + + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Model controller haven't been created for the model. Consider calling create model controller api instead. Model ID: testModelId", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelControllerWithModelControllerEnabledNotFound() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); + return null; + }).when(mlModelManager).getModelController(eq("testModelId"), isA(ActionListener.class)); + when(mlModel.getIsModelControllerEnabled()).thenReturn(true); + + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testUpdateModelControllerWithModelFunctionUnsupported() { + when(mlModel.getAlgorithm()).thenReturn(FunctionName.METRICS_CORRELATION); + + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Creating model controller on this operation on the function category METRICS_CORRELATION is not supported.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void tesUpdateModelControllerWithGetModelNotFound() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(null); + return null; + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Failed to find model to create the corresponding model controller with the provided model ID: testModelId", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelControllerWithUpdateResponseNoop() { + when(updateResponse.getResult()).thenReturn(DocWriteResponse.Result.NOOP); + + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelControllerWithNullUpdateResponse() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).update(any(), any()); + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to update model controller with model ID: testModelId", argumentCaptor.getValue().getMessage()); + + } + + @Test + public void testUpdateModelControllerWithDeploySuccessNullFailures() { + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlDeployModelControllerNodesResponse); + return null; + }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); + when(mlDeployModelControllerNodesResponse.failures()).thenReturn(null); + + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelControllerWithDeployNotRequiredAfterUpdateSuccess() { + updateModelControllerRequest = MLUpdateModelControllerRequest.builder().updateModelControllerInput(modelController).build(); + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelControllerWithModelNDeployedAndDeployNotRequiredAfterUpdateSuccess() { + updateModelControllerRequest = MLUpdateModelControllerRequest.builder().updateModelControllerInput(modelController).build(); + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(false); + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelControllerWithUndeploySuccessEmptyFailures() { + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlDeployModelControllerNodesResponse); + return null; + }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); + when(mlDeployModelControllerNodesResponse.failures()).thenReturn(new ArrayList<>()); + + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelControllerWithUndeploySuccessPartiallyFailures() { + List failures = List + .of(new FailedNodeException("foo1", "Undeploy failed.", new RuntimeException("Exception occurred."))); + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlDeployModelControllerNodesResponse); + return null; + }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); + when(mlDeployModelControllerNodesResponse.failures()).thenReturn(failures); + + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Successfully update model controller index with model ID testModelId " + + "but deploy model controller to cache was failed on following nodes [foo1], please retry.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelControllerWithUndeployNullResponse() { + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); + + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Successfully update model controller index with model ID testModelId " + + "but deploy model controller to cache was failed on following nodes [foo1, foo2], please retry.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelControllerWithUndeployOtherException() { + when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener + .onFailure( + new RuntimeException("Exception occurred. Please check log for more details.") + ); + return null; + }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); + + updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Exception occurred. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java index 1f09877c92..529163cc5c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -752,7 +752,10 @@ public void testUpdateModelStateDeployingException() { transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Model is deploying, please wait for it complete. model ID test_model_id", argumentCaptor.getValue().getMessage()); + assertEquals( + "Model is deploying. Please wait for the model to complete deployment. model ID test_model_id", + argumentCaptor.getValue().getMessage() + ); } @Test @@ -1163,7 +1166,7 @@ private MLUpdateModelRequest prepareRemoteRequest(String remoteRequestType) thro .name("updated_test_name") .description("updated_test_description") .modelGroupId("updated_test_model_group_id") - .connectorUpdateContent(updateContent) + .connector(updateContent) .build(); return MLUpdateModelRequest.builder().updateModelInput(updateRemoteModelInput).build(); default: diff --git a/plugin/src/test/java/org/opensearch/ml/action/update_cache/UpdateModelCacheTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/update_cache/UpdateModelCacheTransportActionTests.java index ec659b872c..4d81511ce8 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/update_cache/UpdateModelCacheTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/update_cache/UpdateModelCacheTransportActionTests.java @@ -105,18 +105,17 @@ public void setUp() throws Exception { when(clusterService.localNode()).thenReturn(localNode); when(clusterService.getClusterName()).thenReturn(new ClusterName("Local Cluster")); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(1); listener.onResponse("successful"); return null; - }).when(mlModelManager).updateModelCache(any(), any(Boolean.class), any()); + }).when(mlModelManager).updateModelCache(any(), any()); } @Test public void testNewResponses() { final MLUpdateModelCacheNodesRequest nodesRequest = new MLUpdateModelCacheNodesRequest( new String[] { "nodeId1", "nodeId2" }, - "testModelId", - true + "testModelId" ); Map modelUpdateModelCacheStatusMap = new HashMap<>(); modelUpdateModelCacheStatusMap.put("modelName:version", "response"); @@ -131,8 +130,7 @@ public void testNewResponses() { public void testNewNodeRequest() { final MLUpdateModelCacheNodesRequest request = new MLUpdateModelCacheNodesRequest( new String[] { "nodeId1", "nodeId2" }, - "testModelId", - true + "testModelId" ); final MLUpdateModelCacheNodeRequest updateModelCacheNodeRequest = action.newNodeRequest(request); assertNotNull(updateModelCacheNodeRequest); @@ -153,8 +151,7 @@ public void testNewNodeStreamRequest() throws IOException { public void testNodeOperation() { final MLUpdateModelCacheNodesRequest request = new MLUpdateModelCacheNodesRequest( new String[] { "nodeId1", "nodeId2" }, - "testModelId", - true + "testModelId" ); final MLUpdateModelCacheNodeResponse response = action.nodeOperation(new MLUpdateModelCacheNodeRequest(request)); assertNotNull(response); @@ -163,14 +160,13 @@ public void testNodeOperation() { @Test public void testNodeOperationException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("Test exception")); return null; - }).when(mlModelManager).updateModelCache(any(), any(Boolean.class), any()); + }).when(mlModelManager).updateModelCache(any(), any()); final MLUpdateModelCacheNodesRequest request = new MLUpdateModelCacheNodesRequest( new String[] { "nodeId1", "nodeId2" }, - "testModelId", - true + "testModelId" ); final MLUpdateModelCacheNodeResponse response = action.nodeOperation(new MLUpdateModelCacheNodeRequest(request)); assertNotNull(response); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 2f8ef74f66..60338e4ccd 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -548,6 +548,27 @@ public void testRegisterModel_ClientFailedToGetThreadPool() { } public void testDeployModel_FailedToGetModel() { + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(384) + .build(); + model = MLModel + .builder() + .modelId(modelId) + .modelState(MLModelState.DEPLOYING) + .algorithm(FunctionName.TEXT_EMBEDDING) + .name(modelName) + .version(version) + .totalChunks(2) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .modelContentHash(modelContentHashValue) + .modelContentSizeInBytes(modelContentSize) + .build(); + String[] nodes = new String[] { "node1", "node2" }; + mlTask.setWorkerNodes(List.of(nodes)); ActionListener listener = mock(ActionListener.class); when(modelCacheHelper.isModelDeployed(modelId)).thenReturn(false); when(modelCacheHelper.getDeployedModels()).thenReturn(new String[] {}); @@ -569,6 +590,27 @@ public void testDeployModel_FailedToGetModel() { } public void testDeployModel_NullGetModelResponse() { + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(384) + .build(); + model = MLModel + .builder() + .modelId(modelId) + .modelState(MLModelState.DEPLOYING) + .algorithm(FunctionName.TEXT_EMBEDDING) + .name(modelName) + .version(version) + .totalChunks(2) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .modelContentHash(modelContentHashValue) + .modelContentSizeInBytes(modelContentSize) + .build(); + String[] nodes = new String[] { "node1", "node2" }; + mlTask.setWorkerNodes(List.of(nodes)); ActionListener listener = mock(ActionListener.class); when(modelCacheHelper.isModelDeployed(modelId)).thenReturn(false); when(modelCacheHelper.getDeployedModels()).thenReturn(new String[] {}); @@ -589,6 +631,27 @@ public void testDeployModel_NullGetModelResponse() { } public void testDeployModel_GetModelResponse_NotExist() { + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(384) + .build(); + model = MLModel + .builder() + .modelId(modelId) + .modelState(MLModelState.DEPLOYING) + .algorithm(FunctionName.TEXT_EMBEDDING) + .name(modelName) + .version(version) + .totalChunks(2) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .modelContentHash(modelContentHashValue) + .modelContentSizeInBytes(modelContentSize) + .build(); + String[] nodes = new String[] { "node1", "node2" }; + mlTask.setWorkerNodes(List.of(nodes)); ActionListener listener = mock(ActionListener.class); when(modelCacheHelper.isModelDeployed(modelId)).thenReturn(false); when(modelCacheHelper.getDeployedModels()).thenReturn(new String[] {}); @@ -609,6 +672,28 @@ public void testDeployModel_GetModelResponse_NotExist() { } public void testDeployModel_GetModelResponse_wrong_hash_value() { + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(384) + .build(); + model = MLModel + .builder() + .modelId(modelId) + .modelState(MLModelState.DEPLOYING) + .algorithm(FunctionName.TEXT_EMBEDDING) + .name(modelName) + .version(version) + .totalChunks(2) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .modelContentHash(modelContentHashValue) + .modelContentSizeInBytes(modelContentSize) + .build(); + modelChunk0 = model.toBuilder().content(Base64.getEncoder().encodeToString("test chunk1".getBytes(StandardCharsets.UTF_8))).build(); + String[] nodes = new String[] { "node1", "node2" }; + mlTask.setWorkerNodes(List.of(nodes)); ActionListener listener = mock(ActionListener.class); when(modelCacheHelper.isModelDeployed(modelId)).thenReturn(false); when(modelCacheHelper.getDeployedModels()).thenReturn(new String[] {}); @@ -638,6 +723,27 @@ public void testDeployModel_GetModelResponse_wrong_hash_value() { } public void testDeployModel_GetModelResponse_FailedToDeploy() { + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(384) + .build(); + model = MLModel + .builder() + .modelId(modelId) + .modelState(MLModelState.DEPLOYING) + .algorithm(FunctionName.TEXT_EMBEDDING) + .name(modelName) + .version(version) + .totalChunks(2) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .modelContentHash(modelContentHashValue) + .modelContentSizeInBytes(modelContentSize) + .build(); + String[] nodes = new String[] { "node1", "node2" }; + mlTask.setWorkerNodes(List.of(nodes)); ActionListener listener = mock(ActionListener.class); when(modelCacheHelper.isModelDeployed(modelId)).thenReturn(false); when(modelCacheHelper.getDeployedModels()).thenReturn(new String[] {}); @@ -685,6 +791,27 @@ public void testDeployModel_ExceedMaxDeployedModel() { } public void testDeployModel_ThreadPoolException() { + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(384) + .build(); + model = MLModel + .builder() + .modelId(modelId) + .modelState(MLModelState.DEPLOYING) + .algorithm(FunctionName.TEXT_EMBEDDING) + .name(modelName) + .version(version) + .totalChunks(2) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .modelContentHash(modelContentHashValue) + .modelContentSizeInBytes(modelContentSize) + .build(); + String[] nodes = new String[] { "node1", "node2" }; + mlTask.setWorkerNodes(List.of(nodes)); when(modelCacheHelper.isModelDeployed(modelId)).thenReturn(false); when(modelCacheHelper.getDeployedModels()).thenReturn(new String[] {}); when(modelCacheHelper.getLocalDeployedModels()).thenReturn(new String[] {}); @@ -817,6 +944,27 @@ public void test_addModelWorkerNodes_success() { } private void testDeployModel_FailedToRetrieveModelChunks(boolean lastChunk) { + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(384) + .build(); + model = MLModel + .builder() + .modelId(modelId) + .modelState(MLModelState.DEPLOYING) + .algorithm(FunctionName.TEXT_EMBEDDING) + .name(modelName) + .version(version) + .totalChunks(2) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .modelContentHash(modelContentHashValue) + .modelContentSizeInBytes(modelContentSize) + .build(); + String[] nodes = new String[] { "node1", "node2" }; + mlTask.setWorkerNodes(List.of(nodes)); when(modelCacheHelper.isModelDeployed(modelId)).thenReturn(false); when(modelCacheHelper.getDeployedModels()).thenReturn(new String[] {}); when(threadPool.executor(DEPLOY_THREAD_POOL)).thenReturn(taskExecutorService); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateModelControllerActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateModelControllerActionTests.java new file mode 100644 index 0000000000..9e358b8ceb --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateModelControllerActionTests.java @@ -0,0 +1,183 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.transport.controller.MLCreateModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLCreateModelControllerRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLCreateModelControllerActionTests extends OpenSearchTestCase { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private RestMLCreateModelControllerAction restMLCreateModelControllerAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + restMLCreateModelControllerAction = new RestMLCreateModelControllerAction(); + doAnswer(invocation -> { + invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLCreateModelControllerAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + @Test + public void testConstructor() { + RestMLCreateModelControllerAction CreateModelAction = new RestMLCreateModelControllerAction(); + assertNotNull(CreateModelAction); + } + + @Test + public void testGetName() { + String actionName = restMLCreateModelControllerAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_create_model_controller_action", actionName); + } + + @Test + public void testRoutes() { + List routes = restMLCreateModelControllerAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.POST, route.getMethod()); + assertEquals("/_plugins/_ml/model_controllers/{model_id}", route.getPath()); + } + + @Test + public void testCreateModelControllerRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLCreateModelControllerAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLCreateModelControllerRequest.class); + verify(client, times(1)).execute(eq(MLCreateModelControllerAction.INSTANCE), argumentCaptor.capture(), any()); + MLModelController createModelControllerInput = argumentCaptor.getValue().getModelControllerInput(); + assertEquals("testModelId", createModelControllerInput.getModelId()); + } + + @Test + public void testCreateModelControllerRequestWithEmptyContent() throws Exception { + exceptionRule.expect(OpenSearchParseException.class); + exceptionRule.expectMessage("Create model controller request has empty body"); + RestRequest request = getRestRequestWithEmptyContent(); + restMLCreateModelControllerAction.handleRequest(request, channel, client); + } + + @Test + public void testCreateModelControllerRequestWithNullModelId() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Request should contain model_id"); + RestRequest request = getRestRequestWithNullModelId(); + restMLCreateModelControllerAction.handleRequest(request, channel, client); + } + + @Test + public void testCreateModelControllerRequestWithNullField() throws Exception { + exceptionRule.expect(ParsingException.class); + exceptionRule.expectMessage("expecting token of type [START_OBJECT] but found [VALUE_NULL]"); + RestRequest request = getRestRequestWithNullField(); + restMLCreateModelControllerAction.handleRequest(request, channel, client); + } + + private RestRequest getRestRequest() { + RestRequest.Method method = RestRequest.Method.POST; + String requestContent = "{\"user_rate_limiter_config\":{\"testUser\":{}}}"; + Map params = Map.of("model_id", "testModelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/model_controllers/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithEmptyContent() { + RestRequest.Method method = RestRequest.Method.POST; + Map params = Map.of("model_id", "testModelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/model_controllers/{model_id}") + .withParams(params) + .withContent(new BytesArray(""), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithNullModelId() { + RestRequest.Method method = RestRequest.Method.POST; + String requestContent = "{\"user_rate_limiter_config\":{\"testUser\":{}}}"; + Map params = new HashMap<>(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/model_controllers/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithNullField() { + RestRequest.Method method = RestRequest.Method.POST; + String requestContent = "{\"user_rate_limiter_config\":{\"testUser\":null}}"; + Map params = new HashMap<>(); + params.put("model_id", "testModelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/model_controllers/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelControllerActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelControllerActionTests.java new file mode 100644 index 0000000000..867ec2ce33 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelControllerActionTests.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.controller.MLModelControllerDeleteAction; +import org.opensearch.ml.common.transport.controller.MLModelControllerDeleteRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLDeleteModelControllerActionTests extends OpenSearchTestCase { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private RestMLDeleteModelControllerAction restMLDeleteModelControllerAction; + + NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + restMLDeleteModelControllerAction = new RestMLDeleteModelControllerAction(); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLModelControllerDeleteAction.INSTANCE), any(), any()); + + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLDeleteModelControllerAction mlDeleteModelControllerAction = new RestMLDeleteModelControllerAction(); + assertNotNull(mlDeleteModelControllerAction); + } + + public void testGetName() { + String actionName = restMLDeleteModelControllerAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_delete_model_controller_action", actionName); + } + + public void testRoutes() { + List routes = restMLDeleteModelControllerAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.DELETE, route.getMethod()); + assertEquals("/_plugins/_ml/model_controllers/{model_id}", route.getPath()); + } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLDeleteModelControllerAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLModelControllerDeleteRequest.class); + verify(client, times(1)).execute(eq(MLModelControllerDeleteAction.INSTANCE), argumentCaptor.capture(), any()); + String taskId = argumentCaptor.getValue().getModelId(); + assertEquals(taskId, "testModelId"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_MODEL_ID, "testModelId"); + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelControllerActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelControllerActionTests.java new file mode 100644 index 0000000000..7597aed17d --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelControllerActionTests.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.controller.MLModelControllerGetAction; +import org.opensearch.ml.common.transport.controller.MLModelControllerGetRequest; +import org.opensearch.ml.common.transport.controller.MLModelControllerGetResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLGetModelControllerActionTests extends OpenSearchTestCase { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private RestMLGetModelControllerAction restMLGetModelControllerAction; + + NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + restMLGetModelControllerAction = new RestMLGetModelControllerAction(); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLModelControllerGetAction.INSTANCE), any(), any()); + + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLGetModelControllerAction mlGetModelControllerAction = new RestMLGetModelControllerAction(); + assertNotNull(mlGetModelControllerAction); + } + + public void testGetName() { + String actionName = restMLGetModelControllerAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_get_model_controller_action", actionName); + } + + public void testRoutes() { + List routes = restMLGetModelControllerAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.GET, route.getMethod()); + assertEquals("/_plugins/_ml/model_controllers/{model_id}", route.getPath()); + } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLGetModelControllerAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLModelControllerGetRequest.class); + verify(client, times(1)).execute(eq(MLModelControllerGetAction.INSTANCE), argumentCaptor.capture(), any()); + String taskId = argumentCaptor.getValue().getModelId(); + assertEquals(taskId, "testModelId"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_MODEL_ID, "testModelId"); + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelGroupActionTests.java index 57a222abad..221848c597 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelGroupActionTests.java @@ -12,7 +12,6 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import java.io.IOException; import java.util.List; import java.util.Map; @@ -22,6 +21,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; import org.opensearch.action.get.GetResponse; import org.opensearch.client.node.NodeClient; import org.opensearch.common.settings.Settings; @@ -104,7 +104,7 @@ public void testRegisterModelGroupRequest() throws Exception { } public void testRegisterModelGroupRequestWithEmptyContent() throws Exception { - exceptionRule.expect(IOException.class); + exceptionRule.expect(OpenSearchParseException.class); exceptionRule.expectMessage("Model group request has empty body"); RestRequest request = getRestRequestWithEmptyContent(); restMLRegisterModelGroupAction.handleRequest(request, channel, client); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java index 1c6a3d2ae7..c3a21bde1f 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java @@ -25,11 +25,9 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchParseException; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.node.NodeClient; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -68,11 +66,6 @@ public void setup() { when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); restMLUpdateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting); - - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(2); - return null; - }).when(client).execute(eq(MLUpdateConnectorAction.INSTANCE), any(), any()); } @Override @@ -103,6 +96,11 @@ public void testRoutes() { } public void testUpdateConnectorRequest() throws Exception { + doAnswer(invocation -> { + invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLUpdateConnectorAction.INSTANCE), any(), any()); + RestRequest request = getRestRequest(); restMLUpdateConnectorAction.handleRequest(request, channel, client); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLUpdateConnectorRequest.class); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java index 10e78ecf2d..c11c7e3fb8 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java @@ -115,7 +115,7 @@ public void testUpdateModelRequest() throws Exception { @Test public void testUpdateModelRequestWithEmptyContent() throws Exception { exceptionRule.expect(OpenSearchParseException.class); - exceptionRule.expectMessage("Model update request has empty body"); + exceptionRule.expectMessage("Update model request has empty body"); RestRequest request = getRestRequestWithEmptyContent(); restMLUpdateModelAction.handleRequest(request, channel, client); } @@ -166,7 +166,7 @@ public void testUpdateModelRequestWithConnectorUpdateContent() throws Exception assertEquals("testModelName", updateModelInput.getName()); assertEquals( "{\"description\":\"updated description\",\"version\":\"1\",\"parameters\":{},\"credential\":{}}", - toJsonString(updateModelInput.getConnectorUpdateContent()) + toJsonString(updateModelInput.getConnector()) ); } @@ -242,7 +242,7 @@ private RestRequest getRestRequestWithConnectorIDAndConnectorUpdateContent() { "This is test description", "connector_id", "testConnectorID", - "connector_update_content", + "connector", updateContent ); String requestContent = new Gson().toJson(modelContent); @@ -282,7 +282,7 @@ private RestRequest getRestRequestWithConnectorUpdateContent() { .description("updated description") .build(); final Map modelContent = Map - .of("name", "testModelName", "description", "This is test description", "connector_update_content", updateContent); + .of("name", "testModelName", "description", "This is test description", "connector", updateContent); String requestContent = new Gson().toJson(modelContent); Map params = new HashMap<>(); params.put("model_id", "test_modelId"); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelControllerActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelControllerActionTests.java new file mode 100644 index 0000000000..caadbab1fb --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelControllerActionTests.java @@ -0,0 +1,183 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.transport.controller.MLUpdateModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLUpdateModelControllerRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLUpdateModelControllerActionTests extends OpenSearchTestCase { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private RestMLUpdateModelControllerAction restMLUpdateModelControllerAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + restMLUpdateModelControllerAction = new RestMLUpdateModelControllerAction(); + doAnswer(invocation -> { + invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLUpdateModelControllerAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + @Test + public void testConstructor() { + RestMLUpdateModelControllerAction UpdateModelAction = new RestMLUpdateModelControllerAction(); + assertNotNull(UpdateModelAction); + } + + @Test + public void testGetName() { + String actionName = restMLUpdateModelControllerAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_update_model_controller_action", actionName); + } + + @Test + public void testRoutes() { + List routes = restMLUpdateModelControllerAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.PUT, route.getMethod()); + assertEquals("/_plugins/_ml/model_controllers/{model_id}", route.getPath()); + } + + @Test + public void testUpdateModelControllerRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLUpdateModelControllerAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLUpdateModelControllerRequest.class); + verify(client, times(1)).execute(eq(MLUpdateModelControllerAction.INSTANCE), argumentCaptor.capture(), any()); + MLModelController updateModelControllerInput = argumentCaptor.getValue().getUpdateModelControllerInput(); + assertEquals("testModelId", updateModelControllerInput.getModelId()); + } + + @Test + public void testUpdateModelControllerRequestWithEmptyContent() throws Exception { + exceptionRule.expect(OpenSearchParseException.class); + exceptionRule.expectMessage("Update model controller request has empty body"); + RestRequest request = getRestRequestWithEmptyContent(); + restMLUpdateModelControllerAction.handleRequest(request, channel, client); + } + + @Test + public void testUpdateModelControllerRequestWithNullModelId() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Request should contain model_id"); + RestRequest request = getRestRequestWithNullModelId(); + restMLUpdateModelControllerAction.handleRequest(request, channel, client); + } + + @Test + public void testUpdateModelControllerRequestWithNullField() throws Exception { + exceptionRule.expect(ParsingException.class); + exceptionRule.expectMessage("expecting token of type [START_OBJECT] but found [VALUE_NULL]"); + RestRequest request = getRestRequestWithNullField(); + restMLUpdateModelControllerAction.handleRequest(request, channel, client); + } + + private RestRequest getRestRequest() { + RestRequest.Method method = RestRequest.Method.PUT; + String requestContent = "{\"user_rate_limiter_config\":{\"testUser\":{}}}"; + Map params = Map.of("model_id", "testModelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/model_controllers/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithEmptyContent() { + RestRequest.Method method = RestRequest.Method.PUT; + Map params = Map.of("model_id", "testModelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/model_controllers/{model_id}") + .withParams(params) + .withContent(new BytesArray(""), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithNullModelId() { + RestRequest.Method method = RestRequest.Method.PUT; + String requestContent = "{\"user_rate_limiter_config\":{\"testUser\":{}}}"; + Map params = new HashMap<>(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/model_controllers/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithNullField() { + RestRequest.Method method = RestRequest.Method.PUT; + String requestContent = "{\"user_rate_limiter_config\":{\"testUser\":null}}"; + Map params = new HashMap<>(); + params.put("model_id", "testModelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/model_controllers/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } +}