Skip to content

Commit

Permalink
Add SPI subproject and interfaces for Tools
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Oct 31, 2023
1 parent 96c8d4d commit 04f5a20
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ ml-algorithms/build/
plugin/build/
.DS_Store
*/bin/
.classpath
.project
.settings
2 changes: 2 additions & 0 deletions settings.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ rootProject.name = 'opensearch-ml'

include 'client'
project(":client").name = rootProject.name + "-client"
include 'spi'
project(":spi").name = rootProject.name + "-spi"
include 'common'
project(":common").name = rootProject.name + "-common"
include 'plugin'
Expand Down
143 changes: 143 additions & 0 deletions spi/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import com.github.jengelman.gradle.plugins.shadow.ShadowBasePlugin
import org.opensearch.gradle.test.RestIntegTestTask

plugins {
id 'com.github.johnrengelman.shadow'
id 'jacoco'
id 'maven-publish'
id 'signing'
}

apply plugin: 'opensearch.java'
apply plugin: 'opensearch.testclusters'
apply plugin: 'opensearch.java-rest-test'

repositories {
mavenLocal()
mavenCentral()
maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" }
}

ext {
projectSubstitutions = [:]
licenseFile = rootProject.file('LICENSE.txt')
noticeFile = rootProject.file('NOTICE')
}

jacoco {
toolVersion = '0.8.7'
reportsDirectory = file("$buildDir/JacocoReport")
}

jacocoTestReport {
reports {
xml.required = false
csv.required = false
html.destination file("${buildDir}/jacoco/")
}
}
check.dependsOn jacocoTestReport

def slf4j_version_of_cronutils = "1.7.36"
dependencies {
compileOnly "org.opensearch:opensearch:${opensearch_version}"

testImplementation "org.opensearch.test:framework:${opensearch_version}"
testImplementation "org.apache.logging.log4j:log4j-core:${versions.log4j}"
}

configurations.all {
if (it.state != Configuration.State.UNRESOLVED) return
resolutionStrategy {
force "org.slf4j:slf4j-api:${slf4j_version_of_cronutils}"
}
}

shadowJar {
archiveClassifier = null
}

test {
doFirst {
test.classpath -= project.files(project.tasks.named('shadowJar'))
test.classpath -= project.configurations.getByName(ShadowBasePlugin.CONFIGURATION_NAME)
test.classpath += project.extensions.getByType(SourceSetContainer).getByName(SourceSet.MAIN_SOURCE_SET_NAME).runtimeClasspath
}
systemProperty 'tests.security.manager', 'false'
}

task integTest(type: RestIntegTestTask) {
description 'Run integ test with opensearch test framework'
group 'verification'
systemProperty 'tests.security.manager', 'false'
dependsOn test
}
check.dependsOn integTest

testClusters.javaRestTest {
testDistribution = 'INTEG_TEST'
}

task sourcesJar(type: Jar) {
archiveClassifier.set 'sources'
from sourceSets.main.allJava
}

task javadocJar(type: Jar) {
archiveClassifier.set 'javadoc'
from javadoc.destinationDir
dependsOn javadoc
}

publishing {
repositories {
maven {
name = 'staging'
url = "${rootProject.buildDir}/local-staging-repo"
}
maven {
name = "Snapshots" // optional target repository name
url = "https://aws.oss.sonatype.org/content/repositories/snapshots"
credentials {
username "$System.env.SONATYPE_USERNAME"
password "$System.env.SONATYPE_PASSWORD"
}
}
}
publications {
shadow(MavenPublication) { publication ->
project.shadow.component(publication)
artifact sourcesJar
artifact javadocJar

pom {
name = "OpenSearch ML Commons SPI"
packaging = "jar"
url = "https://github.com/opensearch-project/ml-commons"
description = "OpenSearch ML spi"
scm {
connection = "scm:[email protected]:opensearch-project/ml-commons.git"
developerConnection = "scm:[email protected]:opensearch-project/ml-commons.git"
url = "[email protected]:opensearch-project/ml-commons.git"
}
licenses {
license {
name = "The Apache License, Version 2.0"
url = "http://www.apache.org/licenses/LICENSE-2.0.txt"
}
}
developers {
developer {
name = "OpenSearch"
url = "https://github.com/opensearch-project/ml-commons"
}
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.spi;

import org.opensearch.ml.common.spi.tools.Tool;

import java.util.List;

/**
* ml-commons extension interface.
*/
public interface MLCommonsExtension {

/**
* Get tools.
* @return
*/
List<Tool> getTools();

/**
* Get tool factories.
* @return
*/
List<Tool.Factory<? extends Tool>> getToolFactories();
}
21 changes: 21 additions & 0 deletions spi/src/main/java/org/opensearch/ml/common/spi/tools/Parser.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.spi.tools;

/**
* General parser interface.
* @param <S> The input type
* @param <T> The return type
*/
public interface Parser<S, T> {

/**
* Parse input.
* @param input
* @return output
*/
T parse(S input);
}
92 changes: 92 additions & 0 deletions spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.spi.tools;

import org.opensearch.core.action.ActionListener;
import java.util.Map;

/**
* General tool interface.
*/
public interface Tool {

/**
* Run tool and return response.
* @param parameters input parameters
* @return the tool's output
* @param <T> The output type
*/
default <T> T run(Map<String, String> parameters) {return null;};

default <T> void run(Map<String, String> parameters, ActionListener<T> listener) {};

/**
* Set input parser.
* @param parser
*/
default void setInputParser(Parser<?, ?> parser){};

/**
* Set output parser.
* @param parser
*/
default void setOutputParser(Parser<?, ?> parser){};

/**
* Get tool name.
* @return
*/
String getName();

/**
* Get tool alias.
* @return
*/
String getAlias();

/**
* Set tool alias.
* @param alias
*/
void setAlias(String alias);

/**
* Get tool description.
* @return
*/
String getDescription();

/**
* Set tool description.
* @param description
*/
void setDescription(String description);

/**
* Validate if the input is good.
* @param parameters input parameters
* @return
*/
boolean validate(Map<String, String> parameters);

/**
* Check if should end the whole CoT immediately.
* For example, if some critical error detected like high memory pressure,
* the tool may end the whole CoT process by returning true.
* @param input
* @param toolParameters
* @return true as a signal to CoT to end the chain, false to continue CoT
*/
default boolean end(String input, Map<String, String> toolParameters){return false;}

/**
* Tool factory which can create instance of {@link Tool}.
* @param <T> The subclass this factory produces
*/
interface Factory<T extends Tool> {
T create(Map<String, Object> params);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.spi.tools;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface ToolAnnotation {
String value();
}

0 comments on commit 04f5a20

Please sign in to comment.