From 5d12e5411a69be6c236d8fa1d4790716feb15456 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Wed, 10 Jan 2024 12:52:11 -0800 Subject: [PATCH] fix race confition in index initialization and RestUpdateConnector UT Signed-off-by: Xun Zhang --- .../ml/engine/indices/MLIndicesHandler.java | 13 ++++++++++--- .../engine/indices/MLIndicesHandlerTest.java | 18 ++++++++++++++++++ .../rest/RestMLUpdateConnectorActionTests.java | 16 ++++++++-------- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java index 15b6a926d6..46135e5496 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java @@ -13,6 +13,8 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import org.opensearch.OpenSearchWrapperException; +import org.opensearch.ResourceAlreadyExistsException; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.admin.indices.mapping.put.PutMappingRequest; @@ -85,7 +87,6 @@ public void initMLAgentIndex(ActionListener listener) { public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) { String indexName = index.getIndexName(); String mapping = index.getMapping(); - try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); if (!clusterService.state().metadata().hasIndex(indexName)) { @@ -97,8 +98,14 @@ public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) internalListener.onResponse(false); } }, e -> { - log.error("Failed to create index " + indexName, e); - internalListener.onFailure(e); + if (e instanceof ResourceAlreadyExistsException + || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { + log.info("Skip creating the Index:{} that is already created by another parallel request", indexName); + internalListener.onResponse(true); + } else { + log.error("Failed to create index " + indexName, e); + internalListener.onFailure(e); + } }); CreateIndexRequest request = new CreateIndexRequest(indexName).mapping(mapping).settings(INDEX_SETTINGS); client.admin().indices().create(request, actionListener); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java index 5ca7e2d31a..be2a5669ca 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java @@ -22,6 +22,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.ResourceAlreadyExistsException; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.support.master.AcknowledgedResponse; @@ -191,4 +192,21 @@ public void initMLAgentIndexNoIndex() { verify(listener).onResponse(argumentCaptor.capture()); assertEquals(true, argumentCaptor.getValue()); } + + @Test + public void initMLConnectorIndex_ResourceAlreadyExistsException_RaceCondition() { + ActionListener listener = mock(ActionListener.class); + when(metadata.hasIndex(anyString())).thenReturn(false); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new ResourceAlreadyExistsException("index [.plugins-ml-connector] already exists")); + return null; + }).when(indicesAdminClient).create(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + indicesHandler.initMLConnectorIndex(listener); + + verify(indicesAdminClient).create(isA(CreateIndexRequest.class), any()); + verify(listener).onResponse(argumentCaptor.capture()); + assertEquals(true, argumentCaptor.getValue()); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java index c3a21bde1f..81b52818b4 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java @@ -142,14 +142,14 @@ public void testPrepareRequestFeatureDisabled() throws Exception { } private RestRequest getRestRequest() { - RestRequest.Method method = RestRequest.Method.POST; + RestRequest.Method method = RestRequest.Method.PUT; final Map updateContent = Map.of("version", "2", "description", "This is test description"); String requestContent = new Gson().toJson(updateContent).toString(); Map params = new HashMap<>(); params.put("connector_id", "test_connectorId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/connectors/_update/{connector_id}") + .withPath("/_plugins/_ml/connectors/{connector_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); @@ -157,13 +157,13 @@ private RestRequest getRestRequest() { } private RestRequest getRestRequestWithNullValue() { - RestRequest.Method method = RestRequest.Method.POST; + RestRequest.Method method = RestRequest.Method.PUT; String requestContent = "{\"version\":\"2\",\"description\":null}"; Map params = new HashMap<>(); params.put("connector_id", "test_connectorId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/connectors/_update/{connector_id}") + .withPath("/_plugins/_ml/connectors/{connector_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); @@ -171,12 +171,12 @@ private RestRequest getRestRequestWithNullValue() { } private RestRequest getRestRequestWithEmptyContent() { - RestRequest.Method method = RestRequest.Method.POST; + RestRequest.Method method = RestRequest.Method.PUT; Map params = new HashMap<>(); params.put("connector_id", "test_connectorId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/connectors/_update/{connector_id}") + .withPath("/_plugins/_ml/connectors/{connector_id}") .withParams(params) .withContent(new BytesArray(""), XContentType.JSON) .build(); @@ -184,13 +184,13 @@ private RestRequest getRestRequestWithEmptyContent() { } private RestRequest getRestRequestWithNullConnectorId() { - RestRequest.Method method = RestRequest.Method.POST; + RestRequest.Method method = RestRequest.Method.PUT; final Map updateContent = Map.of("version", "2", "description", "This is test description"); String requestContent = new Gson().toJson(updateContent).toString(); Map params = new HashMap<>(); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/connectors/_update/{connector_id}") + .withPath("/_plugins/_ml/connectors/{connector_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build();