Skip to content

Commit

Permalink
Fix test and clean up code
Browse files Browse the repository at this point in the history
Signed-off-by: Derek Ho <[email protected]>
  • Loading branch information
derek-ho committed Jan 9, 2025
1 parent 5b48fb9 commit e13c055
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import com.google.common.collect.ImmutableList;
Expand All @@ -32,6 +33,7 @@
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.rest.RestStatus;
Expand All @@ -47,6 +49,7 @@
import org.opensearch.security.dlic.rest.api.RestApiAdminPrivilegesEvaluator;
import org.opensearch.security.dlic.rest.api.RestApiPrivilegesEvaluator;
import org.opensearch.security.dlic.rest.api.SecurityApiDependencies;
import org.opensearch.security.dlic.rest.support.Utils;
import org.opensearch.security.identity.SecurityTokenManager;
import org.opensearch.security.privileges.PrivilegesEvaluator;
import org.opensearch.security.securityconf.DynamicConfigFactory;
Expand Down Expand Up @@ -138,12 +141,18 @@ protected RestChannelConsumer prepareRequest(final RestRequest request, final No
}

RestChannelConsumer doPrepareRequest(RestRequest request, NodeClient client) {
return switch (request.method()) {
case POST -> handlePost(request, client);
case DELETE -> handleDelete(request, client);
case GET -> handleGet(request, client);
default -> throw new IllegalArgumentException(request.method() + " not supported");
};
final var originalUserAndRemoteAddress = Utils.userAndRemoteAddressFrom(client.threadPool().getThreadContext());
try (final ThreadContext.StoredContext ctx = client.threadPool().getThreadContext().stashContext()) {
client.threadPool()
.getThreadContext()
.putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, originalUserAndRemoteAddress.getLeft());
return switch (request.method()) {
case POST -> handlePost(request, client);
case DELETE -> handleDelete(request, client);
case GET -> handleGet(request, client);
default -> throw new IllegalArgumentException(request.method() + " not supported");
};
}
}

private RestChannelConsumer handleGet(RestRequest request, NodeClient client) {
Expand Down Expand Up @@ -177,7 +186,6 @@ private RestChannelConsumer handleGet(RestRequest request, NodeClient client) {

private RestChannelConsumer handlePost(RestRequest request, NodeClient client) {
return channel -> {
final XContentBuilder builder = channel.newBuilder();
try {
final Map<String, Object> requestBody = request.contentOrSourceParamParser().map();
validateRequestParameters(requestBody);
Expand Down Expand Up @@ -305,7 +313,6 @@ void validateIndexPermissionsList(List<Map<String, Object>> indexPermsList) {

private RestChannelConsumer handleDelete(RestRequest request, NodeClient client) {
return channel -> {
final XContentBuilder builder = channel.newBuilder();
try {
final Map<String, Object> requestBody = request.contentOrSourceParamParser().map();

Expand Down Expand Up @@ -449,4 +456,14 @@ protected void authorizeSecurityAccess(RestRequest request) throws IOException {
throw new SecurityException("User does not have required security API access");
}
}

private <T> T withSecurityContext(NodeClient client, Supplier<T> operation) {
final var originalUserAndRemoteAddress = Utils.userAndRemoteAddressFrom(client.threadPool().getThreadContext());
try (final ThreadContext.StoredContext ctx = client.threadPool().getThreadContext().stashContext()) {
client.threadPool()
.getThreadContext()
.putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, originalUserAndRemoteAddress.getLeft());
return operation.get();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Supplier;

import com.google.common.collect.ImmutableMap;
import org.apache.logging.log4j.LogManager;
Expand All @@ -27,7 +26,6 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
Expand All @@ -42,7 +40,6 @@
import org.opensearch.index.reindex.DeleteByQueryRequest;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.security.dlic.rest.support.Utils;
import org.opensearch.security.support.ConfigConstants;

import static org.opensearch.security.action.apitokens.ApiToken.NAME_FIELD;
Expand All @@ -59,116 +56,79 @@ public ApiTokenIndexHandler(Client client, ClusterService clusterService) {
}

public void indexTokenMetadata(ApiToken token) {
withSecurityContext(() -> {
try {

XContentBuilder builder = XContentFactory.jsonBuilder();
String jsonString = token.toXContent(builder, ToXContent.EMPTY_PARAMS).toString();

IndexRequest request = new IndexRequest(ConfigConstants.OPENSEARCH_API_TOKENS_INDEX).source(jsonString, XContentType.JSON);

ActionListener<IndexResponse> irListener = ActionListener.wrap(idxResponse -> {
LOGGER.info("Created {} entry.", ConfigConstants.OPENSEARCH_API_TOKENS_INDEX);
}, (failResponse) -> {
LOGGER.error(failResponse.getMessage());
LOGGER.info("Failed to create {} entry.", ConfigConstants.OPENSEARCH_API_TOKENS_INDEX);
});
client.index(request, irListener);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
try {

XContentBuilder builder = XContentFactory.jsonBuilder();
String jsonString = token.toXContent(builder, ToXContent.EMPTY_PARAMS).toString();

IndexRequest request = new IndexRequest(ConfigConstants.OPENSEARCH_API_TOKENS_INDEX).source(jsonString, XContentType.JSON);

ActionListener<IndexResponse> irListener = ActionListener.wrap(idxResponse -> {
LOGGER.info("Created {} entry.", ConfigConstants.OPENSEARCH_API_TOKENS_INDEX);
}, (failResponse) -> {
LOGGER.error(failResponse.getMessage());
LOGGER.info("Failed to create {} entry.", ConfigConstants.OPENSEARCH_API_TOKENS_INDEX);
});
client.index(request, irListener);
} catch (IOException e) {
throw new RuntimeException(e);
}

}

public void deleteToken(String name) throws ApiTokenException {
withSecurityContext(() -> {
DeleteByQueryRequest request = new DeleteByQueryRequest(ConfigConstants.OPENSEARCH_API_TOKENS_INDEX).setQuery(
QueryBuilders.matchQuery(NAME_FIELD, name)
).setRefresh(true);
DeleteByQueryRequest request = new DeleteByQueryRequest(ConfigConstants.OPENSEARCH_API_TOKENS_INDEX).setQuery(
QueryBuilders.matchQuery(NAME_FIELD, name)
).setRefresh(true);

BulkByScrollResponse response = client.execute(DeleteByQueryAction.INSTANCE, request).actionGet();
BulkByScrollResponse response = client.execute(DeleteByQueryAction.INSTANCE, request).actionGet();

long deletedDocs = response.getDeleted();
long deletedDocs = response.getDeleted();

if (deletedDocs == 0) {
throw new ApiTokenException("No token found with name " + name);
}
});
if (deletedDocs == 0) {
throw new ApiTokenException("No token found with name " + name);
}
}

public Map<String, ApiToken> getTokenMetadatas() {
return withSecurityContext(() -> {
try {
SearchRequest searchRequest = new SearchRequest(ConfigConstants.OPENSEARCH_API_TOKENS_INDEX);
searchRequest.source(new SearchSourceBuilder());

SearchResponse response = client.search(searchRequest).actionGet();

Map<String, ApiToken> tokens = new HashMap<>();
for (SearchHit hit : response.getHits().getHits()) {
try (
XContentParser parser = XContentType.JSON.xContent()
.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
hit.getSourceRef().streamInput()
)
) {

ApiToken token = ApiToken.fromXContent(parser);
tokens.put(token.getName(), token);
}
try {
SearchRequest searchRequest = new SearchRequest(ConfigConstants.OPENSEARCH_API_TOKENS_INDEX);
searchRequest.source(new SearchSourceBuilder());

SearchResponse response = client.search(searchRequest).actionGet();

Map<String, ApiToken> tokens = new HashMap<>();
for (SearchHit hit : response.getHits().getHits()) {
try (
XContentParser parser = XContentType.JSON.xContent()
.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
hit.getSourceRef().streamInput()
)
) {

ApiToken token = ApiToken.fromXContent(parser);
tokens.put(token.getName(), token);
}
return tokens;
} catch (IOException e) {
throw new RuntimeException(e);
}
});
return tokens;
} catch (IOException e) {
throw new RuntimeException(e);
}
}

public Boolean apiTokenIndexExists() {
return clusterService.state().metadata().hasConcreteIndex(ConfigConstants.OPENSEARCH_API_TOKENS_INDEX);
}

private <T> T withSecurityContext(Supplier<T> operation) {
final var originalUserAndRemoteAddress = Utils.userAndRemoteAddressFrom(client.threadPool().getThreadContext());
try (final ThreadContext.StoredContext ctx = client.threadPool().getThreadContext().stashContext()) {
client.threadPool()
.getThreadContext()
.putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, originalUserAndRemoteAddress.getLeft());
return operation.get();
}
}

private void withSecurityContext(Runnable operation) {
final var originalUserAndRemoteAddress = Utils.userAndRemoteAddressFrom(client.threadPool().getThreadContext());
try (final ThreadContext.StoredContext ctx = client.threadPool().getThreadContext().stashContext()) {
client.threadPool()
.getThreadContext()
.putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, originalUserAndRemoteAddress.getLeft());
operation.run();
}
}

public void createApiTokenIndexIfAbsent() {
if (!apiTokenIndexExists()) {
final var originalUserAndRemoteAddress = Utils.userAndRemoteAddressFrom(client.threadPool().getThreadContext());
try (final ThreadContext.StoredContext ctx = client.threadPool().getThreadContext().stashContext()) {
client.threadPool()
.getThreadContext()
.putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, originalUserAndRemoteAddress.getLeft());
final Map<String, Object> indexSettings = ImmutableMap.of(
"index.number_of_shards",
1,
"index.auto_expand_replicas",
"0-all"
);
final CreateIndexRequest createIndexRequest = new CreateIndexRequest(ConfigConstants.OPENSEARCH_API_TOKENS_INDEX).settings(
indexSettings
);
client.admin().indices().create(createIndexRequest);
}
final Map<String, Object> indexSettings = ImmutableMap.of("index.number_of_shards", 1, "index.auto_expand_replicas", "0-all");
final CreateIndexRequest createIndexRequest = new CreateIndexRequest(ConfigConstants.OPENSEARCH_API_TOKENS_INDEX).settings(
indexSettings
);
client.admin().indices().create(createIndexRequest);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ default String build() {
.put(Endpoint.ROLES, action -> buildEndpointPermission(Endpoint.ROLES))
.put(Endpoint.ROLESMAPPING, action -> buildEndpointPermission(Endpoint.ROLESMAPPING))
.put(Endpoint.TENANTS, action -> buildEndpointPermission(Endpoint.TENANTS))
.put(Endpoint.SSL, action -> buildEndpointActionPermission(Endpoint.SSL, action))
.put(Endpoint.APITOKENS, action -> buildEndpointPermission(Endpoint.APITOKENS))
.put(Endpoint.SSL, action -> buildEndpointActionPermission(Endpoint.SSL, action))
.build();

private final ThreadContext threadContext;
Expand Down

0 comments on commit e13c055

Please sign in to comment.