Skip to content

Commit

Permalink
seperate security.py into multiple different files
Browse files Browse the repository at this point in the history
  • Loading branch information
Josephine committed Mar 1, 2024
1 parent d774064 commit 9c300cb
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 220 deletions.
7 changes: 7 additions & 0 deletions src/cnaas_nms/models/singleton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class SingletonType(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(SingletonType, cls).__call__(*args, **kwargs)
return cls._instances[cls]
43 changes: 43 additions & 0 deletions src/cnaas_nms/tools/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import json
import time
from typing import Any

from redis.exceptions import RedisError

from cnaas_nms.db.session import redis_session
from cnaas_nms.tools.log import get_logger
from cnaas_nms.tools.oidc.token import Token

logger = get_logger()


REDIS_OAUTH_TOKEN_INFO_KEY = "oauth_userinfo"


def get_token_info_from_cache(token: Token) -> Any:
"""Check if the userinfo is in the cache to avoid multiple calls to the OIDC server"""
try:
with redis_session() as redis:
cached_token_info = redis.hget(REDIS_OAUTH_TOKEN_INFO_KEY, token.decoded_token["sub"])
if cached_token_info:
return json.loads(cached_token_info)
except RedisError as e:
logger.debug("Redis cache error: {}".format(str(e)))
except (TypeError, KeyError) as e:
logger.debug("Error while getting userinfo cache: {}".format(str(e)))


def put_token_info_in_cache(token: Token, token_info) -> Any:
"""Put the userinfo in the cache to avoid multiple calls to the OIDC server"""
try:
with redis_session() as redis:
if "exp" in token.decoded_token:
redis.hsetnx(REDIS_OAUTH_TOKEN_INFO_KEY, token.decoded_token["sub"], token_info)
# expire hash at access_token expiry time or 1 hour from now (whichever is sooner)
# Entire hash is expired, since redis does not support expiry on individual keys
expire_at = min(int(token.decoded_token["exp"]), int(time.time()) + 3600)
redis.expireat(REDIS_OAUTH_TOKEN_INFO_KEY, when=expire_at, lt=True)
except RedisError as e:
logger.debug("Redis cache error: {}".format(str(e)))
except (TypeError, KeyError) as e:
logger.debug("Error while getting userinfo cache: {}".format(str(e)))
56 changes: 56 additions & 0 deletions src/cnaas_nms/tools/oidc/key_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Any, Mapping, Optional

import requests
from jwt.exceptions import InvalidKeyError

from cnaas_nms.models.singleton import SingletonType
from cnaas_nms.tools.log import get_logger
from cnaas_nms.tools.oidc.oidc_client_call import get_openid_configuration

logger = get_logger()


class JWKSStore(object, metaclass=SingletonType):
keys: Mapping[str, Any]

def __init__(self, keys: Optional[Mapping[str, Any]] = None):
if keys:
self.keys = keys
else:
self.keys = {}


def get_keys():
"""Get the keys for the OIDC decoding"""
try:
session = requests.Session()
openid_configuration = get_openid_configuration(session)
keys_endpoint = openid_configuration["jwks_uri"]
response = session.get(url=keys_endpoint)
jwks_store = JWKSStore()
jwks_store.keys = response.json()["keys"]
except KeyError as e:
raise InvalidKeyError(e)
except requests.exceptions.HTTPError:
raise ConnectionError("Can't retrieve keys")


def get_key(kid):
"""Get the key based on the kid"""
jwks_store = JWKSStore()
key = [k for k in jwks_store.keys if k["kid"] == kid]
if len(key) == 0:
logger.debug("Key not found. Get the keys.")
get_keys()
if len(jwks_store.keys) == 0:
logger.error("Keys not downloaded")
raise ConnectionError("Can't retrieve keys")
try:
key = [k for k in jwks_store.keys if k["kid"] == kid]
except KeyError as e:
logger.error("Keys in different format?")
raise InvalidKeyError(e)
if len(key) == 0:
logger.error("Key not in keys")
raise InvalidKeyError()
return key
120 changes: 120 additions & 0 deletions src/cnaas_nms/tools/oidc/oidc_client_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import json
from typing import Any

import requests
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
from requests.auth import HTTPBasicAuth

from cnaas_nms.app_settings import auth_settings
from cnaas_nms.tools.cache import get_token_info_from_cache, put_token_info_in_cache
from cnaas_nms.tools.log import get_logger
from cnaas_nms.tools.oidc.token import Token

logger = get_logger()


def get_openid_configuration(session: requests.Session) -> dict:
"""Get the openid configuration"""
try:
request_openid_configuration = session.get(auth_settings.OIDC_CONF_WELL_KNOWN_URL)
request_openid_configuration.raise_for_status()
openid_configuration = request_openid_configuration.json()
return openid_configuration
except requests.exceptions.HTTPError:
raise ConnectionError("Can't reach the OIDC URL")
except requests.exceptions.ConnectionError:
raise ConnectionError("OIDC metadata unavailable")
except requests.exceptions.JSONDecodeError as e:
raise InvalidTokenError("Invalid JSON in openid Config response: {}".format(str(e)))


def get_token_info_from_userinfo(session: requests.Session, token: Token, user_info_endpoint: str) -> Any:
"""Get token info from userinfo"""
try:
userinfo_data = {"token_type_hint": "access_token"}
userinfo_headers = {"Authorization": "Bearer " + token.token_string}
userinfo_resp = session.post(user_info_endpoint, data=userinfo_data, headers=userinfo_headers)
userinfo_resp.raise_for_status()
userinfo_resp.json()
token_info = userinfo_resp.text
return token_info
except requests.exceptions.HTTPError as e:
try:
body = json.loads(e.response.content)
logger.debug("OIDC userinfo endpoint request not successful: " + body["error_description"])
raise InvalidTokenError(body["error_description"])
except (json.decoder.JSONDecodeError, KeyError):
logger.debug("OIDC userinfo endpoint request not successful: {}".format(str(e)))
raise InvalidTokenError(str(e))
except requests.exceptions.JSONDecodeError as e:
raise InvalidTokenError("Invalid JSON in userinfo response: {}".format(str(e)))


def get_token_info_from_introspect(session: requests.Session, token: Token, introspection_endpoint: str) -> Any:
"""Get token info from introspect"""
try:
introspect_data = {"token": token.token_string}
introspect_auth = HTTPBasicAuth(auth_settings.OIDC_CLIENT_ID, auth_settings.OIDC_CLIENT_SECRET)
introspect_resp = session.post(introspection_endpoint, data=introspect_data, auth=introspect_auth)
introspect_resp.raise_for_status()
introspect_json = introspect_resp.json()
if "active" in introspect_json and introspect_json["active"]:
token_info = introspect_resp.text
return token_info
else:
raise ExpiredSignatureError("Token is no longer active")

except requests.exceptions.HTTPError as e:
try:
body = json.loads(e.response.content)
logger.debug("OIDC introspection endpoint request not successful: " + body["error_description"])
raise InvalidTokenError(body["error_description"])
except (json.decoder.JSONDecodeError, KeyError):
logger.debug("OIDC introspection endpoint request not successful: {}".format(str(e)))
raise InvalidTokenError(str(e))
except requests.exceptions.JSONDecodeError as e:
raise InvalidTokenError("Invalid JSON in introspection response: {}".format(str(e)))


def get_oauth_token_info(token: Token) -> Any:
"""Give back the details about the token from userinfo or introspection
If OIDC is disabled, we return None.
For authorization code access_tokens we can use userinfo endpoint,
for client_credentials we can use introspection endpoint.
Returns:
resp.json(): Object of the user info or introspection
"""
# For now unnecessary, useful when we only use one log in method
if not auth_settings.OIDC_ENABLED:
return None

# Get the cached token info

cached_token_info = get_token_info_from_cache(token)
if cached_token_info:
return cached_token_info

# Get the openid-configuration
session = requests.Session()
openid_configuration = get_openid_configuration(session)

# Request the userinfo
try:
token_info = get_token_info_from_userinfo(session, token, openid_configuration["userinfo_endpoint"])
except requests.exceptions.HTTPError:
# if the userinfo doesn't work, try the introspectinfo
introspect_endpoint = openid_configuration.get(
"introspection_endpoint", openid_configuration["introspect_endpoint"]
)
get_token_info_from_introspect(session, token, introspect_endpoint)

except requests.exceptions.JSONDecodeError as e:
raise InvalidTokenError("Invalid JSON in userinfo response: {}".format(str(e)))

# put the token info in cache
put_token_info_in_cache(token, token_info)
return json.loads(token_info)
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
class Token:
token_string: str = ""
decoded_token = {}
expires_at = ""

def __init__(self, token_string: str, decoded_token: dict):
self.token_string = token_string
self.decoded_token = decoded_token
class Token:
token_string: str = ""
decoded_token = {}
expires_at = ""

def __init__(self, token_string: str, decoded_token: dict):
self.token_string = token_string
self.decoded_token = decoded_token
Loading

0 comments on commit 9c300cb

Please sign in to comment.