Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NIFI-13016 Add groups mapping from OIDC token claim for Registry #9566

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
import java.io.IOException;
import java.io.StringWriter;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;

public class StandardManagedAuthorizer implements ManagedAuthorizer {

Expand Down Expand Up @@ -95,19 +98,29 @@ public AuthorizationResult authorize(AuthorizationRequest request) throws Author

final UserAndGroups userAndGroups = userGroupProvider.getUserAndGroups(request.getIdentity());

final User user = userAndGroups.getUser();
if (user == null) {
return AuthorizationResult.denied(String.format("Unknown user with identity '%s'.", request.getIdentity()));
}
// combine groups from incoming request with groups from UserAndGroups because the request may contain groups from
// an external identity provider and the membership may not be maintained within any of the UserGroupProviders
final Set<Group> userGroups = new HashSet<>();
userGroups.addAll(userAndGroups.getGroups() == null ? Collections.emptySet() : userAndGroups.getGroups());
userGroups.addAll(getGroups(request.getGroups()));

final Set<Group> userGroups = userAndGroups.getGroups();
if (policy.getUsers().contains(user.getIdentifier()) || containsGroup(userGroups, policy)) {
if (containsUser(userAndGroups.getUser(), policy) || containsGroup(userGroups, policy)) {
return AuthorizationResult.approved();
}

return AuthorizationResult.denied(request.getExplanationSupplier().get());
}

private Set<Group> getGroups(final Set<String> groupNames) {
if (groupNames == null || groupNames.isEmpty()) {
return Collections.emptySet();
}

return userGroupProvider.getGroups().stream()
.filter(group -> groupNames.contains(group.getName()))
.collect(Collectors.toSet());
}

/**
* Determines if the policy contains one of the user's groups.
*
Expand All @@ -129,6 +142,20 @@ private boolean containsGroup(final Set<Group> userGroups, final AccessPolicy po
return false;
}

/**
* Determines if the policy contains the user's identifier.
*
* @param user the user
* @param policy the policy
* @return true if the user is non-null and the user's identifies is contained in the policy's users
*/
private boolean containsUser(final User user, final AccessPolicy policy) {
if (user == null || policy.getUsers().isEmpty()) {
return false;
}
return policy.getUsers().contains(user.getIdentifier());
}

@Override
public String getFingerprint() throws AuthorizationAccessException {
XMLStreamWriter writer = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ public class NiFiRegistryProperties extends ApplicationProperties {
public static final String SECURITY_USER_OIDC_PREFERRED_JWSALGORITHM = "nifi.registry.security.user.oidc.preferred.jwsalgorithm";
public static final String SECURITY_USER_OIDC_ADDITIONAL_SCOPES = "nifi.registry.security.user.oidc.additional.scopes";
public static final String SECURITY_USER_OIDC_CLAIM_IDENTIFYING_USER = "nifi.registry.security.user.oidc.claim.identifying.user";
public static final String SECURITY_USER_OIDC_CLAIM_GROUPS = "nifi.registry.security.user.oidc.claim.groups";

// Revision Management Properties
public static final String REVISIONS_ENABLED = "nifi.registry.revisions.enabled";
Expand Down Expand Up @@ -481,6 +482,16 @@ public List<String> getOidcAdditionalScopes() {
public String getOidcClaimIdentifyingUser() {
return getProperty(SECURITY_USER_OIDC_CLAIM_IDENTIFYING_USER, "email").trim();
}
/**
* Returns the claim to be used to extract user groups from the OIDC payload.
* Claim must be requested by adding the scope for it.
* Default is 'groups'.
*
* @return The claim to be used to extract user groups.
*/
public String getOidcClaimGroups() {
return getProperty(SECURITY_USER_OIDC_CLAIM_GROUPS, "groups").trim();
}

/**
* Returns the network interface list to use for HTTPS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.apache.nifi.registry.security.authentication;

import java.io.Serializable;
import java.util.Collections;
import java.util.Set;

/**
* Authentication response for a user login attempt.
Expand All @@ -27,6 +29,7 @@ public class AuthenticationResponse implements Serializable {
private final String username;
private final long expiration;
private final String issuer;
private final Set<String> groups;

/**
* Creates an authentication response. The username and how long the authentication is valid in milliseconds
Expand All @@ -37,10 +40,24 @@ public class AuthenticationResponse implements Serializable {
* @param issuer The issuer of the token
*/
public AuthenticationResponse(final String identity, final String username, final long expiration, final String issuer) {
this(identity, username, expiration, issuer, Collections.emptySet());
}

/**
* Creates an authentication response. The username and how long the authentication is valid in milliseconds
*
* @param identity The user identity
* @param username The username
* @param expiration The expiration in milliseconds
* @param issuer The issuer of the token
* @param groups The user groups
*/
public AuthenticationResponse(final String identity, final String username, final long expiration, final String issuer, final Set<String> groups) {
this.identity = identity;
this.username = username;
this.expiration = expiration;
this.issuer = issuer;
this.groups = groups;
}

public String getIdentity() {
Expand All @@ -64,6 +81,10 @@ public long getExpiration() {
return expiration;
}

public Set<String> getGroups() {
return groups;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.Collections;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class IdentityAuthenticationProvider implements AuthenticationProvider {

Expand Down Expand Up @@ -94,7 +95,7 @@ protected AuthenticationSuccessToken buildAuthenticatedToken(
return new AuthenticationSuccessToken(new NiFiUserDetails(
new StandardNiFiUser.Builder()
.identity(mappedIdentity)
.groups(getUserGroups(mappedIdentity))
.groups(getUserGroups(mappedIdentity, response))
.clientAddress(requestToken.getClientAddress())
.build()));
}
Expand All @@ -112,6 +113,12 @@ protected Set<String> getUserGroups(final String identity) {
return getUserGroups(authorizer, identity);
}

protected Set<String> getUserGroups(final String identity, AuthenticationResponse response) {
return Stream
.concat(getUserGroups(authorizer, identity).stream(), response.getGroups().stream())
.collect(Collectors.toSet());
}

private static Set<String> getUserGroups(final Authorizer authorizer, final String userIdentity) {
if (authorizer instanceof ManagedAuthorizer) {
final ManagedAuthorizer managedAuthorizer = (ManagedAuthorizer) authorizer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.nifi.registry.web.security.authentication.jwt;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.JwtException;
import org.apache.nifi.registry.properties.NiFiRegistryProperties;
import org.apache.nifi.registry.security.authentication.AuthenticationRequest;
Expand All @@ -34,6 +36,7 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.Set;
import java.util.concurrent.TimeUnit;

@Component
Expand Down Expand Up @@ -61,16 +64,19 @@ public AuthenticationResponse authenticate(AuthenticationRequest authenticationR
}

final Object credentials = authenticationRequest.getCredentials();
String jwtAuthToken = credentials != null && credentials instanceof String ? (String) credentials : null;

if (credentials == null) {
logger.info("JWT not found in authenticationRequest credentials, returning null.");
return null;
}

try {
final String jwtPrincipal = jwtService.getUserIdentityFromToken(jwtAuthToken);
return new AuthenticationResponse(jwtPrincipal, jwtPrincipal, expiration, issuer);
String jwtAuthToken = credentials instanceof String ? (String) credentials : null;
hazmat345 marked this conversation as resolved.
Show resolved Hide resolved
final Jws<Claims> jws = jwtService.parseAndValidateToken(jwtAuthToken);

final String jwtPrincipal = jwtService.getUserIdentityFromToken(jws);
final Set<String> groups = jwtService.getUserGroupsFromToken(jws);

return new AuthenticationResponse(jwtPrincipal, jwtPrincipal, expiration, issuer, groups);
} catch (JwtException e) {
throw new InvalidAuthenticationException(e.getMessage(), e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@
import org.springframework.stereotype.Service;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.TimeUnit;

@Service
Expand All @@ -48,6 +52,7 @@ public class JwtService {
private static final MacAlgorithm SIGNATURE_ALGORITHM = Jwts.SIG.HS256;
private static final String KEY_ID_CLAIM = "kid";
private static final String USERNAME_CLAIM = "preferred_username";
private static final String GROUPS_CLAIM = "groups";

private final KeyService keyService;

Expand All @@ -56,7 +61,7 @@ public JwtService(final KeyService keyService) {
this.keyService = keyService;
}

public String getUserIdentityFromToken(final String base64EncodedToken) throws JwtException {
public Jws<Claims> parseAndValidateToken(final String base64EncodedToken) throws JwtException {
// The library representations of the JWT should be kept internal to this service.
try {
final Jws<Claims> jws = parseTokenFromBase64EncodedString(base64EncodedToken);
Expand All @@ -74,14 +79,24 @@ public String getUserIdentityFromToken(final String base64EncodedToken) throws J
if (StringUtils.isEmpty(jws.getPayload().getIssuer())) {
throw new JwtException("No issuer available in token");
}
return jws.getPayload().getSubject();

return jws;
} catch (JwtException e) {
final String errorMessage = "There was an error validating the JWT";
logger.error(errorMessage, e);
throw e;
throw new JwtException("There was an error validating the JWT", e);
}
}

public String getUserIdentityFromToken(final Jws<Claims> jws) throws JwtException {
return jws.getPayload().getSubject();
}

public Set<String> getUserGroupsFromToken(final Jws<Claims> jws) throws JwtException {
@SuppressWarnings("unchecked")
ArrayList<String> groupsString = jws.getPayload().get(GROUPS_CLAIM, ArrayList.class);
hazmat345 marked this conversation as resolved.
Show resolved Hide resolved

return new HashSet<>(groupsString);
}

private Jws<Claims> parseTokenFromBase64EncodedString(final String base64EncodedToken) throws JwtException {
try {
return Jwts.parser().setSigningKeyResolver(new SigningKeyResolverAdapter() {
Expand Down Expand Up @@ -125,11 +140,15 @@ public String generateSignedToken(final AuthenticationResponse authenticationRes
authenticationResponse.getUsername(),
authenticationResponse.getIssuer(),
authenticationResponse.getIssuer(),
authenticationResponse.getExpiration());
authenticationResponse.getExpiration(),
null);
}

public String generateSignedToken(String identity, String preferredUsername, String issuer, String audience, long expirationMillis) throws JwtException {
return this.generateSignedToken(identity, preferredUsername, issuer, audience, expirationMillis, null);
}

public String generateSignedToken(String identity, String preferredUsername, String issuer, String audience, long expirationMillis, Collection<String> groups) throws JwtException {
if (identity == null || StringUtils.isEmpty(identity)) {
String errorMessage = "Cannot generate a JWT for a token with an empty identity";
errorMessage = issuer != null ? errorMessage + " issued by " + issuer + "." : ".";
Expand All @@ -155,6 +174,7 @@ public String generateSignedToken(String identity, String preferredUsername, Str
.audience().add(audience).and()
.claim(USERNAME_CLAIM, preferredUsername)
.claim(KEY_ID_CLAIM, key.getId())
.claim(GROUPS_CLAIM, groups)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This groups claim value can be null in some cases, such as authentication with LDAP. Instead of adding a groups claim with a null value, either an empty list should be set, or the claim should be omitted. As this is the JWT that NiFi Registry generates, it seems better to set an empty list when the input groups parameter is null.

.issuedAt(now.getTime())
.expiration(expiration.getTime())
.signWith(Keys.hmacShaKeyFor(keyBytes), SIGNATURE_ALGORITHM).compact();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,10 @@ private String convertOIDCTokenToNiFiToken(OIDCTokenResponse response) throws Ba
String identityClaim = properties.getOidcClaimIdentifyingUser();
String identity = claimsSet.getStringClaim(identityClaim);

// Attempt to extract groups from the configured claim; default is 'groups'
String groupsClaim = properties.getOidcClaimGroups();
List<String> groups = claimsSet.getStringListClaim(groupsClaim);
hazmat345 marked this conversation as resolved.
Show resolved Hide resolved

// If default identity not available, attempt secondary identity extraction
if (StringUtils.isBlank(identity)) {
// Provide clear message to admin that desired claim is missing and present available claims
Expand All @@ -425,7 +429,7 @@ private String convertOIDCTokenToNiFiToken(OIDCTokenResponse response) throws Ba
final String issuer = claimsSet.getIssuer().getValue();

// convert into a nifi jwt for retrieval later
return jwtService.generateSignedToken(identity, identity, issuer, issuer, expiresIn);
return jwtService.generateSignedToken(identity, identity, issuer, issuer, expiresIn, groups);
}

private String retrieveIdentityFromUserInfoEndpoint(OIDCTokens oidcTokens) throws IOException {
Expand Down
Loading