diff --git a/spdmlib/src/crypto/x509v3.rs b/spdmlib/src/crypto/x509v3.rs index 0f0d8de..c8ccace 100644 --- a/spdmlib/src/crypto/x509v3.rs +++ b/spdmlib/src/crypto/x509v3.rs @@ -26,6 +26,8 @@ const ASN1_TAG_SEQUENCE: u8 = const ASN1_TAG_EXPLICIT_EXTENSION: u8 = 0xA3; const ASN1_TAG_EXTN_VALUE: u8 = 0x04; const ASN1_LENGTH_MULTI_OCTET_MASK: u8 = 0x80; +const ASN1_LENGTH_ONE_OCTET: u8 = 0x81; +const ASN1_LENGTH_TWO_OCTET: u8 = 0x82; const X509V3_VERSION: u8 = 2; const OID_RSA_SHA256RSA: &[u8] = &[0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0bu8]; @@ -263,19 +265,43 @@ fn check_length(data: &[u8]) -> SpdmResult<(usize, usize)> { let len = data.len(); if len < 1 { Err(SPDM_STATUS_VERIF_FAIL) - } else if data[0] & ASN1_LENGTH_MULTI_OCTET_MASK == 0 { - Ok((data[0] as usize, 1)) } else { - let length_count = data[0] - ASN1_LENGTH_MULTI_OCTET_MASK; - if len < (length_count as usize + 1) || length_count == 0 || length_count > 8 { - Err(SPDM_STATUS_VERIF_FAIL) - } else { - let mut length = [0u8; 8]; - for (i, b) in data[1..length_count as usize + 1].iter().rev().enumerate() { - length[i] = *b; + let length_byte0 = data[0]; + let (length, byte_comsumed) = match length_byte0 { + n if (n & ASN1_LENGTH_MULTI_OCTET_MASK) == 0 => (n as usize, 1), + ASN1_LENGTH_ONE_OCTET => { + if len < 2 { + return Err(SPDM_STATUS_VERIF_FAIL); + } else { + let second_byte = data[1] as usize; + if second_byte < 0x80 { + return Err(SPDM_STATUS_VERIF_FAIL); // Not the canonical encoding. + } else { + (second_byte, 2) + } + } } - Ok((usize::from_le_bytes(length), length_count as usize + 1)) - } + ASN1_LENGTH_TWO_OCTET => { + if len < 3 { + return Err(SPDM_STATUS_VERIF_FAIL); + } else { + let second_byte = data[1] as usize; + let third_byte = data[2] as usize; + + let combined = (second_byte << 8) | third_byte; + if combined < 0x100 { + return Err(SPDM_STATUS_VERIF_FAIL); // Not the canonical encoding. + } else { + (combined, 3) + } + } + } + _ => { + return Err(SPDM_STATUS_VERIF_FAIL); // spdm-rs don't support longer lengths. + } + }; + + Ok((length, byte_comsumed)) } } @@ -998,13 +1024,13 @@ mod tests { #[test] fn test_case0_check_length() { let l1 = [0x03]; - let l2 = [0x81, 0x12]; + let l2 = [0x81, 0x82]; let l3 = [0x82, 0x01, 0xD7]; let l1_wrong = [0x80]; let l2_wrong = [0x81]; let l3_wrong = [0x82, 0x01]; assert_eq!(check_length(&l1), Ok((3, 1))); - assert_eq!(check_length(&l2), Ok((0x12, 2))); + assert_eq!(check_length(&l2), Ok((0x82, 2))); assert_eq!(check_length(&l3), Ok((0x1D7, 3))); assert_eq!(check_length(&l1_wrong), Err(SPDM_STATUS_VERIF_FAIL)); assert_eq!(check_length(&l2_wrong), Err(SPDM_STATUS_VERIF_FAIL));