From 25cf1222a56dbc263129c2fd97d74240fcdc6095 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Mon, 22 Jan 2024 20:19:51 -0600 Subject: [PATCH] changes to hidden model code to use OPENDISTRO_SECURITY_USER instad of ssl principal (#1897) (#1900) * changes to hidden model code to use OPENDISTRO_SECURITY_USER instad of ssl principal Signed-off-by: Bhavana Ramaram (cherry picked from commit de59efc80d2a4b53917edfd941ba14bf99956639) --- .../opensearch/ml/utils/RestActionUtils.java | 35 ++++++++++++++----- .../ml/rest/RestMLRemoteInferenceIT.java | 1 - .../ml/utils/RestActionUtilsTests.java | 10 ++++-- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java index 36db5ac317..2c54e87ce2 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java @@ -8,6 +8,9 @@ import static org.opensearch.ml.common.MLModel.MODEL_CONTENT_FIELD; import static org.opensearch.ml.common.MLModel.OLD_MODEL_CONTENT_FIELD; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; @@ -30,7 +33,6 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Nullable; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -44,6 +46,8 @@ import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.search.internal.InternalSearchResponse; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; @@ -71,9 +75,12 @@ public class RestActionUtils { public static final String PARAMETER_TOOL_NAME = "tool_name"; public static final String OPENDISTRO_SECURITY_CONFIG_PREFIX = "_opendistro_security_"; - public static final String OPENDISTRO_SECURITY_SSL_PRINCIPAL = OPENDISTRO_SECURITY_CONFIG_PREFIX + "ssl_principal"; + + public static final String OPENDISTRO_SECURITY_USER = OPENDISTRO_SECURITY_CONFIG_PREFIX + "user"; static final Set adminDn = new HashSet<>(); + static final Set adminUsernames = new HashSet(); + static final ObjectMapper objectMapper = new ObjectMapper(); public static String getAlgorithm(RestRequest request) { String algorithm = request.param(PARAMETER_ALGORITHM); @@ -212,7 +219,7 @@ public static Optional getStringParam(RestRequest request, String paramN */ public static User getUserContext(Client client) { String userStr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - logger.debug("Filtering result by " + userStr); + logger.debug("Current user is " + userStr); return User.parse(userStr); } @@ -226,13 +233,25 @@ public static boolean isSuperAdminUser(ClusterService clusterService, Client cli logger.debug("{} is registered as an admin dn", dn); adminDn.add(new LdapName(dn)); } catch (final InvalidNameException e) { - logger.error("Unable to parse admin dn {}", dn, e); + logger.debug("Unable to parse admin dn {}", dn, e); + adminUsernames.add(dn); } } - ThreadContext threadContext = client.threadPool().getThreadContext(); - final String sslPrincipal = threadContext.getTransient(OPENDISTRO_SECURITY_SSL_PRINCIPAL); - return isAdminDN(sslPrincipal); + Object userObject = client.threadPool().getThreadContext().getTransient(OPENDISTRO_SECURITY_USER); + if (userObject == null) + return false; + try { + return AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + String userContext = objectMapper.writeValueAsString(userObject); + final JsonNode node = objectMapper.readTree(userContext); + final String userName = node.get("name").asText(); + + return isAdminDN(userName); + }); + } catch (PrivilegedActionException e) { + throw new RuntimeException(e); + } } private static boolean isAdminDN(String dn) { @@ -241,7 +260,7 @@ private static boolean isAdminDN(String dn) { try { return isAdminDN(new LdapName(dn)); } catch (InvalidNameException e) { - return false; + return adminUsernames.contains(dn); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index ba69c59432..0c59f3d2bc 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -94,7 +94,6 @@ public void testDeleteConnector() throws IOException { assertEquals("deleted", (String) responseMap.get("result")); } - public void testSearchConnectors_beforeConnectorCreation() throws IOException { String searchEntity = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " },\n" + " \"size\": 1000\n" + "}"; Response response = TestHelper diff --git a/plugin/src/test/java/org/opensearch/ml/utils/RestActionUtilsTests.java b/plugin/src/test/java/org/opensearch/ml/utils/RestActionUtilsTests.java index 22947b6407..bf2714c618 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/RestActionUtilsTests.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/RestActionUtilsTests.java @@ -297,16 +297,20 @@ public void testIsSuperAdminUser() { ThreadContext threadContext = new ThreadContext(Settings.EMPTY); when(clusterService.getSettings()) - .thenReturn(Settings.builder().putList(RestActionUtils.SECURITY_AUTHCZ_ADMIN_DN, "cn=admin").build()); + .thenReturn( + Settings.builder().putList(RestActionUtils.SECURITY_AUTHCZ_ADMIN_DN, "CN=kirk,OU=client,O=client,L=test, C=de").build() + ); when(client.threadPool()).thenReturn(mock(ThreadPool.class)); when(client.threadPool().getThreadContext()).thenReturn(threadContext); - threadContext.putTransient(RestActionUtils.OPENDISTRO_SECURITY_SSL_PRINCIPAL, "cn=admin"); + threadContext.putTransient(RestActionUtils.OPENDISTRO_SECURITY_USER, Map.of("name", "CN=kirk,OU=client,O=client,L=test,C=de")); boolean isAdmin = RestActionUtils.isSuperAdminUser(clusterService, client); Assert.assertTrue(isAdmin); } + // Need to add a test case to cover non Ldap user + @Test public void testIsSuperAdminUser_NotAdmin() { ClusterService clusterService = mock(ClusterService.class); @@ -317,7 +321,7 @@ public void testIsSuperAdminUser_NotAdmin() { .thenReturn(Settings.builder().putList(RestActionUtils.SECURITY_AUTHCZ_ADMIN_DN, "cn=admin").build()); when(client.threadPool()).thenReturn(mock(ThreadPool.class)); when(client.threadPool().getThreadContext()).thenReturn(threadContext); - threadContext.putTransient(RestActionUtils.OPENDISTRO_SECURITY_SSL_PRINCIPAL, "cn=notadmin"); + threadContext.putTransient(RestActionUtils.OPENDISTRO_SECURITY_USER, Map.of("name", "nonAdmin")); boolean isAdmin = RestActionUtils.isSuperAdminUser(clusterService, client); Assert.assertFalse(isAdmin);