Skip to content

Commit

Permalink
OIDC Authenticator get Algo info from server
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Jan 10, 2025
1 parent cc21998 commit eab3d8d
Showing 1 changed file with 45 additions and 78 deletions.
123 changes: 45 additions & 78 deletions tiled/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import re
import secrets
from collections.abc import Iterable
from typing import Any, cast

import httpx
from fastapi import APIRouter, Request
from jose import JWTError, jwk, jwt
from pydantic import Secret
from starlette.responses import RedirectResponse

from .server.authentication import Mode
Expand Down Expand Up @@ -115,67 +117,61 @@ class OIDCAuthenticator:
type: object
additionalProperties: false
properties:
audience:
type: string
client_id:
type: string
client_secret:
type: string
token_uri:
type: string
authorization_endpoint:
well_known_uri:
type: string
public_keys:
type: array
item:
type: object
properties:
- alg:
type: string
- e
type: string
- kid
type: string
- kty
type: string
- n
type: string
- use
type: string
required:
- alg
- e
- kid
- kty
- n
- use
confirmation_message:
type: string
description: May be displayed by client after successful login.
"""

def __init__(
self,
client_id,
client_secret,
public_keys,
token_uri,
authorization_endpoint,
confirmation_message,
audience: str,
client_id: str,
client_secret: str,
well_known_uri: str,
):
self.client_id = client_id
self.client_secret = client_secret
self.confirmation_message = confirmation_message
self.public_keys = public_keys
self.token_uri = token_uri
self.authorization_endpoint = httpx.URL(authorization_endpoint)
self._audience = audience
self._client_id = client_id
self._client_secret = Secret(client_secret)
self._well_known_url = well_known_uri

@functools.cached_property
def _config_from_oidc_url(self) -> dict[str, Any]:
response: httpx.Response = httpx.get(self._well_known_url)
response.raise_for_status()
return response.json()

@functools.cached_property
def token_endpoint(self) -> str:
return cast(str, self._config_from_oidc_url.get("token_endpoint"))

@functools.cached_property
def jwks_uri(self) -> str:
return cast(str, self._config_from_oidc_url.get("jwks_uri"))

@functools.cached_property
def id_token_signing_alg_values_supported(self) -> list[str]:
return cast(
list[str],
self._config_from_oidc_url.get("id_token_signing_alg_values_supported"),
)

async def authenticate(self, request) -> UserSessionState:
async def authenticate(self, request: Request) -> UserSessionState:
code = request.query_params["code"]
# A proxy in the middle may make the request into something like
# 'http://localhost:8000/...' so we fix the first part but keep
# the original URI path.
redirect_uri = f"{get_root_url(request)}{request.url.path}"
response = await exchange_code(
self.token_uri, code, self.client_id, self.client_secret, redirect_uri
self.token_endpoint,
code,
self._client_id,
self._client_secret.get_secret_value(),
redirect_uri,
)
response_body = response.json()
if response.is_error:
Expand All @@ -184,11 +180,13 @@ async def authenticate(self, request) -> UserSessionState:
response_body = response.json()
id_token = response_body["id_token"]
access_token = response_body["access_token"]
# Match the kid in id_token to a key in the list of public_keys.
key = find_key(id_token, self.public_keys)
keys = request.get(self.jwks_uri)
try:
verified_body = jwt.decode(
id_token, key, access_token=access_token, audience=self.client_id
access_token,
keys,
algorithms=self.id_token_signing_alg_values_supported,
audience=self._audience,
)
except JWTError:
logger.exception(
Expand All @@ -203,37 +201,6 @@ class KeyNotFoundError(Exception):
pass


def find_key(token, keys):
"""
Find a key from the configured keys based on the kid claim of the token
Parameters
----------
token : token to search for the kid from
keys: list of keys
Raises
------
KeyNotFoundError:
returned if the token does not have a kid claim
Returns
------
key: found key object
"""

unverified = jwt.get_unverified_header(token)
kid = unverified.get("kid")
if not kid:
raise KeyNotFoundError("No 'kid' in token")

for key in keys:
if key["kid"] == kid:
return jwk.construct(key)
return KeyNotFoundError(
f"Token specifies {kid} but we have {[k['kid'] for k in keys]}"
)


async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect_uri):
"""Method that talks to an IdP to exchange a code for an access_token and/or id_token
Expand Down

0 comments on commit eab3d8d

Please sign in to comment.