diff --git a/tests/test_auth.py b/tests/test_auth.py index 58636ee..9bac249 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -2,15 +2,15 @@ from unittest.mock import Mock, call, patch from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import hashes, hmac, serialization -from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives import hashes, hmac from django.conf import settings from django.contrib.auth import get_user_model from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation from django.test import RequestFactory, TestCase, override_settings from django.utils.encoding import force_bytes, smart_str -from josepy.b64 import b64encode -from josepy.jwa import ES256 +from jwcrypto.common import base64url_encode +from jwcrypto.jwt import JWT +from jwcrypto.jwk import JWK from mozilla_django_oidc.auth import OIDCAuthenticationBackend, default_username_algo @@ -72,13 +72,10 @@ def test_allowed_unsecured_token(self): header = force_bytes(json.dumps({"alg": "none"})) payload = force_bytes(json.dumps({"foo": "bar"})) signature = "" - token = force_bytes( - "{}.{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)), signature - ) - ) + token = "{}.{}.{}".format(base64url_encode(header), base64url_encode(payload), signature) + token_bytes = force_bytes(token) - extracted_payload = self.backend.get_payload_data(token, None) + extracted_payload = self.backend.get_payload_data(token_bytes, None) self.assertEqual(payload, extracted_payload) @override_settings(OIDC_ALLOW_UNSECURED_JWT=False) @@ -87,14 +84,11 @@ def test_disallowed_unsecured_token(self): header = force_bytes(json.dumps({"alg": "none"})) payload = force_bytes(json.dumps({"foo": "bar"})) signature = "" - token = force_bytes( - "{}.{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)), signature - ) - ) + token = "{}.{}.{}".format(base64url_encode(header), base64url_encode(payload), signature) + token_bytes = force_bytes(token) with self.assertRaises(SuspiciousOperation): - self.backend.get_payload_data(token, None) + self.backend.get_payload_data(token_bytes, None) @override_settings(OIDC_ALLOW_UNSECURED_JWT=True) def test_allowed_unsecured_valid_token(self): @@ -105,17 +99,11 @@ def test_allowed_unsecured_valid_token(self): # Compute signature key = b"mysupersecuretestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) - msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) - ) + msg = "{}.{}".format(base64url_encode(header), base64url_encode(payload)) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64url_encode(h.finalize()) - token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), - smart_str(signature), - ) + token = "{}.{}.{}".format(base64url_encode(header), base64url_encode(payload), signature) token_bytes = force_bytes(token) key_text = smart_str(key) output = self.backend.get_payload_data(token_bytes, key_text) @@ -130,17 +118,11 @@ def test_disallowed_unsecured_valid_token(self): # Compute signature key = b"mysupersecuretestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) - msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) - ) + msg = "{}.{}".format(base64url_encode(header), base64url_encode(payload)) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64url_encode(h.finalize()) - token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), - smart_str(signature), - ) + token = "{}.{}.{}".format(base64url_encode(header), base64url_encode(payload), signature) token_bytes = force_bytes(token) key_text = smart_str(key) output = self.backend.get_payload_data(token_bytes, key_text) @@ -156,17 +138,11 @@ def test_allowed_unsecured_invalid_token(self): key = b"mysupersecuretestkey" fake_key = b"mysupersecurefaketestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) - msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) - ) + msg = "{}.{}".format(base64url_encode(header), base64url_encode(payload)) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64url_encode(h.finalize()) - token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), - smart_str(signature), - ) + token = "{}.{}.{}".format(base64url_encode(header), base64url_encode(payload), signature) token_bytes = force_bytes(token) key_text = smart_str(fake_key) @@ -184,17 +160,11 @@ def test_disallowed_unsecured_invalid_token(self): key = b"mysupersecuretestkey" fake_key = b"mysupersecurefaketestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) - msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) - ) + msg = "{}.{}".format(base64url_encode(header), base64url_encode(payload)) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64url_encode(h.finalize()) - token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), - smart_str(signature), - ) + token = "{}.{}.{}".format(base64url_encode(header), base64url_encode(payload), signature) token_bytes = force_bytes(token) key_text = smart_str(fake_key) @@ -966,14 +936,14 @@ def test_retrieve_matching_jwk(self, mock_requests): key = b"mysupersecuretestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) + smart_str(base64url_encode(header)), smart_str(base64url_encode(payload)) ) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64url_encode(h.finalize()) token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), + smart_str(base64url_encode(header)), + smart_str(base64url_encode(payload)), smart_str(signature), ) @@ -1012,14 +982,14 @@ def test_retrieve_matching_jwk_same_kid(self, mock_requests): key = b"mysupersecuretestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) + smart_str(base64url_encode(header)), smart_str(base64url_encode(payload)) ) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64url_encode(h.finalize()) token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), + smart_str(base64url_encode(header)), + smart_str(base64url_encode(payload)), smart_str(signature), ) @@ -1048,14 +1018,14 @@ def test_retrieve_mismatcing_jwk_alg(self, mock_requests): key = b"mysupersecuretestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) + smart_str(base64url_encode(header)), smart_str(base64url_encode(payload)) ) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64url_encode(h.finalize()) token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), + smart_str(base64url_encode(header)), + smart_str(base64url_encode(payload)), smart_str(signature), ) @@ -1086,14 +1056,14 @@ def test_retrieve_mismatcing_jwk_kid(self, mock_requests): key = b"mysupersecuretestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) + smart_str(base64url_encode(header)), smart_str(base64url_encode(payload)) ) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64url_encode(h.finalize()) token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), + smart_str(base64url_encode(header)), + smart_str(base64url_encode(payload)), smart_str(signature), ) @@ -1123,14 +1093,14 @@ def test_retrieve_jwk_optional_alg(self, mock_requests): key = b"mysupersecuretestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) + smart_str(base64url_encode(header)), smart_str(base64url_encode(payload)) ) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64url_encode(h.finalize()) token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), + smart_str(base64url_encode(header)), + smart_str(base64url_encode(payload)), smart_str(signature), ) @@ -1154,14 +1124,14 @@ def test_retrieve_not_existing_jwk(self, mock_requests): key = b"mysupersecuretestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) + smart_str(base64url_encode(header)), smart_str(base64url_encode(payload)) ) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64url_encode(h.finalize()) token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), + smart_str(base64url_encode(header)), + smart_str(base64url_encode(payload)), smart_str(signature), ) @@ -1234,54 +1204,36 @@ def test_es256_alg_verification(self, mock_requests): self.backend = OIDCAuthenticationBackend() # Generate a private key to create a test token with - private_key = ec.generate_private_key(ec.SECP256R1, default_backend()) - private_key_pem = private_key.private_bytes( - serialization.Encoding.PEM, - serialization.PrivateFormat.PKCS8, - serialization.NoEncryption(), - ) + private_key = JWK.generate(kty="EC", alg="ES256") + public_key = private_key.export_public(as_dict=True) # Make the public key available through the JWKS response - public_numbers = private_key.public_key().public_numbers() get_json_mock = Mock() get_json_mock.json.return_value = { "keys": [ { - "kid": "eckid", + "kid": private_key.thumbprint(), "kty": "EC", "alg": "ES256", "use": "sig", - "x": smart_str(b64encode(public_numbers.x.to_bytes(32, "big"))), - "y": smart_str(b64encode(public_numbers.y.to_bytes(32, "big"))), + "x": smart_str(public_key["x"]), + "y": smart_str(public_key["y"]), "crv": "P-256", } ] } mock_requests.get.return_value = get_json_mock - header = force_bytes( - json.dumps( - { - "typ": "JWT", - "alg": "ES256", - "kid": "eckid", - }, - ) - ) + header = { + "typ": "JWT", + "alg": "ES256", + "kid": private_key.thumbprint(), + } data = {"name": "John Doe", "test": "test_es256_alg_verification"} - h = hmac.HMAC(private_key_pem, hashes.SHA256(), backend=default_backend()) - msg = "{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(force_bytes(json.dumps(data)))), - ) - h.update(force_bytes(msg)) - - signature = b64encode(ES256.sign(private_key, force_bytes(msg))) - token = "{}.{}".format( - msg, - smart_str(signature), - ) + jwt = JWT(header, data) + jwt.make_signed_token(private_key) + token = jwt.serialize(compact=True) # Verify the token created with the private key by using the JWKS endpoint, # where the public numbers are.