From ee6fddc15568ac148b53dee2eca202162d25be11 Mon Sep 17 00:00:00 2001 From: Marta Mularczyk Date: Fri, 17 Jan 2025 09:42:13 +0100 Subject: [PATCH] Remove lifetime from ReceivedMessage --- mls-rs/examples/api_1x.rs | 5 +- mls-rs/examples/basic_server_usage.rs | 19 +++- mls-rs/examples/basic_usage.rs | 4 +- mls-rs/examples/custom.rs | 11 +-- mls-rs/src/client.rs | 2 +- mls-rs/src/external_client/group.rs | 99 +++++++++++-------- mls-rs/src/group/commit/processor.rs | 15 +-- .../interop_test_vectors/passive_client.rs | 13 +-- mls-rs/src/group/message_processor.rs | 70 ++++++++----- mls-rs/src/group/mod.rs | 57 +++++++---- mls-rs/src/group/test_utils.rs | 15 +-- mls-rs/test_harness_integration/src/main.rs | 15 +-- mls-rs/tests/client_tests.rs | 13 ++- 13 files changed, 212 insertions(+), 126 deletions(-) diff --git a/mls-rs/examples/api_1x.rs b/mls-rs/examples/api_1x.rs index 55950245..aeab131e 100644 --- a/mls-rs/examples/api_1x.rs +++ b/mls-rs/examples/api_1x.rs @@ -42,7 +42,10 @@ fn main() -> Result<(), MlsError> { // Alice and bob can chat let msg = alice_group.encrypt_application_message(b"hello world", Default::default())?; - let msg = bob_group.process_incoming_message(msg)?; + + let msg = bob_group + .process_incoming_message(msg)? + .into_received_message(); println!("Received message: {:?}", msg); diff --git a/mls-rs/examples/basic_server_usage.rs b/mls-rs/examples/basic_server_usage.rs index ebe2f272..4033b749 100644 --- a/mls-rs/examples/basic_server_usage.rs +++ b/mls-rs/examples/basic_server_usage.rs @@ -60,9 +60,12 @@ impl BasicServer { let mut group = server.load_group(group_state)?; let proposal_msg = MlsMessage::from_bytes(&proposal)?; - let res = group.process_incoming_message(proposal_msg)?; - let ExternalReceivedMessage::Proposal(proposal_desc) = res else { + let res = group + .process_incoming_message(proposal_msg)? + .into_received_message(); + + let Some(ExternalReceivedMessage::Proposal(proposal_desc)) = res else { panic!("expected proposal message!") }; @@ -86,9 +89,12 @@ impl BasicServer { } let commit_msg = MlsMessage::from_bytes(&commit)?; - let res = group.process_incoming_message(commit_msg)?; - let ExternalReceivedMessage::Commit(_commit_desc) = res else { + let res = group + .process_incoming_message(commit_msg)? + .into_received_message(); + + let Some(ExternalReceivedMessage::Commit(_commit_desc)) = res else { panic!("expected commit message!") }; @@ -184,7 +190,10 @@ fn main() -> Result<(), MlsError> { // Bob downloads the commit let message = server.download_messages(1).first().unwrap(); - let res = bob_group.process_incoming_message(MlsMessage::from_bytes(message)?)?; + let res = bob_group + .process_incoming_message(MlsMessage::from_bytes(message)?)? + .into_received_message() + .unwrap(); let ReceivedMessage::Commit(commit_desc) = res else { panic!("expected commit message") diff --git a/mls-rs/examples/basic_usage.rs b/mls-rs/examples/basic_usage.rs index 98f8de62..27eede7e 100644 --- a/mls-rs/examples/basic_usage.rs +++ b/mls-rs/examples/basic_usage.rs @@ -72,7 +72,9 @@ fn main() -> Result<(), MlsError> { let msg = alice_group.encrypt_application_message(b"hello world", Default::default())?; // Bob decrypts the application message from Alice. - let msg = bob_group.process_incoming_message(msg)?; + let msg = bob_group + .process_incoming_message(msg)? + .into_received_message(); println!("Received message: {:?}", msg); diff --git a/mls-rs/examples/custom.rs b/mls-rs/examples/custom.rs index 61869c80..950140db 100644 --- a/mls-rs/examples/custom.rs +++ b/mls-rs/examples/custom.rs @@ -29,7 +29,7 @@ use mls_rs::{ error::MlsError, group::{ proposal::{MlsCustomProposal, Proposal}, - GroupContext, ReceivedMessage, Sender, + GroupContext, Sender, }, mls_rules::{ProposalBundle, ProposalSource}, CipherSuite, CipherSuiteProvider, Client, CryptoProvider, ExtensionList, IdentityProvider, @@ -408,11 +408,10 @@ fn main() -> Result<(), CustomError> { alice_tablet_group.apply_pending_commit()?; - let ReceivedMessage::CommitProcessor(mut processor) = - alice_pc_group.process_incoming_message(commit.commit_message)? - else { - return Err(CustomError); - }; + let mut processor = alice_pc_group + .process_incoming_message(commit.commit_message)? + .into_processor() + .ok_or(CustomError)?; handle_custom_proposals(&processor.context().clone(), processor.proposals_mut())?; processor.process()?; diff --git a/mls-rs/src/client.rs b/mls-rs/src/client.rs index b5c60cf8..3b49c5da 100644 --- a/mls-rs/src/client.rs +++ b/mls-rs/src/client.rs @@ -958,7 +958,7 @@ mod tests { .unwrap(); assert_matches!( - message, + message.into_received_message().unwrap(), ReceivedMessage::Proposal(ProposalMessageDescription { proposal: Proposal::Add(p), ..} ) if p.key_package.leaf_node.signing_identity == bob_identity diff --git a/mls-rs/src/external_client/group.rs b/mls-rs/src/external_client/group.rs index 322ec817..9ef45f35 100644 --- a/mls-rs/src/external_client/group.rs +++ b/mls-rs/src/external_client/group.rs @@ -2,7 +2,7 @@ // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) -use core::fmt::{self, Debug}; +use core::fmt::Debug; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_core::{ @@ -85,9 +85,9 @@ use alloc::boxed::Box; /// The result of processing an [ExternalGroup](ExternalGroup) message using /// [process_incoming_message](ExternalGroup::process_incoming_message) -#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))] #[allow(clippy::large_enum_variant)] -pub enum ExternalReceivedMessage<'a, C: ExternalClientConfig> { +#[derive(Clone, Debug)] +pub enum ExternalReceivedMessage { /// A new commit was processed creating a new group state. Commit(CommitMessageDescription), /// A proposal was received. @@ -100,30 +100,12 @@ pub enum ExternalReceivedMessage<'a, C: ExternalClientConfig> { Welcome, /// Validated key package KeyPackage(KeyPackage), - /// A new commit can be processed to create a new group state. - CommitProcessor(ExternalCommitProcessor<'a, C>), } -impl Debug for ExternalReceivedMessage<'_, C> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ExternalReceivedMessage::Commit(value) => f.write_str(&format!("Commit({value:?})")), - ExternalReceivedMessage::Proposal(value) => { - f.write_str(&format!("Proposal({value:?})")) - } - ExternalReceivedMessage::Ciphertext(value) => { - f.write_str(&format!("Ciphertext({value:?})")) - } - ExternalReceivedMessage::GroupInfo(value) => { - f.write_str(&format!("GroupInfo({value:?})")) - } - ExternalReceivedMessage::KeyPackage(value) => { - f.write_str(&format!("KeyPackage({value:?})")) - } - ExternalReceivedMessage::Welcome => f.write_str("Welcome"), - ExternalReceivedMessage::CommitProcessor(_) => f.write_str("CommitProcessor"), - } - } +pub enum ExternalReceivedMessageOrProcessor<'a, C: ExternalClientConfig> { + ReceivedMessage(ExternalReceivedMessage), + /// A new commit can be processed to create a new group state. + CommitProcessor(ExternalCommitProcessor<'a, C>), } /// A handle to an observed group that can track plaintext control messages @@ -214,7 +196,7 @@ impl ExternalGroup { pub async fn process_incoming_message( &mut self, message: MlsMessage, - ) -> Result, MlsError> { + ) -> Result, MlsError> { MessageProcessor::process_incoming_message( self, message, @@ -228,7 +210,7 @@ impl ExternalGroup { pub async fn process_incoming_message_oneshot( &mut self, message: MlsMessage, - ) -> Result, MlsError> { + ) -> Result { let received_message = MessageProcessor::process_incoming_message( self, message, @@ -238,10 +220,10 @@ impl ExternalGroup { .await?; match received_message { - ExternalReceivedMessage::CommitProcessor(p) => { + ExternalReceivedMessageOrProcessor::CommitProcessor(p) => { p.process().await.map(ExternalReceivedMessage::Commit) } - other => Ok(other), + ExternalReceivedMessageOrProcessor::ReceivedMessage(m) => Ok(m), } } @@ -630,7 +612,7 @@ where { type MlsRules = C::MlsRules; type IdentityProvider = C::IdentityProvider; - type OutputType = ExternalReceivedMessage<'a, C>; + type OutputType = ExternalReceivedMessageOrProcessor<'a, C>; type CipherSuiteProvider = ::CipherSuiteProvider; #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] @@ -654,9 +636,11 @@ where &mut self, cipher_text: &PrivateMessage, ) -> Result, MlsError> { - Ok(EventOrContent::Event(ExternalReceivedMessage::Ciphertext( - cipher_text.content_type, - ))) + Ok(EventOrContent::Event( + ExternalReceivedMessageOrProcessor::ReceivedMessage( + ExternalReceivedMessage::Ciphertext(cipher_text.content_type), + ), + )) } async fn update_key_schedule( @@ -791,8 +775,24 @@ where } } +impl<'a, C: ExternalClientConfig> ExternalReceivedMessageOrProcessor<'a, C> { + pub fn into_received_message(self) -> Option { + match self { + ExternalReceivedMessageOrProcessor::ReceivedMessage(m) => Some(m), + _ => None, + } + } + + pub fn into_processor(self) -> Option> { + match self { + ExternalReceivedMessageOrProcessor::CommitProcessor(c) => Some(c), + _ => None, + } + } +} + impl TryFrom - for ExternalReceivedMessage<'_, C> + for ExternalReceivedMessageOrProcessor<'_, C> { type Error = MlsError; @@ -802,38 +802,46 @@ impl TryFrom } impl<'a, C: ExternalClientConfig> From>> - for ExternalReceivedMessage<'a, C> + for ExternalReceivedMessageOrProcessor<'a, C> { fn from(value: InternalCommitProcessor<'a, ExternalGroup>) -> Self { - ExternalReceivedMessage::CommitProcessor(ExternalCommitProcessor(value)) + ExternalReceivedMessageOrProcessor::CommitProcessor(ExternalCommitProcessor(value)) } } -impl From for ExternalReceivedMessage<'_, C> { +impl<'a, C: ExternalClientConfig, T: Into> From + for ExternalReceivedMessageOrProcessor<'a, C> +{ + fn from(value: T) -> Self { + Self::ReceivedMessage(value.into()) + } +} + +impl From for ExternalReceivedMessage { fn from(value: CommitMessageDescription) -> Self { ExternalReceivedMessage::Commit(value) } } -impl From for ExternalReceivedMessage<'_, C> { +impl From for ExternalReceivedMessage { fn from(value: ProposalMessageDescription) -> Self { ExternalReceivedMessage::Proposal(value) } } -impl From for ExternalReceivedMessage<'_, C> { +impl From for ExternalReceivedMessage { fn from(value: GroupInfo) -> Self { ExternalReceivedMessage::GroupInfo(value) } } -impl From for ExternalReceivedMessage<'_, C> { +impl From for ExternalReceivedMessage { fn from(_: Welcome) -> Self { ExternalReceivedMessage::Welcome } } -impl From for ExternalReceivedMessage<'_, C> { +impl From for ExternalReceivedMessage { fn from(value: KeyPackage) -> Self { ExternalReceivedMessage::KeyPackage(value) } @@ -843,6 +851,15 @@ pub struct ExternalCommitProcessor<'a, C: ExternalClientConfig>( InternalCommitProcessor<'a, ExternalGroup>, ); +impl<'a, C: ExternalClientConfig> Debug for ExternalReceivedMessageOrProcessor<'a, C> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::ReceivedMessage(m) => f.write_str(&format!("ExternalReceivedMessage({m:?})")), + Self::CommitProcessor(_) => f.write_str("CommitProcessor"), + } + } +} + impl ExternalCommitProcessor<'_, C> { #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn process(self) -> Result { diff --git a/mls-rs/src/group/commit/processor.rs b/mls-rs/src/group/commit/processor.rs index b7866d81..46ed7b55 100644 --- a/mls-rs/src/group/commit/processor.rs +++ b/mls-rs/src/group/commit/processor.rs @@ -2,6 +2,8 @@ // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) +use core::fmt::Debug; + use alloc::boxed::Box; use alloc::vec::Vec; @@ -359,7 +361,7 @@ mod tests { use crate::{ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION}, crypto::test_utils::TestCryptoProvider, - group::{ReceivedMessage, Sender}, + group::Sender, mls_rules::{CommitSource, ProposalInfo}, test_utils::get_test_groups, }; @@ -388,11 +390,12 @@ mod tests { let member_0 = groups[0].roster().member_with_index(0).unwrap(); - let ReceivedMessage::CommitProcessor(processor) = - groups[1].process_incoming_message(commit).await.unwrap() - else { - panic!("expected commit processor") - }; + let processor = groups[1] + .process_incoming_message(commit) + .await + .unwrap() + .into_processor() + .unwrap(); assert_eq!( &processor.proposals().removals, diff --git a/mls-rs/src/group/interop_test_vectors/passive_client.rs b/mls-rs/src/group/interop_test_vectors/passive_client.rs index 9ade40e4..9abc217e 100644 --- a/mls-rs/src/group/interop_test_vectors/passive_client.rs +++ b/mls-rs/src/group/interop_test_vectors/passive_client.rs @@ -20,7 +20,7 @@ use rand::{seq::IteratorRandom, Rng, SeedableRng}; use crate::{ client_builder::{ClientBuilder, MlsConfig}, crypto::test_utils::TestCryptoProvider, - group::{ClientConfig, CommitBuilder, ExportedTree, ReceivedMessage}, + group::{ClientConfig, CommitBuilder, ExportedTree}, identity::basic::BasicIdentityProvider, mls_rules::CommitOptions, test_utils::{ @@ -214,11 +214,12 @@ async fn interop_passive_client() { let group_clone = group.clone(); - let ReceivedMessage::CommitProcessor(mut processor) = - group.process_incoming_message(message).await.unwrap() - else { - panic!("expected commit") - }; + let mut processor = group + .process_incoming_message(message) + .await + .unwrap() + .into_processor() + .unwrap(); processor = test_case .external_psks diff --git a/mls-rs/src/group/message_processor.rs b/mls-rs/src/group/message_processor.rs index ff4df7bf..0e98f438 100644 --- a/mls-rs/src/group/message_processor.rs +++ b/mls-rs/src/group/message_processor.rs @@ -36,7 +36,6 @@ use itertools::Itertools; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use alloc::boxed::Box; -use alloc::format; use alloc::vec::Vec; use core::fmt::{self, Debug}; use mls_rs_core::{ @@ -199,7 +198,8 @@ impl MlsDecode for CommitEffect { #[allow(clippy::large_enum_variant)] /// An event generated as a result of processing a message for a group with /// [`Group::process_incoming_message`](crate::group::Group::process_incoming_message). -pub enum ReceivedMessage<'a, C: ClientConfig> { +#[derive(Clone, Debug)] +pub enum ReceivedMessage { /// An application message was decrypted. ApplicationMessage(ApplicationMessageDescription), /// A new commit was processed creating a new group state. @@ -212,65 +212,91 @@ pub enum ReceivedMessage<'a, C: ClientConfig> { Welcome, /// Validated key package KeyPackage(KeyPackage), +} + +pub enum ReceivedMessageOrProcessor<'a, C: ClientConfig> { + /// An application message was decrypted. + ReceivedMessage(ReceivedMessage), /// A new commit can be processed to create a new group state. CommitProcessor(CommitProcessor<'a, C>), } -impl Debug for ReceivedMessage<'_, C> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl<'a, C: ClientConfig> Debug for ReceivedMessageOrProcessor<'a, C> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { - ReceivedMessage::ApplicationMessage(value) => { - f.write_str(&format!("ApplicationMessage({value:?})")) - } - ReceivedMessage::Commit(value) => f.write_str(&format!("Commit({value:?})")), - ReceivedMessage::Proposal(value) => f.write_str(&format!("Proposal({value:?})")), - ReceivedMessage::GroupInfo(value) => f.write_str(&format!("GroupInfo({value:?})")), - ReceivedMessage::KeyPackage(value) => f.write_str(&format!("KeyPackage({value:?})")), - ReceivedMessage::Welcome => f.write_str("Welcome"), - ReceivedMessage::CommitProcessor(_) => f.write_str("CommitProcessor"), + Self::ReceivedMessage(m) => f.write_str(&format!("ReceivedMessage({m:?})")), + Self::CommitProcessor(_) => f.write_str("CommitProcessor"), } } } -impl TryFrom for ReceivedMessage<'_, C> { +impl<'a, C: ClientConfig> ReceivedMessageOrProcessor<'a, C> { + pub fn into_received_message(self) -> Option { + match self { + ReceivedMessageOrProcessor::ReceivedMessage(m) => Some(m), + _ => None, + } + } + + pub fn into_processor(self) -> Option> { + match self { + ReceivedMessageOrProcessor::CommitProcessor(c) => Some(c), + _ => None, + } + } +} + +impl<'a, C: ClientConfig, T: Into> From for ReceivedMessageOrProcessor<'a, C> { + fn from(value: T) -> Self { + Self::ReceivedMessage(value.into()) + } +} + +impl<'a, C: ClientConfig> TryFrom + for ReceivedMessageOrProcessor<'a, C> +{ type Error = MlsError; fn try_from(value: ApplicationMessageDescription) -> Result { - Ok(ReceivedMessage::ApplicationMessage(value)) + Ok(ReceivedMessageOrProcessor::ReceivedMessage( + ReceivedMessage::ApplicationMessage(value), + )) } } -impl<'a, C: ClientConfig> From>> for ReceivedMessage<'a, C> { +impl<'a, C: ClientConfig> From>> + for ReceivedMessageOrProcessor<'a, C> +{ fn from(value: InternalCommitProcessor<'a, Group>) -> Self { - ReceivedMessage::CommitProcessor(CommitProcessor(value)) + ReceivedMessageOrProcessor::CommitProcessor(CommitProcessor(value)) } } -impl From for ReceivedMessage<'_, C> { +impl From for ReceivedMessage { fn from(value: CommitMessageDescription) -> Self { ReceivedMessage::Commit(value) } } -impl From for ReceivedMessage<'_, C> { +impl From for ReceivedMessage { fn from(value: ProposalMessageDescription) -> Self { ReceivedMessage::Proposal(value) } } -impl From for ReceivedMessage<'_, C> { +impl From for ReceivedMessage { fn from(value: GroupInfo) -> Self { ReceivedMessage::GroupInfo(value) } } -impl From for ReceivedMessage<'_, C> { +impl From for ReceivedMessage { fn from(_: Welcome) -> Self { ReceivedMessage::Welcome } } -impl From for ReceivedMessage<'_, C> { +impl From for ReceivedMessage { fn from(value: KeyPackage) -> Self { ReceivedMessage::KeyPackage(value) } diff --git a/mls-rs/src/group/mod.rs b/mls-rs/src/group/mod.rs index 79aaf741..e7dccafe 100644 --- a/mls-rs/src/group/mod.rs +++ b/mls-rs/src/group/mod.rs @@ -5,6 +5,7 @@ use alloc::vec; use alloc::vec::Vec; use core::fmt::{self, Debug}; +use message_processor::ReceivedMessageOrProcessor; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_core::error::IntoAnyError; use mls_rs_core::identity::MemberValidationContext; @@ -1168,14 +1169,16 @@ where pub async fn process_incoming_message( &mut self, message: MlsMessage, - ) -> Result, MlsError> { + ) -> Result, MlsError> { if let Some(pending) = self.pending_commit.commit_hash()? { let message_hash = MessageHash::compute(&self.cipher_suite_provider, &message).await?; if message_hash == pending { let message_description = self.apply_pending_commit().await?; - return Ok(ReceivedMessage::Commit(message_description)); + return Ok(ReceivedMessageOrProcessor::ReceivedMessage( + ReceivedMessage::Commit(message_description), + )); } } @@ -1188,7 +1191,9 @@ where .await?; if let Some(cached) = cached_own_proposal { - return Ok(ReceivedMessage::Proposal(cached)); + return Ok(ReceivedMessageOrProcessor::ReceivedMessage( + ReceivedMessage::Proposal(cached), + )); } } @@ -1206,12 +1211,14 @@ where pub async fn process_incoming_message_oneshot( &mut self, message: MlsMessage, - ) -> Result, MlsError> { + ) -> Result { let received_message = self.process_incoming_message(message).await?; match received_message { - ReceivedMessage::CommitProcessor(p) => p.process().await.map(ReceivedMessage::Commit), - other => Ok(other), + ReceivedMessageOrProcessor::CommitProcessor(p) => { + p.process().await.map(ReceivedMessage::Commit) + } + ReceivedMessageOrProcessor::ReceivedMessage(m) => Ok(m), } } @@ -1550,7 +1557,7 @@ where { type MlsRules = C::MlsRules; type IdentityProvider = C::IdentityProvider; - type OutputType = ReceivedMessage<'a, C>; + type OutputType = ReceivedMessageOrProcessor<'a, C>; type CipherSuiteProvider = ::CipherSuiteProvider; #[cfg(feature = "private_message")] @@ -2798,7 +2805,7 @@ mod tests { let received_by_alice = alice_group.process_incoming_message(msg).await.unwrap(); assert_matches!( - received_by_alice, + received_by_alice.into_received_message().unwrap(), ReceivedMessage::ApplicationMessage(ApplicationMessageDescription { sender_index, .. }) if sender_index == bob_group.current_member_index() ); @@ -2832,7 +2839,7 @@ mod tests { .unwrap(); assert_matches!( - received_message, + received_message.into_received_message().unwrap(), ReceivedMessage::ApplicationMessage(m) if m.data() == b"foobar" ); @@ -3538,13 +3545,15 @@ mod tests { groups[1].process_incoming_message(proposal).await.unwrap(); - let received_message = groups[1].process_incoming_message(commit).await.unwrap(); - - let ReceivedMessage::CommitProcessor(p) = received_message else { - panic!("expected commit processor"); - }; - - let res = p.time_sent(future_time).process().await; + let res = groups[1] + .process_incoming_message(commit) + .await + .unwrap() + .into_processor() + .unwrap() + .time_sent(future_time) + .process() + .await; assert_matches!(res, Err(MlsError::InvalidLifetime)); } @@ -3612,7 +3621,12 @@ mod tests { .await .unwrap(); - let recv_prop = bob.process_incoming_message(proposal).await.unwrap(); + let recv_prop = bob + .process_incoming_message(proposal) + .await + .unwrap() + .into_received_message() + .unwrap(); assert_matches!(recv_prop, ReceivedMessage::Proposal(ProposalMessageDescription { proposal: Proposal::Custom(c), ..}) if c == custom_proposal); @@ -3622,7 +3636,12 @@ mod tests { let ReceivedMessage::Commit(CommitMessageDescription { effect: CommitEffect::NewEpoch(new_epoch), .. - }) = bob.process_incoming_message(commit).await.unwrap() + }) = bob + .process_incoming_message(commit) + .await + .unwrap() + .into_received_message() + .unwrap() else { panic!("unexpected commit effect"); }; @@ -4089,6 +4108,8 @@ mod tests { let update = group .process_incoming_message(commit.commit_message) .await + .unwrap() + .into_received_message() .unwrap(); let ReceivedMessage::Commit(update) = update else { diff --git a/mls-rs/src/group/test_utils.rs b/mls-rs/src/group/test_utils.rs index 8e90aaaa..b7a099d0 100644 --- a/mls-rs/src/group/test_utils.rs +++ b/mls-rs/src/group/test_utils.rs @@ -133,7 +133,7 @@ impl TestGroup { pub(crate) async fn process_message( &mut self, message: MlsMessage, - ) -> Result, MlsError> { + ) -> Result { self.process_incoming_message_oneshot(message).await } @@ -160,10 +160,11 @@ impl TestGroup { &mut self, commit: MlsMessage, ) -> CommitProcessor<'_, TestClientConfig> { - match self.process_incoming_message(commit).await.unwrap() { - ReceivedMessage::CommitProcessor(p) => p, - _ => panic!("expected commit"), - } + self.process_incoming_message(commit) + .await + .unwrap() + .into_processor() + .unwrap() } } @@ -504,7 +505,7 @@ impl<'a> MessageProcessor<'a> for GroupWithoutKeySchedule { } impl<'a> From> - for ReceivedMessage<'a, TestClientConfig> + for ReceivedMessageOrProcessor<'a, TestClientConfig> { fn from(value: InternalCommitProcessor<'a, GroupWithoutKeySchedule>) -> Self { let value = InternalCommitProcessor { @@ -521,6 +522,6 @@ impl<'a> From> psks: value.psks, }; - ReceivedMessage::CommitProcessor(CommitProcessor(value)) + ReceivedMessageOrProcessor::CommitProcessor(CommitProcessor(value)) } } diff --git a/mls-rs/test_harness_integration/src/main.rs b/mls-rs/test_harness_integration/src/main.rs index e0f0dc4d..c483cf28 100644 --- a/mls-rs/test_harness_integration/src/main.rs +++ b/mls-rs/test_harness_integration/src/main.rs @@ -507,10 +507,11 @@ impl MlsClient for MlsClientImpl { .as_mut() .ok_or_else(|| Status::aborted("no group with such index."))? .process_incoming_message(ciphertext) - .map_err(abort)?; + .map_err(abort)? + .into_received_message(); let app_msg = match message { - ReceivedMessage::ApplicationMessage(app_msg) => app_msg, + Some(ReceivedMessage::ApplicationMessage(app_msg)) => app_msg, _ => return Err(Status::aborted("message type is not application data.")), }; @@ -914,11 +915,11 @@ impl MlsClientImpl { #[cfg(feature = "psk")] let group_clone = group.clone(); - let ReceivedMessage::CommitProcessor(processor) = - group.process_incoming_message(commit).map_err(abort)? - else { - return Err(Status::aborted("expected commit message")); - }; + let processor = group + .process_incoming_message(commit) + .map_err(abort)? + .into_processor() + .ok_or(Status::aborted("expected commit message"))?; #[cfg(feature = "psk")] let required_external_psks = processor diff --git a/mls-rs/tests/client_tests.rs b/mls-rs/tests/client_tests.rs index d79711cd..025a4c26 100644 --- a/mls-rs/tests/client_tests.rs +++ b/mls-rs/tests/client_tests.rs @@ -376,6 +376,8 @@ async fn test_application_messages( let decrypted = g .process_incoming_message(ciphertext.clone()) .await + .unwrap() + .into_received_message() .unwrap(); assert_matches!(decrypted, ReceivedMessage::ApplicationMessage(m) if m.data() == test_message); @@ -435,6 +437,8 @@ async fn test_out_of_order_application_messages() { let res = bob_group .process_incoming_message(ciphertexts[i].clone()) .await + .unwrap() + .into_received_message() .unwrap(); assert_matches!( @@ -467,12 +471,9 @@ async fn processing_message_from_self_returns_error( .await .unwrap(); - let error = creator_group - .process_incoming_message(msg) - .await - .unwrap_err(); + let res = creator_group.process_incoming_message(msg).await; - assert_matches!(error, MlsError::CantProcessMessageFromSelf); + assert_matches!(res, Err(MlsError::CantProcessMessageFromSelf)); } #[cfg(feature = "private_message")] @@ -542,6 +543,8 @@ async fn external_commits_work( let processed = group .process_incoming_message(message.clone()) .await + .unwrap() + .into_received_message() .unwrap(); if let ReceivedMessage::Proposal(p) = &processed {