diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index cb779abf31..8b7cf10ef9 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -98,8 +98,8 @@ public void execute(Input input, ActionListener listener) { AgentMLInput agentMLInput = (AgentMLInput) input; String agentId = agentMLInput.getAgentId(); RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) agentMLInput.getInputDataset(); - if (inputDataSet.getParameters() == null) { - throw new IllegalArgumentException("wrong input"); + if (inputDataSet == null || inputDataSet.getParameters() == null) { + throw new IllegalArgumentException("Agent input data can not be empty."); } List outputs = new ArrayList<>(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java index f2e692f986..ce89464e37 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java @@ -183,6 +183,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws mlAgentExecutor.execute(input, agentActionListener); } + @Test(expected = IllegalArgumentException.class) + public void test_NonInputData_ThrowsException() { + AgentMLInput agentMLInput = new AgentMLInput("test", FunctionName.AGENT, null); + mlAgentExecutor.execute(agentMLInput, agentActionListener); + } + + @Test(expected = IllegalArgumentException.class) + public void test_NonInputParas_ThrowsException() { + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(null).build(); + AgentMLInput agentMLInput = new AgentMLInput("test", FunctionName.AGENT, inputDataSet); + mlAgentExecutor.execute(agentMLInput, agentActionListener); + } + @Test public void test_HappyCase_ReturnsResult() { ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build();