From 4de0f786af337208bd9875b087db65c8d67b462d Mon Sep 17 00:00:00 2001 From: Varun Valada Date: Thu, 9 Jan 2025 17:48:04 -0600 Subject: [PATCH] Add permissions field to jwt token --- server/src/api/v1.py | 20 +++++++++++++------- server/tests/test_v1_authorization.py | 12 ++++++++---- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/server/src/api/v1.py b/server/src/api/v1.py index 50978efa..4fe20195 100644 --- a/server/src/api/v1.py +++ b/server/src/api/v1.py @@ -137,7 +137,8 @@ def check_token_priority( if priority == 0: return decoded_jwt = decode_jwt_token(auth_token, secret_key) - max_priority_dict = decoded_jwt.get("max_priority", {}) + permissions = decoded_jwt.get("permissions", {}) + max_priority_dict = permissions.get("max_priority", {}) star_priority = max_priority_dict.get("*", 0) queue_priority = max_priority_dict.get(queue, 0) max_priority = max(star_priority, queue_priority) @@ -160,7 +161,8 @@ def check_token_queue(auth_token: str, secret_key: str, queue: str): if not database.check_queue_restricted(queue): return decoded_jwt = decode_jwt_token(auth_token, secret_key) - allowed_queues = decoded_jwt.get("allowed_queues", []) + permissions = decoded_jwt.get("permissions", {}) + allowed_queues = permissions.get("allowed_queues", []) if queue not in allowed_queues: abort( 403, @@ -183,7 +185,8 @@ def check_token_reservation_timeout( if reservation_timeout <= max_reservation_time: return decoded_jwt = decode_jwt_token(auth_token, secret_key) - max_reservation_time_dict = decoded_jwt.get("max_reservation_time", {}) + permissions = decoded_jwt.get("permissions", {}) + max_reservation_time_dict = permissions.get("max_reservation_time", {}) queue_reservation_time = max_reservation_time_dict.get(queue, 0) star_reservation_time = max_reservation_time_dict.get("*", 0) max_reservation_time = max(queue_reservation_time, star_reservation_time) @@ -803,8 +806,9 @@ def generate_token(allowed_resources, secret_key): "exp": expiration_time, "iat": datetime.now(timezone.utc), # Issued at time "sub": "access_token", + "permissions": {}, } - token_payload.update(allowed_resources) + token_payload.get("permissions").update(allowed_resources) token = jwt.encode(token_payload, secret_key, algorithm="HS256") return token @@ -841,9 +845,11 @@ def retrieve_token(): exp: , iat: , sub: , - max_priority: , - allowed_queues: , - max_reservation_time: , + permissions: { + max_priority: , + allowed_queues: , + max_reservation_time: , + } } """ auth_header = request.authorization diff --git a/server/tests/test_v1_authorization.py b/server/tests/test_v1_authorization.py index d157bcda..bdf6b9a3 100644 --- a/server/tests/test_v1_authorization.py +++ b/server/tests/test_v1_authorization.py @@ -50,9 +50,9 @@ def test_retrieve_token(mongo_app_with_permissions): token, os.environ.get("JWT_SIGNING_KEY"), algorithms="HS256", - options={"require": ["exp", "iat", "sub", "max_priority"]}, + options={"require": ["exp", "iat", "sub", "permissions"]}, ) - assert decoded_token["max_priority"] == max_priority + assert decoded_token["permissions"]["max_priority"] == max_priority def test_retrieve_token_invalid_client_id(mongo_app_with_permissions): @@ -147,7 +147,9 @@ def test_priority_expired_token(mongo_app_with_permissions): "exp": datetime.utcnow() - timedelta(seconds=2), "iat": datetime.utcnow() - timedelta(seconds=4), "sub": "access_token", - "max_priority": {}, + "permissions": { + "max_priority": {}, + }, } token = jwt.encode(expired_token_payload, secret_key, algorithm="HS256") job = {"job_queue": "myqueue", "job_priority": 100} @@ -163,7 +165,9 @@ def test_missing_fields_in_token(mongo_app_with_permissions): app, _, _, _, _ = mongo_app_with_permissions secret_key = os.environ.get("JWT_SIGNING_KEY") incomplete_token_payload = { - "max_priority": {}, + "permissions": { + "max_priority": {}, + } } token = jwt.encode(incomplete_token_payload, secret_key, algorithm="HS256") job = {"job_queue": "myqueue", "job_priority": 100}