diff --git a/CHANGELOG.md b/CHANGELOG.md index 91e14522..683de602 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## UNRELEASED + +- Fix race condition on concurrent logout call + ## 4.1.0 - Expiry format now defaults to whatever is used Django REST framework diff --git a/knox/auth.py b/knox/auth.py index ffb5a287..5c556c72 100644 --- a/knox/auth.py +++ b/knox/auth.py @@ -7,6 +7,7 @@ def compare_digest(a, b): import binascii from django.contrib.auth import get_user_model +from django.db import DatabaseError from django.utils import timezone from django.utils.translation import ugettext_lazy as _ from rest_framework import exceptions @@ -73,7 +74,13 @@ def authenticate_credentials(self, token): raise exceptions.AuthenticationFailed(msg) if compare_digest(digest, auth_token.digest): if knox_settings.AUTO_REFRESH and auth_token.expiry: - self.renew_token(auth_token) + # It may happen that a token gets deleted while we try + # to update its expiry, catch that and consider the token + # invalid + try: + self.renew_token(auth_token) + except DatabaseError: + break return self.validate_user(auth_token) raise exceptions.AuthenticationFailed(msg) diff --git a/tests/tests.py b/tests/tests.py index ffe6f2d8..4fd94323 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -2,9 +2,11 @@ from datetime import datetime, timedelta from django.contrib.auth import get_user_model +from django.db import DatabaseError from django.test import override_settings from django.utils.six.moves import reload_module from freezegun import freeze_time +from rest_framework.exceptions import AuthenticationFailed from rest_framework.serializers import DateTimeField from rest_framework.test import APIRequestFactory, APITestCase as TestCase @@ -15,6 +17,13 @@ from knox.settings import CONSTANTS, knox_settings from knox.signals import token_expired +try: + # Python 3 + from unittest import mock +except ImportError: + # Python 2 + import mock + try: # For django >= 2.0 from django.urls import reverse @@ -396,3 +405,19 @@ def test_expiry_is_present(self): response.data['expiry'], DateTimeField().to_representation(AuthToken.objects.first().expiry) ) + + def test_authenticate_credentials_handles_expiry_update_of_gone_token(self): + """This tests a race condition of an authentication against logout + + It may happen that a token gets deleted while we are inside + authenticate_credentials with the Django ORM raising a DatabaseError + when trying to update the expiry time.""" + + instance, token = AuthToken.objects.create(user=self.user) + with override_settings(REST_KNOX=auto_refresh_knox): + reload_module(auth) + token_auth = TokenAuthentication() + with mock.patch.object(token_auth, 'renew_token') as m: + m.side_effect = DatabaseError() + with self.assertRaises(AuthenticationFailed): + token_auth.authenticate_credentials(token.encode('utf-8')) diff --git a/tox.ini b/tox.ini index f80ab08d..881e382b 100644 --- a/tox.ini +++ b/tox.ini @@ -35,6 +35,7 @@ deps = django22: Django>=2.2,<2.3 django-nose markdown<3.0 + mock isort djangorestframework freezegun