From e33e311f8f5ccba0e87aa5bafd5ca087404cead4 Mon Sep 17 00:00:00 2001 From: Sree Revoori Date: Wed, 21 Feb 2024 18:01:07 +0000 Subject: [PATCH] Refactor CertValidity to hold ArrayVec instead of &str This will allow platforms which use non-const cert-validities to return the data correctly, since we cannot copy out a &str from a function because it lives on the stack. --- dpe/src/commands/certify_key.rs | 2 +- dpe/src/x509.rs | 46 ++++++++++++++++++++++----------- platform/src/default.rs | 14 +++++++--- platform/src/lib.rs | 15 ++++++----- 4 files changed, 52 insertions(+), 25 deletions(-) diff --git a/dpe/src/commands/certify_key.rs b/dpe/src/commands/certify_key.rs index bbb35fe6..ba5b4180 100644 --- a/dpe/src/commands/certify_key.rs +++ b/dpe/src/commands/certify_key.rs @@ -150,7 +150,7 @@ impl CommandExecution for CertifyKeyCmd { &subject_name, &pub_key, &measurements, - cert_validity, + &cert_validity, )?; if bytes_written > MAX_CERT_SIZE { return Err(DpeErrorCode::InternalError); diff --git a/dpe/src/x509.rs b/dpe/src/x509.rs index 7d4b0c1c..02e02344 100644 --- a/dpe/src/x509.rs +++ b/dpe/src/x509.rs @@ -258,9 +258,9 @@ impl CertWriter<'_> { } /// If `tagged`, include the tag and size fields - fn get_validity_size(validity: CertValidity<'_>, tagged: bool) -> Result { - let len = Self::get_bytes_size(validity.not_before.as_bytes(), true)? - + Self::get_bytes_size(validity.not_after.as_bytes(), true)?; + fn get_validity_size(validity: &CertValidity, tagged: bool) -> Result { + let len = Self::get_bytes_size(validity.not_before.as_slice(), true)? + + Self::get_bytes_size(validity.not_after.as_slice(), true)?; Self::get_structure_size(len, tagged) } @@ -470,7 +470,7 @@ impl CertWriter<'_> { subject_name: &Name, pubkey: &EcdsaPub, measurements: &MeasurementData, - validity: CertValidity<'_>, + validity: &CertValidity, tagged: bool, ) -> Result { let tbs_size = Self::get_version_size(/*tagged=*/ true)? @@ -886,7 +886,7 @@ impl CertWriter<'_> { } // Encode ASN.1 Validity according to Platform - fn encode_validity(&mut self, validity: CertValidity<'_>) -> Result { + fn encode_validity(&mut self, validity: &CertValidity) -> Result { let seq_size = Self::get_validity_size(validity, /*tagged=*/ false)?; let mut bytes_written = self.encode_tag_field(Self::SEQUENCE_TAG)?; @@ -894,11 +894,11 @@ impl CertWriter<'_> { bytes_written += self.encode_tag_field(Self::GENERALIZE_TIME_TAG)?; bytes_written += self.encode_size_field(validity.not_before.len())?; - bytes_written += self.encode_bytes(validity.not_before.as_bytes())?; + bytes_written += self.encode_bytes(validity.not_before.as_slice())?; bytes_written += self.encode_tag_field(Self::GENERALIZE_TIME_TAG)?; bytes_written += self.encode_size_field(validity.not_after.len())?; - bytes_written += self.encode_bytes(validity.not_after.as_bytes())?; + bytes_written += self.encode_bytes(validity.not_after.as_slice())?; Ok(bytes_written) } @@ -1647,7 +1647,7 @@ impl CertWriter<'_> { subject_name: &Name, pubkey: &EcdsaPub, measurements: &MeasurementData, - validity: CertValidity<'_>, + validity: &CertValidity, ) -> Result { let tbs_size = Self::get_tbs_size( serial_number, @@ -1840,7 +1840,7 @@ mod tests { use crate::x509::{CertWriter, DirectoryString, MeasurementData, Name}; use crate::DPE_PROFILE; use crypto::{CryptoBuf, EcdsaPub, EcdsaSig}; - use platform::CertValidity; + use platform::{ArrayVec, CertValidity}; use std::str; use x509_parser::certificate::X509CertificateParser; use x509_parser::nom::Parser; @@ -2086,9 +2086,17 @@ mod tests { supports_recursive: true, }; + let mut not_before = ArrayVec::new(); + not_before + .try_extend_from_slice("20230227000000Z".as_bytes()) + .unwrap(); + let mut not_after = ArrayVec::new(); + not_after + .try_extend_from_slice("99991231235959Z".as_bytes()) + .unwrap(); let validity = CertValidity { - not_before: "20230227000000Z", - not_after: "99991231235959Z", + not_before, + not_after, }; let bytes_written = w @@ -2098,7 +2106,7 @@ mod tests { &test_subject_name, &test_pub, &measurements, - validity, + &validity, ) .unwrap(); @@ -2152,9 +2160,17 @@ mod tests { supports_recursive: true, }; + let mut not_before = ArrayVec::new(); + not_before + .try_extend_from_slice("20230227000000Z".as_bytes()) + .unwrap(); + let mut not_after = ArrayVec::new(); + not_after + .try_extend_from_slice("99991231235959Z".as_bytes()) + .unwrap(); let validity = CertValidity { - not_before: "20230227000000Z", - not_after: "99991231235959Z", + not_before, + not_after, }; let mut tbs_writer = CertWriter::new(cert_buf, true); @@ -2165,7 +2181,7 @@ mod tests { &TEST_SUBJECT_NAME, &test_pub, &measurements, - validity, + &validity, ) .unwrap(); diff --git a/platform/src/default.rs b/platform/src/default.rs index a92bbcbd..b28818f6 100644 --- a/platform/src/default.rs +++ b/platform/src/default.rs @@ -152,10 +152,18 @@ impl Platform for DefaultPlatform { Ok(()) } - fn get_cert_validity<'a>(&mut self) -> Result, PlatformError> { + fn get_cert_validity(&mut self) -> Result { + let mut not_before_vec = ArrayVec::new(); + not_before_vec + .try_extend_from_slice(NOT_BEFORE.as_bytes()) + .map_err(|_| PlatformError::CertValidityError(0))?; + let mut not_after_vec = ArrayVec::new(); + not_after_vec + .try_extend_from_slice(NOT_AFTER.as_bytes()) + .map_err(|_| PlatformError::CertValidityError(0))?; Ok(CertValidity { - not_before: NOT_BEFORE, - not_after: NOT_AFTER, + not_before: not_before_vec, + not_after: not_after_vec, }) } } diff --git a/platform/src/lib.rs b/platform/src/lib.rs index 0a1bcf58..922b56dd 100644 --- a/platform/src/lib.rs +++ b/platform/src/lib.rs @@ -19,8 +19,9 @@ pub const MAX_CHUNK_SIZE: usize = 2048; pub const MAX_ISSUER_NAME_SIZE: usize = 128; pub const MAX_SN_SIZE: usize = 20; pub const MAX_SKI_SIZE: usize = 20; +pub const MAX_VALIDITY_SIZE: usize = 24; -#[allow(variant_size_differences)] +#[derive(Debug, PartialEq, Eq)] pub enum SignerIdentifier { IssuerAndSerialNumber { issuer_name: ArrayVec, @@ -29,10 +30,10 @@ pub enum SignerIdentifier { SubjectKeyIdentifier(ArrayVec), } -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct CertValidity<'a> { - pub not_before: &'a str, - pub not_after: &'a str, +#[derive(Debug, PartialEq, Eq)] +pub struct CertValidity { + pub not_before: ArrayVec, + pub not_after: ArrayVec, } #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -44,6 +45,7 @@ pub enum PlatformError { PrintError(u32) = 0x4, SerialNumberError(u32) = 0x5, SubjectKeyIdentifierError(u32) = 0x6, + CertValidityError(u32) = 0x7, } impl PlatformError { @@ -62,6 +64,7 @@ impl PlatformError { PlatformError::PrintError(code) => Some(*code), PlatformError::SerialNumberError(code) => Some(*code), PlatformError::SubjectKeyIdentifierError(code) => Some(*code), + PlatformError::CertValidityError(code) => Some(*code), } } } @@ -111,5 +114,5 @@ pub trait Platform { /// in the yyyyMMddHHmmss format followed by a timezone. /// /// Example: 99991231235959Z is December 31st, 9999 23:59:59 UTC - fn get_cert_validity<'a>(&mut self) -> Result, PlatformError>; + fn get_cert_validity(&mut self) -> Result; }