Skip to content

Commit

Permalink
[Backport 2.11] exclude remote models in circuit breaker checks and f…
Browse files Browse the repository at this point in the history
…ix memory CB bugs (opensearch-project#2713)

* exclude remote models in circuit breaker checks and fix memory CB bugs

Signed-off-by: Xun Zhang <[email protected]>

* use static max heap threshold 100

Signed-off-by: Xun Zhang <[email protected]>

* fix issues after backport in 2.11

Signed-off-by: Xun Zhang <[email protected]>

---------

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt authored Jul 25, 2024
1 parent 0c27efc commit 69b0ca2
Show file tree
Hide file tree
Showing 12 changed files with 165 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public void test_validateIp_validIp_noException() throws UnknownHostException {
@Test
public void test_validateIp_invalidIp_throwException() throws UnknownHostException {
expectedException.expect(UnknownHostException.class);
MLHttpClientFactory.validateIp("www.zaniu.com");
MLHttpClientFactory.validateIp("www.zanniu.com");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.action.prediction;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
Expand All @@ -14,9 +15,12 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
Expand Down Expand Up @@ -108,7 +112,27 @@ public void onResponse(MLModel mlModel) {
}
}, e -> {
log.error("Failed to Validate Access for ModelId " + modelId, e);
wrappedListener.onFailure(e);
if (e instanceof OpenSearchStatusException) {
wrappedListener
.onFailure(
new OpenSearchStatusException(
e.getMessage(),
RestStatus.fromCode(((OpenSearchStatusException) e).status().getStatus())
)
);
} else if (e instanceof MLResourceNotFoundException) {
wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND));
} else if (e instanceof CircuitBreakingException) {
wrappedListener.onFailure(e);
} else {
wrappedListener
.onFailure(
new OpenSearchStatusException(
"Failed to Validate Access for ModelId " + modelId,
RestStatus.FORBIDDEN
)
);
}
}));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ public class MemoryCircuitBreaker extends ThresholdCircuitBreaker<Short> {
// TODO: make this value configurable as cluster setting
private static final String ML_MEMORY_CB = "Memory Circuit Breaker";
public static final short DEFAULT_JVM_HEAP_USAGE_THRESHOLD = 85;
public static final short JVM_HEAP_MAX_THRESHOLD = 100; // when threshold is 100, this CB check is ignored
private final JvmService jvmService;
private volatile Integer jvmHeapMemThreshold = 85;

public MemoryCircuitBreaker(JvmService jvmService) {
super(DEFAULT_JVM_HEAP_USAGE_THRESHOLD);
Expand All @@ -34,8 +34,9 @@ public MemoryCircuitBreaker(short threshold, JvmService jvmService) {
public MemoryCircuitBreaker(Settings settings, ClusterService clusterService, JvmService jvmService) {
super(DEFAULT_JVM_HEAP_USAGE_THRESHOLD);
this.jvmService = jvmService;
this.jvmHeapMemThreshold = ML_COMMONS_JVM_HEAP_MEM_THRESHOLD.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD, it -> jvmHeapMemThreshold = it);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD, it -> super.setThreshold(it.shortValue()));
}

@Override
Expand All @@ -45,6 +46,6 @@ public String getName() {

@Override
public boolean isOpen() {
return jvmService.stats().getMem().getHeapUsedPercent() > this.getThreshold();
return getThreshold() < JVM_HEAP_MAX_THRESHOLD && jvmService.stats().getMem().getHeapUsedPercent() > getThreshold();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

package org.opensearch.ml.breaker;

import lombok.Data;

/**
* An abstract class for all breakers with threshold.
* @param <T> data type of threshold
*/
@Data
public abstract class ThresholdCircuitBreaker<T> implements CircuitBreaker {

private T threshold;
Expand All @@ -17,10 +20,6 @@ public ThresholdCircuitBreaker(T threshold) {
this.threshold = threshold;
}

public T getThreshold() {
return threshold;
}

@Override
public abstract boolean isOpen();
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedDeque;
Expand Down Expand Up @@ -781,7 +782,9 @@ private <T> ThreadedActionListener<T> threadedActionListener(String threadPoolNa
* @param runningTaskLimit limit
*/
public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) {
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
if (Objects.nonNull(mlTask) && mlTask.getFunctionName() != FunctionName.REMOTE) {
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
}
mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public void dispatchTask(
if (clusterService.localNode().getId().equals(node.getId())) {
log.debug("Execute ML predict request {} locally on node {}", request.getRequestID(), node.getId());
request.setDispatchTask(false);
executeTask(request, listener);
checkCBAndExecute(functionName, request, listener);
} else {
log.debug("Execute ML predict request {} remotely on node {}", request.getRequestID(), node.getId());
request.setDispatchTask(false);
Expand Down
10 changes: 8 additions & 2 deletions plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ protected void handleAsyncMLTaskComplete(MLTask mlTask) {
public void run(FunctionName functionName, Request request, TransportService transportService, ActionListener<Response> listener) {
if (!request.isDispatchTask()) {
log.debug("Run ML request {} locally", request.getRequestID());
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
executeTask(request, listener);
checkCBAndExecute(functionName, request, listener);
return;
}
dispatchTask(functionName, request, transportService, listener);
Expand Down Expand Up @@ -129,4 +128,11 @@ public void dispatchTask(
protected abstract TransportResponseHandler<Response> getResponseHandler(ActionListener<Response> listener);

protected abstract void executeTask(Request request, ActionListener<Response> listener);

protected void checkCBAndExecute(FunctionName functionName, Request request, ActionListener<Response> listener) {
if (functionName != FunctionName.REMOTE) {
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
}
executeTask(request, listener);
}
}
8 changes: 6 additions & 2 deletions plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;

Expand Down Expand Up @@ -60,7 +61,10 @@ public static void checkOpenCircuitBreaker(MLCircuitBreakerService mlCircuitBrea
ThresholdCircuitBreaker openCircuitBreaker = mlCircuitBreakerService.checkOpenCB();
if (openCircuitBreaker != null) {
mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).increment();
throw new MLLimitExceededException(openCircuitBreaker.getName() + " is open, please check your resources!");
throw new CircuitBreakingException(
openCircuitBreaker.getName() + " is open, please check your resources!",
CircuitBreaker.Durability.TRANSIENT
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
package org.opensearch.ml.breaker;

import static org.mockito.Mockito.when;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.monitor.jvm.JvmService;
import org.opensearch.monitor.jvm.JvmStats;

Expand All @@ -26,6 +30,9 @@ public class MemoryCircuitBreakerTests {
@Mock
JvmStats.Mem mem;

@Mock
ClusterService clusterService;

@Before
public void setup() {
MockitoAnnotations.openMocks(this);
Expand Down Expand Up @@ -60,4 +67,39 @@ public void testIsOpen_CustomThreshold_ExceedMemoryThreshold() {
when(mem.getHeapUsedPercent()).thenReturn((short) 95);
Assert.assertTrue(breaker.isOpen());
}

@Test
public void testIsOpen_UpdatedByClusterSettings_ExceedMemoryThreshold() {
ClusterSettings settingsService = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
settingsService.registerSetting(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD);
when(clusterService.getClusterSettings()).thenReturn(settingsService);

CircuitBreaker breaker = new MemoryCircuitBreaker(Settings.builder().build(), clusterService, jvmService);

when(mem.getHeapUsedPercent()).thenReturn((short) 90);
Assert.assertTrue(breaker.isOpen());

Settings.Builder newSettingsBuilder = Settings.builder();
newSettingsBuilder.put("plugins.ml_commons.jvm_heap_memory_threshold", 95);
settingsService.applySettings(newSettingsBuilder.build());
Assert.assertFalse(breaker.isOpen());
}

@Test
public void testIsOpen_DisableMemoryCB() {
ClusterSettings settingsService = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
settingsService.registerSetting(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD);
when(clusterService.getClusterSettings()).thenReturn(settingsService);

CircuitBreaker breaker = new MemoryCircuitBreaker(Settings.builder().build(), clusterService, jvmService);

when(mem.getHeapUsedPercent()).thenReturn((short) 90);
Assert.assertTrue(breaker.isOpen());

when(mem.getHeapUsedPercent()).thenReturn((short) 100);
Settings.Builder newSettingsBuilder = Settings.builder();
newSettingsBuilder.put("plugins.ml_commons.jvm_heap_memory_threshold", 100);
settingsService.applySettings(newSettingsBuilder.build());
Assert.assertFalse(breaker.isOpen());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
Expand All @@ -98,6 +100,7 @@
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
Expand Down Expand Up @@ -318,7 +321,7 @@ public void testRegisterMLModel_CircuitBreakerOpen() {
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
when(thresholdCircuitBreaker.getName()).thenReturn("Disk Circuit Breaker");
when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
expectedEx.expect(MLException.class);
expectedEx.expect(CircuitBreakingException.class);
expectedEx.expectMessage("Disk Circuit Breaker is open, please check your resources!");
modelManager.registerMLModel(registerModelInput, mlTask);
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
Expand Down Expand Up @@ -409,6 +412,55 @@ public void testRegisterMLModel_RegisterPreBuildModel() throws PrivilegedActionE
);
}

public void testRegisterMLRemoteModel() throws PrivilegedActionException {
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null);
when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService);
when(modelHelper.downloadPrebuiltModelMetaList(any(), any())).thenReturn(Collections.singletonList("demo"));
when(modelHelper.isModelAllowed(any(), any())).thenReturn(true);
MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);
MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();
mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true);
doAnswer(invocation -> {
ActionListener<IndexResponse> indexResponseActionListener = (ActionListener<IndexResponse>) invocation.getArguments()[1];
indexResponseActionListener.onResponse(indexResponse);
return null;
}).when(client).index(any(), any());
when(indexResponse.getId()).thenReturn("mockIndexId");
modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener);
assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE);
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
}

public void testRegisterMLRemoteModel_SkipMemoryCBOpen() {
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
when(mlCircuitBreakerService.checkOpenCB())
.thenThrow(
new CircuitBreakingException(
"Memory Circuit Breaker is open, please check your resources!",
CircuitBreaker.Durability.TRANSIENT
)
);
when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService);
when(modelHelper.isModelAllowed(any(), any())).thenReturn(true);

MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);
MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();
mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true);
doAnswer(invocation -> {
ActionListener<IndexResponse> indexResponseActionListener = (ActionListener<IndexResponse>) invocation.getArguments()[1];
indexResponseActionListener.onResponse(indexResponse);
return null;
}).when(client).index(any(), any());
when(indexResponse.getId()).thenReturn("mockIndexId");
modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener);

assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE);
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
}

@Ignore
public void testRegisterMLModel_DownloadModelFile() throws IOException {
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
Expand Down Expand Up @@ -963,4 +1015,16 @@ private MLRegisterModelInput mockPretrainedInput() {
.functionName(FunctionName.SPARSE_ENCODING)
.build();
}

private MLRegisterModelInput mockRemoteModelInput(boolean isHidden) {
return MLRegisterModelInput
.builder()
.modelName(modelName)
.version(version)
.modelGroupId("modelGroupId")
.modelFormat(modelFormat)
.functionName(FunctionName.REMOTE)
.deployModel(true)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase {
+ " \"content_type\": \"application/json\",\n"
+ " \"max_tokens\": 7,\n"
+ " \"temperature\": 0,\n"
+ " \"model\": \"text-davinci-003\"\n"
+ " \"model\": \"davinci-002\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"openAI_key\": \""
Expand Down Expand Up @@ -250,6 +250,7 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep
assertNotNull(responseMap);
}

@Ignore
public void testOpenAIEditsModel() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
Expand Down
Loading

0 comments on commit 69b0ca2

Please sign in to comment.