Skip to content

Commit

Permalink
fixed group membership removal
Browse files Browse the repository at this point in the history
  • Loading branch information
flashguerdon committed Dec 24, 2024
1 parent 3cb11a7 commit 7c5cbe4
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 129 deletions.
9 changes: 4 additions & 5 deletions fence/blueprints/login/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,9 @@ def __init__(

self.app = app

# This block of code probably need to be made more concise
if "persist_refresh_token" in config["OPENID_CONNECT"].get(self.idp_name, {}):
self.persist_refresh_token = config["OPENID_CONNECT"][self.idp_name][
"persist_refresh_token"
]
self.persist_refresh_token = (
config["OPENID_CONNECT"].get(self.idp_name, {}).get("persist_refresh_token")
)

if "is_authz_groups_sync_enabled" in config["OPENID_CONNECT"].get(
self.idp_name, {}
Expand Down Expand Up @@ -163,6 +161,7 @@ def get(self):
# default to now + REFRESH_TOKEN_EXPIRES_IN
if expires is None:
expires = int(time.time()) + config["REFRESH_TOKEN_EXPIRES_IN"]
logger.info(self, f"Refresh token not in JWT, using default: {expires}")

# Store refresh token in db
should_persist_token = (
Expand Down
6 changes: 5 additions & 1 deletion fence/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ ENCRYPTION_KEY: ''
# //////////////////////////////////////////////////////////////////////////////////////
# flask's debug setting
# WARNING: DO NOT ENABLE IN PRODUCTION (for testing purposes only)
DEBUG: true
DEBUG: false
# if true, will automatically login a user with username "test"
# WARNING: DO NOT ENABLE IN PRODUCTION (for testing purposes only)
MOCK_AUTH: false
Expand Down Expand Up @@ -127,6 +127,10 @@ OPENID_CONNECT:
# or removed from relevant groups in the local system to ensure their group memberships
# remain up-to-date. If this flag is disabled, no group synchronization occurs
is_authz_groups_sync_enabled: true
# Key used to retrieve group information from the token
group_claim_field: "groups"
# IdP group membership expiration (seconds).
group_membership_expiration_duration: 604800
authz_groups_sync:
# This defines the prefix used to identify authorization groups.
group_prefix: "some_prefix"
Expand Down
4 changes: 4 additions & 0 deletions fence/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def post_process(self):
)

for idp_id, idp in self._configs.get("OPENID_CONNECT", {}).items():
if not isinstance(idp, dict):
raise TypeError(
"Expected 'OPENID_CONNECT' configuration to be a dictionary."
)
mfa_info = idp.get("multifactor_auth_claim_info")
if mfa_info and mfa_info["claim"] not in ["amr", "acr"]:
logger.warning(
Expand Down
51 changes: 37 additions & 14 deletions fence/error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,44 @@


def get_error_response(error: Exception):
"""
Generates a response for the given error with detailed logs and appropriate status codes.
Args:
error (Exception): The error that occurred.
Returns:
Tuple (str, int): Rendered error HTML and HTTP status code.
"""
details, status_code = get_error_details_and_status(error)
support_email = config.get("SUPPORT_EMAIL_FOR_ERRORS")
app_name = config.get("APP_NAME", "Gen3 Data Commons")

message = details.get("message")

error_id = _get_error_identifier()
logger.error(
"{} HTTP error occured. ID: {}\nDetails: {}".format(
status_code, error_id, str(details)
"{} HTTP error occurred. ID: {}\nDetails: {}\nTraceback: {}".format(
status_code, error_id, details, traceback.format_exc()
)
)

# TODO: Issue: Error messages are obfuscated, the line below needs be
# uncommented when troubleshooting errors.
# Breaks tests if not commented out / removed. We need a fix for this.
# raise error
# Decide whether to re-raise errors or handle gracefully based on the debug flag
debug_mode = config.get("DEBUG", False)

if debug_mode:
# Re-raise the error in debug mode for troubleshooting
raise error

# don't include internal details in the public error message
# to do this, only include error messages for known http status codes
# that are less that 500
# Prepare user-facing message
message = details.get("message")
valid_http_status_codes = [
int(code) for code in list(http_responses.keys()) if int(code) < 500
]

try:
status_code = int(status_code)
if status_code not in valid_http_status_codes:
message = None
except (ValueError, TypeError):
# this handles case where status_code is NOT a valid integer (e.g. HTTP status code)
message = None
status_code = 500

Expand All @@ -65,6 +73,15 @@ def get_error_response(error: Exception):


def get_error_details_and_status(error):
"""
Extracts details and HTTP status code from the given error.
Args:
error (Exception): The error to process.
Returns:
Tuple (dict, int): Error details as a dictionary and HTTP status code.
"""
message = error.message if hasattr(error, "message") else str(error)
if isinstance(error, APIError):
if hasattr(error, "json") and error.json:
Expand All @@ -76,11 +93,11 @@ def get_error_details_and_status(error):
error_response = {"message": error.description}, error.status_code
elif isinstance(error, HTTPException):
error_response = (
{"message": getattr(error, "description")},
{"message": getattr(error, "description", str(error))},
error.get_response().status_code,
)
else:
logger.exception("Catch exception")
logger.exception("Unexpected exception occurred")
error_code = 500
if hasattr(error, "code"):
error_code = error.code
Expand All @@ -92,4 +109,10 @@ def get_error_details_and_status(error):


def _get_error_identifier():
"""
Generates a unique identifier for tracking the error.
Returns:
UUID: A unique identifier for the error.
"""
return uuid.uuid4()
9 changes: 1 addition & 8 deletions fence/job/access_token_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
logger = get_logger(__name__, log_level="debug")


class AccessTokenUpdater(object):
class TokenAndAuthUpdater(object):
def __init__(
self,
chunk_size=None,
Expand Down Expand Up @@ -51,14 +51,8 @@ def __init__(
self.oidc_clients_requiring_token_refresh = {}

# keep this as a special case, because RAS will not set group information configuration.
# Initialize visa clients:
oidc = config.get("OPENID_CONNECT", {})

if not isinstance(oidc, dict):
raise TypeError(
"Expected 'OPENID_CONNECT' configuration to be a dictionary."
)

if "ras" not in oidc:
self.logger.error("RAS client not configured")
else:
Expand Down Expand Up @@ -96,7 +90,6 @@ async def update_tokens(self, db_session):
"""
start_time = time.time()
# Change this line to reflect we are refreshing tokens, not just visas
self.logger.info("Initializing Visa Update and Token refreshing Cronjob . . .")
self.logger.info("Total concurrency size: {}".format(self.concurrency))
self.logger.info("Total thread pool size: {}".format(self.thread_pool_size))
Expand Down
144 changes: 57 additions & 87 deletions fence/resources/openid/idp_oauth2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from email.policy import default

from authlib.integrations.requests_client import OAuth2Session
from boto3 import client
from cached_property import cached_property
Expand Down Expand Up @@ -94,38 +96,6 @@ def get_jwt_keys(self, jwks_uri):
return None
return resp.json()["keys"]

def get_raw_token_claims(self, token_id):
"""Extracts unvalidated claims from a JWT (JSON Web Token).
This function decodes a JWT and extracts claims without verifying
the token's signature or audience. It is intended for cases where
access to the raw, unvalidated token claims is sufficient.
Args:
token_id (str): The JWT token from which to extract claims.
Returns:
dict: A dictionary of token claims if decoding is successful.
Raises:
JWTError: If there is an error decoding the token without validation.
Notes:
This function does not perform any validation of the token. It should
only be used in contexts where validation is not critical or is handled
elsewhere in the application.
"""
try:
# Decode without verification
unvalidated_claims = jwt.decode(
token_id, options={"verify_signature": False}
)
self.logger.info("Raw token claims extracted successfully.")
return unvalidated_claims
except JWTError as e:
self.logger.error(f"Error extracting claims: {e}")
raise JWTError("Unable to decode the token without validation.")

def decode_and_validate_token(self, token_id, keys, audience, verify_aud=True):
"""Decodes and validates a JWT (JSON Web Token) using provided keys and audience.
Expand Down Expand Up @@ -279,7 +249,8 @@ def get_auth_info(self, code):

if self.read_authz_groups_from_tokens:
try:
groups = claims.get("groups")
group_claim_field = self.settings.get("group_claim_field", "groups")
groups = claims.get(group_claim_field)
group_prefix = self.settings.get("authz_groups_sync", {}).get(
"group_prefix", ""
)
Expand Down Expand Up @@ -315,15 +286,15 @@ def get_access_token(self, user, token_endpoint, db_session=None):
"""
Get access_token using a refresh_token and store new refresh in upstream_refresh_token table.
"""
# this function is not correct. use self.session.fetch_access_token,
# validate the token for audience and then return the validated token.
# Still store the refresh token. it will be needed for periodic re-fetching of information.
refresh_token = None
expires = None
# get refresh_token and expiration from db

# Get the refresh_token and expiration from the database
for row in sorted(user.upstream_refresh_tokens, key=lambda row: row.expires):
refresh_token = row.refresh_token
expires = row.expires

# Check if the token is expired
if time.time() > expires:
# reset to check for next token
refresh_token = None
Expand All @@ -336,21 +307,29 @@ def get_access_token(self, user, token_endpoint, db_session=None):
if not refresh_token:
raise AuthError("User doesn't have a valid, non-expired refresh token")

token_response = self.session.refresh_token(
url=token_endpoint,
proxies=self.get_proxies(),
refresh_token=refresh_token,
)
refresh_token = token_response["refresh_token"]
try:
token_response = self.session.refresh_token(
url=token_endpoint,
proxies=self.get_proxies(),
refresh_token=refresh_token,
)

self.store_refresh_token(
user,
refresh_token=refresh_token,
expires=expires,
db_session=db_session,
)
refresh_token = token_response["refresh_token"]
# Fetching the expires at from token_response.
# Defaulting to 1 hour if not available.
expires_at = token_response.get("expires_at", time.time() + 3600)

self.store_refresh_token(
user,
refresh_token=refresh_token,
expires=expires_at,
db_session=db_session,
)

return token_response
return token_response
except Exception as e:
self.logger.exception(f"Error refreshing token for user {user.id}: {e}")
raise AuthError("Failed to refresh access token.")

def has_mfa_claim(self, decoded_id_token):
"""
Expand Down Expand Up @@ -405,8 +384,24 @@ def store_refresh_token(self, user, refresh_token, expires, db_session=None):
db_session.commit()

def get_groups_from_token(self, decoded_id_token, group_prefix=""):
"""Retrieve and format groups from the decoded token."""
authz_groups_from_idp = decoded_id_token.get("groups", [])
"""
Retrieve and format groups from the decoded token based on a configurable field name.
Args:
decoded_id_token (dict): The decoded token containing claims.
group_prefix (str): The prefix to strip from group names.
Returns:
list: A list of formatted group names.
Variables:
group_claim_field (str): The field name in the token that contains the group information.
authz_groups_from_idp (list): The list of groups retrieved from the token, potentially empty.
"""
# Retrieve the configured field name for groups, defaulting to 'groups'
group_claim_field = self.settings.get("group_claim_field", "groups")
authz_groups_from_idp = decoded_id_token.get(group_claim_field, [])

if authz_groups_from_idp:
authz_groups_from_idp = [
group.removeprefix(group_prefix).lstrip("/")
Expand Down Expand Up @@ -455,9 +450,6 @@ def update_user_authorization(self, user, pkey_cache, db_session=None, **kwargs)
"""
db_session = db_session or current_app.scoped_session()

# Initialize the failure flag for group removal
removal_failed = False

expires_at = None

try:
Expand Down Expand Up @@ -505,48 +497,26 @@ def update_user_authorization(self, user, pkey_cache, db_session=None, **kwargs)

idp_group_names = set(authz_groups_from_idp)

# Expiration for group membership. Default 7 days
group_membership_duration = self.settings.get(
"group_membership_expiration_duration", 3600 * 24 * 7
)
group_membership_expires_at = datetime.datetime.now(
tz=datetime.timezone.utc
) + datetime.timedelta(seconds=group_membership_duration)

# Add user to all matching groups from IDP
for arborist_group in arborist_groups:
if arborist_group["name"] in idp_group_names:
self.logger.info(
f"Adding {user.username} to group: {arborist_group['name']}, sub: {user.id} exp: {exp}"
f"Adding {user.username} to group: {arborist_group['name']}, sub: {user.id} exp: {group_membership_expires_at}"
)
self.arborist.add_user_to_group(
username=user.username,
group_name=arborist_group["name"],
expires_at=exp,
expires_at=group_membership_expires_at,
)

# Remove user from groups in Arborist that they are not part of in IDP
for arborist_group in arborist_groups:
if arborist_group["name"] not in idp_group_names:
if user.username in arborist_group.get("users", []):
try:
self.remove_user_from_arborist_group(
user.username, arborist_group["name"]
)
except Exception as e:
self.logger.error(
f"Failed to remove {user.username} from group {arborist_group['name']}: {e}"
)
removal_failed = (
# Set the failure flag if any removal fails
True
)

else:
self.logger.warning(
f"is_authz_groups_sync_enabled feature is enabled, but did not receive groups from idp {self.idp} for user: {user.username}"
)

# Raise an exception if any group removal failed
if removal_failed:
raise Exception("One or more group removals failed.")

def remove_user_from_arborist_group(self, username, group_name):
"""
Attempt to remove a user from an Arborist group, catching any errors to allow
processing of remaining groups. Logs errors and re-raises them after all removals are attempted.
"""
self.logger.info(f"Removing {username} from group: {group_name}")
self.arborist.remove_user_from_group(username=username, group_name=group_name)
Loading

0 comments on commit 7c5cbe4

Please sign in to comment.