Skip to content

Commit

Permalink
Add Test checking GroupContextExtensionProposal Validation: Commit co…
Browse files Browse the repository at this point in the history
…ntains up to one GCE Proposal (openmls#1590)

Co-authored-by: Jan Winkelmann (keks) <[email protected]>
  • Loading branch information
keks and keks authored Jun 20, 2024
1 parent 70d5767 commit 2f835f7
Show file tree
Hide file tree
Showing 16 changed files with 830 additions and 40 deletions.
2 changes: 2 additions & 0 deletions openmls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ openmls_memory_storage = { path = "../memory_storage", features = [
], optional = true }
openmls_test = { path = "../openmls_test", optional = true }
openmls_libcrux_crypto = { path = "../libcrux_crypto", optional = true }
once_cell = { version = "1.19.0", optional = true }

[features]
default = ["backtrace"]
Expand All @@ -49,6 +50,7 @@ test-utils = [
"dep:openmls_basic_credential",
"dep:openmls_memory_storage",
"dep:openmls_test",
"dep:once_cell",
]
libcrux-provider = [
"dep:openmls_libcrux_crypto",
Expand Down
2 changes: 1 addition & 1 deletion openmls/src/framing/mls_auth_content.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub(crate) struct FramedContentAuthData {
}

impl FramedContentAuthData {
pub(super) fn deserialize<R: Read>(
pub(crate) fn deserialize<R: Read>(
bytes: &mut R,
content_type: ContentType,
) -> Result<Self, tls_codec::Error> {
Expand Down
7 changes: 6 additions & 1 deletion openmls/src/group/core_group/new_from_welcome.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,12 @@ pub(in crate::group) fn build_staged_welcome<Provider: OpenMlsProvider>(
log_crypto!(trace, " Got: {:x?}", confirmation_tag);
log_crypto!(trace, " Expected: {:x?}", public_group.confirmation_tag());
debug_assert!(false, "Confirmation tag mismatch");
return Err(WelcomeError::ConfirmationTagMismatch);

// in some tests we need to be able to proceed despite the tag being wrong,
// e.g. to test whether a later validation check is performed correctly.
if !crate::skip_validation::is_disabled::confirmation_tag() {
return Err(WelcomeError::ConfirmationTagMismatch);
}
}

let message_secrets_store = MessageSecretsStore::new_with_secret(0, message_secrets);
Expand Down
7 changes: 6 additions & 1 deletion openmls/src/group/core_group/staged_commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,12 @@ impl CoreGroup {
// TODO: We have tests expecting this error.
// They need to be rewritten.
// debug_assert!(false, "Confirmation tag mismatch");
return Err(StageCommitError::ConfirmationTagMismatch);

// in some tests we need to be able to proceed despite the tag being wrong,
// e.g. to test whether a later validation check is performed correctly.
if !crate::skip_validation::is_disabled::confirmation_tag() {
return Err(StageCommitError::ConfirmationTagMismatch);
}
}

diff.update_interim_transcript_hash(ciphersuite, provider.crypto(), own_confirmation_tag)?;
Expand Down
161 changes: 143 additions & 18 deletions openmls/src/group/mls_group/test_mls_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ use crate::{
key_packages::*,
messages::proposals::*,
prelude::Capabilities,
test_utils::test_framework::{
errors::ClientError, noop_authentication_service, ActionType::Commit, CodecUse,
MlsGroupTestSetup,
test_utils::{
frankenstein::{self, FrankenMlsMessage},
test_framework::{
errors::ClientError, noop_authentication_service, ActionType::Commit, CodecUse,
MlsGroupTestSetup,
},
},
tree::sender_ratchet::SenderRatchetConfiguration,
};
Expand Down Expand Up @@ -1143,15 +1146,57 @@ fn remove_prosposal_by_ref(
// Test that the builder pattern accurately configures the new group.
#[openmls_test]
fn group_context_extensions_proposal() {
let alice_provider = &mut Provider::default();
let bob_provider = &mut Provider::default();
let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) =
setup_client("Alice", ciphersuite, provider);
setup_client("Alice", ciphersuite, alice_provider);
let (bob_credential_with_key, _bob_kpb, bob_signer, _bob_pk) =
setup_client("bob", ciphersuite, bob_provider);

// === Alice creates a group ===
let mut alice_group = MlsGroup::builder()
.ciphersuite(ciphersuite)
.build(provider, &alice_signer, alice_credential_with_key)
.with_wire_format_policy(WireFormatPolicy::new(
OutgoingWireFormatPolicy::AlwaysPlaintext,
IncomingWireFormatPolicy::Mixed,
))
.build(alice_provider, &alice_signer, alice_credential_with_key)
.expect("error creating group using builder");

// === Alice adds Bob ===
let bob_key_package = KeyPackage::builder()
.build(
ciphersuite,
bob_provider,
&bob_signer,
bob_credential_with_key,
)
.expect("error building key package");

let (_, welcome, _) = alice_group
.add_members(
alice_provider,
&alice_signer,
&[bob_key_package.key_package().clone()],
)
.unwrap();
alice_group.merge_pending_commit(alice_provider).unwrap();

let welcome: MlsMessageIn = welcome.into();
let welcome = welcome
.into_welcome()
.expect("expected message to be a welcome");

let mut bob_group = StagedWelcome::new_from_welcome(
bob_provider,
alice_group.configuration(),
welcome,
Some(alice_group.export_ratchet_tree().into()),
)
.expect("Error creating staged join from Welcome")
.into_group(bob_provider)
.expect("Error creating group from staged join");

// No required capabilities, so no specifically required extensions.
assert!(alice_group
.group()
Expand All @@ -1168,20 +1213,104 @@ fn group_context_extensions_proposal() {
RequiredCapabilitiesExtension::new(&[ExtensionType::RatchetTree], &[], &[]),
));

alice_group
.propose_group_context_extensions(provider, new_extensions.clone(), &alice_signer)
let (proposal, _) = alice_group
.propose_group_context_extensions(alice_provider, new_extensions.clone(), &alice_signer)
.expect("failed to build group context extensions proposal");

let proc_msg = bob_group
.process_message(bob_provider, proposal.into_protocol_message().unwrap())
.unwrap();
match proc_msg.into_content() {
ProcessedMessageContent::ProposalMessage(proposal) => bob_group
.store_pending_proposal(bob_provider.storage(), *proposal)
.unwrap(),
_ => unreachable!(),
};

assert_eq!(alice_group.pending_proposals().count(), 1);

alice_group
.commit_to_pending_proposals(provider, &alice_signer)
let (commit, _, _) = alice_group
.commit_to_pending_proposals(alice_provider, &alice_signer)
.expect("failed to commit to pending proposals");

// we'll change the commit we feed to bob to include two GCE proposals
let mut franken_commit = FrankenMlsMessage::tls_deserialize(
&mut commit.tls_serialize_detached().unwrap().as_slice(),
)
.unwrap();

// Craft a commit that has two GroupContextExtension proposals. This is forbidden by the RFC.
// Change the commit before alice commits, so alice's state is still in the old epoch and we can
// use her state to forge the macs and signatures
match &mut franken_commit.body {
frankenstein::FrankenMlsMessageBody::PublicMessage(msg) => {
match &mut msg.content.body {
frankenstein::FrankenFramedContentBody::Commit(commit) => {
let second_gces = frankenstein::FrankenProposalOrRef::Proposal(
frankenstein::FrankenProposal::GroupContextExtensions(vec![
frankenstein::FrankenExtension::LastResort,
]),
);

commit.proposals.push(second_gces);
}
_ => unreachable!(),
}

let group_context = alice_group.export_group_context().clone();

let bob_group_context = bob_group.export_group_context();
assert_eq!(
bob_group_context.confirmed_transcript_hash(),
group_context.confirmed_transcript_hash()
);

let secrets = alice_group.group.message_secrets();
let membership_key = secrets.membership_key().as_slice();

*msg = frankenstein::FrankenPublicMessage::auth(
alice_provider,
group_context.ciphersuite(),
&alice_signer,
msg.content.clone(),
Some(&group_context.into()),
Some(membership_key),
// this is a dummy confirmation_tag:
Some(vec![0u8; 32].into()),
);
}
_ => unreachable!(),
}

// alice merges the unmodified commit
alice_group
.merge_pending_commit(provider)
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");

let fake_commit = MlsMessageIn::tls_deserialize(
&mut franken_commit.tls_serialize_detached().unwrap().as_slice(),
)
.unwrap();

let fake_commit_protocol_msg = fake_commit.into_protocol_message().unwrap();

let err = {
let validation_skip_handle = crate::skip_validation::checks::confirmation_tag::handle();
validation_skip_handle.with_disabled(|| {
bob_group.process_message(bob_provider, fake_commit_protocol_msg.clone())
})
}
.expect_err("expected an error");

assert!(matches!(
err,
ProcessMessageError::InvalidCommit(
StageCommitError::GroupContextExtensionsProposalValidationError(
GroupContextExtensionsProposalValidationError::TooManyGCEProposals
)
)
));

let required_capabilities = alice_group
.group()
.context()
Expand All @@ -1195,18 +1324,18 @@ fn group_context_extensions_proposal() {
// === committing to two group context extensions should fail

alice_group
.propose_group_context_extensions(provider, new_extensions, &alice_signer)
.propose_group_context_extensions(alice_provider, new_extensions, &alice_signer)
.expect("failed to build group context extensions proposal");

// the proposals need to be different or they will be deduplicated
alice_group
.propose_group_context_extensions(provider, new_extensions_2, &alice_signer)
.propose_group_context_extensions(alice_provider, new_extensions_2, &alice_signer)
.expect("failed to build group context extensions proposal");

assert_eq!(alice_group.pending_proposals().count(), 2);

alice_group
.commit_to_pending_proposals(provider, &alice_signer)
.commit_to_pending_proposals(alice_provider, &alice_signer)
.expect_err(
"expected error when committing to multiple group context extensions proposals",
);
Expand All @@ -1220,12 +1349,8 @@ fn group_context_extensions_proposal() {
));

alice_group
.propose_group_context_extensions(provider, new_extensions, &alice_signer)
.propose_group_context_extensions(alice_provider, new_extensions, &alice_signer)
.expect_err("expected an error building GCE proposal with bad required_capabilities");

// TODO: we need to test that processing a commit with multiple group context extensions
// proposal also fails. however, we can't generate this commit, because our functions for
// constructing commits does not permit it. See #1476
}

// Test that the builder pattern accurately configures the new group.
Expand Down
1 change: 1 addition & 0 deletions openmls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ pub mod storage;

// Private
mod binary_tree;
mod skip_validation;
mod tree;

/// Single place, re-exporting the most used public functions.
Expand Down
100 changes: 100 additions & 0 deletions openmls/src/skip_validation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
//! This module contains helpers for skipping validation. It is built such that setting the flag to
//! disable validation can only by set when the "test-utils" feature is enabled.
//! This module is used in two places, and they use different parts of it.
//! Code that performs validation and wants to check whether a check is disabled only uses the
//! [`is_disabled`] submodule. It contains getter functions that read the current state of the
//! flag.
//! Test code that disables checks uses the code in the [`checks`] submodule. It contains a module
//! for each check that can be disabled, and a getter for a handle, protected by a [`Mutex`]. This
//! is done because the flag state is shared between tests, and tests that set and unset the same
//! checks are not safe to run concurrently.
//! For example, a test could cann [`checks::confirmation_tag::handle`] to get a handle to disable
//! and re-enable the validation of confirmation tags.
pub(crate) mod is_disabled {
use super::checks::*;

pub(crate) fn confirmation_tag() -> bool {
confirmation_tag::FLAG.load(core::sync::atomic::Ordering::Relaxed)
}
}

#[cfg(test)]
use std::sync::atomic::AtomicBool;

/// Contains a reference to a flag. Provides convenience functions to set and clear the flag.
#[cfg(test)]
#[derive(Clone, Copy, Debug)]
pub struct SkipValidationHandle {
// we keep this field so we can see which handle this is when printing it. we don't need it otherwise
#[allow(dead_code)]
name: &'static str,
flag: &'static AtomicBool,
}

/// Contains the flags and functions that return handles to control them.
pub(crate) mod checks {
/// Disables validation of the confirmation_tag.
pub(crate) mod confirmation_tag {
use std::sync::atomic::AtomicBool;

/// A way of disabling verification and validation of confirmation tags.
pub(in crate::skip_validation) static FLAG: AtomicBool = AtomicBool::new(false);

#[cfg(test)]
pub(crate) use lock::handle;

#[cfg(test)]
mod lock {
use super::FLAG;
use crate::skip_validation::SkipValidationHandle;
use once_cell::sync::Lazy;
use std::sync::{Mutex, MutexGuard};

/// The name of the check that can be skipped here
const NAME: &str = "confirmation_tag";

/// A mutex needed to run tests that use this flag sequentially
static MUTEX: Lazy<Mutex<SkipValidationHandle>> =
Lazy::new(|| Mutex::new(SkipValidationHandle::new_confirmation_tag_handle()));

/// Takes the mutex and returns the control handle to the validation skipper
pub(crate) fn handle() -> MutexGuard<'static, SkipValidationHandle> {
MUTEX.lock().unwrap_or_else(|e| {
panic!("error taking skip-validation mutex for '{NAME}': {e}")
})
}

impl SkipValidationHandle {
pub fn new_confirmation_tag_handle() -> Self {
Self {
name: NAME,
flag: &FLAG,
}
}
}
}
}
}

#[cfg(test)]
impl SkipValidationHandle {
/// Disables validation for the check controlled by this handle
pub fn disable_validation(self) {
self.flag.store(true, core::sync::atomic::Ordering::Relaxed);
}

/// Enables validation for the check controlled by this handle
pub fn enable_validation(self) {
self.flag
.store(false, core::sync::atomic::Ordering::Relaxed);
}

/// Runs function `f` with validation disabled
pub fn with_disabled<R, F: FnMut() -> R>(self, mut f: F) -> R {
self.disable_validation();
let r = f();
self.enable_validation();
r
}
}
4 changes: 3 additions & 1 deletion openmls/src/test_utils/frankenstein/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ impl Size for FrankenExtension {
impl Serialize for FrankenExtension {
fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
let written = self.extension_type().tls_serialize(writer)?;
let extension_data_len = self.tls_serialized_len();

// subtract the two bytes for the type header
let extension_data_len = self.tls_serialized_len() - 2;
let mut extension_data = Vec::with_capacity(extension_data_len);

let _ = match self {
Expand Down
Loading

0 comments on commit 2f835f7

Please sign in to comment.