diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java index 25a9cdae..2d54d649 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java @@ -74,7 +74,7 @@ public PlainActionFuture execute( ); String agentId = (String) inputs.get(AGENT_ID); - mlClient.deleteAgent(agentId, new ActionListener<>() { + mlClient.deleteAgent(agentId, tenantId, new ActionListener<>() { @Override public void onResponse(DeleteResponse deleteResponse) { deleteAgentFuture.onResponse( diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 2d331d7e..1303c67b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -86,7 +86,7 @@ public PlainActionFuture execute( String modelId = (String) inputs.get(MODEL_ID); - mlClient.deploy(modelId, new ActionListener<>() { + mlClient.deploy(modelId, tenantId, new ActionListener<>() { @Override public void onResponse(MLDeployModelResponse mlDeployModelResponse) { logger.info("Model deployment state {}", mlDeployModelResponse.getStatus()); diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index 0f5c2b50..00bf856d 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -182,7 +182,8 @@ public void onFailure(Exception ex) { .parameters(parametersMap) .createdTime(createdTime) .lastUpdateTime(lastUpdateTime) - .appType(appType); + .appType(appType) + .tenantId(tenantId); MLAgent mlAgent = builder.build(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java index ab2e0a36..a4764576 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java @@ -78,7 +78,7 @@ public PlainActionFuture execute( String modelId = inputs.get(MODEL_ID).toString(); - mlClient.undeploy(new String[] { modelId }, null, new ActionListener<>() { + mlClient.undeploy(new String[] { modelId }, null, tenantId, new ActionListener<>() { @Override public void onResponse(MLUndeployModelsResponse mlUndeployModelsResponse) { List failures = mlUndeployModelsResponse.getResponse().failures(); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java index c2dbd04c..31f63d8d 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java @@ -29,6 +29,7 @@ import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; @@ -54,12 +55,12 @@ public void testDeleteAgent() throws IOException, ExecutionException, Interrupte doAnswer(invocation -> { String agentIdArg = invocation.getArgument(0); - ActionListener actionListener = invocation.getArgument(1); + ActionListener actionListener = invocation.getArgument(2); ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); DeleteResponse output = new DeleteResponse(shardId, agentIdArg, 1, 1, 1, true); actionListener.onResponse(output); return null; - }).when(machineLearningNodeClient).deleteAgent(any(String.class), any()); + }).when(machineLearningNodeClient).deleteAgent(any(String.class), nullable(String.class), any()); PlainActionFuture future = deleteAgentStep.execute( inputData.getNodeId(), @@ -69,7 +70,7 @@ public void testDeleteAgent() throws IOException, ExecutionException, Interrupte Collections.emptyMap(), null ); - verify(machineLearningNodeClient).deleteAgent(any(String.class), any()); + verify(machineLearningNodeClient).deleteAgent(any(String.class), nullable(String.class), any()); assertTrue(future.isDone()); assertEquals(agentId, future.get().getContent().get(AGENT_ID)); @@ -81,10 +82,10 @@ public void testDeleteAgentNotFound() throws IOException, ExecutionException, In DeleteAgentStep deleteAgentStep = new DeleteAgentStep(machineLearningNodeClient); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); + ActionListener actionListener = invocation.getArgument(2); actionListener.onFailure(new OpenSearchStatusException("No agent found with that id", RestStatus.NOT_FOUND)); return null; - }).when(machineLearningNodeClient).deleteAgent(any(String.class), any()); + }).when(machineLearningNodeClient).deleteAgent(any(String.class), nullable(String.class), any()); PlainActionFuture future = deleteAgentStep.execute( inputData.getNodeId(), @@ -94,7 +95,7 @@ public void testDeleteAgentNotFound() throws IOException, ExecutionException, In Collections.emptyMap(), null ); - verify(machineLearningNodeClient).deleteAgent(any(String.class), any()); + verify(machineLearningNodeClient).deleteAgent(any(String.class), nullable(String.class), any()); assertTrue(future.isDone()); assertEquals(agentId, future.get().getContent().get(AGENT_ID)); @@ -122,10 +123,10 @@ public void testDeleteAgentFailure() throws IOException { DeleteAgentStep deleteAgentStep = new DeleteAgentStep(machineLearningNodeClient); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); + ActionListener actionListener = invocation.getArgument(2); actionListener.onFailure(new FlowFrameworkException("Failed to delete agent", RestStatus.INTERNAL_SERVER_ERROR)); return null; - }).when(machineLearningNodeClient).deleteAgent(any(String.class), any()); + }).when(machineLearningNodeClient).deleteAgent(any(String.class), nullable(String.class), any()); PlainActionFuture future = deleteAgentStep.execute( inputData.getNodeId(), @@ -136,7 +137,7 @@ public void testDeleteAgentFailure() throws IOException { null ); - verify(machineLearningNodeClient).deleteAgent(any(String.class), any()); + verify(machineLearningNodeClient).deleteAgent(any(String.class), nullable(String.class), any()); assertTrue(future.isDone()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index 822d95ce..4cb3f8fa 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -47,6 +47,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -118,11 +119,11 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); + ActionListener actionListener = invocation.getArgument(2); MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status); actionListener.onResponse(output); return null; - }).when(machineLearningNodeClient).deploy(eq(modelId), actionListenerCaptor.capture()); + }).when(machineLearningNodeClient).deploy(eq(modelId), nullable(String.class), actionListenerCaptor.capture()); // Stub getTask for success case doAnswer(invocation -> { @@ -150,7 +151,7 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I future.actionGet(); - verify(machineLearningNodeClient, times(1)).deploy(any(String.class), any()); + verify(machineLearningNodeClient, times(1)).deploy(any(String.class), nullable(String.class), any()); verify(machineLearningNodeClient, times(1)).getTask(any(), any()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); @@ -162,10 +163,10 @@ public void testDeployModelFailure() { ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); + ActionListener actionListener = invocation.getArgument(2); actionListener.onFailure(new FlowFrameworkException("Failed to deploy model", RestStatus.INTERNAL_SERVER_ERROR)); return null; - }).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + }).when(machineLearningNodeClient).deploy(eq("modelId"), nullable(String.class), actionListenerCaptor.capture()); PlainActionFuture future = deployModel.execute( inputData.getNodeId(), @@ -176,7 +177,7 @@ public void testDeployModelFailure() { null ); - verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + verify(machineLearningNodeClient).deploy(eq("modelId"), nullable(String.class), actionListenerCaptor.capture()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); @@ -194,11 +195,11 @@ public void testDeployModelTaskFailure() throws IOException, InterruptedExceptio ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); + ActionListener actionListener = invocation.getArgument(2); MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status); actionListener.onResponse(output); return null; - }).when(machineLearningNodeClient).deploy(eq(modelId), actionListenerCaptor.capture()); + }).when(machineLearningNodeClient).deploy(eq(modelId), nullable(String.class), actionListenerCaptor.capture()); // Stub getTask for success case doAnswer(invocation -> { diff --git a/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java index 203a9bbb..d8d96c48 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java @@ -32,6 +32,7 @@ import static org.opensearch.flowframework.common.CommonValue.SUCCESS; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; @@ -57,7 +58,7 @@ public void testUndeployModel() throws IOException, ExecutionException, Interrup doAnswer(invocation -> { ClusterName clusterName = new ClusterName("clusterName"); - ActionListener actionListener = invocation.getArgument(2); + ActionListener actionListener = invocation.getArgument(3); MLUndeployModelNodesResponse mlUndeployModelNodesResponse = new MLUndeployModelNodesResponse( clusterName, Collections.emptyList(), @@ -66,7 +67,7 @@ public void testUndeployModel() throws IOException, ExecutionException, Interrup MLUndeployModelsResponse output = new MLUndeployModelsResponse(mlUndeployModelNodesResponse); actionListener.onResponse(output); return null; - }).when(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); + }).when(machineLearningNodeClient).undeploy(any(String[].class), any(), nullable(String.class), any()); PlainActionFuture future = UndeployModelStep.execute( inputData.getNodeId(), @@ -76,7 +77,7 @@ public void testUndeployModel() throws IOException, ExecutionException, Interrup Collections.emptyMap(), null ); - verify(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); + verify(machineLearningNodeClient).undeploy(any(String[].class), any(), nullable(String.class), any()); assertTrue(future.isDone()); assertTrue((boolean) future.get().getContent().get(SUCCESS)); @@ -105,7 +106,7 @@ public void testUndeployModelFailure() throws IOException { doAnswer(invocation -> { ClusterName clusterName = new ClusterName("clusterName"); - ActionListener actionListener = invocation.getArgument(2); + ActionListener actionListener = invocation.getArgument(3); MLUndeployModelNodesResponse mlUndeployModelNodesResponse = new MLUndeployModelNodesResponse( clusterName, Collections.emptyList(), @@ -116,7 +117,7 @@ public void testUndeployModelFailure() throws IOException { actionListener.onFailure(new FlowFrameworkException("Failed to undeploy model", RestStatus.INTERNAL_SERVER_ERROR)); return null; - }).when(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); + }).when(machineLearningNodeClient).undeploy(any(String[].class), any(), nullable(String.class), any()); PlainActionFuture future = UndeployModelStep.execute( inputData.getNodeId(), @@ -127,7 +128,7 @@ public void testUndeployModelFailure() throws IOException { null ); - verify(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); + verify(machineLearningNodeClient).undeploy(any(String[].class), any(), nullable(String.class), any()); assertTrue(future.isDone()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());