Skip to content

Commit

Permalink
Clean up authentication helper functions
Browse files Browse the repository at this point in the history
* Remove unused function and its test
* Refactor existing functions
  • Loading branch information
hmpf authored Oct 9, 2024
1 parent 140d45a commit 537afc4
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 17 deletions.
3 changes: 3 additions & 0 deletions changelog.d/+oauth2.removed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
As part of refactoring some authentication utility functions the function
`get_psa_authentication_names()` has been removed as it wasn't used anywhere in
Argus proper.
35 changes: 24 additions & 11 deletions src/argus/auth/utils.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,45 @@
from django.conf import settings
from django.contrib.auth.backends import ModelBackend
from django.contrib.auth.backends import ModelBackend, RemoteUserBackend
from django.utils.module_loading import import_string

from rest_framework.reverse import reverse
from social_core.backends.base import BaseAuth

from social_core.backends.oauth import BaseOAuth2


_all__ = [
"get_authentication_backend_classes",
"has_model_backend",
"has_remote_user_backend",
"get_psa_authentication_backends",
"get_authentication_backend_name_and_type",
]


def get_authentication_backend_classes():
backend_dotted_paths = getattr(settings, "AUTHENTICATION_BACKENDS")
backends = [import_string(path) for path in backend_dotted_paths]
return backends


def get_psa_authentication_names(backends=None):
def has_model_backend(backends):
return ModelBackend in backends


def has_remote_user_backend(backends):
return RemoteUserBackend in backends


def get_psa_authentication_backends(backends=None):
backends = backends if backends else get_authentication_backend_classes()
psa_backends = set()
for backend in backends:
if issubclass(backend, BaseAuth):
psa_backends.add(backend.name)
return sorted(psa_backends)
return [backend for backend in backends if issubclass(backend, BaseOAuth2)]


def get_authentication_backend_name_and_type(request):
# Needed for SPA /login-methods/ API endpoint
backends = get_authentication_backend_classes()
data = []
if ModelBackend in backends:
if has_model_backend(backends):
data.append(
{
"type": "username_password",
Expand All @@ -40,8 +54,7 @@ def get_authentication_backend_name_and_type(request):
"url": reverse("social:begin", kwargs={"backend": backend.name}, request=request),
"name": backend.name,
}
for backend in backends
if issubclass(backend, BaseOAuth2)
for backend in get_psa_authentication_backends(backends)
)

return data
7 changes: 1 addition & 6 deletions tests/auth/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from django.contrib.auth.backends import ModelBackend
from django.test import TestCase

from argus.auth.utils import get_authentication_backend_name_and_type, get_psa_authentication_names
from argus.auth.utils import get_authentication_backend_name_and_type
from argus.dataporten.social import DataportenFeideOAuth2


Expand Down Expand Up @@ -33,8 +33,3 @@ def test_get_authentication_backend_name_and_type_returns_feide_login(
"name": "dataporten_feide",
}
]

@patch("argus.auth.utils.get_authentication_backend_classes")
def test_get_psa_authentication_names_returns_feide_name(self, mock_get_authentication_backend_classes):
mock_get_authentication_backend_classes.return_value = [DataportenFeideOAuth2]
assert get_psa_authentication_names() == ["dataporten_feide"]

0 comments on commit 537afc4

Please sign in to comment.