Skip to content

Commit

Permalink
Add permissions field to jwt token
Browse files Browse the repository at this point in the history
  • Loading branch information
val500 committed Jan 9, 2025
1 parent bd53c9d commit 4de0f78
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
20 changes: 13 additions & 7 deletions server/src/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def check_token_priority(
if priority == 0:
return
decoded_jwt = decode_jwt_token(auth_token, secret_key)
max_priority_dict = decoded_jwt.get("max_priority", {})
permissions = decoded_jwt.get("permissions", {})
max_priority_dict = permissions.get("max_priority", {})
star_priority = max_priority_dict.get("*", 0)
queue_priority = max_priority_dict.get(queue, 0)
max_priority = max(star_priority, queue_priority)
Expand All @@ -160,7 +161,8 @@ def check_token_queue(auth_token: str, secret_key: str, queue: str):
if not database.check_queue_restricted(queue):
return
decoded_jwt = decode_jwt_token(auth_token, secret_key)
allowed_queues = decoded_jwt.get("allowed_queues", [])
permissions = decoded_jwt.get("permissions", {})
allowed_queues = permissions.get("allowed_queues", [])
if queue not in allowed_queues:
abort(
403,
Expand All @@ -183,7 +185,8 @@ def check_token_reservation_timeout(
if reservation_timeout <= max_reservation_time:
return
decoded_jwt = decode_jwt_token(auth_token, secret_key)
max_reservation_time_dict = decoded_jwt.get("max_reservation_time", {})
permissions = decoded_jwt.get("permissions", {})
max_reservation_time_dict = permissions.get("max_reservation_time", {})
queue_reservation_time = max_reservation_time_dict.get(queue, 0)
star_reservation_time = max_reservation_time_dict.get("*", 0)
max_reservation_time = max(queue_reservation_time, star_reservation_time)
Expand Down Expand Up @@ -803,8 +806,9 @@ def generate_token(allowed_resources, secret_key):
"exp": expiration_time,
"iat": datetime.now(timezone.utc), # Issued at time
"sub": "access_token",
"permissions": {},
}
token_payload.update(allowed_resources)
token_payload.get("permissions").update(allowed_resources)
token = jwt.encode(token_payload, secret_key, algorithm="HS256")
return token

Expand Down Expand Up @@ -841,9 +845,11 @@ def retrieve_token():
exp: <Expiration DateTime of Token>,
iat: <Issuance DateTime of Token>,
sub: <Subject Field of Token>,
max_priority: <Queue to Priority Level Dict>,
allowed_queues: <List of Allowed Restricted Queues>,
max_reservation_time: <Queue to Max Reservation Time Dict>,
permissions: {
max_priority: <Queue to Priority Level Dict>,
allowed_queues: <List of Allowed Restricted Queues>,
max_reservation_time: <Queue to Max Reservation Time Dict>,
}
}
"""
auth_header = request.authorization
Expand Down
12 changes: 8 additions & 4 deletions server/tests/test_v1_authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def test_retrieve_token(mongo_app_with_permissions):
token,
os.environ.get("JWT_SIGNING_KEY"),
algorithms="HS256",
options={"require": ["exp", "iat", "sub", "max_priority"]},
options={"require": ["exp", "iat", "sub", "permissions"]},
)
assert decoded_token["max_priority"] == max_priority
assert decoded_token["permissions"]["max_priority"] == max_priority


def test_retrieve_token_invalid_client_id(mongo_app_with_permissions):
Expand Down Expand Up @@ -147,7 +147,9 @@ def test_priority_expired_token(mongo_app_with_permissions):
"exp": datetime.utcnow() - timedelta(seconds=2),
"iat": datetime.utcnow() - timedelta(seconds=4),
"sub": "access_token",
"max_priority": {},
"permissions": {
"max_priority": {},
},
}
token = jwt.encode(expired_token_payload, secret_key, algorithm="HS256")
job = {"job_queue": "myqueue", "job_priority": 100}
Expand All @@ -163,7 +165,9 @@ def test_missing_fields_in_token(mongo_app_with_permissions):
app, _, _, _, _ = mongo_app_with_permissions
secret_key = os.environ.get("JWT_SIGNING_KEY")
incomplete_token_payload = {
"max_priority": {},
"permissions": {
"max_priority": {},
}
}
token = jwt.encode(incomplete_token_payload, secret_key, algorithm="HS256")
job = {"job_queue": "myqueue", "job_priority": 100}
Expand Down

0 comments on commit 4de0f78

Please sign in to comment.