Skip to content

Commit

Permalink
Move visualization tool to ml-commons
Browse files Browse the repository at this point in the history
Signed-off-by: Hailong Cui <[email protected]>
  • Loading branch information
Hailong-am committed Apr 26, 2024
1 parent ea7fefa commit e25707d
Show file tree
Hide file tree
Showing 12 changed files with 889 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.tools;

import java.util.Arrays;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;

import org.opensearch.ExceptionsHelper;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.client.Requests;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.search.SearchHits;
import org.opensearch.search.builder.SearchSourceBuilder;

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

@Log4j2
@ToolAnnotation(VisualizationsTool.TYPE)
public class VisualizationsTool implements Tool {
public static final String NAME = "FindVisualizations";
public static final String TYPE = "VisualizationTool";
public static final String VERSION = "v1.0";

public static final String SAVED_OBJECT_TYPE = "visualization";

/**
* default number of visualizations returned
*/
private static final int DEFAULT_SIZE = 3;
private static final String DEFAULT_DESCRIPTION =
"Use this tool to find user created visualizations. This tool takes the visualization name as input and returns matching visualizations";
@Setter
@Getter
private String description = DEFAULT_DESCRIPTION;

@Getter
@Setter
private String name = NAME;
@Getter
@Setter
private String type = TYPE;
@Getter
private final String version = VERSION;
private final Client client;
@Getter
private final String index;
@Getter
private final int size;

@Builder
public VisualizationsTool(Client client, String index, int size) {
this.client = client;
this.index = index;
this.size = size;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
boolQueryBuilder.must().add(QueryBuilders.termQuery("type", SAVED_OBJECT_TYPE));
boolQueryBuilder.must().add(QueryBuilders.matchQuery(SAVED_OBJECT_TYPE + ".title", parameters.get("input")));

SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.searchSource().query(boolQueryBuilder);
searchSourceBuilder.from(0).size(size);
SearchRequest searchRequest = Requests.searchRequest(index).source(searchSourceBuilder);

client.search(searchRequest, new ActionListener<>() {
@Override
public void onResponse(SearchResponse searchResponse) {
SearchHits hits = searchResponse.getHits();
StringBuilder visBuilder = new StringBuilder();
visBuilder.append("Title,Id\n");
if (hits.getTotalHits().value > 0) {
Arrays.stream(hits.getHits()).forEach(h -> {
String id = trimIdPrefix(h.getId());
Map<String, String> visMap = (Map<String, String>) h.getSourceAsMap().get(SAVED_OBJECT_TYPE);
String title = visMap.get("title");
visBuilder.append(String.format(Locale.ROOT, "%s,%s\n", title, id));
});

listener.onResponse((T) visBuilder.toString());
} else {
listener.onResponse((T) "No Visualization found");
}
}

@Override
public void onFailure(Exception e) {
if (ExceptionsHelper.unwrapCause(e) instanceof IndexNotFoundException) {
listener.onResponse((T) "No Visualization found");
} else {
listener.onFailure(e);
}
}
});
}

String trimIdPrefix(String id) {
id = Optional.ofNullable(id).orElse("");
if (id.startsWith(SAVED_OBJECT_TYPE)) {
String prefix = String.format(Locale.ROOT, "%s:", SAVED_OBJECT_TYPE);
return id.substring(prefix.length());
}
return id;
}

@Override
public boolean validate(Map<String, String> parameters) {
return parameters.containsKey("input") && !Strings.isNullOrEmpty(parameters.get("input"));
}

public static class Factory implements Tool.Factory<VisualizationsTool> {
private Client client;

private static Factory INSTANCE;

public static Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (VisualizationsTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new Factory();
return INSTANCE;
}
}

public void init(Client client) {
this.client = client;
}

@Override
public VisualizationsTool create(Map<String, Object> params) {
String index = params.get("index") == null ? ".kibana" : (String) params.get("index");
String sizeStr = params.get("size") == null ? "3" : (String) params.get("size");
int size;
try {
size = Integer.parseInt(sizeStr);
} catch (NumberFormatException ignored) {
size = DEFAULT_SIZE;
}
return VisualizationsTool.builder().client(client).index(index).size(size).build();
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}

@Override
public String getDefaultType() {
return TYPE;
}

@Override
public String getDefaultVersion() {
return null;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.tools;

import static org.junit.Assert.assertEquals;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.spi.tools.Tool;

public class VisualizationsToolTests {
@Mock
private Client client;

private String searchResponse = "{}";
private String searchResponseNotFound = "{}";

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
VisualizationsTool.Factory.getInstance().init(client);
try (InputStream searchResponseIns = VisualizationsToolTests.class.getResourceAsStream("visualization.json")) {
if (searchResponseIns != null) {
searchResponse = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8);
}
}
try (InputStream searchResponseIns = VisualizationsToolTests.class.getResourceAsStream("visualization_not_found.json")) {
if (searchResponseIns != null) {
searchResponseNotFound = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8);
}
}
}

@Test
public void testToolIndexName() {
VisualizationsTool tool1 = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap());
assertEquals(tool1.getIndex(), ".kibana");

VisualizationsTool tool2 = VisualizationsTool.Factory.getInstance().create(Map.of("index", "test-index"));
assertEquals(tool2.getIndex(), "test-index");
}

@Test
public void testNumberOfVisualizationReturned() {
VisualizationsTool tool1 = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap());
assertEquals(tool1.getSize(), 3);

VisualizationsTool tool2 = VisualizationsTool.Factory.getInstance().create(Map.of("size", "1"));
assertEquals(tool2.getSize(), 1);

VisualizationsTool tool3 = VisualizationsTool.Factory.getInstance().create(Map.of("size", "badString"));
assertEquals(tool3.getSize(), 3);
}

@Test
public void testTrimPrefix() {
VisualizationsTool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap());
assertEquals(tool.trimIdPrefix(null), "");
assertEquals(tool.trimIdPrefix("abc"), "abc");
assertEquals(tool.trimIdPrefix("visualization:abc"), "abc");
}

@Test
public void testParameterValidation() {
VisualizationsTool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap());
Assert.assertFalse(tool.validate(Collections.emptyMap()));
Assert.assertFalse(tool.validate(Map.of("input", "")));
Assert.assertTrue(tool.validate(Map.of("input", "question")));
}

@Test
public void testRunToolWithVisualizationFound() throws Exception {
Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap());
final CompletableFuture<String> future = new CompletableFuture<>();
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);

ArgumentCaptor<ActionListener<SearchResponse>> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class);
Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture());

Map<String, String> params = Map.of("input", "Sales by gender");

tool.run(params, listener);

SearchResponse response = SearchResponse
.fromXContent(
JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, searchResponse)
);
searchResponseListener.getValue().onResponse(response);

future.join();
assertEquals("Title,Id\n[Ecommerce]Sales by gender,aeb212e0-4c84-11e8-b3d7-01146121b73d\n", future.get());
}

@Test
public void testRunToolWithNoVisualizationFound() throws Exception {
Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap());
final CompletableFuture<String> future = new CompletableFuture<>();
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);

ArgumentCaptor<ActionListener<SearchResponse>> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class);
Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture());

Map<String, String> params = Map.of("input", "Sales by gender");

tool.run(params, listener);

SearchResponse response = SearchResponse
.fromXContent(
JsonXContent.jsonXContent
.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, searchResponseNotFound)
);
searchResponseListener.getValue().onResponse(response);

future.join();
assertEquals("No Visualization found", future.get());
}

@Test
public void testRunToolWithIndexNotExists() throws Exception {
Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap());
final CompletableFuture<String> future = new CompletableFuture<>();
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);

ArgumentCaptor<ActionListener<SearchResponse>> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class);
Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture());

Map<String, String> params = Map.of("input", "Sales by gender");

tool.run(params, listener);

IndexNotFoundException notFoundException = new IndexNotFoundException("test-index");
searchResponseListener.getValue().onFailure(notFoundException);

future.join();
assertEquals("No Visualization found", future.get());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"took": 4,
"timed_out": false,
"_shards": {
"total": 1,
"successful": 1,
"skipped": 0,
"failed": 0
},
"hits": {
"total": {
"value": 1,
"relation": "eq"
},
"max_score": 0.2847877,
"hits": [
{
"_index": ".kibana_1",
"_id": "visualization:aeb212e0-4c84-11e8-b3d7-01146121b73d",
"_score": 0.2847877,
"_source": {
"visualization": {
"title": "[Ecommerce]Sales by gender",
"visState": "",
"uiStateJSON": "{}",
"description": "",
"version": 1,
"kibanaSavedObjectMeta": {
"searchSourceJSON": "{}"
}
},
"type": "visualization",
"references": [
{
"name": "control_0_index_pattern",
"type": "index-pattern",
"id": "d3d7af60-4c81-11e8-b3d7-01146121b73d"
},
{
"name": "control_1_index_pattern",
"type": "index-pattern",
"id": "d3d7af60-4c81-11e8-b3d7-01146121b73d"
},
{
"name": "control_2_index_pattern",
"type": "index-pattern",
"id": "d3d7af60-4c81-11e8-b3d7-01146121b73d"
}
],
"migrationVersion": {
"visualization": "7.10.0"
},
"updated_at": "2023-11-10T02:50:24.881Z"
}
}
]
}
}
Loading

0 comments on commit e25707d

Please sign in to comment.