Skip to content

Commit

Permalink
Added amazon rekognition as a trust endpoint (opensearch-project#3419)
Browse files Browse the repository at this point in the history
* feat: add rekognition trust endpoint

Signed-off-by: Pavan Yekbote <[email protected]>

* refactor: consistent regex with previous connectors and looser checks on region

Signed-off-by: Pavan Yekbote <[email protected]>

* test: add test case to validate connector creation successful

Signed-off-by: Pavan Yekbote <[email protected]>

* chore: spotless apply

Signed-off-by: Pavan Yekbote <[email protected]>

* chore: linter fixes

Signed-off-by: Pavan Yekbote <[email protected]>

---------

Signed-off-by: Pavan Yekbote <[email protected]>
  • Loading branch information
pyek-bot committed Jan 28, 2025
1 parent a852152 commit ae451c9
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ private MLCommonsSettings() {}
public static final Setting<Boolean> ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED = Setting
.boolSetting("plugins.ml_commons.offline_batch_inference_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final String REKOGNITION_TRUST_ENDPOINT_REGEX = "^https://rekognition(-fips)?\\..*[a-z0-9-]\\.amazonaws\\.com$";

public static final Setting<List<String>> ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX = Setting
.listSetting(
"plugins.ml_commons.trusted_connector_endpoints_regex",
Expand All @@ -169,7 +171,8 @@ private MLCommonsSettings() {}
"^https://bedrock-agent-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
"^https://bedrock\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
"^https://textract\\..*[a-z0-9-]\\.amazonaws\\.com$",
"^https://comprehend\\..*[a-z0-9-]\\.amazonaws\\.com$"
"^https://comprehend\\..*[a-z0-9-]\\.amazonaws\\.com$",
REKOGNITION_TRUST_ENDPOINT_REGEX
),
Function.identity(),
Setting.Property.NodeScope,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.REKOGNITION_TRUST_ENDPOINT_REGEX;
import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;

Expand Down Expand Up @@ -118,7 +119,13 @@ public class TransportCreateConnectorActionTests extends OpenSearchTestCase {
private ArgumentCaptor<PutDataObjectRequest> putDataObjectRequestArgumentCaptor;

private static final List<String> TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList
.of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$");
.of(
"^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$",
"^https://api\\.openai\\.com/.*$",
"^https://api\\.cohere\\.ai/.*$",
REKOGNITION_TRUST_ENDPOINT_REGEX,
"^https://api\\.deepseek\\.com/.*$"
);

@Before
public void setup() {
Expand Down Expand Up @@ -539,4 +546,117 @@ public void test_execute_URL_notMatchingExpression_exception() {
argumentCaptor.getValue().getMessage()
);
}

public void test_connector_creation_success_deepseek() {
TransportCreateConnectorAction action = new TransportCreateConnectorAction(
transportService,
actionFilters,
mlIndicesHandler,
client,
sdkClient,
mlEngine,
connectorAccessControlHelper,
settings,
clusterService,
mlModelManager,
mlFeatureEnabledSetting
);
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(0);
listener.onResponse(true);
return null;
}).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class));
doAnswer(invocation -> {
ActionListener<IndexResponse> listener = invocation.getArgument(1);
listener.onResponse(indexResponse);
return null;
}).when(client).index(any(IndexRequest.class), isA(ActionListener.class));
List<ConnectorAction> actions = new ArrayList<>();
actions
.add(
ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("https://api.deepseek.com/v1/chat/completions")
.build()
);
Map<String, String> credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret");
MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput
.builder()
.name(randomAlphaOfLength(5))
.description(randomAlphaOfLength(10))
.version("1")
.protocol(ConnectorProtocols.HTTP)
.credential(credential)
.actions(actions)
.build();
MLCreateConnectorRequest request = new MLCreateConnectorRequest(mlCreateConnectorInput);
action.doExecute(task, request, actionListener);
verify(actionListener).onResponse(any(MLCreateConnectorResponse.class));
}

public void test_connector_creation_success_rekognition() {
TransportCreateConnectorAction action = new TransportCreateConnectorAction(
transportService,
actionFilters,
mlIndicesHandler,
client,
sdkClient,
mlEngine,
connectorAccessControlHelper,
settings,
clusterService,
mlModelManager,
mlFeatureEnabledSetting
);

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(0);
listener.onResponse(true);
return null;
}).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<IndexResponse> listener = invocation.getArgument(1);
listener.onResponse(indexResponse);
return null;
}).when(client).index(any(IndexRequest.class), isA(ActionListener.class));

List<ConnectorAction> actions = new ArrayList<>();
actions
.add(
ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("https://rekognition.test-region-1.amazonaws.com")
.build()
);
actions
.add(
ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("https://rekognition-fips.test-region-1.amazonaws.com")
.build()
);

Map<String, String> credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret");
MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput
.builder()
.name(randomAlphaOfLength(5))
.description(randomAlphaOfLength(10))
.version("1")
.protocol(ConnectorProtocols.HTTP)
.credential(credential)
.actions(actions)
.build();

MLCreateConnectorRequest request = new MLCreateConnectorRequest(mlCreateConnectorInput);

action.doExecute(task, request, actionListener);
verify(actionListener).onResponse(any(MLCreateConnectorResponse.class));
}
}

0 comments on commit ae451c9

Please sign in to comment.