Skip to content

Commit

Permalink
Refactor certificate file handling for enhanced modularity and extens…
Browse files Browse the repository at this point in the history
…ibility.

- Introduced `CertFileFactory::createReader` to simplify certificate and key file reading.
- Replaced explicit PKCS#12 and PEM file handling functions with a flexible factory-based approach.
- Removed redundant PKCS#12 and PEM certification functions in favor of common interface.
- Migrated cross-cutting concerns such as error checking and private key handling within factory methods.
- Added required inclusions and build configurations for the new modular setup in Makefile and source files.
  • Loading branch information
george-mcintyre committed Dec 3, 2024
1 parent f527c07 commit 46351c4
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 199 deletions.
20 changes: 16 additions & 4 deletions certs/certfilefactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,15 @@ class CertFileFactory {
const std::shared_ptr<KeyPair>& key_pair = nullptr, X509* cert_ptr = nullptr,
STACK_OF(X509) * certs_ptr = nullptr, const std::string& usage = "certificate",
const std::string& pem_string = "",
bool certs_only = false); // Move implementation to cpp file
bool certs_only = false);

static std::unique_ptr<CertFileFactory> createReader(const std::string& filename, const std::string& password="", const std::string& key_filename="", const std::string& key_password="") {
auto cert_file_factory = create(filename, password);
if ( !key_filename.empty() )
cert_file_factory->key_file_ = create(key_filename, key_password);

return cert_file_factory;
}

virtual ~CertFileFactory() = default;

Expand Down Expand Up @@ -113,16 +121,20 @@ class CertFileFactory {
const std::string& pem_string = "", bool certs_only = false)
: filename_(filename), cert_ptr_(cert_ptr), certs_ptr_(certs_ptr), usage_(usage), pem_string_(pem_string), certs_only_(certs_only) {}

const std::string filename_;
const std::string filename_{};
X509* cert_ptr_{nullptr};
STACK_OF(X509) * certs_ptr_ { nullptr };
const std::string usage_;
const std::string pem_string_;
const std::string usage_{};
const std::string pem_string_{};
const bool certs_only_{false};
std::unique_ptr<CertFileFactory> key_file_;

static void backupFileIfExists(const std::string& filename);
static void chainFromRootCertPtr(STACK_OF(X509) * &chain, X509* root_cert_ptr);
static std::string getExtension(const std::string& filename) { return filename.substr(filename.find_last_of(".") + 1); };

private:
std::string password_{};
};

} // namespace certs
Expand Down
29 changes: 18 additions & 11 deletions certs/p12filefactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,33 +79,40 @@ std::shared_ptr<KeyPair> P12FileFactory::getKeyFromFile() {
* The P12 file is parsed to extract the certificate and chain.
* If it contains a private key too then it is read and returned in the CertData object.
*
* @param filename the path to the P12 file
* @param password the optional password for the file. If blank then the password is not used.
* @return a CertData object
* @throw std::runtime_error if the file cannot be opened or parsed
*/
CertData P12FileFactory::getCertDataFromFile(std::string filename, std::string password) {
CertData P12FileFactory::getCertDataFromFile() {
ossl_ptr<X509> cert;
STACK_OF(X509) *chain_ptr = nullptr;
std::shared_ptr<KeyPair> key_pair;
ossl_ptr<EVP_PKEY> pkey;

auto file(fopen(filename.c_str(), "rb"));
// Get cert from configured file
auto file(fopen(filename_.c_str(), "rb"));
if (!file) {
throw std::runtime_error(SB() << "Error opening certificate file for reading binary contents: \"" << filename << "\"");
throw std::runtime_error(SB() << "Error opening certificate file for reading binary contents: \"" << filename_ << "\"");
}
file_ptr fp(file);

ossl_ptr<PKCS12> p12(d2i_PKCS12_fp(fp.get(), NULL));
if (!p12) {
throw std::runtime_error(SB() << "Error opening certificate file as a PKCS#12 object: " << filename);
throw std::runtime_error(SB() << "Error opening certificate file as a PKCS#12 object: " << filename_);
}

// Try to get private key and certificate but if we can't then try only certs
ossl_ptr<EVP_PKEY> pkey;
if (!PKCS12_parse(p12.get(), password.c_str(), pkey.acquire(), cert.acquire(), &chain_ptr)) {
if (!PKCS12_parse(p12.get(), password.c_str(), nullptr, cert.acquire(), &chain_ptr)) {
throw std::runtime_error(SB() << "Error parsing certificate file: " << filename);
if (!PKCS12_parse(p12.get(), password_.c_str(), pkey.acquire(), cert.acquire(), &chain_ptr)) {
if (!PKCS12_parse(p12.get(), password_.c_str(), nullptr, cert.acquire(), &chain_ptr)) {
throw std::runtime_error(SB() << "Error parsing certificate file: " << filename_);
}
}

// Try to get key from file if we didn't already get it and it is configured
if (!pkey && key_file_) {
key_pair = key_file_->getKeyFromFile();
pkey = std::move(key_pair->pkey);
}

ossl_shared_ptr<STACK_OF(X509)> chain;
if (chain_ptr)
chain = ossl_shared_ptr<STACK_OF(X509)>(chain_ptr);
Expand Down Expand Up @@ -139,7 +146,7 @@ ossl_ptr<PKCS12> P12FileFactory::pemStringToP12(std::string password, EVP_PKEY *
}

// Get first Cert as Certificate
ossl_ptr<X509> cert(PEM_read_bio_X509_AUX(bio.get(), NULL, NULL, (void *)password.c_str()));
ossl_ptr<X509> cert(PEM_read_bio_X509_AUX(bio.get(), NULL, NULL, (void *)password.c_str()), false);
if (!cert) {
throw std::runtime_error("Unable to read certificate");
}
Expand Down
3 changes: 1 addition & 2 deletions certs/p12filefactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,12 @@ class P12FileFactory : public CertFileFactory {
P12FileFactory(const std::string &filename, const std::string &password, const std::shared_ptr<KeyPair> &key_pair, PKCS12 *p12_ptr, bool certs_only = false)
: CertFileFactory(filename, nullptr, nullptr, "certificate", "", certs_only), password_(password), key_pair_(key_pair), p12_ptr_(p12_ptr) {}

static CertData getCertDataFromFile(std::string filename, std::string password);
void writePKCS12File();

void writeCertFile() override { writePKCS12File(); }

CertData getCertDataFromFile() override ;
std::shared_ptr<KeyPair> getKeyFromFile() override;
CertData getCertDataFromFile() override { return getCertDataFromFile(filename_, password_); }

private:
const std::string password_{};
Expand Down
117 changes: 62 additions & 55 deletions certs/pemfilefactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,46 +14,6 @@ namespace certs {

DEFINE_LOGGER(pemcerts, "pvxs.certs.pem");

/**
* @brief Get the certificate data from a PEM file
*
* @param filename the path to the PEM file
* @return a CertData object containing the certificate and the chain
* @throw std::runtime_error if the file cannot be opened or read
*/
CertData PEMFileFactory::getCertDataFromFile(const std::string& filename) {
file_ptr fp(fopen(filename.c_str(), "r"));
if (!fp) {
throw std::runtime_error(SB() << "Error opening certificate file: " << filename);
}

// Read the first certificate (main cert)
ossl_ptr<X509> cert(PEM_read_X509(fp.get(), nullptr, nullptr, nullptr));
if (!cert) {
throw std::runtime_error(SB() << "Error reading certificate from file: " << filename);
}

// Read any additional certificates (chain)
ossl_shared_ptr<STACK_OF(X509)> chain(sk_X509_new_null());
if (!chain) {
throw std::runtime_error("Unable to allocate certificate chain");
}

ossl_ptr<X509> ca;
while (X509* ca_ptr = PEM_read_X509(fp.get(), nullptr, nullptr, nullptr)) {
ca = ossl_ptr<X509>(ca_ptr);
if (sk_X509_push(chain.get(), ca.get()) != 1) {
throw std::runtime_error("Failed to add certificate to chain");
}
ca.release();
}

// Clear any end-of-file errors
ERR_clear_error();

return CertData(cert, chain);
}

/**
* @brief Create a root PEM file from a PEM string
*
Expand Down Expand Up @@ -215,6 +175,62 @@ void PEMFileFactory::writePEMFile() {
log_info_printf(pemcerts, "Certificate file created: %s\n", filename_.c_str());
}

/**
* @brief Get the certificate data from a PEM file
*
* @param filename the path to the PEM file
* @return a CertData object containing the certificate and the chain
* @throw std::runtime_error if the file cannot be opened or read
*/
CertData PEMFileFactory::getCertDataFromFile() {
file_ptr fp(fopen(filename_.c_str(), "r"));
if (!fp) {
throw std::runtime_error(SB() << "Error opening certificate file: " << filename_);
}

// Read the first certificate (main cert)
ossl_ptr<X509> cert(PEM_read_X509(fp.get(), nullptr, nullptr, nullptr), false);
if (!cert) {
throw std::runtime_error(SB() << "Error reading certificate from file: " << filename_);
}

// Read any additional certificates (chain)
ossl_shared_ptr<STACK_OF(X509)> chain(sk_X509_new_null());
if (!chain) {
throw std::runtime_error("Unable to allocate certificate chain");
}

ossl_ptr<X509> ca;
while (X509* ca_ptr = PEM_read_X509(fp.get(), nullptr, nullptr, nullptr)) {
ca = ossl_ptr<X509>(ca_ptr);
if (sk_X509_push(chain.get(), ca.get()) != 1) {
throw std::runtime_error("Failed to add certificate to chain");
}
ca.release();
}

// Clear any end-of-file errors
ERR_clear_error();

// Read any private key
std::shared_ptr<KeyPair> key_pair;

// Try to read the private key
try {
ossl_ptr<EVP_PKEY> pkey(PEM_read_PrivateKey(fp.get(), nullptr, nullptr, nullptr), false);

// Try to get key from file if it is configured
if (!pkey && key_file_) {
key_pair = key_file_->getKeyFromFile();
pkey = std::move(key_pair->pkey);
}
if (pkey)
return CertData(cert, chain, std::make_shared<KeyPair>(std::move(pkey)));
} catch (...) {}

return CertData(cert, chain);
}

/**
* @brief Get a key pair from a PEM file
*
Expand All @@ -227,23 +243,14 @@ std::shared_ptr<KeyPair> PEMFileFactory::getKeyFromFile() {
throw std::runtime_error(SB() << "Error opening private key file: \"" << filename_ << "\"");
}

// Read through the file looking for PEM objects
char line[256];
while (fgets(line, sizeof(line), fp.get())) {
if (strstr(line, "-----BEGIN PRIVATE KEY-----") || strstr(line, "-----BEGIN RSA PRIVATE KEY-----") || strstr(line, "-----BEGIN EC PRIVATE KEY-----")) {
// Found a private key header, rewind to start of this PEM block
fseek(fp.get(), -strlen(line), SEEK_CUR);

// Try to read the private key
ossl_ptr<EVP_PKEY> pkey(PEM_read_PrivateKey(fp.get(), nullptr, nullptr, nullptr));
if (pkey) {
return std::make_shared<KeyPair>(std::move(pkey));
}
ERR_clear_error();
}
// Try to read the private key
ossl_ptr<EVP_PKEY> pkey(PEM_read_PrivateKey(fp.get(), nullptr, nullptr, nullptr));
if (!pkey) {
ERR_clear_error();
throw std::runtime_error(SB() << "No private key found in file: " << filename_);
}

throw std::runtime_error(SB() << "No private key found in file: " << filename_);
return std::make_shared<KeyPair>(std::move(pkey));
}

} // namespace certs
Expand Down
3 changes: 1 addition & 2 deletions certs/pemfilefactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@ class PEMFileFactory : public CertFileFactory {

PEMFileFactory(const std::string& filename, const std::string& pem_string, bool certs_only = false) : CertFileFactory(filename, nullptr, nullptr, "certificate", pem_string, certs_only) {}

static CertData getCertDataFromFile(const std::string& filename);
static bool createRootPemFile(const std::string& pemString, bool overwrite = false);

std::shared_ptr<KeyPair> getKeyFromFile() override;
CertData getCertDataFromFile() override { return getCertDataFromFile(filename_); }
CertData getCertDataFromFile() override;

void writeCertFile() override { writePEMFile(); }
void writePEMFile();
Expand Down
1 change: 0 additions & 1 deletion certs/pvacms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
#include "evhelper.h"
#include "openssl.h"
#include "ownedptr.h"
#include "p12filefactory.h"
#include "sqlite3.h"
#include "utilpvt.h"

Expand Down
6 changes: 6 additions & 0 deletions src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ SHRLIB_VERSION = $(PVXS_MAJOR_VERSION).$(PVXS_MINOR_VERSION)

# Access to certs specific headers
USR_CPPFLAGS += -I$(TOP)/certs
SRC_DIRS += $(TOP)/certs

INC += pvxs/client.h
INC += pvxs/config.h
Expand Down Expand Up @@ -115,6 +116,11 @@ LIB_SRCS += udp_collector.cpp
LIB_SRCS += unittest.cpp
LIB_SRCS += util.cpp

LIB_SRCS += certfactory.cpp
LIB_SRCS += certfilefactory.cpp
LIB_SRCS += p12filefactory.cpp
LIB_SRCS += pemfilefactory.cpp

LIB_LIBS += Com

# special case matching configure/RULES_PVXS_MODULE
Expand Down
Loading

0 comments on commit 46351c4

Please sign in to comment.