From f6bc39d37977e961cf3cc931d3a7739f84e86823 Mon Sep 17 00:00:00 2001 From: dudi levy <4785835+dudil@users.noreply.github.com> Date: Sat, 2 Dec 2023 10:58:15 +0200 Subject: [PATCH] Fix typing annotations in auth.py exception handling in msal_scheme.py - fix and close #15, #30 --- fastapi_msal/auth.py | 12 +++++++----- fastapi_msal/security/msal_scheme.py | 8 ++++++-- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/fastapi_msal/auth.py b/fastapi_msal/auth.py index 9ef7692..403c41d 100644 --- a/fastapi_msal/auth.py +++ b/fastapi_msal/auth.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Annotated, Optional from fastapi import APIRouter, Form, Header from starlette.requests import Request @@ -65,12 +65,14 @@ async def _login_route( redirect_uri = str(request.url_for("_get_token_route")) return await self.handler.authorize_redirect(request=request, redirec_uri=redirect_uri, state=state) - async def _get_token_route(self, request: Request, code: str, state: Optional[str]) -> RedirectResponse: + async def _get_token_route(self, request: Request, code: str, state: OptStr) -> RedirectResponse: await self.handler.authorize_access_token(request=request, code=code, state=state) return RedirectResponse(url=f"{self.return_to_path}", headers=dict(request.headers.items())) - async def _post_token_route(self, request: Request, code: str = Form(...)) -> BearerToken: - token: AuthToken = await self.handler.authorize_access_token(request=request, code=code) + async def _post_token_route( + self, request: Request, code: Annotated[str, Form()], state: Annotated[OptStr, Form()] = None + ) -> BearerToken: + token: AuthToken = await self.handler.authorize_access_token(request=request, code=code, state=state) return BearerToken(access_token=token.id_token) async def _logout_route(self, request: Request, referer: OptStr = Header(None)) -> RedirectResponse: # noqa: B008 @@ -83,7 +85,7 @@ async def get_session_token(self, request: Request) -> Optional[AuthToken]: async def check_authenticated_session(self, request: Request) -> bool: auth_token: Optional[AuthToken] = await self.get_session_token(request) if auth_token and auth_token.id_token: - token_claims = self.handler.parse_id_token(request=request, token=auth_token) + token_claims = await self.handler.parse_id_token(request=request, token=auth_token) if token_claims: return True return False diff --git a/fastapi_msal/security/msal_scheme.py b/fastapi_msal/security/msal_scheme.py index 1dd8957..3e71269 100644 --- a/fastapi_msal/security/msal_scheme.py +++ b/fastapi_msal/security/msal_scheme.py @@ -46,8 +46,12 @@ async def __call__(self, request: Request) -> IDTokenClaims: authorization: Optional[str] = request.headers.get("Authorization") scheme, token = get_authorization_scheme_param(authorization) token_claims: Optional[IDTokenClaims] = None - if authorization and scheme.lower() != "bearer": - token_claims = await self.handler.parse_id_token(request=request, token=token, validate=True) + if authorization and scheme.lower() == "bearer": + try: + token_claims = await self.handler.parse_id_token(request=request, token=token, validate=True) + except RuntimeError as e: + print(e) + raise http_exception from e else: session_token: Optional[AuthToken] = await self.handler.get_token_from_session(request=request) if session_token: