forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Check before delete (opensearch-project#3209) (opensearch-project#3431)
* add logic to detect agent before deleting * add logic to detect agent before deleting * add logic to detect pipelines before delete model * check pipeline before deleting * apply spotless * remove useless file * rename functions * fix failure test * add UT * apply spotless * renam * refactor to parallel check * concate error message * move logic after user access check * change agent model searcher map to set * rename and remove useless method * fix bug to fetch all pipelines * apply spotless * apply spotless * remove and add comment * rename and add more UTs * use correct key * simplify function * change to a better class * apply spotless * change compareAndSet to set * apply comment * change name and reformat logic * change name * remove useless line * change to a better method * change name * apply spotless * add java doc for function * add another interface * apply java spotless * change interface to with model * apply spot less * add settings * apply spot less * add test for cluster setting * apply spotless * recover useless change * change default value of cluster setting * rename setting and add comment * apply spot * remove logic for hidden model * reorder code * reorder code * reorder code * apply spot * add UT * add more UT * remove search for hidden agent * fix logic and apply spot * add exist for UT * change dsl to query index * change query logic * remove useless ut * rebert * apply spot * rechange code * apply spot * remove useless should * apply spot * fix final dsl logic and ut --------- (cherry picked from commit 570edaf) Signed-off-by: xinyual <[email protected]>
- Loading branch information
Showing
10 changed files
with
825 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
66 changes: 66 additions & 0 deletions
66
ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/AgentModelsSearcher.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
package org.opensearch.ml.engine.utils; | ||
|
||
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; | ||
import static org.opensearch.ml.common.CommonValue.TOOL_PARAMETERS_PREFIX; | ||
|
||
import java.util.HashSet; | ||
import java.util.Map; | ||
import java.util.Set; | ||
|
||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.index.query.BoolQueryBuilder; | ||
import org.opensearch.index.query.QueryBuilders; | ||
import org.opensearch.ml.common.agent.MLAgent; | ||
import org.opensearch.ml.common.spi.tools.Tool; | ||
import org.opensearch.ml.common.spi.tools.WithModelTool; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
|
||
public class AgentModelsSearcher { | ||
private final Set<String> relatedModelIdSet; | ||
|
||
public AgentModelsSearcher(Map<String, Tool.Factory> toolFactories) { | ||
relatedModelIdSet = new HashSet<>(); | ||
for (Map.Entry<String, Tool.Factory> entry : toolFactories.entrySet()) { | ||
Tool.Factory toolFactory = entry.getValue(); | ||
if (toolFactory instanceof WithModelTool.Factory) { | ||
WithModelTool.Factory withModelTool = (WithModelTool.Factory) toolFactory; | ||
relatedModelIdSet.addAll(withModelTool.getAllModelKeys()); | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Construct a should query to search all agent which containing candidate model Id | ||
@param candidateModelId the candidate model Id | ||
@return a should search request towards agent index. | ||
*/ | ||
public SearchRequest constructQueryRequestToSearchModelIdInsideAgent(String candidateModelId) { | ||
SearchRequest searchRequest = new SearchRequest(ML_AGENT_INDEX); | ||
// Two conditions here | ||
// 1. {[(exists hidden field) and (hidden field = false)] or (not exist hidden field)} and | ||
// 2. Any model field contains candidate ID | ||
BoolQueryBuilder searchAgentQuery = QueryBuilders.boolQuery(); | ||
|
||
BoolQueryBuilder hiddenFieldQuery = QueryBuilders.boolQuery(); | ||
// not exist hidden | ||
hiddenFieldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD))); | ||
// exist but equal to false | ||
BoolQueryBuilder existHiddenFieldQuery = QueryBuilders.boolQuery(); | ||
existHiddenFieldQuery.must(QueryBuilders.termsQuery(MLAgent.IS_HIDDEN_FIELD, false)); | ||
existHiddenFieldQuery.must(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD)); | ||
hiddenFieldQuery.should(existHiddenFieldQuery); | ||
|
||
// | ||
BoolQueryBuilder modelIdQuery = QueryBuilders.boolQuery(); | ||
for (String keyField : relatedModelIdSet) { | ||
modelIdQuery.should(QueryBuilders.termsQuery(TOOL_PARAMETERS_PREFIX + keyField, candidateModelId)); | ||
} | ||
|
||
searchAgentQuery.must(hiddenFieldQuery); | ||
searchAgentQuery.must(modelIdQuery); | ||
searchRequest.source(new SearchSourceBuilder().query(searchAgentQuery)); | ||
return searchRequest; | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
96 changes: 96 additions & 0 deletions
96
ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/AgentModelSearcherTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.engine.utils; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
import static org.junit.Assert.assertTrue; | ||
import static org.mockito.Mockito.mock; | ||
import static org.mockito.Mockito.when; | ||
|
||
import java.util.Arrays; | ||
import java.util.Collections; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import org.junit.Test; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.index.query.BoolQueryBuilder; | ||
import org.opensearch.index.query.ExistsQueryBuilder; | ||
import org.opensearch.index.query.QueryBuilder; | ||
import org.opensearch.index.query.TermsQueryBuilder; | ||
import org.opensearch.ml.common.agent.MLAgent; | ||
import org.opensearch.ml.common.spi.tools.Tool; | ||
import org.opensearch.ml.common.spi.tools.WithModelTool; | ||
|
||
public class AgentModelSearcherTests { | ||
|
||
@Test | ||
public void testConstructor_CollectsModelIds() { | ||
// Arrange | ||
WithModelTool.Factory withModelToolFactory1 = mock(WithModelTool.Factory.class); | ||
when(withModelToolFactory1.getAllModelKeys()).thenReturn(Arrays.asList("modelKey1", "modelKey2")); | ||
|
||
WithModelTool.Factory withModelToolFactory2 = mock(WithModelTool.Factory.class); | ||
when(withModelToolFactory2.getAllModelKeys()).thenReturn(Collections.singletonList("anotherModelKey")); | ||
|
||
// This tool factory does not implement WithModelTool.Factory | ||
Tool.Factory regularToolFactory = mock(Tool.Factory.class); | ||
|
||
Map<String, Tool.Factory> toolFactories = new HashMap<>(); | ||
toolFactories.put("withModelTool1", withModelToolFactory1); | ||
toolFactories.put("withModelTool2", withModelToolFactory2); | ||
toolFactories.put("regularTool", regularToolFactory); | ||
|
||
// Act | ||
AgentModelsSearcher searcher = new AgentModelsSearcher(toolFactories); | ||
|
||
// (Optional) We can't directly access relatedModelIdSet, | ||
// but we can test the behavior indirectly using the search call: | ||
SearchRequest request = searcher.constructQueryRequestToSearchModelIdInsideAgent("candidateId"); | ||
|
||
// Assert | ||
// Verify the searchRequest uses all keys from the WithModelTool factories | ||
BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) request.source().query(); | ||
// We expect modelKey1, modelKey2, anotherModelKey => total 3 "should" clauses | ||
assertEquals(2, boolQueryBuilder.must().size()); | ||
for (QueryBuilder query : boolQueryBuilder.must()) { | ||
BoolQueryBuilder subBoolQueryBuilder = (BoolQueryBuilder) query; | ||
assertTrue(subBoolQueryBuilder.should().size() == 2 || subBoolQueryBuilder.should().size() == 3); | ||
if (subBoolQueryBuilder.should().size() == 3) { | ||
boolQueryBuilder.should().forEach(subQuery -> { | ||
assertTrue(subQuery instanceof TermsQueryBuilder); | ||
TermsQueryBuilder termsQuery = (TermsQueryBuilder) subQuery; | ||
// Each TermsQueryBuilder should contain candidateModelId | ||
assertTrue(termsQuery.values().contains("candidateId")); | ||
}); | ||
} else { | ||
boolQueryBuilder.should().forEach(subQuery -> { | ||
assertTrue(subQuery instanceof BoolQueryBuilder); | ||
BoolQueryBuilder boolQuery = (BoolQueryBuilder) subQuery; | ||
assertTrue(boolQuery.must().size() == 2 || boolQuery.mustNot().size() == 1); | ||
if (boolQuery.must().size() == 2) { | ||
boolQuery.must().forEach(existSubQuery -> { | ||
assertTrue(existSubQuery instanceof ExistsQueryBuilder || existSubQuery instanceof TermsQueryBuilder); | ||
if (existSubQuery instanceof TermsQueryBuilder) { | ||
TermsQueryBuilder termsQuery = (TermsQueryBuilder) existSubQuery; | ||
assertTrue(termsQuery.fieldName().equals(MLAgent.IS_HIDDEN_FIELD)); | ||
assertTrue(termsQuery.values().contains(false)); | ||
} else { | ||
ExistsQueryBuilder existsQuery = (ExistsQueryBuilder) existSubQuery; | ||
assertTrue(existsQuery.fieldName().equals(MLAgent.IS_HIDDEN_FIELD)); | ||
} | ||
}); | ||
} else { | ||
QueryBuilder mustNotQuery = boolQuery.mustNot().get(0); | ||
assertTrue(mustNotQuery instanceof ExistsQueryBuilder); | ||
assertEquals(MLAgent.IS_HIDDEN_FIELD, ((ExistsQueryBuilder) mustNotQuery).fieldName()); | ||
} | ||
}); | ||
} | ||
} | ||
|
||
} | ||
} |
Oops, something went wrong.