forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move visualization tool to ml-commons
Signed-off-by: Hailong Cui <[email protected]>
- Loading branch information
1 parent
ea7fefa
commit e25707d
Showing
12 changed files
with
889 additions
and
1 deletion.
There are no files selected for viewing
178 changes: 178 additions & 0 deletions
178
ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VisualizationsTool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.engine.tools; | ||
|
||
import java.util.Arrays; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
import java.util.Optional; | ||
|
||
import org.opensearch.ExceptionsHelper; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.client.Requests; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.common.Strings; | ||
import org.opensearch.index.IndexNotFoundException; | ||
import org.opensearch.index.query.BoolQueryBuilder; | ||
import org.opensearch.index.query.QueryBuilders; | ||
import org.opensearch.ml.common.spi.tools.Tool; | ||
import org.opensearch.ml.common.spi.tools.ToolAnnotation; | ||
import org.opensearch.search.SearchHits; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
@ToolAnnotation(VisualizationsTool.TYPE) | ||
public class VisualizationsTool implements Tool { | ||
public static final String NAME = "FindVisualizations"; | ||
public static final String TYPE = "VisualizationTool"; | ||
public static final String VERSION = "v1.0"; | ||
|
||
public static final String SAVED_OBJECT_TYPE = "visualization"; | ||
|
||
/** | ||
* default number of visualizations returned | ||
*/ | ||
private static final int DEFAULT_SIZE = 3; | ||
private static final String DEFAULT_DESCRIPTION = | ||
"Use this tool to find user created visualizations. This tool takes the visualization name as input and returns matching visualizations"; | ||
@Setter | ||
@Getter | ||
private String description = DEFAULT_DESCRIPTION; | ||
|
||
@Getter | ||
@Setter | ||
private String name = NAME; | ||
@Getter | ||
@Setter | ||
private String type = TYPE; | ||
@Getter | ||
private final String version = VERSION; | ||
private final Client client; | ||
@Getter | ||
private final String index; | ||
@Getter | ||
private final int size; | ||
|
||
@Builder | ||
public VisualizationsTool(Client client, String index, int size) { | ||
this.client = client; | ||
this.index = index; | ||
this.size = size; | ||
} | ||
|
||
@Override | ||
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) { | ||
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); | ||
boolQueryBuilder.must().add(QueryBuilders.termQuery("type", SAVED_OBJECT_TYPE)); | ||
boolQueryBuilder.must().add(QueryBuilders.matchQuery(SAVED_OBJECT_TYPE + ".title", parameters.get("input"))); | ||
|
||
SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.searchSource().query(boolQueryBuilder); | ||
searchSourceBuilder.from(0).size(size); | ||
SearchRequest searchRequest = Requests.searchRequest(index).source(searchSourceBuilder); | ||
|
||
client.search(searchRequest, new ActionListener<>() { | ||
@Override | ||
public void onResponse(SearchResponse searchResponse) { | ||
SearchHits hits = searchResponse.getHits(); | ||
StringBuilder visBuilder = new StringBuilder(); | ||
visBuilder.append("Title,Id\n"); | ||
if (hits.getTotalHits().value > 0) { | ||
Arrays.stream(hits.getHits()).forEach(h -> { | ||
String id = trimIdPrefix(h.getId()); | ||
Map<String, String> visMap = (Map<String, String>) h.getSourceAsMap().get(SAVED_OBJECT_TYPE); | ||
String title = visMap.get("title"); | ||
visBuilder.append(String.format(Locale.ROOT, "%s,%s\n", title, id)); | ||
}); | ||
|
||
listener.onResponse((T) visBuilder.toString()); | ||
} else { | ||
listener.onResponse((T) "No Visualization found"); | ||
} | ||
} | ||
|
||
@Override | ||
public void onFailure(Exception e) { | ||
if (ExceptionsHelper.unwrapCause(e) instanceof IndexNotFoundException) { | ||
listener.onResponse((T) "No Visualization found"); | ||
} else { | ||
listener.onFailure(e); | ||
} | ||
} | ||
}); | ||
} | ||
|
||
String trimIdPrefix(String id) { | ||
id = Optional.ofNullable(id).orElse(""); | ||
if (id.startsWith(SAVED_OBJECT_TYPE)) { | ||
String prefix = String.format(Locale.ROOT, "%s:", SAVED_OBJECT_TYPE); | ||
return id.substring(prefix.length()); | ||
} | ||
return id; | ||
} | ||
|
||
@Override | ||
public boolean validate(Map<String, String> parameters) { | ||
return parameters.containsKey("input") && !Strings.isNullOrEmpty(parameters.get("input")); | ||
} | ||
|
||
public static class Factory implements Tool.Factory<VisualizationsTool> { | ||
private Client client; | ||
|
||
private static Factory INSTANCE; | ||
|
||
public static Factory getInstance() { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
synchronized (VisualizationsTool.class) { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
INSTANCE = new Factory(); | ||
return INSTANCE; | ||
} | ||
} | ||
|
||
public void init(Client client) { | ||
this.client = client; | ||
} | ||
|
||
@Override | ||
public VisualizationsTool create(Map<String, Object> params) { | ||
String index = params.get("index") == null ? ".kibana" : (String) params.get("index"); | ||
String sizeStr = params.get("size") == null ? "3" : (String) params.get("size"); | ||
int size; | ||
try { | ||
size = Integer.parseInt(sizeStr); | ||
} catch (NumberFormatException ignored) { | ||
size = DEFAULT_SIZE; | ||
} | ||
return VisualizationsTool.builder().client(client).index(index).size(size).build(); | ||
} | ||
|
||
@Override | ||
public String getDefaultDescription() { | ||
return DEFAULT_DESCRIPTION; | ||
} | ||
|
||
@Override | ||
public String getDefaultType() { | ||
return TYPE; | ||
} | ||
|
||
@Override | ||
public String getDefaultVersion() { | ||
return null; | ||
} | ||
} | ||
} |
161 changes: 161 additions & 0 deletions
161
ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/VisualizationsToolTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.engine.tools; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
|
||
import java.io.IOException; | ||
import java.io.InputStream; | ||
import java.nio.charset.StandardCharsets; | ||
import java.util.Collections; | ||
import java.util.Map; | ||
import java.util.concurrent.CompletableFuture; | ||
|
||
import org.junit.Assert; | ||
import org.junit.Before; | ||
import org.junit.Test; | ||
import org.mockito.ArgumentCaptor; | ||
import org.mockito.ArgumentMatchers; | ||
import org.mockito.Mock; | ||
import org.mockito.Mockito; | ||
import org.mockito.MockitoAnnotations; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.xcontent.json.JsonXContent; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.xcontent.DeprecationHandler; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.index.IndexNotFoundException; | ||
import org.opensearch.ml.common.spi.tools.Tool; | ||
|
||
public class VisualizationsToolTests { | ||
@Mock | ||
private Client client; | ||
|
||
private String searchResponse = "{}"; | ||
private String searchResponseNotFound = "{}"; | ||
|
||
@Before | ||
public void setup() throws IOException { | ||
MockitoAnnotations.openMocks(this); | ||
VisualizationsTool.Factory.getInstance().init(client); | ||
try (InputStream searchResponseIns = VisualizationsToolTests.class.getResourceAsStream("visualization.json")) { | ||
if (searchResponseIns != null) { | ||
searchResponse = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); | ||
} | ||
} | ||
try (InputStream searchResponseIns = VisualizationsToolTests.class.getResourceAsStream("visualization_not_found.json")) { | ||
if (searchResponseIns != null) { | ||
searchResponseNotFound = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); | ||
} | ||
} | ||
} | ||
|
||
@Test | ||
public void testToolIndexName() { | ||
VisualizationsTool tool1 = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); | ||
assertEquals(tool1.getIndex(), ".kibana"); | ||
|
||
VisualizationsTool tool2 = VisualizationsTool.Factory.getInstance().create(Map.of("index", "test-index")); | ||
assertEquals(tool2.getIndex(), "test-index"); | ||
} | ||
|
||
@Test | ||
public void testNumberOfVisualizationReturned() { | ||
VisualizationsTool tool1 = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); | ||
assertEquals(tool1.getSize(), 3); | ||
|
||
VisualizationsTool tool2 = VisualizationsTool.Factory.getInstance().create(Map.of("size", "1")); | ||
assertEquals(tool2.getSize(), 1); | ||
|
||
VisualizationsTool tool3 = VisualizationsTool.Factory.getInstance().create(Map.of("size", "badString")); | ||
assertEquals(tool3.getSize(), 3); | ||
} | ||
|
||
@Test | ||
public void testTrimPrefix() { | ||
VisualizationsTool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); | ||
assertEquals(tool.trimIdPrefix(null), ""); | ||
assertEquals(tool.trimIdPrefix("abc"), "abc"); | ||
assertEquals(tool.trimIdPrefix("visualization:abc"), "abc"); | ||
} | ||
|
||
@Test | ||
public void testParameterValidation() { | ||
VisualizationsTool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); | ||
Assert.assertFalse(tool.validate(Collections.emptyMap())); | ||
Assert.assertFalse(tool.validate(Map.of("input", ""))); | ||
Assert.assertTrue(tool.validate(Map.of("input", "question"))); | ||
} | ||
|
||
@Test | ||
public void testRunToolWithVisualizationFound() throws Exception { | ||
Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); | ||
final CompletableFuture<String> future = new CompletableFuture<>(); | ||
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally); | ||
|
||
ArgumentCaptor<ActionListener<SearchResponse>> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class); | ||
Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture()); | ||
|
||
Map<String, String> params = Map.of("input", "Sales by gender"); | ||
|
||
tool.run(params, listener); | ||
|
||
SearchResponse response = SearchResponse | ||
.fromXContent( | ||
JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, searchResponse) | ||
); | ||
searchResponseListener.getValue().onResponse(response); | ||
|
||
future.join(); | ||
assertEquals("Title,Id\n[Ecommerce]Sales by gender,aeb212e0-4c84-11e8-b3d7-01146121b73d\n", future.get()); | ||
} | ||
|
||
@Test | ||
public void testRunToolWithNoVisualizationFound() throws Exception { | ||
Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); | ||
final CompletableFuture<String> future = new CompletableFuture<>(); | ||
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally); | ||
|
||
ArgumentCaptor<ActionListener<SearchResponse>> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class); | ||
Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture()); | ||
|
||
Map<String, String> params = Map.of("input", "Sales by gender"); | ||
|
||
tool.run(params, listener); | ||
|
||
SearchResponse response = SearchResponse | ||
.fromXContent( | ||
JsonXContent.jsonXContent | ||
.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, searchResponseNotFound) | ||
); | ||
searchResponseListener.getValue().onResponse(response); | ||
|
||
future.join(); | ||
assertEquals("No Visualization found", future.get()); | ||
} | ||
|
||
@Test | ||
public void testRunToolWithIndexNotExists() throws Exception { | ||
Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); | ||
final CompletableFuture<String> future = new CompletableFuture<>(); | ||
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally); | ||
|
||
ArgumentCaptor<ActionListener<SearchResponse>> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class); | ||
Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture()); | ||
|
||
Map<String, String> params = Map.of("input", "Sales by gender"); | ||
|
||
tool.run(params, listener); | ||
|
||
IndexNotFoundException notFoundException = new IndexNotFoundException("test-index"); | ||
searchResponseListener.getValue().onFailure(notFoundException); | ||
|
||
future.join(); | ||
assertEquals("No Visualization found", future.get()); | ||
} | ||
} |
58 changes: 58 additions & 0 deletions
58
ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/visualization.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
{ | ||
"took": 4, | ||
"timed_out": false, | ||
"_shards": { | ||
"total": 1, | ||
"successful": 1, | ||
"skipped": 0, | ||
"failed": 0 | ||
}, | ||
"hits": { | ||
"total": { | ||
"value": 1, | ||
"relation": "eq" | ||
}, | ||
"max_score": 0.2847877, | ||
"hits": [ | ||
{ | ||
"_index": ".kibana_1", | ||
"_id": "visualization:aeb212e0-4c84-11e8-b3d7-01146121b73d", | ||
"_score": 0.2847877, | ||
"_source": { | ||
"visualization": { | ||
"title": "[Ecommerce]Sales by gender", | ||
"visState": "", | ||
"uiStateJSON": "{}", | ||
"description": "", | ||
"version": 1, | ||
"kibanaSavedObjectMeta": { | ||
"searchSourceJSON": "{}" | ||
} | ||
}, | ||
"type": "visualization", | ||
"references": [ | ||
{ | ||
"name": "control_0_index_pattern", | ||
"type": "index-pattern", | ||
"id": "d3d7af60-4c81-11e8-b3d7-01146121b73d" | ||
}, | ||
{ | ||
"name": "control_1_index_pattern", | ||
"type": "index-pattern", | ||
"id": "d3d7af60-4c81-11e8-b3d7-01146121b73d" | ||
}, | ||
{ | ||
"name": "control_2_index_pattern", | ||
"type": "index-pattern", | ||
"id": "d3d7af60-4c81-11e8-b3d7-01146121b73d" | ||
} | ||
], | ||
"migrationVersion": { | ||
"visualization": "7.10.0" | ||
}, | ||
"updated_at": "2023-11-10T02:50:24.881Z" | ||
} | ||
} | ||
] | ||
} | ||
} |
Oops, something went wrong.