forked from SUNET/cnaas-nms
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
seperate security.py into multiple different files
- Loading branch information
Josephine
committed
Mar 1, 2024
1 parent
d774064
commit 9c300cb
Showing
6 changed files
with
242 additions
and
220 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
class SingletonType(type): | ||
_instances = {} | ||
|
||
def __call__(cls, *args, **kwargs): | ||
if cls not in cls._instances: | ||
cls._instances[cls] = super(SingletonType, cls).__call__(*args, **kwargs) | ||
return cls._instances[cls] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import json | ||
import time | ||
from typing import Any | ||
|
||
from redis.exceptions import RedisError | ||
|
||
from cnaas_nms.db.session import redis_session | ||
from cnaas_nms.tools.log import get_logger | ||
from cnaas_nms.tools.oidc.token import Token | ||
|
||
logger = get_logger() | ||
|
||
|
||
REDIS_OAUTH_TOKEN_INFO_KEY = "oauth_userinfo" | ||
|
||
|
||
def get_token_info_from_cache(token: Token) -> Any: | ||
"""Check if the userinfo is in the cache to avoid multiple calls to the OIDC server""" | ||
try: | ||
with redis_session() as redis: | ||
cached_token_info = redis.hget(REDIS_OAUTH_TOKEN_INFO_KEY, token.decoded_token["sub"]) | ||
if cached_token_info: | ||
return json.loads(cached_token_info) | ||
except RedisError as e: | ||
logger.debug("Redis cache error: {}".format(str(e))) | ||
except (TypeError, KeyError) as e: | ||
logger.debug("Error while getting userinfo cache: {}".format(str(e))) | ||
|
||
|
||
def put_token_info_in_cache(token: Token, token_info) -> Any: | ||
"""Put the userinfo in the cache to avoid multiple calls to the OIDC server""" | ||
try: | ||
with redis_session() as redis: | ||
if "exp" in token.decoded_token: | ||
redis.hsetnx(REDIS_OAUTH_TOKEN_INFO_KEY, token.decoded_token["sub"], token_info) | ||
# expire hash at access_token expiry time or 1 hour from now (whichever is sooner) | ||
# Entire hash is expired, since redis does not support expiry on individual keys | ||
expire_at = min(int(token.decoded_token["exp"]), int(time.time()) + 3600) | ||
redis.expireat(REDIS_OAUTH_TOKEN_INFO_KEY, when=expire_at, lt=True) | ||
except RedisError as e: | ||
logger.debug("Redis cache error: {}".format(str(e))) | ||
except (TypeError, KeyError) as e: | ||
logger.debug("Error while getting userinfo cache: {}".format(str(e))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from typing import Any, Mapping, Optional | ||
|
||
import requests | ||
from jwt.exceptions import InvalidKeyError | ||
|
||
from cnaas_nms.models.singleton import SingletonType | ||
from cnaas_nms.tools.log import get_logger | ||
from cnaas_nms.tools.oidc.oidc_client_call import get_openid_configuration | ||
|
||
logger = get_logger() | ||
|
||
|
||
class JWKSStore(object, metaclass=SingletonType): | ||
keys: Mapping[str, Any] | ||
|
||
def __init__(self, keys: Optional[Mapping[str, Any]] = None): | ||
if keys: | ||
self.keys = keys | ||
else: | ||
self.keys = {} | ||
|
||
|
||
def get_keys(): | ||
"""Get the keys for the OIDC decoding""" | ||
try: | ||
session = requests.Session() | ||
openid_configuration = get_openid_configuration(session) | ||
keys_endpoint = openid_configuration["jwks_uri"] | ||
response = session.get(url=keys_endpoint) | ||
jwks_store = JWKSStore() | ||
jwks_store.keys = response.json()["keys"] | ||
except KeyError as e: | ||
raise InvalidKeyError(e) | ||
except requests.exceptions.HTTPError: | ||
raise ConnectionError("Can't retrieve keys") | ||
|
||
|
||
def get_key(kid): | ||
"""Get the key based on the kid""" | ||
jwks_store = JWKSStore() | ||
key = [k for k in jwks_store.keys if k["kid"] == kid] | ||
if len(key) == 0: | ||
logger.debug("Key not found. Get the keys.") | ||
get_keys() | ||
if len(jwks_store.keys) == 0: | ||
logger.error("Keys not downloaded") | ||
raise ConnectionError("Can't retrieve keys") | ||
try: | ||
key = [k for k in jwks_store.keys if k["kid"] == kid] | ||
except KeyError as e: | ||
logger.error("Keys in different format?") | ||
raise InvalidKeyError(e) | ||
if len(key) == 0: | ||
logger.error("Key not in keys") | ||
raise InvalidKeyError() | ||
return key |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import json | ||
from typing import Any | ||
|
||
import requests | ||
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError | ||
from requests.auth import HTTPBasicAuth | ||
|
||
from cnaas_nms.app_settings import auth_settings | ||
from cnaas_nms.tools.cache import get_token_info_from_cache, put_token_info_in_cache | ||
from cnaas_nms.tools.log import get_logger | ||
from cnaas_nms.tools.oidc.token import Token | ||
|
||
logger = get_logger() | ||
|
||
|
||
def get_openid_configuration(session: requests.Session) -> dict: | ||
"""Get the openid configuration""" | ||
try: | ||
request_openid_configuration = session.get(auth_settings.OIDC_CONF_WELL_KNOWN_URL) | ||
request_openid_configuration.raise_for_status() | ||
openid_configuration = request_openid_configuration.json() | ||
return openid_configuration | ||
except requests.exceptions.HTTPError: | ||
raise ConnectionError("Can't reach the OIDC URL") | ||
except requests.exceptions.ConnectionError: | ||
raise ConnectionError("OIDC metadata unavailable") | ||
except requests.exceptions.JSONDecodeError as e: | ||
raise InvalidTokenError("Invalid JSON in openid Config response: {}".format(str(e))) | ||
|
||
|
||
def get_token_info_from_userinfo(session: requests.Session, token: Token, user_info_endpoint: str) -> Any: | ||
"""Get token info from userinfo""" | ||
try: | ||
userinfo_data = {"token_type_hint": "access_token"} | ||
userinfo_headers = {"Authorization": "Bearer " + token.token_string} | ||
userinfo_resp = session.post(user_info_endpoint, data=userinfo_data, headers=userinfo_headers) | ||
userinfo_resp.raise_for_status() | ||
userinfo_resp.json() | ||
token_info = userinfo_resp.text | ||
return token_info | ||
except requests.exceptions.HTTPError as e: | ||
try: | ||
body = json.loads(e.response.content) | ||
logger.debug("OIDC userinfo endpoint request not successful: " + body["error_description"]) | ||
raise InvalidTokenError(body["error_description"]) | ||
except (json.decoder.JSONDecodeError, KeyError): | ||
logger.debug("OIDC userinfo endpoint request not successful: {}".format(str(e))) | ||
raise InvalidTokenError(str(e)) | ||
except requests.exceptions.JSONDecodeError as e: | ||
raise InvalidTokenError("Invalid JSON in userinfo response: {}".format(str(e))) | ||
|
||
|
||
def get_token_info_from_introspect(session: requests.Session, token: Token, introspection_endpoint: str) -> Any: | ||
"""Get token info from introspect""" | ||
try: | ||
introspect_data = {"token": token.token_string} | ||
introspect_auth = HTTPBasicAuth(auth_settings.OIDC_CLIENT_ID, auth_settings.OIDC_CLIENT_SECRET) | ||
introspect_resp = session.post(introspection_endpoint, data=introspect_data, auth=introspect_auth) | ||
introspect_resp.raise_for_status() | ||
introspect_json = introspect_resp.json() | ||
if "active" in introspect_json and introspect_json["active"]: | ||
token_info = introspect_resp.text | ||
return token_info | ||
else: | ||
raise ExpiredSignatureError("Token is no longer active") | ||
|
||
except requests.exceptions.HTTPError as e: | ||
try: | ||
body = json.loads(e.response.content) | ||
logger.debug("OIDC introspection endpoint request not successful: " + body["error_description"]) | ||
raise InvalidTokenError(body["error_description"]) | ||
except (json.decoder.JSONDecodeError, KeyError): | ||
logger.debug("OIDC introspection endpoint request not successful: {}".format(str(e))) | ||
raise InvalidTokenError(str(e)) | ||
except requests.exceptions.JSONDecodeError as e: | ||
raise InvalidTokenError("Invalid JSON in introspection response: {}".format(str(e))) | ||
|
||
|
||
def get_oauth_token_info(token: Token) -> Any: | ||
"""Give back the details about the token from userinfo or introspection | ||
If OIDC is disabled, we return None. | ||
For authorization code access_tokens we can use userinfo endpoint, | ||
for client_credentials we can use introspection endpoint. | ||
Returns: | ||
resp.json(): Object of the user info or introspection | ||
""" | ||
# For now unnecessary, useful when we only use one log in method | ||
if not auth_settings.OIDC_ENABLED: | ||
return None | ||
|
||
# Get the cached token info | ||
|
||
cached_token_info = get_token_info_from_cache(token) | ||
if cached_token_info: | ||
return cached_token_info | ||
|
||
# Get the openid-configuration | ||
session = requests.Session() | ||
openid_configuration = get_openid_configuration(session) | ||
|
||
# Request the userinfo | ||
try: | ||
token_info = get_token_info_from_userinfo(session, token, openid_configuration["userinfo_endpoint"]) | ||
except requests.exceptions.HTTPError: | ||
# if the userinfo doesn't work, try the introspectinfo | ||
introspect_endpoint = openid_configuration.get( | ||
"introspection_endpoint", openid_configuration["introspect_endpoint"] | ||
) | ||
get_token_info_from_introspect(session, token, introspect_endpoint) | ||
|
||
except requests.exceptions.JSONDecodeError as e: | ||
raise InvalidTokenError("Invalid JSON in userinfo response: {}".format(str(e))) | ||
|
||
# put the token info in cache | ||
put_token_info_in_cache(token, token_info) | ||
return json.loads(token_info) |
16 changes: 8 additions & 8 deletions
16
src/cnaas_nms/tools/rbac/token.py → src/cnaas_nms/tools/oidc/token.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
class Token: | ||
token_string: str = "" | ||
decoded_token = {} | ||
expires_at = "" | ||
|
||
def __init__(self, token_string: str, decoded_token: dict): | ||
self.token_string = token_string | ||
self.decoded_token = decoded_token | ||
class Token: | ||
token_string: str = "" | ||
decoded_token = {} | ||
expires_at = "" | ||
|
||
def __init__(self, token_string: str, decoded_token: dict): | ||
self.token_string = token_string | ||
self.decoded_token = decoded_token |
Oops, something went wrong.