Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport to main] update connector API #1651

Merged
merged 3 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import org.opensearch.action.ActionType;
import org.opensearch.action.update.UpdateResponse;

public class MLUpdateConnectorAction extends ActionType<UpdateResponse> {
public static final MLUpdateConnectorAction INSTANCE = new MLUpdateConnectorAction();
public static final String NAME = "cluster:admin/opensearch/ml/connectors/update";

private MLUpdateConnectorAction() { super(NAME, UpdateResponse::new);}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import lombok.Builder;
import lombok.Getter;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
public class MLUpdateConnectorRequest extends ActionRequest {
String connectorId;
Map<String, Object> updateContent;

@Builder
public MLUpdateConnectorRequest(String connectorId, Map<String, Object> updateContent) {
this.connectorId = connectorId;
this.updateContent = updateContent;
}

public MLUpdateConnectorRequest(StreamInput in) throws IOException {
super(in);
this.connectorId = in.readString();
this.updateContent = in.readMap();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.connectorId);
out.writeMap(this.getUpdateContent());
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;

if (this.connectorId == null) {
exception = addValidationError("ML connector id can't be null", exception);
}

return exception;
}

public static MLUpdateConnectorRequest parse(XContentParser parser, String connectorId) throws IOException {
Map<String, Object> dataAsMap = null;
dataAsMap = parser.map();

return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(dataAsMap).build();
}

public static MLUpdateConnectorRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLUpdateConnectorRequest) {
return (MLUpdateConnectorRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLUpdateConnectorRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLUpdateConnectorRequest", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.rest.RestRequest;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.when;

public class MLUpdateConnectorRequestTests {
private String connectorId;
private Map<String, Object> updateContent;
private MLUpdateConnectorRequest mlUpdateConnectorRequest;

@Mock
XContentParser parser;

@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
this.connectorId = "test-connector_id";
this.updateContent = Map.of("description", "new description");
mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder()
.connectorId(connectorId)
.updateContent(updateContent)
.build();
}

@Test
public void writeTo_Success() throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
mlUpdateConnectorRequest.writeTo(bytesStreamOutput);
MLUpdateConnectorRequest parsedUpdateRequest = new MLUpdateConnectorRequest(bytesStreamOutput.bytes().streamInput());
assertEquals(connectorId, parsedUpdateRequest.getConnectorId());
assertEquals(updateContent, parsedUpdateRequest.getUpdateContent());
}

@Test
public void validate_Success() {
assertNull(mlUpdateConnectorRequest.validate());
}

@Test
public void validate_Exception_NullConnectorId() {
MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.builder().build();
Exception exception = updateConnectorRequest.validate();

assertEquals("Validation Failed: 1: ML connector id can't be null;", exception.getMessage());
}

@Test
public void parse_success() throws IOException {
RestRequest.Method method = RestRequest.Method.POST;
final Map<String, Object> updatefields = Map.of("version", "new version", "description", "new description");
when(parser.map()).thenReturn(updatefields);

MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId);
assertEquals(updateConnectorRequest.getConnectorId(), connectorId);
assertEquals(updateConnectorRequest.getUpdateContent(), updatefields);
}

@Test
public void fromActionRequest_Success() {
MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder()
.connectorId(connectorId)
.updateContent(updateContent)
.build();
assertSame(MLUpdateConnectorRequest.fromActionRequest(mlUpdateConnectorRequest), mlUpdateConnectorRequest);
}

@Test
public void fromActionRequest_Success_fromActionRequest() {
MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder()
.connectorId(connectorId)
.updateContent(updateContent)
.build();
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
mlUpdateConnectorRequest.writeTo(out);
}
};
MLUpdateConnectorRequest request = MLUpdateConnectorRequest.fromActionRequest(actionRequest);
assertNotSame(request, mlUpdateConnectorRequest);
assertEquals(mlUpdateConnectorRequest.getConnectorId(), request.getConnectorId());
assertEquals(mlUpdateConnectorRequest.getUpdateContent(), request.getUpdateContent());
}

@Test(expected = UncheckedIOException.class)
public void fromActionRequest_IOException() {
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException();
}
};
MLUpdateConnectorRequest.fromActionRequest(actionRequest);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.connector;

import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.AccessLevel;
import lombok.experimental.FieldDefaults;
import lombok.extern.log4j.Log4j2;

@Log4j2
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class UpdateConnectorTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
Client client;

ConnectorAccessControlHelper connectorAccessControlHelper;
MLModelManager mlModelManager;

@Inject
public UpdateConnectorTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
ConnectorAccessControlHelper connectorAccessControlHelper,
MLModelManager mlModelManager
) {
super(MLUpdateConnectorAction.NAME, transportService, actionFilters, MLUpdateConnectorRequest::new);
this.client = client;
this.connectorAccessControlHelper = connectorAccessControlHelper;
this.mlModelManager = mlModelManager;
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
MLUpdateConnectorRequest mlUpdateConnectorAction = MLUpdateConnectorRequest.fromActionRequest(request);
String connectorId = mlUpdateConnectorAction.getConnectorId();
UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, connectorId);
updateRequest.doc(mlUpdateConnectorAction.getUpdateContent());
updateRequest.docAsUpsert(true);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper.validateConnectorAccess(client, connectorId, ActionListener.wrap(hasPermission -> {
if (Boolean.TRUE.equals(hasPermission)) {
updateUndeployedConnector(connectorId, updateRequest, listener, context);
} else {
listener
.onFailure(
new IllegalArgumentException("You don't have permission to update the connector, connector id: " + connectorId)
);
}
}, exception -> {
log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", connectorId, exception);
listener.onFailure(exception);
}));
} catch (Exception e) {
log.error("Failed to update ML connector for connector id {}. Details {}:", connectorId, e);
listener.onFailure(e);
}
}

private void updateUndeployedConnector(
String connectorId,
UpdateRequest updateRequest,
ActionListener<UpdateResponse> listener,
ThreadContext.StoredContext context
) {
SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX);
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
boolQueryBuilder.must(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId));
boolQueryBuilder.must(QueryBuilders.idsQuery().addIds(mlModelManager.getAllModelIds()));
sourceBuilder.query(boolQueryBuilder);
searchRequest.source(sourceBuilder);

client.search(searchRequest, ActionListener.wrap(searchResponse -> {
SearchHit[] searchHits = searchResponse.getHits().getHits();
if (searchHits.length == 0) {
client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context));
} else {
log.error(searchHits.length + " models are still using this connector, please undeploy the models first!");
listener
.onFailure(
new MLValidationException(
searchHits.length + " models are still using this connector, please undeploy the models first!"
)
);
}
}, e -> {
if (e instanceof IndexNotFoundException) {
client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context));
return;
}
log.error("Failed to update ML connector: " + connectorId, e);
listener.onFailure(e);

}));
}

private ActionListener<UpdateResponse> getUpdateResponseListener(
String connectorId,
ActionListener<UpdateResponse> actionListener,
ThreadContext.StoredContext context
) {
return ActionListener.runBefore(ActionListener.wrap(updateResponse -> {
if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) {
log.info("Failed to update the connector with ID: {}", connectorId);
actionListener.onResponse(updateResponse);
return;
}
log.info("Successfully updated the connector with ID: {}", connectorId);
actionListener.onResponse(updateResponse);
}, exception -> {
log.error("Failed to update ML connector with ID {}. Details: {}", connectorId, exception);
actionListener.onFailure(exception);
}), context::restore);
}
}
Loading