Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor CertValidity to hold ArrayVec instead of &str #318

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dpe/src/commands/certify_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ impl CommandExecution for CertifyKeyCmd {
&subject_name,
&pub_key,
&measurements,
cert_validity,
&cert_validity,
)?;
if bytes_written > MAX_CERT_SIZE {
return Err(DpeErrorCode::InternalError);
Expand Down
46 changes: 31 additions & 15 deletions dpe/src/x509.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ impl CertWriter<'_> {
}

/// If `tagged`, include the tag and size fields
fn get_validity_size(validity: CertValidity<'_>, tagged: bool) -> Result<usize, DpeErrorCode> {
let len = Self::get_bytes_size(validity.not_before.as_bytes(), true)?
+ Self::get_bytes_size(validity.not_after.as_bytes(), true)?;
fn get_validity_size(validity: &CertValidity, tagged: bool) -> Result<usize, DpeErrorCode> {
let len = Self::get_bytes_size(validity.not_before.as_slice(), true)?
+ Self::get_bytes_size(validity.not_after.as_slice(), true)?;
Self::get_structure_size(len, tagged)
}

Expand Down Expand Up @@ -470,7 +470,7 @@ impl CertWriter<'_> {
subject_name: &Name,
pubkey: &EcdsaPub,
measurements: &MeasurementData,
validity: CertValidity<'_>,
validity: &CertValidity,
tagged: bool,
) -> Result<usize, DpeErrorCode> {
let tbs_size = Self::get_version_size(/*tagged=*/ true)?
Expand Down Expand Up @@ -886,19 +886,19 @@ impl CertWriter<'_> {
}

// Encode ASN.1 Validity according to Platform
fn encode_validity(&mut self, validity: CertValidity<'_>) -> Result<usize, DpeErrorCode> {
fn encode_validity(&mut self, validity: &CertValidity) -> Result<usize, DpeErrorCode> {
let seq_size = Self::get_validity_size(validity, /*tagged=*/ false)?;

let mut bytes_written = self.encode_tag_field(Self::SEQUENCE_TAG)?;
bytes_written += self.encode_size_field(seq_size)?;

bytes_written += self.encode_tag_field(Self::GENERALIZE_TIME_TAG)?;
bytes_written += self.encode_size_field(validity.not_before.len())?;
bytes_written += self.encode_bytes(validity.not_before.as_bytes())?;
bytes_written += self.encode_bytes(validity.not_before.as_slice())?;

bytes_written += self.encode_tag_field(Self::GENERALIZE_TIME_TAG)?;
bytes_written += self.encode_size_field(validity.not_after.len())?;
bytes_written += self.encode_bytes(validity.not_after.as_bytes())?;
bytes_written += self.encode_bytes(validity.not_after.as_slice())?;

Ok(bytes_written)
}
Expand Down Expand Up @@ -1647,7 +1647,7 @@ impl CertWriter<'_> {
subject_name: &Name,
pubkey: &EcdsaPub,
measurements: &MeasurementData,
validity: CertValidity<'_>,
validity: &CertValidity,
) -> Result<usize, DpeErrorCode> {
let tbs_size = Self::get_tbs_size(
serial_number,
Expand Down Expand Up @@ -1840,7 +1840,7 @@ mod tests {
use crate::x509::{CertWriter, DirectoryString, MeasurementData, Name};
use crate::DPE_PROFILE;
use crypto::{CryptoBuf, EcdsaPub, EcdsaSig};
use platform::CertValidity;
use platform::{ArrayVec, CertValidity};
use std::str;
use x509_parser::certificate::X509CertificateParser;
use x509_parser::nom::Parser;
Expand Down Expand Up @@ -2086,9 +2086,17 @@ mod tests {
supports_recursive: true,
};

let mut not_before = ArrayVec::new();
not_before
.try_extend_from_slice("20230227000000Z".as_bytes())
.unwrap();
let mut not_after = ArrayVec::new();
not_after
.try_extend_from_slice("99991231235959Z".as_bytes())
.unwrap();
let validity = CertValidity {
not_before: "20230227000000Z",
not_after: "99991231235959Z",
not_before,
not_after,
};

let bytes_written = w
Expand All @@ -2098,7 +2106,7 @@ mod tests {
&test_subject_name,
&test_pub,
&measurements,
validity,
&validity,
)
.unwrap();

Expand Down Expand Up @@ -2152,9 +2160,17 @@ mod tests {
supports_recursive: true,
};

let mut not_before = ArrayVec::new();
not_before
.try_extend_from_slice("20230227000000Z".as_bytes())
.unwrap();
let mut not_after = ArrayVec::new();
not_after
.try_extend_from_slice("99991231235959Z".as_bytes())
.unwrap();
let validity = CertValidity {
not_before: "20230227000000Z",
not_after: "99991231235959Z",
not_before,
not_after,
};

let mut tbs_writer = CertWriter::new(cert_buf, true);
Expand All @@ -2165,7 +2181,7 @@ mod tests {
&TEST_SUBJECT_NAME,
&test_pub,
&measurements,
validity,
&validity,
)
.unwrap();

Expand Down
14 changes: 11 additions & 3 deletions platform/src/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,18 @@ impl Platform for DefaultPlatform {
Ok(())
}

fn get_cert_validity<'a>(&mut self) -> Result<CertValidity<'a>, PlatformError> {
fn get_cert_validity(&mut self) -> Result<CertValidity, PlatformError> {
let mut not_before_vec = ArrayVec::new();
not_before_vec
.try_extend_from_slice(NOT_BEFORE.as_bytes())
.map_err(|_| PlatformError::CertValidityError(0))?;
let mut not_after_vec = ArrayVec::new();
not_after_vec
.try_extend_from_slice(NOT_AFTER.as_bytes())
.map_err(|_| PlatformError::CertValidityError(0))?;
Ok(CertValidity {
not_before: NOT_BEFORE,
not_after: NOT_AFTER,
not_before: not_before_vec,
not_after: not_after_vec,
})
}
}
15 changes: 9 additions & 6 deletions platform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ pub const MAX_CHUNK_SIZE: usize = 2048;
pub const MAX_ISSUER_NAME_SIZE: usize = 128;
pub const MAX_SN_SIZE: usize = 20;
pub const MAX_SKI_SIZE: usize = 20;
pub const MAX_VALIDITY_SIZE: usize = 24;

#[allow(variant_size_differences)]
#[derive(Debug, PartialEq, Eq)]
pub enum SignerIdentifier {
IssuerAndSerialNumber {
issuer_name: ArrayVec<u8, { MAX_ISSUER_NAME_SIZE }>,
Expand All @@ -29,10 +30,10 @@ pub enum SignerIdentifier {
SubjectKeyIdentifier(ArrayVec<u8, { MAX_SKI_SIZE }>),
}

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct CertValidity<'a> {
pub not_before: &'a str,
pub not_after: &'a str,
#[derive(Debug, PartialEq, Eq)]
pub struct CertValidity {
pub not_before: ArrayVec<u8, { MAX_VALIDITY_SIZE }>,
pub not_after: ArrayVec<u8, { MAX_VALIDITY_SIZE }>,
}

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
Expand All @@ -44,6 +45,7 @@ pub enum PlatformError {
PrintError(u32) = 0x4,
SerialNumberError(u32) = 0x5,
SubjectKeyIdentifierError(u32) = 0x6,
CertValidityError(u32) = 0x7,
}

impl PlatformError {
Expand All @@ -62,6 +64,7 @@ impl PlatformError {
PlatformError::PrintError(code) => Some(*code),
PlatformError::SerialNumberError(code) => Some(*code),
PlatformError::SubjectKeyIdentifierError(code) => Some(*code),
PlatformError::CertValidityError(code) => Some(*code),
}
}
}
Expand Down Expand Up @@ -111,5 +114,5 @@ pub trait Platform {
/// in the yyyyMMddHHmmss format followed by a timezone.
///
/// Example: 99991231235959Z is December 31st, 9999 23:59:59 UTC
fn get_cert_validity<'a>(&mut self) -> Result<CertValidity<'a>, PlatformError>;
fn get_cert_validity(&mut self) -> Result<CertValidity, PlatformError>;
}
Loading