diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e868121b85d..3fcfe6a63ceb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,15 @@ and this project adheres Fto [Semantic Versioning](http://semver.org/spec/v2.0.0 - All definitions in CCF's public headers are now under the `ccf::` namespace. Any application code which references any of these types directly (notably `StartupConfig`, `http_status`, `LoggerLevel`), they will now need to be prefixed with the `ccf::` namespace. - `cchost` now requires `--config`. +### Changed + +- JWT authentication now supports raw public keys along with certificates (#6601). + - Public key information ('n' and 'e', or 'x', 'y' and 'crv' fields) now have a priority if defined in JWK set, 'x5c' remains as a backup option. + - Has same side-effects as #5809 does please see the changelog entry for that change for more details. In short: + - stale JWKs may be used for JWT validation on older nodes during the upgrade. + - old tables are not cleaned up, #6222 is tracking those. +- A deprecated `GET /gov/jwt_keys/all` has been altered because of #6601, as soon as JWT certificates are no longer stored in CCF. A new "public_key" field has been added, "cert" is now left empty. + ## [6.0.0-dev7] [6.0.0-dev7]: https://github.com/microsoft/CCF/releases/tag/6.0.0-dev7 diff --git a/doc/schemas/gov_openapi.json b/doc/schemas/gov_openapi.json index 90ab5ed30d61..290b503c64b2 100644 --- a/doc/schemas/gov_openapi.json +++ b/doc/schemas/gov_openapi.json @@ -799,6 +799,24 @@ "type": "string" }, "OpenIDJWKMetadata": { + "properties": { + "constraint": { + "$ref": "#/components/schemas/string" + }, + "issuer": { + "$ref": "#/components/schemas/string" + }, + "public_key": { + "$ref": "#/components/schemas/base64string" + } + }, + "required": [ + "issuer", + "public_key" + ], + "type": "object" + }, + "OpenIDJWKMetadataLegacy": { "properties": { "cert": { "$ref": "#/components/schemas/base64string" @@ -811,11 +829,17 @@ } }, "required": [ - "cert", - "issuer" + "issuer", + "cert" ], "type": "object" }, + "OpenIDJWKMetadataLegacy_array": { + "items": { + "$ref": "#/components/schemas/OpenIDJWKMetadataLegacy" + }, + "type": "array" + }, "OpenIDJWKMetadata_array": { "items": { "$ref": "#/components/schemas/OpenIDJWKMetadata" @@ -1228,6 +1252,12 @@ }, "type": "object" }, + "string_to_OpenIDJWKMetadataLegacy_array": { + "additionalProperties": { + "$ref": "#/components/schemas/OpenIDJWKMetadataLegacy_array" + }, + "type": "object" + }, "string_to_OpenIDJWKMetadata_array": { "additionalProperties": { "$ref": "#/components/schemas/OpenIDJWKMetadata_array" @@ -1752,6 +1782,31 @@ "get": { "deprecated": true, "operationId": "GetGovKvJwtPublicSigningKeysMetadata", + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/string_to_OpenIDJWKMetadataLegacy_array" + } + } + }, + "description": "Default response description" + }, + "default": { + "$ref": "#/components/responses/default" + } + }, + "summary": "This route is auto-generated from the KV schema.", + "x-ccf-forwarding": { + "$ref": "#/components/x-ccf-forwarding/sometimes" + } + } + }, + "/gov/kv/jwt/public_signing_keys_metadata_v2": { + "get": { + "deprecated": true, + "operationId": "GetGovKvJwtPublicSigningKeysMetadataV2", "responses": { "200": { "content": { diff --git a/include/ccf/crypto/ecdsa.h b/include/ccf/crypto/ecdsa.h index e61b3ad24b74..0d6161117e03 100644 --- a/include/ccf/crypto/ecdsa.h +++ b/include/ccf/crypto/ecdsa.h @@ -4,6 +4,7 @@ #include "ccf/crypto/curve.h" +#include #include namespace ccf::crypto @@ -28,7 +29,7 @@ namespace ccf::crypto * @param signature The signature in IEEE P1363 encoding */ std::vector ecdsa_sig_p1363_to_der( - const std::vector& signature); + std::span signature); std::vector ecdsa_sig_der_to_p1363( const std::vector& signature, CurveID curveId); diff --git a/include/ccf/crypto/jwk.h b/include/ccf/crypto/jwk.h index 1b4886cb1a22..ae0f1f5b9ab0 100644 --- a/include/ccf/crypto/jwk.h +++ b/include/ccf/crypto/jwk.h @@ -27,13 +27,12 @@ namespace ccf::crypto JsonWebKeyType kty; std::optional kid = std::nullopt; std::optional> x5c = std::nullopt; - std::optional issuer = std::nullopt; bool operator==(const JsonWebKey&) const = default; }; DECLARE_JSON_TYPE_WITH_OPTIONAL_FIELDS(JsonWebKey); DECLARE_JSON_REQUIRED_FIELDS(JsonWebKey, kty); - DECLARE_JSON_OPTIONAL_FIELDS(JsonWebKey, kid, x5c, issuer); + DECLARE_JSON_OPTIONAL_FIELDS(JsonWebKey, kid, x5c); enum class JsonWebKeyECCurve { @@ -47,6 +46,25 @@ namespace ccf::crypto {JsonWebKeyECCurve::P384, "P-384"}, {JsonWebKeyECCurve::P521, "P-521"}}); + struct JsonWebKeyData + { + JsonWebKeyType kty; + std::optional kid = std::nullopt; + std::optional> x5c = std::nullopt; + std::optional n = std::nullopt; + std::optional e = std::nullopt; + std::optional x = std::nullopt; + std::optional y = std::nullopt; + std::optional crv = std::nullopt; + std::optional issuer = std::nullopt; + + bool operator==(const JsonWebKeyData&) const = default; + }; + DECLARE_JSON_TYPE_WITH_OPTIONAL_FIELDS(JsonWebKeyData); + DECLARE_JSON_REQUIRED_FIELDS(JsonWebKeyData, kty); + DECLARE_JSON_OPTIONAL_FIELDS( + JsonWebKeyData, kid, x5c, n, e, x, y, crv, issuer); + static JsonWebKeyECCurve curve_id_to_jwk_curve(CurveID curve_id) { switch (curve_id) diff --git a/include/ccf/crypto/rsa_public_key.h b/include/ccf/crypto/rsa_public_key.h index cd62eba0e7f4..1fcd81dc6d43 100644 --- a/include/ccf/crypto/rsa_public_key.h +++ b/include/ccf/crypto/rsa_public_key.h @@ -84,6 +84,13 @@ namespace ccf::crypto MDType md_type = MDType::NONE, size_t salt_legth = 0) = 0; + virtual bool verify_pkcs1( + const uint8_t* contents, + size_t contents_size, + const uint8_t* signature, + size_t signature_size, + MDType md_type = MDType::NONE) = 0; + struct Components { std::vector n; diff --git a/include/ccf/endpoints/authentication/jwt_auth.h b/include/ccf/endpoints/authentication/jwt_auth.h index 3a44ee55a73a..70d4f3f7c813 100644 --- a/include/ccf/endpoints/authentication/jwt_auth.h +++ b/include/ccf/endpoints/authentication/jwt_auth.h @@ -17,7 +17,7 @@ namespace ccf nlohmann::json payload; }; - struct VerifiersCache; + struct PublicKeysCache; bool validate_issuer( const std::string& iss, @@ -28,7 +28,7 @@ namespace ccf { protected: static const OpenAPISecuritySchema security_schema; - std::unique_ptr verifiers; + std::unique_ptr keys_cache; public: static constexpr auto SECURITY_SCHEME_NAME = "jwt"; diff --git a/include/ccf/service/tables/jwt.h b/include/ccf/service/tables/jwt.h index 23ebe5268499..8b21448bf58a 100644 --- a/include/ccf/service/tables/jwt.h +++ b/include/ccf/service/tables/jwt.h @@ -37,27 +37,42 @@ namespace ccf using JwtIssuer = std::string; using JwtKeyId = std::string; using Cert = std::vector; + using PublicKey = std::vector; struct OpenIDJWKMetadata { - Cert cert; + PublicKey public_key; JwtIssuer issuer; std::optional constraint; }; DECLARE_JSON_TYPE_WITH_OPTIONAL_FIELDS(OpenIDJWKMetadata); - DECLARE_JSON_REQUIRED_FIELDS(OpenIDJWKMetadata, cert, issuer); + DECLARE_JSON_REQUIRED_FIELDS(OpenIDJWKMetadata, issuer, public_key); DECLARE_JSON_OPTIONAL_FIELDS(OpenIDJWKMetadata, constraint); - using JwtIssuers = ServiceMap; - using JwtPublicSigningKeys = + using JwtPublicSigningKeysMetadata = ServiceMap>; + struct OpenIDJWKMetadataLegacy + { + Cert cert; + JwtIssuer issuer; + std::optional constraint; + }; + DECLARE_JSON_TYPE_WITH_OPTIONAL_FIELDS(OpenIDJWKMetadataLegacy); + DECLARE_JSON_REQUIRED_FIELDS(OpenIDJWKMetadataLegacy, issuer, cert); + DECLARE_JSON_OPTIONAL_FIELDS(OpenIDJWKMetadataLegacy, constraint); + + using JwtPublicSigningKeysMetadataLegacy = + ServiceMap>; + + using JwtIssuers = ServiceMap; + namespace Tables { static constexpr auto JWT_ISSUERS = "public:ccf.gov.jwt.issuers"; static constexpr auto JWT_PUBLIC_SIGNING_KEYS_METADATA = - "public:ccf.gov.jwt.public_signing_keys_metadata"; + "public:ccf.gov.jwt.public_signing_keys_metadata_v2"; namespace Legacy { @@ -65,6 +80,8 @@ namespace ccf "public:ccf.gov.jwt.public_signing_key"; static constexpr auto JWT_PUBLIC_SIGNING_KEY_ISSUER = "public:ccf.gov.jwt.public_signing_key_issuer"; + static constexpr auto JWT_PUBLIC_SIGNING_KEYS_METADATA = + "public:ccf.gov.jwt.public_signing_keys_metadata"; using JwtPublicSigningKeys = ccf::kv::RawCopySerialisedMap; @@ -75,7 +92,7 @@ namespace ccf struct JsonWebKeySet { - std::vector keys; + std::vector keys; bool operator!=(const JsonWebKeySet& rhs) const { diff --git a/samples/constitutions/default/actions.js b/samples/constitutions/default/actions.js index 654ecc065326..98f0ed2f7c0d 100644 --- a/samples/constitutions/default/actions.js +++ b/samples/constitutions/default/actions.js @@ -130,15 +130,28 @@ function checkJwks(value, field) { for (const [i, jwk] of value.keys.entries()) { checkType(jwk.kid, "string", `${field}.keys[${i}].kid`); checkType(jwk.kty, "string", `${field}.keys[${i}].kty`); - checkType(jwk.x5c, "array", `${field}.keys[${i}].x5c`); - checkLength(jwk.x5c, 1, null, `${field}.keys[${i}].x5c`); - for (const [j, b64der] of jwk.x5c.entries()) { - checkType(b64der, "string", `${field}.keys[${i}].x5c[${j}]`); - const pem = - "-----BEGIN CERTIFICATE-----\n" + - b64der + - "\n-----END CERTIFICATE-----"; - checkX509CertBundle(pem, `${field}.keys[${i}].x5c[${j}]`); + if (jwk.x5c) { + checkType(jwk.x5c, "array", `${field}.keys[${i}].x5c`); + checkLength(jwk.x5c, 1, null, `${field}.keys[${i}].x5c`); + for (const [j, b64der] of jwk.x5c.entries()) { + checkType(b64der, "string", `${field}.keys[${i}].x5c[${j}]`); + const pem = + "-----BEGIN CERTIFICATE-----\n" + + b64der + + "\n-----END CERTIFICATE-----"; + checkX509CertBundle(pem, `${field}.keys[${i}].x5c[${j}]`); + } + } else if (jwk.n && jwk.e) { + checkType(jwk.n, "string", `${field}.keys[${i}].n`); + checkType(jwk.e, "string", `${field}.keys[${i}].e`); + } else if (jwk.x && jwk.y) { + checkType(jwk.x, "string", `${field}.keys[${i}].x`); + checkType(jwk.y, "string", `${field}.keys[${i}].y`); + checkType(jwk.crv, "string", `${field}.keys[${i}].crv`); + } else { + throw new Error( + "JWK must contain either x5c, or n/e for RSA key type, or x/y/crv for EC key type", + ); } } } diff --git a/src/crypto/ecdsa.cpp b/src/crypto/ecdsa.cpp index 7ad640631b45..44e6fb5d9967 100644 --- a/src/crypto/ecdsa.cpp +++ b/src/crypto/ecdsa.cpp @@ -45,7 +45,7 @@ namespace ccf::crypto } std::vector ecdsa_sig_p1363_to_der( - const std::vector& signature) + std::span signature) { auto half_size = signature.size() / 2; return ecdsa_sig_from_r_s( diff --git a/src/crypto/openssl/rsa_public_key.cpp b/src/crypto/openssl/rsa_public_key.cpp index b8fb2f61be58..86d44d6418b9 100644 --- a/src/crypto/openssl/rsa_public_key.cpp +++ b/src/crypto/openssl/rsa_public_key.cpp @@ -54,6 +54,17 @@ namespace ccf::crypto auto msg = OpenSSL::error_string(ec); throw std::runtime_error(fmt::format("OpenSSL error: {}", msg)); } + +// As it's a common pattern to rely on successful key wrapper construction as a +// confirmation of a concrete key type, this must fail for non-RSA keys. +#if defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3 + if (!key || EVP_PKEY_get_base_id(key) != EVP_PKEY_RSA) +#else + if (!key || !EVP_PKEY_get0_RSA(key)) +#endif + { + throw std::logic_error("invalid RSA key"); + } } std::pair get_modulus_and_exponent( @@ -208,6 +219,22 @@ namespace ccf::crypto pctx, signature, signature_size, hash.data(), hash.size()) == 1; } + bool RSAPublicKey_OpenSSL::verify_pkcs1( + const uint8_t* contents, + size_t contents_size, + const uint8_t* signature, + size_t signature_size, + MDType md_type) + { + auto hash = OpenSSLHashProvider().Hash(contents, contents_size, md_type); + Unique_EVP_PKEY_CTX pctx(key); + CHECK1(EVP_PKEY_verify_init(pctx)); + CHECK1(EVP_PKEY_CTX_set_rsa_padding(pctx, RSA_PKCS1_PADDING)); + CHECK1(EVP_PKEY_CTX_set_signature_md(pctx, get_md_type(md_type))); + return EVP_PKEY_verify( + pctx, signature, signature_size, hash.data(), hash.size()) == 1; + } + std::vector RSAPublicKey_OpenSSL::bn_bytes(const BIGNUM* bn) { std::vector r(BN_num_bytes(bn)); diff --git a/src/crypto/openssl/rsa_public_key.h b/src/crypto/openssl/rsa_public_key.h index 061ba053ad80..abe43fcf758a 100644 --- a/src/crypto/openssl/rsa_public_key.h +++ b/src/crypto/openssl/rsa_public_key.h @@ -55,6 +55,13 @@ namespace ccf::crypto MDType md_type = MDType::NONE, size_t salt_length = 0) override; + virtual bool verify_pkcs1( + const uint8_t* contents, + size_t contents_size, + const uint8_t* signature, + size_t signature_size, + MDType md_type = MDType::NONE) override; + virtual Components components() const override; static std::vector bn_bytes(const BIGNUM* bn); diff --git a/src/endpoints/authentication/jwt_auth.cpp b/src/endpoints/authentication/jwt_auth.cpp index 05ceb862ff2d..91875aa95525 100644 --- a/src/endpoints/authentication/jwt_auth.cpp +++ b/src/endpoints/authentication/jwt_auth.cpp @@ -3,6 +3,9 @@ #include "ccf/endpoints/authentication/jwt_auth.h" +#include "ccf/crypto/ecdsa.h" +#include "ccf/crypto/public_key.h" +#include "ccf/crypto/rsa_key_pair.h" #include "ccf/ds/nonstd.h" #include "ccf/pal/locking.h" #include "ccf/rpc_context.h" @@ -82,34 +85,77 @@ namespace ccf return tenant_id && tid && *tid == *tenant_id; } - struct VerifiersCache + struct PublicKeysCache { - static constexpr size_t DEFAULT_MAX_VERIFIERS = 10; + static constexpr size_t DEFAULT_MAX_KEYS = 10; using DER = std::vector; - ccf::pal::Mutex verifiers_lock; - LRU verifiers; + using KeyVariant = + std::variant; + ccf::pal::Mutex keys_lock; + LRU keys; - VerifiersCache(size_t max_verifiers = DEFAULT_MAX_VERIFIERS) : - verifiers(max_verifiers) - {} + PublicKeysCache(size_t max_keys = DEFAULT_MAX_KEYS) : keys(max_keys) {} - ccf::crypto::VerifierPtr get_verifier(const DER& der) + bool verify( + const uint8_t* contents, + size_t contents_size, + const uint8_t* signature, + size_t signature_size, + const DER& der) { - std::lock_guard guard(verifiers_lock); + std::lock_guard guard(keys_lock); - auto it = verifiers.find(der); - if (it == verifiers.end()) + auto it = keys.find(der); + if (it == keys.end()) { - it = verifiers.insert(der, ccf::crypto::make_unique_verifier(der)); + try + { + it = keys.insert(der, ccf::crypto::make_rsa_public_key(der)); + } + catch (const std::exception&) + { + it = keys.insert(der, ccf::crypto::make_public_key(der)); + } } - return it->second; + const auto& key = it->second; + if (std::holds_alternative(key)) + { + LOG_DEBUG_FMT("Verify der: {} as RSA key", der); + + // Obsolete PKCS1 padding is chosen for JWT, as explained in details in + // https://github.com/microsoft/CCF/issues/6601#issuecomment-2512059875. + return std::get(key)->verify_pkcs1( + contents, + contents_size, + signature, + signature_size, + ccf::crypto::MDType::SHA256); + } + else if (std::holds_alternative(key)) + { + LOG_DEBUG_FMT("Verify der: {} as EC key", der); + + const auto sig_der = + ccf::crypto::ecdsa_sig_p1363_to_der({signature, signature_size}); + return std::get(key)->verify( + contents, + contents_size, + sig_der.data(), + sig_der.size(), + ccf::crypto::MDType::SHA256); + } + else + { + LOG_DEBUG_FMT("Key not found for der: {}", der); + return false; + } } }; JwtAuthnPolicy::JwtAuthnPolicy() : - verifiers(std::make_unique()) + keys_cache(std::make_unique()) {} JwtAuthnPolicy::~JwtAuthnPolicy() = default; @@ -129,11 +175,42 @@ namespace ccf } auto& token = token_opt.value(); - auto keys = tx.ro( + auto keys = tx.ro( ccf::Tables::JWT_PUBLIC_SIGNING_KEYS_METADATA); const auto key_id = token.header_typed.kid; auto token_keys = keys->get(key_id); + // For metadata KID->(cert,issuer,constraint). + // + // Note, that Legacy keys are stored as certs, new approach is raw keys, so + // conversion from cert to raw key is needed. + if (!token_keys) + { + auto fallback_certs = tx.ro( + ccf::Tables::Legacy::JWT_PUBLIC_SIGNING_KEYS_METADATA); + auto fallback_data = fallback_certs->get(key_id); + if (fallback_data) + { + auto new_keys = std::vector(); + for (const auto& metadata : *fallback_data) + { + auto verifier = ccf::crypto::make_unique_verifier(metadata.cert); + new_keys.push_back(OpenIDJWKMetadata{ + .public_key = verifier->public_key_der(), + .issuer = metadata.issuer, + .constraint = metadata.constraint}); + } + if (!new_keys.empty()) + { + token_keys = new_keys; + } + } + } + + // For metadata as two separate tables, KID->JwtIssuer and KID->Cert. + // + // Note, that Legacy keys are stored as certs, new approach is raw keys, so + // conversion from certs to keys is needed. if (!token_keys) { auto fallback_keys = tx.ro( @@ -141,11 +218,12 @@ namespace ccf auto fallback_issuers = tx.ro( ccf::Tables::Legacy::JWT_PUBLIC_SIGNING_KEY_ISSUER); - auto fallback_key = fallback_keys->get(key_id); - if (fallback_key) + auto fallback_cert = fallback_keys->get(key_id); + if (fallback_cert) { + auto verifier = ccf::crypto::make_unique_verifier(*fallback_cert); token_keys = std::vector{OpenIDJWKMetadata{ - .cert = *fallback_key, + .public_key = verifier->public_key_der(), .issuer = *fallback_issuers->get(key_id), .constraint = std::nullopt}}; } @@ -160,8 +238,12 @@ namespace ccf for (const auto& metadata : *token_keys) { - auto verifier = verifiers->get_verifier(metadata.cert); - if (!::http::JwtVerifier::validate_token_signature(token, verifier)) + if (!keys_cache->verify( + (uint8_t*)token.signed_content.data(), + token.signed_content.size(), + token.signature.data(), + token.signature.size(), + metadata.public_key)) { error_reason = "Signature verification failed"; continue; @@ -171,7 +253,7 @@ namespace ccf const size_t time_now = std::chrono::duration_cast( ccf::get_enclave_time()) .count(); - if (time_now < token.payload_typed.nbf) + if (token.payload_typed.nbf && time_now < *token.payload_typed.nbf) { error_reason = fmt::format( "Current time {} is before token's Not Before (nbf) claim {}", diff --git a/src/http/http_jwt.h b/src/http/http_jwt.h index aecf1c71074e..09d688400cb7 100644 --- a/src/http/http_jwt.h +++ b/src/http/http_jwt.h @@ -16,9 +16,13 @@ namespace http { enum class JwtCryptoAlgorithm { - RS256 + RS256, + ES256, }; - DECLARE_JSON_ENUM(JwtCryptoAlgorithm, {{JwtCryptoAlgorithm::RS256, "RS256"}}); + DECLARE_JSON_ENUM( + JwtCryptoAlgorithm, + {{JwtCryptoAlgorithm::RS256, "RS256"}, + {JwtCryptoAlgorithm::ES256, "ES256"}}); struct JwtHeader { @@ -30,14 +34,14 @@ namespace http struct JwtPayload { - size_t nbf; size_t exp; std::string iss; + std::optional nbf; std::optional tid; }; DECLARE_JSON_TYPE_WITH_OPTIONAL_FIELDS(JwtPayload) - DECLARE_JSON_REQUIRED_FIELDS(JwtPayload, nbf, exp, iss); - DECLARE_JSON_OPTIONAL_FIELDS(JwtPayload, tid) + DECLARE_JSON_REQUIRED_FIELDS(JwtPayload, exp, iss); + DECLARE_JSON_OPTIONAL_FIELDS(JwtPayload, nbf, tid); class JwtVerifier { diff --git a/src/js/extensions/ccf/crypto.cpp b/src/js/extensions/ccf/crypto.cpp index 25f2b2ba6d85..9fa54a47ede9 100644 --- a/src/js/extensions/ccf/crypto.cpp +++ b/src/js/extensions/ccf/crypto.cpp @@ -1023,10 +1023,10 @@ namespace ccf::js::extensions } std::vector sig(signature, signature + signature_size); - if (algo_name == "ECDSA") { - sig = ccf::crypto::ecdsa_sig_p1363_to_der(sig); + sig = + ccf::crypto::ecdsa_sig_p1363_to_der({signature, signature_size}); } auto is_cert = key.starts_with("-----BEGIN CERTIFICATE"); diff --git a/src/node/gov/handlers/service_state.h b/src/node/gov/handlers/service_state.h index eabd6d65ea3a..6992fabb3f1c 100644 --- a/src/node/gov/handlers/service_state.h +++ b/src/node/gov/handlers/service_state.h @@ -600,7 +600,7 @@ namespace ccf::gov::endpoints auto keys = nlohmann::json::object(); auto jwt_keys_handle = - ctx.tx.template ro( + ctx.tx.template ro( ccf::Tables::JWT_PUBLIC_SIGNING_KEYS_METADATA); jwt_keys_handle->foreach( @@ -612,11 +612,10 @@ namespace ccf::gov::endpoints { auto info = nlohmann::json::object(); - // cert is stored as DER - convert to PEM for API - const auto cert_pem = - ccf::crypto::cert_der_to_pem(metadata.cert); - info["certificate"] = cert_pem.str(); - + info["publicKey"] = + ccf::crypto::make_rsa_public_key(metadata.public_key) + ->public_key_pem() + .str(); info["issuer"] = metadata.issuer; info["constraint"] = metadata.constraint; diff --git a/src/node/rpc/jwt_management.h b/src/node/rpc/jwt_management.h index af7c011ac8ef..b50b0d031743 100644 --- a/src/node/rpc/jwt_management.h +++ b/src/node/rpc/jwt_management.h @@ -2,6 +2,7 @@ // Licensed under the Apache 2.0 License. #pragma once +#include "ccf/crypto/rsa_key_pair.h" #include "ccf/crypto/verifier.h" #include "ccf/ds/hex.h" #include "ccf/service/tables/jwt.h" @@ -12,13 +13,120 @@ #include #include +namespace +{ + std::vector try_parse_raw_rsa(const ccf::crypto::JsonWebKeyData& jwk) + { + if (!jwk.e || jwk.e->empty() || !jwk.n || jwk.n->empty()) + { + return {}; + } + + std::vector der; + ccf::crypto::JsonWebKeyRSAPublic data; + data.kty = ccf::crypto::JsonWebKeyType::RSA; + data.kid = jwk.kid.value(); + data.n = jwk.n.value(); + data.e = jwk.e.value(); + try + { + const auto pubkey = ccf::crypto::make_rsa_public_key(data); + return pubkey->public_key_der(); + } + catch (const std::invalid_argument& exc) + { + throw std::logic_error( + fmt::format("Failed to construct RSA public key: {}", exc.what())); + } + } + + std::vector try_parse_raw_ec(const ccf::crypto::JsonWebKeyData& jwk) + { + if (!jwk.x || jwk.x->empty() || !jwk.y || jwk.y->empty() || !jwk.crv) + { + return {}; + } + + ccf::crypto::JsonWebKeyECPublic data; + data.kty = ccf::crypto::JsonWebKeyType::EC; + data.kid = jwk.kid.value(); + data.crv = jwk.crv.value(); + data.x = jwk.x.value(); + data.y = jwk.y.value(); + try + { + const auto pubkey = ccf::crypto::make_public_key(data); + return pubkey->public_key_der(); + } + catch (const std::invalid_argument& exc) + { + throw std::logic_error( + fmt::format("Failed to construct EC public key: {}", exc.what())); + } + } + + std::vector try_parse_x5c(const ccf::crypto::JsonWebKeyData& jwk) + { + if (!jwk.x5c || jwk.x5c->empty()) + { + return {}; + } + + const auto& kid = jwk.kid.value(); + auto& der_base64 = jwk.x5c.value()[0]; + ccf::Cert der; + try + { + der = ccf::crypto::raw_from_b64(der_base64); + } + catch (const std::invalid_argument& e) + { + throw std::logic_error( + fmt::format("Could not parse x5c of key id {}: {}", kid, e.what())); + } + try + { + auto verifier = ccf::crypto::make_unique_verifier(der); + return verifier->public_key_der(); + } + catch (std::invalid_argument& exc) + { + throw std::logic_error(fmt::format( + "JWKS kid {} has an invalid X.509 certificate: {}", kid, exc.what())); + } + } + + std::vector try_parse_jwk(const ccf::crypto::JsonWebKeyData& jwk) + { + const auto& kid = jwk.kid.value(); + auto key = try_parse_raw_rsa(jwk); + if (!key.empty()) + { + return key; + } + key = try_parse_raw_ec(jwk); + if (!key.empty()) + { + return key; + } + key = try_parse_x5c(jwk); + if (!key.empty()) + { + return key; + } + + throw std::logic_error( + fmt::format("JWKS kid {} has neither RSA/EC public key or x5c", kid)); + } +} + namespace ccf { static void legacy_remove_jwt_public_signing_keys( ccf::kv::Tx& tx, std::string issuer) { - auto keys = - tx.rw(Tables::Legacy::JWT_PUBLIC_SIGNING_KEYS); + auto keys = tx.rw( + Tables::Legacy::JWT_PUBLIC_SIGNING_KEYS); auto key_issuer = tx.rw( Tables::Legacy::JWT_PUBLIC_SIGNING_KEY_ISSUER); @@ -31,14 +139,38 @@ namespace ccf } return true; }); + + auto metadata = tx.rw( + Tables::Legacy::JWT_PUBLIC_SIGNING_KEYS_METADATA); + metadata->foreach([&issuer, &metadata](const auto& k, const auto& v) { + std::vector updated; + for (const auto& key : v) + { + if (key.issuer != issuer) + { + updated.push_back(key); + } + } + + if (updated.empty()) + { + metadata->remove(k); + } + else if (updated.size() < v.size()) + { + metadata->put(k, updated); + } + + return true; + }); } static bool check_issuer_constraint( const std::string& issuer, const std::string& constraint) { // Only accept key constraints for the same (sub)domain. This is to avoid - // setting keys from issuer A which will be used to validate iss claims for - // issuer B, so this doesn't make sense (at least for now). + // setting keys from issuer A which will be used to validate iss claims + // for issuer B, so this doesn't make sense (at least for now). const auto issuer_domain = ::http::parse_url_full(issuer).host; const auto constraint_domain = ::http::parse_url_full(constraint).host; @@ -48,13 +180,13 @@ namespace ccf return false; } - // Either constraint's domain == issuer's domain or it is a subdomain, e.g.: - // limited.facebook.com + // Either constraint's domain == issuer's domain or it is a subdomain, + // e.g.: limited.facebook.com // .facebook.com // // It may make sense to support vice-versa too, but we haven't found any - // instances of that so far, so leaveing it only-way only for facebook-like - // cases. + // instances of that so far, so leaving it only-way only for + // facebook-like cases. if (issuer_domain != constraint_domain) { const auto pattern = "." + constraint_domain; @@ -68,12 +200,12 @@ namespace ccf ccf::kv::Tx& tx, std::string issuer) { // Unlike resetting JWT keys for a particular issuer, removing keys can be - // safely done on both table revisions, as soon as the application shouldn't - // use them anyway after being ask about that explicitly. + // safely done on both table revisions, as soon as the application + // shouldn't use them anyway after being ask about that explicitly. legacy_remove_jwt_public_signing_keys(tx, issuer); - auto keys = - tx.rw(Tables::JWT_PUBLIC_SIGNING_KEYS_METADATA); + auto keys = tx.rw( + Tables::JWT_PUBLIC_SIGNING_KEYS_METADATA); keys->foreach([&issuer, &keys](const auto& k, const auto& v) { auto it = find_if(v.begin(), v.end(), [&](const auto& metadata) { @@ -105,82 +237,53 @@ namespace ccf const JwtIssuerMetadata& issuer_metadata, const JsonWebKeySet& jwks) { - auto keys = - tx.rw(Tables::JWT_PUBLIC_SIGNING_KEYS_METADATA); + auto keys = tx.rw( + Tables::JWT_PUBLIC_SIGNING_KEYS_METADATA); // add keys if (jwks.keys.empty()) { LOG_FAIL_FMT("{}: JWKS has no keys", log_prefix); return false; } - std::map> new_keys; + std::map new_keys; std::map issuer_constraints; - for (auto& jwk : jwks.keys) - { - if (!jwk.kid.has_value()) - { - LOG_FAIL_FMT("No kid for JWT signing key"); - return false; - } - - if (!jwk.x5c.has_value() && jwk.x5c->empty()) - { - LOG_FAIL_FMT("{}: JWKS is invalid (empty x5c)", log_prefix); - return false; - } - - auto& der_base64 = jwk.x5c.value()[0]; - ccf::Cert der; - auto const& kid = jwk.kid.value(); - try - { - der = ccf::crypto::raw_from_b64(der_base64); - } - catch (const std::invalid_argument& e) - { - LOG_FAIL_FMT( - "{}: Could not parse x5c of key id {}: {}", - log_prefix, - kid, - e.what()); - return false; - } - try - { - ccf::crypto::make_unique_verifier( - (std::vector)der); // throws on error - } - catch (std::invalid_argument& exc) + try + { + for (auto& jwk : jwks.keys) { - LOG_FAIL_FMT( - "{}: JWKS kid {} has an invalid X.509 certificate: {}", - log_prefix, - kid, - exc.what()); - return false; - } + if (!jwk.kid.has_value()) + { + throw std::logic_error("Missing kid for JWT signing key"); + } - LOG_INFO_FMT("{}: Storing JWT signing key with kid {}", log_prefix, kid); - new_keys.emplace(kid, der); + const auto& kid = jwk.kid.value(); + auto key_der = try_parse_jwk(jwk); - if (jwk.issuer) - { - if (!check_issuer_constraint(issuer, *jwk.issuer)) + if (jwk.issuer) { - LOG_FAIL_FMT( - "{}: JWKS kid {} with issuer constraint {} fails validation " - "against issuer {}", - log_prefix, - kid, - *jwk.issuer, - issuer); - return false; + if (!check_issuer_constraint(issuer, *jwk.issuer)) + { + throw std::logic_error(fmt::format( + "JWKS kid {} with issuer constraint {} fails validation " + "against " + "issuer {}", + kid, + *jwk.issuer, + issuer)); + } + + issuer_constraints.emplace(kid, *jwk.issuer); } - issuer_constraints.emplace(kid, *jwk.issuer); + new_keys.emplace(kid, key_der); } } + catch (const std::exception& exc) + { + LOG_FAIL_FMT("{}: {}", log_prefix, exc.what()); + return false; + } if (new_keys.empty()) { @@ -203,7 +306,10 @@ namespace ccf for (auto& [kid, der] : new_keys) { - OpenIDJWKMetadata value{der, issuer, std::nullopt}; + OpenIDJWKMetadata value{ + .public_key = der, .issuer = issuer, .constraint = std::nullopt}; + value.public_key = der; + const auto it = issuer_constraints.find(kid); if (it != issuer_constraints.end()) { @@ -218,7 +324,7 @@ namespace ccf keys_for_kid->begin(), keys_for_kid->end(), [&value](const auto& metadata) { - return metadata.cert == value.cert && + return metadata.public_key == value.public_key && metadata.issuer == value.issuer && metadata.constraint == value.constraint; }) != keys_for_kid->end()) diff --git a/src/node/rpc/member_frontend.h b/src/node/rpc/member_frontend.h index 05086a79eb71..c4f4a5da3a70 100644 --- a/src/node/rpc/member_frontend.h +++ b/src/node/rpc/member_frontend.h @@ -71,6 +71,7 @@ namespace ccf { JwtIssuer issuer; ccf::crypto::Pem cert; + std::string public_key; }; DECLARE_JSON_TYPE(KeyIdInfo) DECLARE_JSON_REQUIRED_FIELDS(KeyIdInfo, issuer, cert) @@ -1108,7 +1109,9 @@ namespace ccf for (const auto& metadata : v) { info.push_back(KeyIdInfo{ - metadata.issuer, ccf::crypto::cert_der_to_pem(metadata.cert)}); + .issuer = metadata.issuer, + .cert = ccf::crypto::Pem(), + .public_key = ccf::crypto::b64_from_raw(metadata.public_key)}); } kmap.emplace(k, std::move(info)); return true; diff --git a/src/service/network_tables.h b/src/service/network_tables.h index 888621e34e07..9a1df3a7348e 100644 --- a/src/service/network_tables.h +++ b/src/service/network_tables.h @@ -154,8 +154,11 @@ namespace ccf // const CACertBundlePEMs ca_cert_bundles = {Tables::CA_CERT_BUNDLE_PEMS}; const JwtIssuers jwt_issuers = {Tables::JWT_ISSUERS}; - const JwtPublicSigningKeys jwt_public_signing_keys_metadata = { + const JwtPublicSigningKeysMetadata jwt_public_signing_keys_metadata = { Tables::JWT_PUBLIC_SIGNING_KEYS_METADATA}; + const JwtPublicSigningKeysMetadataLegacy + legacy_jwt_public_signing_keys_metadata = { + Tables::Legacy::JWT_PUBLIC_SIGNING_KEYS_METADATA}; const Tables::Legacy::JwtPublicSigningKeys legacy_jwt_public_signing_keys = {Tables::Legacy::JWT_PUBLIC_SIGNING_KEYS}; const Tables::Legacy::JwtPublicSigningKeyIssuer @@ -168,6 +171,7 @@ namespace ccf ca_cert_bundles, jwt_issuers, jwt_public_signing_keys_metadata, + legacy_jwt_public_signing_keys_metadata, legacy_jwt_public_signing_keys, legacy_jwt_public_signing_key_issuer); } diff --git a/tests/infra/crypto.py b/tests/infra/crypto.py index 23fbc8039fe3..2947ff2567a9 100644 --- a/tests/infra/crypto.py +++ b/tests/infra/crypto.py @@ -307,10 +307,8 @@ def pub_key_pem_to_der(pem: str) -> bytes: return cert.public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo) -def create_jwt(body_claims: dict, key_priv_pem: str, key_id: str) -> str: - return jwt.encode( - body_claims, key_priv_pem, algorithm="RS256", headers={"kid": key_id} - ) +def create_jwt(body_claims: dict, key_priv_pem: str, key_id: str, alg="RS256") -> str: + return jwt.encode(body_claims, key_priv_pem, algorithm=alg, headers={"kid": key_id}) def cert_pem_to_der(pem: str) -> bytes: diff --git a/tests/infra/jwt_issuer.py b/tests/infra/jwt_issuer.py index 1882e57b02f5..93abddbd7020 100644 --- a/tests/infra/jwt_issuer.py +++ b/tests/infra/jwt_issuer.py @@ -11,8 +11,23 @@ import json import time import uuid + from infra.log_capture import flush_info from loguru import logger as LOG +from enum import Enum +from cryptography.x509 import load_pem_x509_certificate +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import ec + + +class JwtAlg(Enum): + RS256 = "RS256" # RSA using SHA-256 + ES256 = "ES256" # ECDSA using P-256 and SHA-256 + + +class JwtAuthType(Enum): + CERT = 1 + KEY = 2 def make_bearer_header(jwt): @@ -107,17 +122,50 @@ def __exit__(self, exc_type, exc_value, traceback): self.stop() +def get_jwt_issuers(args, node): + with node.api_versioned_client(api_version=args.gov_api_version) as c: + r = c.get("/gov/service/jwk") + assert r.status_code == HTTPStatus.OK, r + body = r.body.json() + return body["issuers"] + + +def get_jwt_keys(args, node): + with node.api_versioned_client(api_version=args.gov_api_version) as c: + r = c.get("/gov/service/jwk") + assert r.status_code == HTTPStatus.OK, r + body = r.body.json() + return body["keys"] + + +def to_b64(number: int): + as_bytes = number.to_bytes((number.bit_length() + 7) // 8, "big") + return base64.b64encode(as_bytes).decode("ascii") + + class JwtIssuer: TEST_JWT_ISSUER_NAME = "https://example.issuer" TEST_CA_BUNDLE_NAME = "test_ca_bundle_name" - def _generate_cert(self, cn=None): - key_priv, key_pub = infra.crypto.generate_rsa_keypair(2048) + def _generate_auth_data(self, cn=None): + if self._alg == JwtAlg.RS256: + key_priv, key_pub = infra.crypto.generate_rsa_keypair(2048) + elif self._alg == JwtAlg.ES256: + key_priv, key_pub = infra.crypto.generate_ec_keypair(ec.SECP256R1) + else: + raise ValueError(f"Unsupported algorithm: {self._alg}") + cert = infra.crypto.generate_cert(key_priv, cn=cn) return (key_priv, key_pub), cert def __init__( - self, name=TEST_JWT_ISSUER_NAME, cert=None, refresh_interval=3, cn=None + self, + name=TEST_JWT_ISSUER_NAME, + cert=None, + refresh_interval=3, + cn=None, + auth_type=JwtAuthType.CERT, + alg=JwtAlg.RS256, ): self.name = name self.default_kid = f"{uuid.uuid4()}" @@ -126,7 +174,9 @@ def __init__( # Auto-refresh ON if issuer name starts with "https://" self.auto_refresh = self.name.startswith("https://") stripped_host = self.name[len("https://") :] if self.auto_refresh else None - (self.tls_priv, _), self.tls_cert = self._generate_cert( + self._auth_type = auth_type + self._alg = alg + (self.tls_priv, _), self.tls_cert = self._generate_auth_data( cn or stripped_host or name ) if not cert: @@ -134,6 +184,11 @@ def __init__( else: self.cert_pem = cert + @property + def public_key(self): + cert = load_pem_x509_certificate(self.cert_pem.encode(), default_backend()) + return cert.public_key() + @property def issuer_url(self): name = f"{self.name}" @@ -141,25 +196,53 @@ def issuer_url(self): name += f":{self.server.bind_port}" return name - def refresh_keys(self, kid=None): + def refresh_keys(self, kid=None, send_update=True): if not kid: self.default_kid = f"{uuid.uuid4()}" kid_ = kid or self.default_kid - (self.key_priv_pem, self.key_pub_pem), self.cert_pem = self._generate_cert() - if self.server: + (self.key_priv_pem, self.key_pub_pem), self.cert_pem = ( + self._generate_auth_data() + ) + if self.server and send_update: self.server.set_jwks(self.create_jwks(kid_)) - def _create_jwks(self, kid, test_invalid_is_key=False): - der_b64 = base64.b64encode( - infra.crypto.cert_pem_to_der(self.cert_pem) - if not test_invalid_is_key - else infra.crypto.pub_key_pem_to_der(self.key_pub_pem) - ).decode("ascii") + def _create_jwks_with_cert(self, kid): + der_b64 = base64.b64encode(infra.crypto.cert_pem_to_der(self.cert_pem)).decode( + "ascii" + ) return {"kty": "RSA", "kid": kid, "x5c": [der_b64], "issuer": self.name[::]} - def create_jwks(self, kid=None, test_invalid_is_key=False): + def _create_jwks_with_raw_key(self, kid): + pubkey = self.public_key + if self._alg == JwtAlg.RS256: + n = to_b64(pubkey.public_numbers().n) + e = to_b64(pubkey.public_numbers().e) + return {"kty": "RSA", "kid": kid, "n": n, "e": e, "issuer": self.name[::]} + elif self._alg == JwtAlg.ES256: + x = to_b64(pubkey.public_numbers().x) + y = to_b64(pubkey.public_numbers().y) + return { + "kty": "EC", + "kid": kid, + "x": x, + "y": y, + "crv": "P-256", + "issuer": self.name, + } + else: + raise ValueError(f"Unsupported algorithm: {self._alg}") + + def _create_jwks(self, kid): + if self._auth_type == JwtAuthType.KEY: + return self._create_jwks_with_raw_key(kid) + elif self._auth_type == JwtAuthType.CERT: + return self._create_jwks_with_cert(kid) + else: + raise ValueError(f"Unsupported auth type: {self._auth_type}") + + def create_jwks(self, kid=None): kid_ = kid or self.default_kid - return {"keys": [self._create_jwks(kid_, test_invalid_is_key)]} + return {"keys": [self._create_jwks(kid_)]} def create_jwks_for_kids(self, kids): jwks = {} @@ -217,7 +300,8 @@ def issue_jwt(self, kid=None, claims=None): claims["exp"] = now + 3600 if "iss" not in claims: claims["iss"] = self.name - return infra.crypto.create_jwt(claims, self.key_priv_pem, kid_) + + return infra.crypto.create_jwt(claims, self.key_priv_pem, kid_, self._alg.value) def wait_for_refresh(self, network, args, kid=None): timeout = self.refresh_interval * 3 @@ -237,10 +321,16 @@ def wait_for_refresh(self, network, args, kid=None): LOG.warning(body) keys = body["keys"] if kid_ in keys: - stored_cert = keys[kid_][0]["certificate"] - if self.cert_pem == stored_cert: - flush_info(logs) - return + if "publicKey" in keys[kid_][0]: + stored_key = keys[kid_][0]["publicKey"] + if self.key_pub_pem == stored_key: + flush_info(logs) + return + else: + stored_cert = keys[kid_][0]["certificate"] + if self.cert_pem == stored_cert: + flush_info(logs) + return time.sleep(0.1) else: with primary.client( diff --git a/tests/js-custom-authorization/custom_authorization.py b/tests/js-custom-authorization/custom_authorization.py index 2dbd949e4ec0..ea4d03924637 100644 --- a/tests/js-custom-authorization/custom_authorization.py +++ b/tests/js-custom-authorization/custom_authorization.py @@ -13,7 +13,7 @@ import base64 import json import time -import infra.jwt_issuer +from infra.jwt_issuer import JwtAlg, JwtAuthType, JwtIssuer, make_bearer_header import datetime import re import uuid @@ -111,7 +111,7 @@ def try_auth(primary, issuer, kid, iss, tid): LOG.info(f"Creating JWT with kid={kid} iss={iss} tenant={tid}") return c.get( "/app/jwt", - headers=infra.jwt_issuer.make_bearer_header( + headers=make_bearer_header( issuer.issue_jwt(kid, claims={"iss": iss, "tid": tid}) ), ) @@ -344,7 +344,7 @@ def create_keypair(local_id, valid_from, validity_days): def test_jwt_auth(network, args): primary, _ = network.find_nodes() - issuer = infra.jwt_issuer.JwtIssuer("https://example.issuer") + issuer = JwtIssuer("https://example.issuer") jwt_kid = "my_key_id" @@ -354,26 +354,26 @@ def test_jwt_auth(network, args): LOG.info("Calling jwt endpoint after storing keys") with primary.client("user0") as c: - r = c.get("/app/jwt", headers=infra.jwt_issuer.make_bearer_header("garbage")) + r = c.get("/app/jwt", headers=make_bearer_header("garbage")) assert r.status_code == HTTPStatus.UNAUTHORIZED, r.status_code assert "Malformed JWT" in parse_error_message(r), r jwt_mismatching_key_priv_pem, _ = infra.crypto.generate_rsa_keypair(2048) jwt = infra.crypto.create_jwt({}, jwt_mismatching_key_priv_pem, jwt_kid) - r = c.get("/app/jwt", headers=infra.jwt_issuer.make_bearer_header(jwt)) + r = c.get("/app/jwt", headers=make_bearer_header(jwt)) assert r.status_code == HTTPStatus.UNAUTHORIZED, r.status_code assert "JWT payload is missing required field" in parse_error_message(r), r r = c.get( "/app/jwt", - headers=infra.jwt_issuer.make_bearer_header(issuer.issue_jwt(jwt_kid)), + headers=make_bearer_header(issuer.issue_jwt(jwt_kid)), ) assert r.status_code == HTTPStatus.OK, r.status_code LOG.info("Calling JWT with too-late nbf") r = c.get( "/app/jwt", - headers=infra.jwt_issuer.make_bearer_header( + headers=make_bearer_header( issuer.issue_jwt(jwt_kid, claims={"nbf": time.time() + 60}) ), ) @@ -383,7 +383,7 @@ def test_jwt_auth(network, args): LOG.info("Calling JWT with too-early exp") r = c.get( "/app/jwt", - headers=infra.jwt_issuer.make_bearer_header( + headers=make_bearer_header( issuer.issue_jwt(jwt_kid, claims={"exp": time.time() - 60}) ), ) @@ -394,6 +394,37 @@ def test_jwt_auth(network, args): return network +@reqs.description("JWT authentication as by OpenID spec with raw public key") +def test_jwt_auth_raw_key(network, args): + primary, _ = network.find_nodes() + + for alg in [JwtAlg.RS256, JwtAlg.ES256]: + issuer = JwtIssuer("noautorefresh://issuer", alg=alg, auth_type=JwtAuthType.KEY) + jwt_kid = "my_key_id" + issuer.register(network, kid=jwt_kid) + + LOG.info("Calling jwt endpoint after storing keys") + with primary.client("user0") as c: + token = issuer.issue_jwt(jwt_kid) + r = c.get( + "/app/jwt", + headers=make_bearer_header(token), + ) + assert r.status_code == HTTPStatus.OK, r.status_code + + # Change client's key only, new token shouldn't pass validation. + issuer.refresh_keys(kid=jwt_kid, send_update=False) + token = issuer.issue_jwt(jwt_kid) + r = c.get( + "/app/jwt", + headers=make_bearer_header(token), + ) + assert r.status_code == HTTPStatus.UNAUTHORIZED, r.status_code + + network.consortium.remove_jwt_issuer(primary, issuer.name) + return network + + @reqs.description("JWT authentication as by MSFT Entra (single tenant)") def test_jwt_auth_msft_single_tenant(network, args): """For a specific tenant, only tokens with this issuer+tenant can auth.""" @@ -405,7 +436,7 @@ def test_jwt_auth_msft_single_tenant(network, args): "https://login.microsoftonline.com/9188050d-6c67-4c5b-b112-36a304b66da/v2.0" ) - issuer = infra.jwt_issuer.JwtIssuer(name="https://login.microsoftonline.com") + issuer = JwtIssuer(name="https://login.microsoftonline.com") jwt_kid = "my_key_id" set_issuer_with_a_key(primary, network, issuer, jwt_kid, ISSUER_TENANT) @@ -443,7 +474,7 @@ def test_jwt_auth_msft_multitenancy(network, args): ANOTHER_TENANT_ID = "deadbeef-6c67-4c5b-b112-36a304b66da" ISSUER_ANOTHER = f"https://login.microsoftonline.com/{ANOTHER_TENANT_ID}/v2.0" - issuer = infra.jwt_issuer.JwtIssuer(name="https://login.microsoftonline.com") + issuer = JwtIssuer(name="https://login.microsoftonline.com") jwt_kid_1 = "my_key_id_1" jwt_kid_2 = "my_key_id_2" @@ -520,8 +551,8 @@ def test_jwt_auth_msft_same_kids_different_issuers(network, args): ANOTHER_TENANT_ID = "deadbeef-6c67-4c5b-b112-36a304b66da" ISSUER_ANOTHER = f"https://login.microsoftonline.com/{ANOTHER_TENANT_ID}/v2.0" - issuer = infra.jwt_issuer.JwtIssuer(name=ISSUER_TENANT) - another = infra.jwt_issuer.JwtIssuer(name=ISSUER_ANOTHER) + issuer = JwtIssuer(name=ISSUER_TENANT) + another = JwtIssuer(name=ISSUER_ANOTHER) # Immitate same key sharing another.cert_pem, another.key_priv_pem = issuer.cert_pem, issuer.key_priv_pem @@ -582,7 +613,7 @@ def test_jwt_auth_msft_same_kids_overwrite_constraint(network, args): ANOTHER_TENANT_ID = "deadbeef-6c67-4c5b-b112-36a304b66da" ISSUER_ANOTHER = f"https://login.microsoftonline.com/{ANOTHER_TENANT_ID}/v2.0" - issuer = infra.jwt_issuer.JwtIssuer(name=ISSUER_TENANT) + issuer = JwtIssuer(name=ISSUER_TENANT) jwt_kid = "my_key_id" set_issuer_with_a_key(primary, network, issuer, jwt_kid, COMMNON_ISSUER) @@ -708,6 +739,7 @@ def run_authn(args): network.start_and_open(args) network = test_cert_auth(network, args) network = test_jwt_auth(network, args) + network = test_jwt_auth_raw_key(network, args) network = test_jwt_auth_msft_single_tenant(network, args) network = test_jwt_auth_msft_multitenancy(network, args) network = test_jwt_auth_msft_same_kids_different_issuers(network, args) diff --git a/tests/jwt_test.py b/tests/jwt_test.py index ef0e861fd3f6..bca66e1e9b96 100644 --- a/tests/jwt_test.py +++ b/tests/jwt_test.py @@ -4,6 +4,7 @@ import tempfile import json import time +import base64 import infra.network import infra.path import infra.proc @@ -12,33 +13,16 @@ import infra.e2e_args import infra.proposal import suite.test_requirements as reqs -import infra.jwt_issuer +from infra.jwt_issuer import get_jwt_issuers, get_jwt_keys from infra.runner import ConcurrentRunner import ca_certs import ccf.ledger from ccf.tx_id import TxID import infra.clients -import http from loguru import logger as LOG -def get_jwt_issuers(args, node): - with node.api_versioned_client(api_version=args.gov_api_version) as c: - r = c.get("/gov/service/jwk") - assert r.status_code == http.HTTPStatus.OK, r - body = r.body.json() - return body["issuers"] - - -def get_jwt_keys(args, node): - with node.api_versioned_client(api_version=args.gov_api_version) as c: - r = c.get("/gov/service/jwk") - assert r.status_code == http.HTTPStatus.OK, r - body = r.body.json() - return body["keys"] - - def set_issuer_with_keys(network, primary, issuer, kids): with tempfile.NamedTemporaryFile(prefix="ccf", mode="w+") as metadata_fp: json.dump({"issuer": issuer.name}, metadata_fp) @@ -213,7 +197,8 @@ def test_jwt_endpoint(network, args): assert kid in service_keys, service_keys assert service_keys[kid][0]["issuer"] == issuer.name assert service_keys[kid][0]["constraint"] == issuer.name - assert service_keys[kid][0]["certificate"] == issuer.cert_pem + assert service_keys[kid][0]["publicKey"] == issuer.key_pub_pem + assert "certificate" not in service_keys[kid][0] @reqs.description("JWT without key policy") @@ -246,7 +231,12 @@ def test_jwt_without_key_policy(network, args): LOG.info("Try to add a public key instead of a certificate") with tempfile.NamedTemporaryFile(prefix="ccf", mode="w+") as jwks_fp: - json.dump(issuer.create_jwks(kid, test_invalid_is_key=True), jwks_fp) + jwks = issuer.create_jwks(kid) + der_b64 = base64.b64encode( + infra.crypto.pub_key_pem_to_der(issuer.key_pub_pem) + ).decode("ascii") + jwks["keys"][0]["x5c"] = [der_b64] + json.dump(jwks, jwks_fp) jwks_fp.flush() try: network.consortium.set_jwt_public_signing_keys( @@ -266,9 +256,9 @@ def test_jwt_without_key_policy(network, args): ) keys = get_jwt_keys(args, primary) - stored_cert = keys[kid][0]["certificate"] + stored_key = keys[kid][0]["publicKey"] - assert stored_cert == issuer.cert_pem, "input cert is not equal to stored cert" + assert stored_key == issuer.key_pub_pem, "input key is not equal to stored key" LOG.info("Remove JWT issuer") network.consortium.remove_jwt_issuer(primary, issuer.name) @@ -285,9 +275,9 @@ def test_jwt_without_key_policy(network, args): network.consortium.set_jwt_issuer(primary, metadata_fp.name) keys = get_jwt_keys(args, primary) - stored_cert = keys[kid][0]["certificate"] + stored_key = keys[kid][0]["publicKey"] - assert stored_cert == issuer.cert_pem, "input cert is not equal to stored cert" + assert stored_key == issuer.key_pub_pem, "input key is not equal to stored key" return network @@ -320,18 +310,18 @@ def make_attested_cert(network, args): return pem -def check_kv_jwt_key_matches(args, network, kid, cert_pem): +def check_kv_jwt_key_matches(args, network, kid, key_pem): primary, _ = network.find_nodes() latest_jwt_signing_keys = get_jwt_keys(args, primary) - if cert_pem is None: + if key_pem is None: assert kid not in latest_jwt_signing_keys else: # Necessary to get an AssertionError if the key is not found yet, # when used from with_timeout() assert kid in latest_jwt_signing_keys - stored_cert = latest_jwt_signing_keys[kid][0]["certificate"] - assert stored_cert == cert_pem, "input cert is not equal to stored cert" + stored_key = latest_jwt_signing_keys[kid][0]["publicKey"] + assert stored_key == key_pem, "input cert is not equal to stored cert" def check_kv_jwt_keys_not_empty(args, network, issuer): @@ -405,7 +395,9 @@ def test_jwt_key_auto_refresh(network, args): LOG.info("Check that keys got refreshed") # Note: refresh interval is set to 1s, see network args below. with_timeout( - lambda: check_kv_jwt_key_matches(args, network, kid, issuer.cert_pem), + lambda: check_kv_jwt_key_matches( + args, network, kid, issuer.key_pub_pem + ), timeout=5, ) @@ -438,7 +430,7 @@ def check_has_failures(): with_timeout( lambda: check_kv_jwt_key_matches(args, network, kid, None), timeout=5 ) - check_kv_jwt_key_matches(args, network, kid2, issuer.cert_pem) + check_kv_jwt_key_matches(args, network, kid2, issuer.key_pub_pem) return network @@ -482,7 +474,9 @@ def test_jwt_key_auto_refresh_entries(network, args): LOG.info("Check that keys got refreshed") # Note: refresh interval is set to 1s, see network args below. with_timeout( - lambda: check_kv_jwt_key_matches(args, network, kid, issuer.cert_pem), + lambda: check_kv_jwt_key_matches( + args, network, kid, issuer.key_pub_pem + ), timeout=5, ) @@ -512,8 +506,10 @@ def test_jwt_key_auto_refresh_entries(network, args): for tx in chunk: txid = TxID(tx.gcm_header.view, tx.gcm_header.seqno) tables = tx.get_public_domain().get_tables() - if "public:ccf.gov.jwt.public_signing_keys_metadata" in tables: - pub_keys = tables["public:ccf.gov.jwt.public_signing_keys_metadata"] + if "public:ccf.gov.jwt.public_signing_keys_metadata_v2" in tables: + pub_keys = tables[ + "public:ccf.gov.jwt.public_signing_keys_metadata_v2" + ] if kid.encode() in pub_keys: if last_key_refresh is None: LOG.info(f"Refresh found for kid: {kid} at {txid}") @@ -567,7 +563,7 @@ def test_jwt_key_initial_refresh(network, args): # Auto-refresh interval has been set to a large value so that it doesn't happen within the timeout. # This is testing the one-off refresh after adding a new issuer. with_timeout( - lambda: check_kv_jwt_key_matches(args, network, kid, issuer.cert_pem), + lambda: check_kv_jwt_key_matches(args, network, kid, issuer.key_pub_pem), timeout=5, )