Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* add logic to detect agent before deleting



* add logic to detect agent before deleting



* add logic to detect pipelines before delete model



* check pipeline before deleting



* apply spotless



* remove useless file



* rename functions



* fix failure test



* add UT



* apply spotless



* renam



* refactor to parallel check



* concate error message



* move logic after user access check



* change agent model searcher map to set



* rename and remove useless method



* fix bug to fetch all pipelines



* apply spotless



* apply spotless



* remove and add comment



* rename and add more UTs



* use correct key



* simplify function



* change to a better class



* apply spotless



* change compareAndSet to set



* apply comment



* change name and reformat logic



* change name



* remove useless line



* change to a better method



* change name



* apply spotless



* add java doc for function



* add another interface



* apply java spotless



* change interface to with model



* apply spot less



* add settings



* apply spot less



* add test for cluster setting



* apply spotless



* recover useless change



* change default value of cluster setting



* rename setting and add comment



* apply spot



* remove logic for hidden model



* reorder code



* reorder code



* reorder code



* apply spot



* add UT



* add more UT



* remove search for hidden agent



* fix logic and apply spot



* add exist for UT



* change dsl to query index



* change query logic



* remove useless ut



* rebert



* apply spot



* rechange code



* apply spot



* remove useless should



* apply spot



* fix final dsl logic and ut



---------


(cherry picked from commit 570edaf)

Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual authored Jan 24, 2025
1 parent ed36d6e commit 880b674
Show file tree
Hide file tree
Showing 10 changed files with 825 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public class CommonValue {
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words";
public static final Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
public static final String TOOL_PARAMETERS_PREFIX = "tools.parameters.";

// Index mapping paths
public static final String ML_MODEL_GROUP_INDEX_MAPPING_PATH = "index-mappings/ml-model-group.json";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
Expand All @@ -33,7 +33,7 @@
*/
@Log4j2
@ToolAnnotation(MLModelTool.TYPE)
public class MLModelTool implements Tool {
public class MLModelTool implements WithModelTool {
public static final String TYPE = "MLModelTool";
public static final String RESPONSE_FIELD = "response_field";
public static final String MODEL_ID_FIELD = "model_id";
Expand Down Expand Up @@ -127,7 +127,7 @@ public boolean validate(Map<String, String> parameters) {
return true;
}

public static class Factory implements Tool.Factory<MLModelTool> {
public static class Factory implements WithModelTool.Factory<MLModelTool> {
private Client client;

private static Factory INSTANCE;
Expand Down Expand Up @@ -172,5 +172,10 @@ public String getDefaultType() {
public String getDefaultVersion() {
return null;
}

@Override
public List<String> getAllModelKeys() {
return List.of(MODEL_ID_FIELD);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package org.opensearch.ml.engine.utils;

import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
import static org.opensearch.ml.common.CommonValue.TOOL_PARAMETERS_PREFIX;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.search.builder.SearchSourceBuilder;

public class AgentModelsSearcher {
private final Set<String> relatedModelIdSet;

public AgentModelsSearcher(Map<String, Tool.Factory> toolFactories) {
relatedModelIdSet = new HashSet<>();
for (Map.Entry<String, Tool.Factory> entry : toolFactories.entrySet()) {
Tool.Factory toolFactory = entry.getValue();
if (toolFactory instanceof WithModelTool.Factory) {
WithModelTool.Factory withModelTool = (WithModelTool.Factory) toolFactory;
relatedModelIdSet.addAll(withModelTool.getAllModelKeys());
}
}
}

/**
* Construct a should query to search all agent which containing candidate model Id
@param candidateModelId the candidate model Id
@return a should search request towards agent index.
*/
public SearchRequest constructQueryRequestToSearchModelIdInsideAgent(String candidateModelId) {
SearchRequest searchRequest = new SearchRequest(ML_AGENT_INDEX);
// Two conditions here
// 1. {[(exists hidden field) and (hidden field = false)] or (not exist hidden field)} and
// 2. Any model field contains candidate ID
BoolQueryBuilder searchAgentQuery = QueryBuilders.boolQuery();

BoolQueryBuilder hiddenFieldQuery = QueryBuilders.boolQuery();
// not exist hidden
hiddenFieldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD)));
// exist but equal to false
BoolQueryBuilder existHiddenFieldQuery = QueryBuilders.boolQuery();
existHiddenFieldQuery.must(QueryBuilders.termsQuery(MLAgent.IS_HIDDEN_FIELD, false));
existHiddenFieldQuery.must(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD));
hiddenFieldQuery.should(existHiddenFieldQuery);

//
BoolQueryBuilder modelIdQuery = QueryBuilders.boolQuery();
for (String keyField : relatedModelIdSet) {
modelIdQuery.should(QueryBuilders.termsQuery(TOOL_PARAMETERS_PREFIX + keyField, candidateModelId));
}

searchAgentQuery.must(hiddenFieldQuery);
searchAgentQuery.must(modelIdQuery);
searchRequest.source(new SearchSourceBuilder().query(searchAgentQuery));
return searchRequest;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;
import static org.opensearch.ml.engine.tools.MLModelTool.DEFAULT_DESCRIPTION;
import static org.opensearch.ml.engine.tools.MLModelTool.MODEL_ID_FIELD;

import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -218,5 +219,6 @@ public void testTool() {
assertTrue(tool.validate(otherParams));
assertFalse(tool.validate(emptyParams));
assertEquals(DEFAULT_DESCRIPTION, tool.getDescription());
assertEquals(List.of(MODEL_ID_FIELD), MLModelTool.Factory.getInstance().getAllModelKeys());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.utils;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import org.junit.Test;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.ExistsQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.WithModelTool;

public class AgentModelSearcherTests {

@Test
public void testConstructor_CollectsModelIds() {
// Arrange
WithModelTool.Factory withModelToolFactory1 = mock(WithModelTool.Factory.class);
when(withModelToolFactory1.getAllModelKeys()).thenReturn(Arrays.asList("modelKey1", "modelKey2"));

WithModelTool.Factory withModelToolFactory2 = mock(WithModelTool.Factory.class);
when(withModelToolFactory2.getAllModelKeys()).thenReturn(Collections.singletonList("anotherModelKey"));

// This tool factory does not implement WithModelTool.Factory
Tool.Factory regularToolFactory = mock(Tool.Factory.class);

Map<String, Tool.Factory> toolFactories = new HashMap<>();
toolFactories.put("withModelTool1", withModelToolFactory1);
toolFactories.put("withModelTool2", withModelToolFactory2);
toolFactories.put("regularTool", regularToolFactory);

// Act
AgentModelsSearcher searcher = new AgentModelsSearcher(toolFactories);

// (Optional) We can't directly access relatedModelIdSet,
// but we can test the behavior indirectly using the search call:
SearchRequest request = searcher.constructQueryRequestToSearchModelIdInsideAgent("candidateId");

// Assert
// Verify the searchRequest uses all keys from the WithModelTool factories
BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) request.source().query();
// We expect modelKey1, modelKey2, anotherModelKey => total 3 "should" clauses
assertEquals(2, boolQueryBuilder.must().size());
for (QueryBuilder query : boolQueryBuilder.must()) {
BoolQueryBuilder subBoolQueryBuilder = (BoolQueryBuilder) query;
assertTrue(subBoolQueryBuilder.should().size() == 2 || subBoolQueryBuilder.should().size() == 3);
if (subBoolQueryBuilder.should().size() == 3) {
boolQueryBuilder.should().forEach(subQuery -> {
assertTrue(subQuery instanceof TermsQueryBuilder);
TermsQueryBuilder termsQuery = (TermsQueryBuilder) subQuery;
// Each TermsQueryBuilder should contain candidateModelId
assertTrue(termsQuery.values().contains("candidateId"));
});
} else {
boolQueryBuilder.should().forEach(subQuery -> {
assertTrue(subQuery instanceof BoolQueryBuilder);
BoolQueryBuilder boolQuery = (BoolQueryBuilder) subQuery;
assertTrue(boolQuery.must().size() == 2 || boolQuery.mustNot().size() == 1);
if (boolQuery.must().size() == 2) {
boolQuery.must().forEach(existSubQuery -> {
assertTrue(existSubQuery instanceof ExistsQueryBuilder || existSubQuery instanceof TermsQueryBuilder);
if (existSubQuery instanceof TermsQueryBuilder) {
TermsQueryBuilder termsQuery = (TermsQueryBuilder) existSubQuery;
assertTrue(termsQuery.fieldName().equals(MLAgent.IS_HIDDEN_FIELD));
assertTrue(termsQuery.values().contains(false));
} else {
ExistsQueryBuilder existsQuery = (ExistsQueryBuilder) existSubQuery;
assertTrue(existsQuery.fieldName().equals(MLAgent.IS_HIDDEN_FIELD));
}
});
} else {
QueryBuilder mustNotQuery = boolQuery.mustNot().get(0);
assertTrue(mustNotQuery instanceof ExistsQueryBuilder);
assertEquals(MLAgent.IS_HIDDEN_FIELD, ((ExistsQueryBuilder) mustNotQuery).fieldName());
}
});
}
}

}
}
Loading

0 comments on commit 880b674

Please sign in to comment.