Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add error codes #556

Merged
merged 8 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ clean_infra:
docker compose down --remove-orphans &&\
docker system prune -a --volumes -f

stop_infra:
cd infra &&\
docker compose down --remove-orphans

sync_infra:
python scripts/gh-download.py --repo=supabase/gotrue-js --branch=master --folder=infra

Expand Down
498 changes: 272 additions & 226 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ unasync-cli = "^0.0.9"

[tool.poetry.group.dev.dependencies]
pygithub = ">=1.57,<3.0"
respx = ">=0.20.2,<0.22.0"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
3 changes: 3 additions & 0 deletions supabase_auth/_async/gotrue_base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import BaseModel
from typing_extensions import Literal, Self

from ..constants import API_VERSION_HEADER_NAME, API_VERSIONS
from ..helpers import handle_exception, model_dump
from ..http_clients import AsyncClient

Expand Down Expand Up @@ -97,6 +98,8 @@ async def _request(
) -> Union[T, None]:
url = f"{self._url}/{path}"
headers = {**self._headers, **(headers or {})}
if API_VERSION_HEADER_NAME not in headers:
headers[API_VERSION_HEADER_NAME] = API_VERSIONS["2024-01-01"].get("name")
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json;charset=UTF-8"
if jwt:
Expand Down
3 changes: 3 additions & 0 deletions supabase_auth/_sync/gotrue_base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import BaseModel
from typing_extensions import Literal, Self

from ..constants import API_VERSION_HEADER_NAME, API_VERSIONS
from ..helpers import handle_exception, model_dump
from ..http_clients import SyncClient

Expand Down Expand Up @@ -97,6 +98,8 @@ def _request(
) -> Union[T, None]:
url = f"{self._url}/{path}"
headers = {**self._headers, **(headers or {})}
if API_VERSION_HEADER_NAME not in headers:
headers[API_VERSION_HEADER_NAME] = API_VERSIONS["2024-01-01"].get("name")
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json;charset=UTF-8"
if jwt:
Expand Down
9 changes: 9 additions & 0 deletions supabase_auth/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from datetime import datetime
from typing import Dict

from .version import __version__
Expand All @@ -12,3 +13,11 @@
MAX_RETRIES = 10
RETRY_INTERVAL = 2 # deciseconds
STORAGE_KEY = "supabase.auth.token"

API_VERSION_HEADER_NAME = "X-Supabase-Api-Version"
API_VERSIONS = {
"2024-01-01": {
"timestamp": datetime.timestamp(datetime.strptime("2024-01-01", "%Y-%m-%d")),
"name": "2024-01-01",
},
}
110 changes: 103 additions & 7 deletions supabase_auth/errors.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,119 @@
from __future__ import annotations

from typing import Union
from typing import List, Literal, Union

from typing_extensions import TypedDict

ErrorCode = Literal[
"unexpected_failure",
"validation_failed",
"bad_json",
"email_exists",
"phone_exists",
"bad_jwt",
"not_admin",
"no_authorization",
"user_not_found",
"session_not_found",
"flow_state_not_found",
"flow_state_expired",
"signup_disabled",
"user_banned",
"provider_email_needs_verification",
"invite_not_found",
"bad_oauth_state",
"bad_oauth_callback",
"oauth_provider_not_supported",
"unexpected_audience",
"single_identity_not_deletable",
"email_conflict_identity_not_deletable",
"identity_already_exists",
"email_provider_disabled",
"phone_provider_disabled",
"too_many_enrolled_mfa_factors",
"mfa_factor_name_conflict",
"mfa_factor_not_found",
"mfa_ip_address_mismatch",
"mfa_challenge_expired",
"mfa_verification_failed",
"mfa_verification_rejected",
"insufficient_aal",
"captcha_failed",
"saml_provider_disabled",
"manual_linking_disabled",
"sms_send_failed",
"email_not_confirmed",
"phone_not_confirmed",
"reauth_nonce_missing",
"saml_relay_state_not_found",
"saml_relay_state_expired",
"saml_idp_not_found",
"saml_assertion_no_user_id",
"saml_assertion_no_email",
"user_already_exists",
"sso_provider_not_found",
"saml_metadata_fetch_failed",
"saml_idp_already_exists",
"sso_domain_already_exists",
"saml_entity_id_mismatch",
"conflict",
"provider_disabled",
"user_sso_managed",
"reauthentication_needed",
"same_password",
"reauthentication_not_valid",
"otp_expired",
"otp_disabled",
"identity_not_found",
"weak_password",
"over_request_rate_limit",
"over_email_send_rate_limit",
"over_sms_send_rate_limit",
"bad_code_verifier",
]


class AuthError(Exception):
def __init__(self, message: str) -> None:
def __init__(self, message: str, code: ErrorCode) -> None:
Exception.__init__(self, message)
self.message = message
self.name = "AuthError"
self.code = code


class AuthApiErrorDict(TypedDict):
name: str
message: str
status: int
code: ErrorCode


class AuthApiError(AuthError):
def __init__(self, message: str, status: int) -> None:
AuthError.__init__(self, message)
def __init__(self, message: str, status: int, code: ErrorCode) -> None:
AuthError.__init__(self, message, code)
self.name = "AuthApiError"
self.status = status
self.code = code

def to_dict(self) -> AuthApiErrorDict:
return {
"name": self.name,
"message": self.message,
"status": self.status,
"code": self.code,
}


class AuthUnknownError(AuthError):
def __init__(self, message: str, original_error: Exception) -> None:
AuthError.__init__(self, message)
AuthError.__init__(self, message, None)
self.name = "AuthUnknownError"
self.original_error = original_error


class CustomAuthError(AuthError):
def __init__(self, message: str, name: str, status: int) -> None:
AuthError.__init__(self, message)
def __init__(self, message: str, name: str, status: int, code: ErrorCode) -> None:
AuthError.__init__(self, message, code)
self.name = name
self.status = status

Expand All @@ -60,6 +132,7 @@ def __init__(self) -> None:
"Auth session missing!",
"AuthSessionMissingError",
400,
None,
)


Expand All @@ -70,6 +143,7 @@ def __init__(self, message: str) -> None:
message,
"AuthInvalidCredentialsError",
400,
None,
)


Expand All @@ -93,6 +167,7 @@ def __init__(
message,
"AuthImplicitGrantRedirectError",
500,
None,
)
self.details = details

Expand All @@ -112,4 +187,25 @@ def __init__(self, message: str, status: int) -> None:
message,
"AuthRetryableError",
status,
None,
)


class AuthWeakPasswordError(CustomAuthError):
def __init__(self, message: str, status: int, reasons: List[str]) -> None:
CustomAuthError.__init__(
self,
message,
"AuthWeakPasswordError",
status,
"weak_password",
)
self.reasons = reasons

def to_dict(self) -> AuthApiErrorDict:
return {
"name": self.name,
"message": self.message,
"status": self.status,
"reasons": self.reasons,
}
83 changes: 79 additions & 4 deletions supabase_auth/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,25 @@

import base64
import hashlib
import re
import secrets
import string
from base64 import urlsafe_b64decode
from datetime import datetime
from json import loads
from typing import Any, Dict, Type, TypeVar, Union, cast

from httpx import HTTPStatusError
from httpx import HTTPStatusError, Response
from pydantic import BaseModel

from .errors import AuthApiError, AuthError, AuthRetryableError, AuthUnknownError
from .constants import API_VERSION_HEADER_NAME, API_VERSIONS
from .errors import (
AuthApiError,
AuthError,
AuthRetryableError,
AuthUnknownError,
AuthWeakPasswordError,
)
from .types import (
AuthOtpResponse,
AuthResponse,
Expand Down Expand Up @@ -114,6 +123,10 @@ def get_error_message(error: Any) -> str:
return next((error[prop] for prop in props if filter(prop)), str(error))


def get_error_code(error: Any) -> str:
return error.get("error_code", None) if isinstance(error, dict) else None


def looks_like_http_status_error(exception: Exception) -> bool:
return isinstance(exception, HTTPStatusError)

Expand All @@ -128,8 +141,51 @@ def handle_exception(exception: Exception) -> AuthError:
return AuthRetryableError(
get_error_message(error), error.response.status_code
)
json = error.response.json()
return AuthApiError(get_error_message(json), error.response.status_code or 500)
data = error.response.json()

error_code = None
response_api_version = parse_response_api_version(error.response)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably this parse_response_api_version should return Unix timestamp 0 and your code below will magically work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean here? If I put 0 just as with the JS library it will fall through to the second part of the if statement which is the elif. Can you elaborate more so that it's clear what could go wrong here please?


if (
response_api_version
and datetime.timestamp(response_api_version)
>= API_VERSIONS.get("2024-01-01").get("timestamp")
and isinstance(data, dict)
and data
and isinstance(data.get("code"), str)
):
error_code = data.get("code")
elif (
isinstance(data, dict) and data and isinstance(data.get("error_code"), str)
):
error_code = data.get("error_code")

if error_code is None:
if (
isinstance(data, dict)
and data
and isinstance(data.get("weak_password"), dict)
and data.get("weak_password")
and isinstance(data.get("weak_password"), list)
and len(data.get("weak_password"))
):
return AuthWeakPasswordError(
get_error_message(data),
error.response.status_code,
data.get("weak_password").get("reasons"),
)
elif error_code == "weak_password":
return AuthWeakPasswordError(
get_error_message(data),
error.response.status_code,
data.get("weak_password", {}).get("reasons", {}),
)

return AuthApiError(
get_error_message(data),
error.response.status_code or 500,
error_code,
)
except Exception as e:
return AuthUnknownError(get_error_message(error), e)

Expand Down Expand Up @@ -163,3 +219,22 @@ def generate_pkce_challenge(code_verifier):
sha256_hash = hashlib.sha256(verifier_bytes).digest()

return base64.urlsafe_b64encode(sha256_hash).rstrip(b"=").decode("utf-8")


API_VERSION_REGEX = r"^2[0-9]{3}-(0[1-9]|1[0-2])-(0[1-9]|1[0-9]|2[0-9]|3[0-1])$"


def parse_response_api_version(response: Response):
api_version = response.headers.get(API_VERSION_HEADER_NAME)

if not api_version:
return None

if re.search(API_VERSION_REGEX, api_version) is None:
return None

try:
dt = datetime.strptime(api_version, "%Y-%m-%d")
return dt
except Exception as e:
return None
Loading
Loading