Skip to content

Commit

Permalink
Upgrade zeroize to 1.5 (#286)
Browse files Browse the repository at this point in the history
* Upgrade zeroize to 1.5

* Update the MSRV to 1.51

* Fix clippy warnings

Co-authored-by: Valentin Tolmer <[email protected]>
  • Loading branch information
nitnelave and nitnelave authored Nov 25, 2022
1 parent 1012439 commit 902a605
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 111 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- u32_backend
toolchain:
- nightly
- 1.41.0
- 1.51.0
name: test
steps:
- name: Checkout sources
Expand Down Expand Up @@ -94,7 +94,7 @@ jobs:
matrix:
toolchain:
- nightly
- 1.41.0
- 1.51.0
name: test simple_login command-line example
steps:
- name: install expect
Expand All @@ -118,7 +118,7 @@ jobs:
matrix:
toolchain:
- nightly
- 1.41.0
- 1.51.0
name: test digital_locker command-line example
steps:
- name: install expect
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ rand = "0.8"
serde = { version = "1", features = ["derive"], optional = true }
subtle = { version = "2.3.0", default-features = false }
thiserror = "1.0.22"
zeroize = { version = "~1.1", features = ["zeroize_derive"] }
zeroize = { version = "~1.5", features = ["zeroize_derive"] }

[dev-dependencies]
anyhow = "1.0.35"
Expand Down
18 changes: 9 additions & 9 deletions src/key_exchange/tripledh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ impl<D: Hash, G: Group> KeyExchange<D, G> for TripleDH {

let mut transcript_hasher = D::new()
.chain(STR_RFC)
.chain(&serialize(&context, 2))
.chain(&id_u)
.chain(serialize(&context, 2))
.chain(id_u)
.chain(&serialized_credential_request[..])
.chain(&id_s)
.chain(id_s)
.chain(&l2_bytes[..])
.chain(&server_nonce[..])
.chain(&server_e_kp.public().to_arr());
.chain(server_e_kp.public().to_arr());

let (session_key, km2, km3) = derive_3dh_keys::<D, G>(
TripleDHComponents {
Expand Down Expand Up @@ -141,12 +141,12 @@ impl<D: Hash, G: Group> KeyExchange<D, G> for TripleDH {
) -> Result<(Vec<u8>, Self::KE3Message), ProtocolError> {
let mut transcript_hasher = D::new()
.chain(STR_RFC)
.chain(&serialize(&context, 2))
.chain(&id_u)
.chain(&serialized_credential_request)
.chain(&id_s)
.chain(serialize(&context, 2))
.chain(id_u)
.chain(serialized_credential_request)
.chain(id_s)
.chain(&l2_component[..])
.chain(&ke2_message.to_bytes_without_info_or_mac());
.chain(ke2_message.to_bytes_without_info_or_mac());

let (session_key, km2, km3) = derive_3dh_keys::<D, G>(
TripleDHComponents {
Expand Down
6 changes: 3 additions & 3 deletions src/oprf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ fn finalize_after_unblind<G: GroupWithMapToCurve, H: Hash>(
let finalize_dst = [STR_VOPRF_FINALIZE, &G::get_context_string(MODE_BASE)].concat();
let hash_input = [
serialize(input, 2),
serialize(&unblinded_element.to_arr().to_vec(), 2),
serialize(&unblinded_element.to_arr(), 2),
serialize(&finalize_dst, 2),
]
.concat();
Expand Down Expand Up @@ -130,7 +130,7 @@ mod tests {
RistrettoPoint::from_scalar_slice(GenericArray::from_slice(&oprf_key[..])).unwrap();
let res = point * scalar;

finalize_after_unblind::<RistrettoPoint, sha2::Sha512>(&input, res)
finalize_after_unblind::<RistrettoPoint, sha2::Sha512>(input, res)
}

#[test]
Expand All @@ -145,7 +145,7 @@ mod tests {
let oprf_key = RistrettoPoint::from_scalar_slice(&oprf_key_bytes)?;
let beta = evaluate::<RistrettoPoint>(alpha, &oprf_key);
let res = finalize::<RistrettoPoint, sha2::Sha512>(&token.data, &token.blind, beta);
let res2 = prf(&input[..], &oprf_key.as_bytes());
let res2 = prf(&input[..], oprf_key.as_bytes());
assert_eq!(res, res2);
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion src/serialization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::errors::PakeError;
// Corresponds to the I2OSP() function from RFC8017
pub(crate) fn i2osp(input: usize, length: usize) -> Vec<u8> {
if length <= std::mem::size_of::<usize>() {
return (&input.to_be_bytes()[std::mem::size_of::<usize>() - length..]).to_vec();
return input.to_be_bytes()[std::mem::size_of::<usize>() - length..].to_vec();
}

let mut output = vec![0u8; length];
Expand Down
76 changes: 40 additions & 36 deletions src/serialization/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ fn registration_request_roundtrip() {
let identity = RistrettoPoint::identity();
let identity_bytes = identity.to_arr().to_vec();

assert!(
match RegistrationRequest::<Default>::deserialize(identity_bytes.as_slice()) {
Err(ProtocolError::VerificationError(PakeError::IdentityGroupElementError)) => true,
_ => false,
}
);
assert!(matches!(
RegistrationRequest::<Default>::deserialize(identity_bytes.as_slice()),
Err(ProtocolError::VerificationError(
PakeError::IdentityGroupElementError
))
));
}

#[test]
Expand All @@ -122,7 +122,7 @@ fn registration_response_roundtrip() {

let mut input = Vec::new();
input.extend_from_slice(beta_bytes.as_slice());
input.extend_from_slice(&pubkey_bytes.as_slice());
input.extend_from_slice(pubkey_bytes.as_slice());

let r2 = RegistrationResponse::<Default>::deserialize(input.as_slice()).unwrap();
let r2_bytes = r2.serialize();
Expand All @@ -132,12 +132,14 @@ fn registration_response_roundtrip() {
let identity = RistrettoPoint::identity();
let identity_bytes = identity.to_arr().to_vec();

assert!(match RegistrationResponse::<Default>::deserialize(
&[identity_bytes, pubkey_bytes.to_vec()].concat()
) {
Err(ProtocolError::VerificationError(PakeError::IdentityGroupElementError)) => true,
_ => false,
});
assert!(matches!(
RegistrationResponse::<Default>::deserialize(
&[identity_bytes, pubkey_bytes.to_vec()].concat()
),
Err(ProtocolError::VerificationError(
PakeError::IdentityGroupElementError
))
));
}

#[test]
Expand Down Expand Up @@ -179,7 +181,7 @@ fn credential_request_roundtrip() {
let mut client_nonce = vec![0u8; NonceLen::to_usize()];
rng.fill_bytes(&mut client_nonce);

let ke1m: Vec<u8> = [&client_nonce[..], &client_e_kp.public()].concat();
let ke1m: Vec<u8> = [&client_nonce[..], client_e_kp.public()].concat();

let mut input = Vec::new();
input.extend_from_slice(&alpha_bytes);
Expand All @@ -193,12 +195,12 @@ fn credential_request_roundtrip() {
let identity = RistrettoPoint::identity();
let identity_bytes = identity.to_arr().to_vec();

assert!(match CredentialRequest::<Default>::deserialize(
&[identity_bytes, ke1m.to_vec()].concat()
) {
Err(ProtocolError::VerificationError(PakeError::IdentityGroupElementError)) => true,
_ => false,
});
assert!(matches!(
CredentialRequest::<Default>::deserialize(&[identity_bytes, ke1m.to_vec()].concat()),
Err(ProtocolError::VerificationError(
PakeError::IdentityGroupElementError
))
));
}

#[test]
Expand All @@ -221,7 +223,7 @@ fn credential_response_roundtrip() {
let mut server_nonce = vec![0u8; NonceLen::to_usize()];
rng.fill_bytes(&mut server_nonce);

let ke2m: Vec<u8> = [&server_nonce[..], &server_e_kp.public(), &mac[..]].concat();
let ke2m: Vec<u8> = [&server_nonce[..], server_e_kp.public(), &mac[..]].concat();

let mut input = Vec::new();
input.extend_from_slice(pt_bytes.as_slice());
Expand All @@ -237,18 +239,20 @@ fn credential_response_roundtrip() {
let identity = RistrettoPoint::identity();
let identity_bytes = identity.to_arr().to_vec();

assert!(match CredentialResponse::<Default>::deserialize(
&[
identity_bytes,
masking_nonce.to_vec(),
masked_response,
ke2m.to_vec()
]
.concat()
) {
Err(ProtocolError::VerificationError(PakeError::IdentityGroupElementError)) => true,
_ => false,
});
assert!(matches!(
CredentialResponse::<Default>::deserialize(
&[
identity_bytes,
masking_nonce.to_vec(),
masked_response,
ke2m.to_vec()
]
.concat()
),
Err(ProtocolError::VerificationError(
PakeError::IdentityGroupElementError
))
));
}

#[test]
Expand Down Expand Up @@ -298,7 +302,7 @@ fn ke1_message_roundtrip() {
let mut client_nonce = vec![0u8; NonceLen::to_usize()];
rng.fill_bytes(&mut client_nonce);

let ke1m: Vec<u8> = [&client_nonce[..], &client_e_kp.public()].concat();
let ke1m: Vec<u8> = [&client_nonce[..], client_e_kp.public()].concat();
let reg = <TripleDH as KeyExchange<sha2::Sha512, RistrettoPoint>>::KE1Message::from_bytes::<
Default,
>(&ke1m[..])
Expand All @@ -317,7 +321,7 @@ fn ke2_message_roundtrip() {
let mut server_nonce = vec![0u8; NonceLen::to_usize()];
rng.fill_bytes(&mut server_nonce);

let ke2m: Vec<u8> = [&server_nonce[..], &server_e_kp.public(), &mac[..]].concat();
let ke2m: Vec<u8> = [&server_nonce[..], server_e_kp.public(), &mac[..]].concat();

let reg = <TripleDH as KeyExchange<sha2::Sha512, RistrettoPoint>>::KE2Message::from_bytes::<
Default,
Expand Down Expand Up @@ -347,7 +351,7 @@ proptest! {

#[test]
fn test_i2osp_os2ip(ref bytes in vec(prop::num::u8::ANY, 0..std::mem::size_of::<usize>())) {
assert_eq!(&i2osp(os2ip(&bytes)?, bytes.len()), bytes);
assert_eq!(&i2osp(os2ip(bytes)?, bytes.len()), bytes);
}

#[test]
Expand Down
82 changes: 41 additions & 41 deletions src/tests/full_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,45 +111,43 @@ static TEST_VECTOR: &str = r#"
"#;

fn decode(values: &Value, key: &str) -> Option<Vec<u8>> {
values[key]
.as_str()
.and_then(|s| hex::decode(&s.to_string()).ok())
values[key].as_str().and_then(|s| hex::decode(s).ok())
}

fn populate_test_vectors(values: &Value) -> TestVectorParameters {
TestVectorParameters {
client_s_pk: decode(&values, "client_s_pk").unwrap(),
client_s_sk: decode(&values, "client_s_sk").unwrap(),
client_e_pk: decode(&values, "client_e_pk").unwrap(),
client_e_sk: decode(&values, "client_e_sk").unwrap(),
server_s_pk: decode(&values, "server_s_pk").unwrap(),
server_s_sk: decode(&values, "server_s_sk").unwrap(),
server_e_pk: decode(&values, "server_e_pk").unwrap(),
server_e_sk: decode(&values, "server_e_sk").unwrap(),
fake_sk: decode(&values, "fake_sk").unwrap(),
credential_identifier: decode(&values, "credential_identifier").unwrap(),
id_u: decode(&values, "id_u").unwrap(),
id_s: decode(&values, "id_s").unwrap(),
password: decode(&values, "password").unwrap(),
blinding_factor: decode(&values, "blinding_factor").unwrap(),
oprf_seed: decode(&values, "oprf_seed").unwrap(),
masking_nonce: decode(&values, "masking_nonce").unwrap(),
envelope_nonce: decode(&values, "envelope_nonce").unwrap(),
client_nonce: decode(&values, "client_nonce").unwrap(),
server_nonce: decode(&values, "server_nonce").unwrap(),
context: decode(&values, "context").unwrap(),
registration_request: decode(&values, "registration_request").unwrap(),
registration_response: decode(&values, "registration_response").unwrap(),
registration_upload: decode(&values, "registration_upload").unwrap(),
credential_request: decode(&values, "credential_request").unwrap(),
credential_response: decode(&values, "credential_response").unwrap(),
credential_finalization: decode(&values, "credential_finalization").unwrap(),
client_registration_state: decode(&values, "client_registration_state").unwrap(),
client_login_state: decode(&values, "client_login_state").unwrap(),
server_login_state: decode(&values, "server_login_state").unwrap(),
password_file: decode(&values, "password_file").unwrap(),
export_key: decode(&values, "export_key").unwrap(),
session_key: decode(&values, "session_key").unwrap(),
client_s_pk: decode(values, "client_s_pk").unwrap(),
client_s_sk: decode(values, "client_s_sk").unwrap(),
client_e_pk: decode(values, "client_e_pk").unwrap(),
client_e_sk: decode(values, "client_e_sk").unwrap(),
server_s_pk: decode(values, "server_s_pk").unwrap(),
server_s_sk: decode(values, "server_s_sk").unwrap(),
server_e_pk: decode(values, "server_e_pk").unwrap(),
server_e_sk: decode(values, "server_e_sk").unwrap(),
fake_sk: decode(values, "fake_sk").unwrap(),
credential_identifier: decode(values, "credential_identifier").unwrap(),
id_u: decode(values, "id_u").unwrap(),
id_s: decode(values, "id_s").unwrap(),
password: decode(values, "password").unwrap(),
blinding_factor: decode(values, "blinding_factor").unwrap(),
oprf_seed: decode(values, "oprf_seed").unwrap(),
masking_nonce: decode(values, "masking_nonce").unwrap(),
envelope_nonce: decode(values, "envelope_nonce").unwrap(),
client_nonce: decode(values, "client_nonce").unwrap(),
server_nonce: decode(values, "server_nonce").unwrap(),
context: decode(values, "context").unwrap(),
registration_request: decode(values, "registration_request").unwrap(),
registration_response: decode(values, "registration_response").unwrap(),
registration_upload: decode(values, "registration_upload").unwrap(),
credential_request: decode(values, "credential_request").unwrap(),
credential_response: decode(values, "credential_response").unwrap(),
credential_finalization: decode(values, "credential_finalization").unwrap(),
client_registration_state: decode(values, "client_registration_state").unwrap(),
client_login_state: decode(values, "client_login_state").unwrap(),
server_login_state: decode(values, "server_login_state").unwrap(),
password_file: decode(values, "password_file").unwrap(),
export_key: decode(values, "export_key").unwrap(),
session_key: decode(values, "session_key").unwrap(),
}
}

Expand Down Expand Up @@ -550,7 +548,7 @@ fn test_registration_upload() -> Result<(), ProtocolError> {
);
assert_eq!(
hex::encode(parameters.export_key),
hex::encode(result.export_key.to_vec())
hex::encode(result.export_key)
);

Ok(())
Expand Down Expand Up @@ -665,7 +663,7 @@ fn test_credential_finalization() -> Result<(), ProtocolError> {

assert_eq!(
hex::encode(&parameters.server_s_pk),
hex::encode(&client_login_finish_result.server_s_pk.to_arr().to_vec())
hex::encode(client_login_finish_result.server_s_pk.to_arr())
);
assert_eq!(
hex::encode(&parameters.session_key),
Expand Down Expand Up @@ -758,10 +756,12 @@ fn test_complete_flow(
hex::encode(client_login_finish_result.export_key)
);
} else {
assert!(match client_login_result {
Err(ProtocolError::VerificationError(PakeError::InvalidLoginError)) => true,
_ => false,
});
assert!(matches!(
client_login_result,
Err(ProtocolError::VerificationError(
PakeError::InvalidLoginError
))
));
}

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion src/tests/mock_rng.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl RngCore for CycleRng {
#[inline]
fn fill_bytes(&mut self, dest: &mut [u8]) {
let len = min(self.v.len(), dest.len());
(&mut dest[..len]).copy_from_slice(&self.v[..len]);
dest[..len].copy_from_slice(&self.v[..len]);
rotate_left(&mut self.v, len);
}

Expand Down
Loading

0 comments on commit 902a605

Please sign in to comment.