Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove unchecked unwrap. #101

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions spdmlib/src/common/key_schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,15 @@ impl SpdmKeySchedule {
return None;
}
} else {
let empty_pskhint = SpdmPskHintStruct::default();
secret::psk::handshake_secret_hkdf_expand(
spdm_version,
hash_algo,
psk_hint.unwrap(),
if let Some(hint) = psk_hint {
hint
} else {
&empty_pskhint
},
bin_str1,
)?
};
Expand Down Expand Up @@ -162,10 +167,15 @@ impl SpdmKeySchedule {
return None;
}
} else {
let empty_pskhint = SpdmPskHintStruct::default();
secret::psk::handshake_secret_hkdf_expand(
spdm_version,
hash_algo,
psk_hint.unwrap(),
if let Some(hint) = psk_hint {
hint
} else {
&empty_pskhint
},
bin_str2,
)?
};
Expand Down Expand Up @@ -295,10 +305,15 @@ impl SpdmKeySchedule {
return None;
}
} else {
let empty_pskhint = SpdmPskHintStruct::default();
secret::psk::master_secret_hkdf_expand(
spdm_version,
hash_algo,
psk_hint.unwrap(),
if let Some(hint) = psk_hint {
hint
} else {
&empty_pskhint
},
bin_str3,
)?
};
Expand Down Expand Up @@ -337,10 +352,15 @@ impl SpdmKeySchedule {
return None;
}
} else {
let empty_pskhint = SpdmPskHintStruct::default();
secret::psk::master_secret_hkdf_expand(
spdm_version,
hash_algo,
psk_hint.unwrap(),
if let Some(hint) = psk_hint {
hint
} else {
&empty_pskhint
},
bin_str4,
)?
};
Expand Down Expand Up @@ -378,10 +398,15 @@ impl SpdmKeySchedule {
return None;
}
} else {
let empty_pskhint = SpdmPskHintStruct::default();
secret::psk::master_secret_hkdf_expand(
spdm_version,
hash_algo,
psk_hint.unwrap(),
if let Some(hint) = psk_hint {
hint
} else {
&empty_pskhint
},
bin_str8,
)?
};
Expand Down
27 changes: 19 additions & 8 deletions spdmlib/src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,7 @@ impl SpdmContext {
crypto::cert_operation::get_cert_from_cert_chain(
&cert_chain.data[..(cert_chain.data_size as usize)],
0,
)
.unwrap();
)?;
let root_cert = &cert_chain.data[root_cert_begin..root_cert_end];
if let Some(root_hash) =
crypto::hash::hash_all(self.negotiate_info.base_hash_sel, root_cert)
Expand Down Expand Up @@ -509,7 +508,9 @@ impl SpdmContext {
}

pub fn append_message_k(&mut self, session_id: u32, new_message: &[u8]) -> SpdmResult {
let session = self.get_session_via_id(session_id).unwrap();
let session = self
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

#[cfg(not(feature = "hashed-transcript-data"))]
{
Expand Down Expand Up @@ -574,7 +575,9 @@ impl SpdmContext {
session_id: u32,
new_message: &[u8],
) -> SpdmResult {
let session = self.get_session_via_id(session_id).unwrap();
let session = self
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
let _ = session
.runtime_info
.message_f
Expand All @@ -590,7 +593,9 @@ impl SpdmContext {
session_id: u32,
new_message: &[u8],
) -> SpdmResult {
let session = self.get_immutable_session_via_id(session_id).unwrap();
let session = self
.get_immutable_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
if session.runtime_info.digest_context_th.is_none() {
return Err(SPDM_STATUS_INVALID_STATE_LOCAL);
}
Expand Down Expand Up @@ -631,18 +636,24 @@ impl SpdmContext {
};

if let Some(mut_cert_digest) = mut_cert_digest {
let session = self.get_session_via_id(session_id).unwrap();
let session = self
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

crypto::hash::hash_ctx_update(
session.runtime_info.digest_context_th.as_mut().unwrap(),
&mut_cert_digest.data[..mut_cert_digest.data_size as usize],
)?;
}
let session = self.get_session_via_id(session_id).unwrap();
let session = self
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
session.runtime_info.message_f_initialized = true;
}

let session = self.get_session_via_id(session_id).unwrap();
let session = self
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
crypto::hash::hash_ctx_update(
session.runtime_info.digest_context_th.as_mut().unwrap(),
new_message,
Expand Down
2 changes: 1 addition & 1 deletion spdmlib/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ impl Codec for SpdmStatus {
let mut sc = 0u32;
sc += (((self.severity as u8) & 0x0F) as u32) << 28;
sc += <StatusCode as TryInto<u24>>::try_into(self.status_code)
.unwrap() //due to the design of encode, panic is allowed
.map_err(|_| codec::EncodeErr)?
.get();
sc.encode(bytes)?;
Ok(4)
Expand Down
25 changes: 15 additions & 10 deletions spdmlib/src/requester/finish_req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl RequesterContext {
if res.is_err() {
self.common
.get_session_via_id(session_id)
.unwrap()
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?
.teardown();
return Err(res.err().unwrap());
}
Expand All @@ -92,7 +92,7 @@ impl RequesterContext {
if res.is_err() {
self.common
.get_session_via_id(session_id)
.unwrap()
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?
.teardown();
return res;
}
Expand All @@ -107,7 +107,7 @@ impl RequesterContext {
if res.is_err() {
self.common
.get_session_via_id(session_id)
.unwrap()
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?
.teardown();
return Err(res.err().unwrap());
}
Expand Down Expand Up @@ -186,13 +186,16 @@ impl RequesterContext {
let session = self
.common
.get_immutable_session_via_id(session_id)
.unwrap();
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

let transcript_hash =
self.common
.calc_req_transcript_hash(false, req_slot_id, is_mut_auth, session)?;

let session = self.common.get_session_via_id(session_id).unwrap();
let session = self
.common
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

let hmac = session.generate_hmac_with_request_finished_key(transcript_hash.as_ref())?;

Expand Down Expand Up @@ -253,7 +256,7 @@ impl RequesterContext {
let session = self
.common
.get_immutable_session_via_id(session_id)
.unwrap();
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

let transcript_hash = self.common.calc_req_transcript_hash(
false,
Expand Down Expand Up @@ -291,7 +294,7 @@ impl RequesterContext {
let session = self
.common
.get_immutable_session_via_id(session_id)
.unwrap();
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

// generate the data secret
let th2 = self.common.calc_req_transcript_hash(
Expand All @@ -303,7 +306,10 @@ impl RequesterContext {

debug!("!!! th2 : {:02x?}\n", th2.as_ref());
let spdm_version_sel = self.common.negotiate_info.spdm_version_sel;
let session = self.common.get_session_via_id(session_id).unwrap();
let session = self
.common
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
match session.generate_data_secret(spdm_version_sel, &th2) {
Ok(_) => {}
Err(e) => {
Expand Down Expand Up @@ -424,8 +430,7 @@ impl RequesterContext {
peer_cert,
transcript_sign.as_ref(),
&signature,
)
.unwrap();
)?;

Ok(signature)
}
Expand Down
31 changes: 20 additions & 11 deletions spdmlib/src/requester/key_exchange_req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use crate::error::SPDM_STATUS_CRYPTO_ERROR;
use crate::error::SPDM_STATUS_ERROR_PEER;
use crate::error::SPDM_STATUS_INVALID_MSG_FIELD;
use crate::error::SPDM_STATUS_INVALID_PARAMETER;
#[cfg(feature = "hashed-transcript-data")]
use crate::error::SPDM_STATUS_INVALID_STATE_LOCAL;
use crate::error::SPDM_STATUS_SESSION_NUMBER_EXCEED;
use crate::error::SPDM_STATUS_VERIF_FAIL;
Expand Down Expand Up @@ -296,7 +295,7 @@ impl RequesterContext {
let session = self
.common
.get_immutable_session_via_id(session_id)
.unwrap();
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

// verify signature
if self
Expand All @@ -321,22 +320,25 @@ impl RequesterContext {
let session = self
.common
.get_immutable_session_via_id(session_id)
.unwrap();
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

// generate the handshake secret (including finished_key) before verify HMAC
let th1 = self
.common
.calc_req_transcript_hash(false, slot_id, false, session)?;
debug!("!!! th1 : {:02x?}\n", th1.as_ref());

let session = self.common.get_session_via_id(session_id).unwrap();
let session = self
.common
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
session.generate_handshake_secret(spdm_version_sel, &th1)?;

if !in_clear_text {
let session = self
.common
.get_immutable_session_via_id(session_id)
.unwrap();
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

// verify HMAC with finished_key
let transcript_hash = self
Expand All @@ -346,7 +348,7 @@ impl RequesterContext {
let session = self
.common
.get_immutable_session_via_id(session_id)
.unwrap();
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

if session
.verify_hmac_with_response_finished_key(
Expand All @@ -356,8 +358,10 @@ impl RequesterContext {
.is_err()
{
error!("verify_hmac_with_response_finished_key fail");
let session =
self.common.get_session_via_id(session_id).unwrap();
let session = self
.common
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
session.teardown();
return Err(SPDM_STATUS_VERIF_FAIL);
} else {
Expand All @@ -373,15 +377,20 @@ impl RequesterContext {
)
.is_err()
{
let session =
self.common.get_session_via_id(session_id).unwrap();
let session = self
.common
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
session.teardown();
return Err(SPDM_STATUS_BUFFER_FULL);
}
}

// append verify_data after TH1
let session = self.common.get_session_via_id(session_id).unwrap();
let session = self
.common
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

session.secure_spdm_version_sel = secure_spdm_version_sel;
session.heartbeat_period = key_exchange_rsp.heartbeat_period;
Expand Down
Loading
Loading