From 22b1e0b82e05d4d4849f830e13a5dc4fda4a35c4 Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Fri, 10 Jan 2025 15:35:54 +0000 Subject: [PATCH] OIDC Authenticator get Algo info from server --- tiled/authenticators.py | 128 ++++++++++++++++------------------------ 1 file changed, 50 insertions(+), 78 deletions(-) diff --git a/tiled/authenticators.py b/tiled/authenticators.py index e8fff3134..759ea91e5 100644 --- a/tiled/authenticators.py +++ b/tiled/authenticators.py @@ -5,10 +5,12 @@ import re import secrets from collections.abc import Iterable +from typing import Any, cast import httpx from fastapi import APIRouter, Request from jose import JWTError, jwk, jwt +from pydantic import Secret from starlette.responses import RedirectResponse from .server.authentication import Mode @@ -115,67 +117,65 @@ class OIDCAuthenticator: type: object additionalProperties: false properties: + audience: + type: string client_id: type: string client_secret: type: string - token_uri: - type: string - authorization_endpoint: - type: string - public_keys: - type: array - item: - type: object - properties: - - alg: - type: string - - e - type: string - - kid - type: string - - kty - type: string - - n - type: string - - use - type: string - required: - - alg - - e - - kid - - kty - - n - - use - confirmation_message: + well_known_uri: type: string - description: May be displayed by client after successful login. """ def __init__( self, - client_id, - client_secret, - public_keys, - token_uri, - authorization_endpoint, - confirmation_message, + audience: str, + client_id: str, + client_secret: str, + well_known_uri: str, ): - self.client_id = client_id - self.client_secret = client_secret - self.confirmation_message = confirmation_message - self.public_keys = public_keys - self.token_uri = token_uri - self.authorization_endpoint = httpx.URL(authorization_endpoint) + self._audience = audience + self._client_id = client_id + self._client_secret = Secret(client_secret) + self._well_known_url = well_known_uri + + @functools.cached_property + def _config_from_oidc_url(self) -> dict[str, Any]: + response: httpx.Response = httpx.get(self._well_known_url) + response.raise_for_status() + return response.json() + + @functools.cached_property + def id_token_signing_alg_values_supported(self) -> list[str]: + return cast( + list[str], + self._config_from_oidc_url.get("id_token_signing_alg_values_supported"), + ) - async def authenticate(self, request) -> UserSessionState: + @functools.cached_property + def issuer(self) -> str: + return cast(str, self._config_from_oidc_url.get("issuer")) + + @functools.cached_property + def jwks_uri(self) -> str: + return cast(str, self._config_from_oidc_url.get("jwks_uri")) + + @functools.cached_property + def token_endpoint(self) -> str: + return cast(str, self._config_from_oidc_url.get("token_endpoint")) + + async def authenticate(self, request: Request) -> UserSessionState: code = request.query_params["code"] # A proxy in the middle may make the request into something like # 'http://localhost:8000/...' so we fix the first part but keep # the original URI path. redirect_uri = f"{get_root_url(request)}{request.url.path}" response = await exchange_code( - self.token_uri, code, self.client_id, self.client_secret, redirect_uri + self.token_endpoint, + code, + self._client_id, + self._client_secret.get_secret_value(), + redirect_uri, ) response_body = response.json() if response.is_error: @@ -184,11 +184,14 @@ async def authenticate(self, request) -> UserSessionState: response_body = response.json() id_token = response_body["id_token"] access_token = response_body["access_token"] - # Match the kid in id_token to a key in the list of public_keys. - key = find_key(id_token, self.public_keys) + keys = httpx.get(self.jwks_uri) try: verified_body = jwt.decode( - id_token, key, access_token=access_token, audience=self.client_id + token=id_token, + key=keys, + algorithms=self.id_token_signing_alg_values_supported, + audience=self._audience, + access_token=access_token, ) except JWTError: logger.exception( @@ -203,37 +206,6 @@ class KeyNotFoundError(Exception): pass -def find_key(token, keys): - """ - Find a key from the configured keys based on the kid claim of the token - - Parameters - ---------- - token : token to search for the kid from - keys: list of keys - - Raises - ------ - KeyNotFoundError: - returned if the token does not have a kid claim - - Returns - ------ - key: found key object - """ - - unverified = jwt.get_unverified_header(token) - kid = unverified.get("kid") - if not kid: - raise KeyNotFoundError("No 'kid' in token") - - for key in keys: - if key["kid"] == kid: - return jwk.construct(key) - return KeyNotFoundError( - f"Token specifies {kid} but we have {[k['kid'] for k in keys]}" - ) - async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect_uri): """Method that talks to an IdP to exchange a code for an access_token and/or id_token