diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 45f80f5113..60fed48365 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -131,6 +131,7 @@ import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.engine.utils.FileUtils; import org.opensearch.ml.profile.MLModelProfile; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.stats.MLNodeLevelStat; @@ -178,6 +179,7 @@ public class MLModelManager { private final MLTaskManager mlTaskManager; private final MLEngine mlEngine; private final DiscoveryNodeHelper nodeHelper; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; private volatile Integer maxModelPerNode; private volatile Integer maxRegisterTasksPerNode; @@ -208,7 +210,8 @@ public MLModelManager( MLTaskManager mlTaskManager, MLModelCacheHelper modelCacheHelper, MLEngine mlEngine, - DiscoveryNodeHelper nodeHelper + DiscoveryNodeHelper nodeHelper, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { this.client = client; this.sdkClient = sdkClient; @@ -224,6 +227,7 @@ public MLModelManager( this.mlTaskManager = mlTaskManager; this.mlEngine = mlEngine; this.nodeHelper = nodeHelper; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; this.maxModelPerNode = ML_COMMONS_MAX_MODELS_PER_NODE.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MAX_MODELS_PER_NODE, it -> maxModelPerNode = it); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 1e15c8a4e5..843c670779 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -511,7 +511,8 @@ public Collection createComponents( mlTaskManager, modelCacheHelper, mlEngine, - nodeHelper + nodeHelper, + mlFeatureEnabledSetting ); mlInputDatasetHandler = new MLInputDatasetHandler(client); modelAccessControlHelper = new ModelAccessControlHelper(clusterService, settings); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 835440040b..d6e549cc63 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -121,6 +121,7 @@ import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.sdkclient.SdkClientFactory; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; @@ -208,7 +209,9 @@ public class MLModelManagerTests extends OpenSearchTestCase { @Mock private MLTask pretrainedMLTask; - + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() throws URISyntaxException, IOException { String masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="; @@ -301,7 +304,8 @@ public void setup() throws URISyntaxException, IOException { mlTaskManager, modelCacheHelper, mlEngine, - nodeHelper + nodeHelper, + mlFeatureEnabledSetting ) );