Skip to content

Commit

Permalink
remove unchecked unwrap.
Browse files Browse the repository at this point in the history
fix #99

Signed-off-by: Yang, Longlong <[email protected]>
  • Loading branch information
longlongyang committed Jun 12, 2024
1 parent 9916d65 commit ea50be1
Show file tree
Hide file tree
Showing 16 changed files with 480 additions and 153 deletions.
35 changes: 30 additions & 5 deletions spdmlib/src/common/key_schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,15 @@ impl SpdmKeySchedule {
return None;
}
} else {
let empty_pskhint = SpdmPskHintStruct::default();
secret::psk::handshake_secret_hkdf_expand(
spdm_version,
hash_algo,
psk_hint.unwrap(),
if let Some(hint) = psk_hint {
hint
} else {
&empty_pskhint
},
bin_str1,
)?
};
Expand Down Expand Up @@ -162,10 +167,15 @@ impl SpdmKeySchedule {
return None;
}
} else {
let empty_pskhint = SpdmPskHintStruct::default();
secret::psk::handshake_secret_hkdf_expand(
spdm_version,
hash_algo,
psk_hint.unwrap(),
if let Some(hint) = psk_hint {
hint
} else {
&empty_pskhint
},
bin_str2,
)?
};
Expand Down Expand Up @@ -295,10 +305,15 @@ impl SpdmKeySchedule {
return None;
}
} else {
let empty_pskhint = SpdmPskHintStruct::default();
secret::psk::master_secret_hkdf_expand(
spdm_version,
hash_algo,
psk_hint.unwrap(),
if let Some(hint) = psk_hint {
hint
} else {
&empty_pskhint
},
bin_str3,
)?
};
Expand Down Expand Up @@ -337,10 +352,15 @@ impl SpdmKeySchedule {
return None;
}
} else {
let empty_pskhint = SpdmPskHintStruct::default();
secret::psk::master_secret_hkdf_expand(
spdm_version,
hash_algo,
psk_hint.unwrap(),
if let Some(hint) = psk_hint {
hint
} else {
&empty_pskhint
},
bin_str4,
)?
};
Expand Down Expand Up @@ -378,10 +398,15 @@ impl SpdmKeySchedule {
return None;
}
} else {
let empty_pskhint = SpdmPskHintStruct::default();
secret::psk::master_secret_hkdf_expand(
spdm_version,
hash_algo,
psk_hint.unwrap(),
if let Some(hint) = psk_hint {
hint
} else {
&empty_pskhint
},
bin_str8,
)?
};
Expand Down
27 changes: 19 additions & 8 deletions spdmlib/src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,7 @@ impl SpdmContext {
crypto::cert_operation::get_cert_from_cert_chain(
&cert_chain.data[..(cert_chain.data_size as usize)],
0,
)
.unwrap();
)?;
let root_cert = &cert_chain.data[root_cert_begin..root_cert_end];
if let Some(root_hash) =
crypto::hash::hash_all(self.negotiate_info.base_hash_sel, root_cert)
Expand Down Expand Up @@ -509,7 +508,9 @@ impl SpdmContext {
}

pub fn append_message_k(&mut self, session_id: u32, new_message: &[u8]) -> SpdmResult {
let session = self.get_session_via_id(session_id).unwrap();
let session = self
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

#[cfg(not(feature = "hashed-transcript-data"))]
{
Expand Down Expand Up @@ -574,7 +575,9 @@ impl SpdmContext {
session_id: u32,
new_message: &[u8],
) -> SpdmResult {
let session = self.get_session_via_id(session_id).unwrap();
let session = self
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
let _ = session
.runtime_info
.message_f
Expand All @@ -590,7 +593,9 @@ impl SpdmContext {
session_id: u32,
new_message: &[u8],
) -> SpdmResult {
let session = self.get_immutable_session_via_id(session_id).unwrap();
let session = self
.get_immutable_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
if session.runtime_info.digest_context_th.is_none() {
return Err(SPDM_STATUS_INVALID_STATE_LOCAL);
}
Expand Down Expand Up @@ -631,18 +636,24 @@ impl SpdmContext {
};

if let Some(mut_cert_digest) = mut_cert_digest {
let session = self.get_session_via_id(session_id).unwrap();
let session = self
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

crypto::hash::hash_ctx_update(
session.runtime_info.digest_context_th.as_mut().unwrap(),
&mut_cert_digest.data[..mut_cert_digest.data_size as usize],
)?;
}
let session = self.get_session_via_id(session_id).unwrap();
let session = self
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
session.runtime_info.message_f_initialized = true;
}

let session = self.get_session_via_id(session_id).unwrap();
let session = self
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
crypto::hash::hash_ctx_update(
session.runtime_info.digest_context_th.as_mut().unwrap(),
new_message,
Expand Down
2 changes: 1 addition & 1 deletion spdmlib/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ impl Codec for SpdmStatus {
let mut sc = 0u32;
sc += (((self.severity as u8) & 0x0F) as u32) << 28;
sc += <StatusCode as TryInto<u24>>::try_into(self.status_code)
.unwrap() //due to the design of encode, panic is allowed
.map_err(|_| codec::EncodeErr)?
.get();
sc.encode(bytes)?;
Ok(4)
Expand Down
25 changes: 15 additions & 10 deletions spdmlib/src/requester/finish_req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl RequesterContext {
if res.is_err() {
self.common
.get_session_via_id(session_id)
.unwrap()
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?
.teardown();
return Err(res.err().unwrap());
}
Expand All @@ -92,7 +92,7 @@ impl RequesterContext {
if res.is_err() {
self.common
.get_session_via_id(session_id)
.unwrap()
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?
.teardown();
return res;
}
Expand All @@ -107,7 +107,7 @@ impl RequesterContext {
if res.is_err() {
self.common
.get_session_via_id(session_id)
.unwrap()
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?
.teardown();
return Err(res.err().unwrap());
}
Expand Down Expand Up @@ -186,13 +186,16 @@ impl RequesterContext {
let session = self
.common
.get_immutable_session_via_id(session_id)
.unwrap();
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

let transcript_hash =
self.common
.calc_req_transcript_hash(false, req_slot_id, is_mut_auth, session)?;

let session = self.common.get_session_via_id(session_id).unwrap();
let session = self
.common
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

let hmac = session.generate_hmac_with_request_finished_key(transcript_hash.as_ref())?;

Expand Down Expand Up @@ -253,7 +256,7 @@ impl RequesterContext {
let session = self
.common
.get_immutable_session_via_id(session_id)
.unwrap();
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

let transcript_hash = self.common.calc_req_transcript_hash(
false,
Expand Down Expand Up @@ -291,7 +294,7 @@ impl RequesterContext {
let session = self
.common
.get_immutable_session_via_id(session_id)
.unwrap();
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

// generate the data secret
let th2 = self.common.calc_req_transcript_hash(
Expand All @@ -303,7 +306,10 @@ impl RequesterContext {

debug!("!!! th2 : {:02x?}\n", th2.as_ref());
let spdm_version_sel = self.common.negotiate_info.spdm_version_sel;
let session = self.common.get_session_via_id(session_id).unwrap();
let session = self
.common
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
match session.generate_data_secret(spdm_version_sel, &th2) {
Ok(_) => {}
Err(e) => {
Expand Down Expand Up @@ -424,8 +430,7 @@ impl RequesterContext {
peer_cert,
transcript_sign.as_ref(),
&signature,
)
.unwrap();
)?;

Ok(signature)
}
Expand Down
30 changes: 20 additions & 10 deletions spdmlib/src/requester/key_exchange_req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ impl RequesterContext {
let session = self
.common
.get_immutable_session_via_id(session_id)
.unwrap();
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

// verify signature
if self
Expand All @@ -321,22 +321,25 @@ impl RequesterContext {
let session = self
.common
.get_immutable_session_via_id(session_id)
.unwrap();
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

// generate the handshake secret (including finished_key) before verify HMAC
let th1 = self
.common
.calc_req_transcript_hash(false, slot_id, false, session)?;
debug!("!!! th1 : {:02x?}\n", th1.as_ref());

let session = self.common.get_session_via_id(session_id).unwrap();
let session = self
.common
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
session.generate_handshake_secret(spdm_version_sel, &th1)?;

if !in_clear_text {
let session = self
.common
.get_immutable_session_via_id(session_id)
.unwrap();
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

// verify HMAC with finished_key
let transcript_hash = self
Expand All @@ -346,7 +349,7 @@ impl RequesterContext {
let session = self
.common
.get_immutable_session_via_id(session_id)
.unwrap();
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

if session
.verify_hmac_with_response_finished_key(
Expand All @@ -356,8 +359,10 @@ impl RequesterContext {
.is_err()
{
error!("verify_hmac_with_response_finished_key fail");
let session =
self.common.get_session_via_id(session_id).unwrap();
let session = self
.common
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
session.teardown();
return Err(SPDM_STATUS_VERIF_FAIL);
} else {
Expand All @@ -373,15 +378,20 @@ impl RequesterContext {
)
.is_err()
{
let session =
self.common.get_session_via_id(session_id).unwrap();
let session = self
.common
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
session.teardown();
return Err(SPDM_STATUS_BUFFER_FULL);
}
}

// append verify_data after TH1
let session = self.common.get_session_via_id(session_id).unwrap();
let session = self
.common
.get_session_via_id(session_id)
.ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;

session.secure_spdm_version_sel = secure_spdm_version_sel;
session.heartbeat_period = key_exchange_rsp.heartbeat_period;
Expand Down
Loading

0 comments on commit ea50be1

Please sign in to comment.