Skip to content

Commit

Permalink
add more user based permission check in Memory (#1927)
Browse files Browse the repository at this point in the history
* add more user based permission check in Memory

Signed-off-by: Xun Zhang <[email protected]>

* add UT for acess denied cases

Signed-off-by: Xun Zhang <[email protected]>

---------

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt authored Jan 26, 2024
1 parent 3d4bb02 commit cdd63b4
Show file tree
Hide file tree
Showing 11 changed files with 558 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,20 @@ public void createInteraction(

/**
* Update a conversation
* @param conversationId the conversation id to update
* @param updateContent update content for the conversations index
* @param listener receives the update response
*/
public void updateConversation(String conversationId, Map<String, Object> updateContent, ActionListener<UpdateResponse> listener);

/**
* Update an interaction
* @param interactionId the interaction id to update
* @param updateContent update content for the interaction index
* @param listener receives the update response
*/
public void updateInteraction(String interactionId, Map<String, Object> updateContent, ActionListener<UpdateResponse> listener);

/**
* Get a single ConversationMeta object
* @param conversationId id of the conversation to get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,20 @@
import java.time.Instant;
import java.util.Map;

import org.opensearch.OpenSearchException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand All @@ -28,29 +30,50 @@
@Log4j2
public class UpdateConversationTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
Client client;
private ConversationalMemoryHandler cmHandler;

private volatile boolean featureIsEnabled;

@Inject
public UpdateConversationTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
public UpdateConversationTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
OpenSearchConversationalMemoryHandler cmHandler,
ClusterService clusterService
) {
super(UpdateConversationAction.NAME, transportService, actionFilters, UpdateConversationRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request);
String conversationId = updateConversationRequest.getConversationId();
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId);
Map<String, Object> updateContent = updateConversationRequest.getUpdateContent();
updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now());
if (!featureIsEnabled) {
listener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
)
);
return;
} else {
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request);
String conversationId = updateConversationRequest.getConversationId();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
Map<String, Object> updateContent = updateConversationRequest.getUpdateContent();
updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now());

updateRequest.doc(updateContent);
updateRequest.docAsUpsert(true);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.update(updateRequest, getUpdateResponseListener(conversationId, listener, context));
} catch (Exception e) {
log.error("Failed to update Conversation for conversation id" + conversationId, e);
listener.onFailure(e);
cmHandler.updateConversation(conversationId, updateContent, getUpdateResponseListener(conversationId, listener, context));
} catch (Exception e) {
log.error("Failed to update Conversation " + conversationId, e);
listener.onFailure(e);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@

package org.opensearch.ml.memory.action.conversation;

import java.util.Map;

import org.opensearch.OpenSearchException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand All @@ -25,26 +29,47 @@
@Log4j2
public class UpdateInteractionTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
Client client;
private ConversationalMemoryHandler cmHandler;

private volatile boolean featureIsEnabled;

@Inject
public UpdateInteractionTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
public UpdateInteractionTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
OpenSearchConversationalMemoryHandler cmHandler,
ClusterService clusterService
) {
super(UpdateInteractionAction.NAME, transportService, actionFilters, UpdateInteractionRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
if (!featureIsEnabled) {
listener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
)
);
return;
}
UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest.fromActionRequest(request);
String interactionId = updateInteractionRequest.getInteractionId();
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.INTERACTIONS_INDEX_NAME, interactionId);
updateRequest.doc(updateInteractionRequest.getUpdateContent());
updateRequest.docAsUpsert(true);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
Map<String, Object> updateContent = updateInteractionRequest.getUpdateContent();

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.update(updateRequest, getUpdateResponseListener(interactionId, listener, context));
cmHandler.updateInteraction(interactionId, updateContent, getUpdateResponseListener(interactionId, listener, context));
} catch (Exception e) {
log.error("Failed to update Interaction for interaction id " + interactionId, e);
log.error("Failed to update Interaction " + interactionId, e);
listener.onFailure(e);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,18 +360,35 @@ public void searchConversations(SearchRequest request, ActionListener<SearchResp
}

/**
* Update conversations in the index
* Update conversation in the index
* @param conversationId the conversation id that needs update
* @param updateRequest original update request
* @param listener receives the update response for the wrapped query
*/
public void updateConversation(UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
public void updateConversation(String conversationId, UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) {
listener
.onFailure(
new IndexNotFoundException("cannot update conversation since the conversation index does not exist", META_INDEX_NAME)
);
return;
}

this.checkAccess(conversationId, ActionListener.wrap(access -> {
if (access) {
innerUpdateConversation(updateRequest, listener);
} else {
String userstr = client
.threadPool()
.getThreadContext()
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId);
}
}, e -> { listener.onFailure(e); }));
}

private void innerUpdateConversation(UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<UpdateResponse> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
client.update(updateRequest, internalListener);
Expand Down
Loading

0 comments on commit cdd63b4

Please sign in to comment.