From 126ed3a4220d5898d68910c1a26553578a4fc0ea Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 30 Jan 2024 02:32:16 +0800 Subject: [PATCH] [BUG FIX] Fix updating plugins.ml_commons.jvm_heap_memory_threshold takes no effect (#1943) * fix plugins.ml_commons.jvm_heap_memory_threshold settings ineffective Signed-off-by: zhichao-aws * add integ test Signed-off-by: zhichao-aws * add license header Signed-off-by: zhichao-aws * fix the bug by override getThreshold Signed-off-by: zhichao-aws --------- Signed-off-by: zhichao-aws --- .../ml/breaker/MemoryCircuitBreaker.java | 5 ++ .../ml/breaker/MemoryCircuitBreakerTests.java | 24 ++++++ .../ml/rest/RestMLMemoryCircuitBreakerIT.java | 80 +++++++++++++++++++ 3 files changed, 109 insertions(+) create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java diff --git a/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java b/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java index 8b467c2452..5e045ae539 100644 --- a/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java +++ b/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java @@ -43,6 +43,11 @@ public String getName() { return ML_MEMORY_CB; } + @Override + public Short getThreshold() { + return this.jvmHeapMemThreshold.shortValue(); + } + @Override public boolean isOpen() { return jvmService.stats().getMem().getHeapUsedPercent() > this.getThreshold(); diff --git a/plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java b/plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java index a749b19264..cdd1f6fc22 100644 --- a/plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java @@ -6,12 +6,16 @@ package org.opensearch.ml.breaker; import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; import org.opensearch.monitor.jvm.JvmService; import org.opensearch.monitor.jvm.JvmStats; @@ -26,6 +30,9 @@ public class MemoryCircuitBreakerTests { @Mock JvmStats.Mem mem; + @Mock + ClusterService clusterService; + @Before public void setup() { MockitoAnnotations.openMocks(this); @@ -60,4 +67,21 @@ public void testIsOpen_CustomThreshold_ExceedMemoryThreshold() { when(mem.getHeapUsedPercent()).thenReturn((short) 95); Assert.assertTrue(breaker.isOpen()); } + + @Test + public void testIsOpen_UpdatedByClusterSettings_ExceedMemoryThreshold() { + ClusterSettings settingsService = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + settingsService.registerSetting(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD); + when(clusterService.getClusterSettings()).thenReturn(settingsService); + + CircuitBreaker breaker = new MemoryCircuitBreaker(Settings.builder().build(), clusterService, jvmService); + + when(mem.getHeapUsedPercent()).thenReturn((short) 90); + Assert.assertTrue(breaker.isOpen()); + + Settings.Builder newSettingsBuilder = Settings.builder(); + newSettingsBuilder.put("plugins.ml_commons.jvm_heap_memory_threshold", 95); + settingsService.applySettings(newSettingsBuilder.build()); + Assert.assertFalse(breaker.isOpen()); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java new file mode 100644 index 0000000000..fb42b9d073 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.rest; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; + +import java.io.IOException; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.After; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.ml.breaker.MemoryCircuitBreaker; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMLMemoryCircuitBreakerIT extends MLCommonsRestTestCase { + @After + public void tearDown() throws Exception { + super.tearDown(); + // restore the threshold to default value + Response response1 = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.jvm_heap_memory_threshold\":" + + MemoryCircuitBreaker.DEFAULT_JVM_HEAP_USAGE_THRESHOLD + + "}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response1.getStatusLine().getStatusCode()); + } + + public void testRunWithMemoryCircuitBreaker() throws IOException { + // set a low threshold + Response response1 = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.jvm_heap_memory_threshold\":1}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response1.getStatusLine().getStatusCode()); + + // expect task fail due to memory limit + Exception exception = assertThrows(ResponseException.class, () -> ingestModelData()); + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString("Memory Circuit Breaker is open, please check your resources!"), + containsString("m_l_limit_exceeded_exception") + ) + ); + + // set a higher threshold + Response response2 = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.jvm_heap_memory_threshold\":100}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response2.getStatusLine().getStatusCode()); + + // expect task success + ingestModelData(); + } +}