Skip to content

Commit

Permalink
Remove lifetime from ReceivedMessage
Browse files Browse the repository at this point in the history
  • Loading branch information
Marta Mularczyk committed Jan 17, 2025
1 parent 8571ef5 commit ee6fddc
Show file tree
Hide file tree
Showing 13 changed files with 212 additions and 126 deletions.
5 changes: 4 additions & 1 deletion mls-rs/examples/api_1x.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ fn main() -> Result<(), MlsError> {

// Alice and bob can chat
let msg = alice_group.encrypt_application_message(b"hello world", Default::default())?;
let msg = bob_group.process_incoming_message(msg)?;

let msg = bob_group
.process_incoming_message(msg)?
.into_received_message();

println!("Received message: {:?}", msg);

Expand Down
19 changes: 14 additions & 5 deletions mls-rs/examples/basic_server_usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,12 @@ impl BasicServer {
let mut group = server.load_group(group_state)?;

let proposal_msg = MlsMessage::from_bytes(&proposal)?;
let res = group.process_incoming_message(proposal_msg)?;

let ExternalReceivedMessage::Proposal(proposal_desc) = res else {
let res = group
.process_incoming_message(proposal_msg)?
.into_received_message();

let Some(ExternalReceivedMessage::Proposal(proposal_desc)) = res else {
panic!("expected proposal message!")
};

Expand All @@ -86,9 +89,12 @@ impl BasicServer {
}

let commit_msg = MlsMessage::from_bytes(&commit)?;
let res = group.process_incoming_message(commit_msg)?;

let ExternalReceivedMessage::Commit(_commit_desc) = res else {
let res = group
.process_incoming_message(commit_msg)?
.into_received_message();

let Some(ExternalReceivedMessage::Commit(_commit_desc)) = res else {
panic!("expected commit message!")
};

Expand Down Expand Up @@ -184,7 +190,10 @@ fn main() -> Result<(), MlsError> {
// Bob downloads the commit
let message = server.download_messages(1).first().unwrap();

let res = bob_group.process_incoming_message(MlsMessage::from_bytes(message)?)?;
let res = bob_group
.process_incoming_message(MlsMessage::from_bytes(message)?)?
.into_received_message()
.unwrap();

let ReceivedMessage::Commit(commit_desc) = res else {
panic!("expected commit message")
Expand Down
4 changes: 3 additions & 1 deletion mls-rs/examples/basic_usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ fn main() -> Result<(), MlsError> {
let msg = alice_group.encrypt_application_message(b"hello world", Default::default())?;

// Bob decrypts the application message from Alice.
let msg = bob_group.process_incoming_message(msg)?;
let msg = bob_group
.process_incoming_message(msg)?
.into_received_message();

println!("Received message: {:?}", msg);

Expand Down
11 changes: 5 additions & 6 deletions mls-rs/examples/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use mls_rs::{
error::MlsError,
group::{
proposal::{MlsCustomProposal, Proposal},
GroupContext, ReceivedMessage, Sender,
GroupContext, Sender,
},
mls_rules::{ProposalBundle, ProposalSource},
CipherSuite, CipherSuiteProvider, Client, CryptoProvider, ExtensionList, IdentityProvider,
Expand Down Expand Up @@ -408,11 +408,10 @@ fn main() -> Result<(), CustomError> {

alice_tablet_group.apply_pending_commit()?;

let ReceivedMessage::CommitProcessor(mut processor) =
alice_pc_group.process_incoming_message(commit.commit_message)?
else {
return Err(CustomError);
};
let mut processor = alice_pc_group
.process_incoming_message(commit.commit_message)?
.into_processor()
.ok_or(CustomError)?;

handle_custom_proposals(&processor.context().clone(), processor.proposals_mut())?;
processor.process()?;
Expand Down
2 changes: 1 addition & 1 deletion mls-rs/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ mod tests {
.unwrap();

assert_matches!(
message,
message.into_received_message().unwrap(),
ReceivedMessage::Proposal(ProposalMessageDescription {
proposal: Proposal::Add(p), ..}
) if p.key_package.leaf_node.signing_identity == bob_identity
Expand Down
99 changes: 58 additions & 41 deletions mls-rs/src/external_client/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use core::fmt::{self, Debug};
use core::fmt::Debug;

use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use mls_rs_core::{
Expand Down Expand Up @@ -85,9 +85,9 @@ use alloc::boxed::Box;

/// The result of processing an [ExternalGroup](ExternalGroup) message using
/// [process_incoming_message](ExternalGroup::process_incoming_message)
#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))]
#[allow(clippy::large_enum_variant)]
pub enum ExternalReceivedMessage<'a, C: ExternalClientConfig> {
#[derive(Clone, Debug)]
pub enum ExternalReceivedMessage {
/// A new commit was processed creating a new group state.
Commit(CommitMessageDescription),
/// A proposal was received.
Expand All @@ -100,30 +100,12 @@ pub enum ExternalReceivedMessage<'a, C: ExternalClientConfig> {
Welcome,
/// Validated key package
KeyPackage(KeyPackage),
/// A new commit can be processed to create a new group state.
CommitProcessor(ExternalCommitProcessor<'a, C>),
}

impl<C: ExternalClientConfig> Debug for ExternalReceivedMessage<'_, C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ExternalReceivedMessage::Commit(value) => f.write_str(&format!("Commit({value:?})")),
ExternalReceivedMessage::Proposal(value) => {
f.write_str(&format!("Proposal({value:?})"))
}
ExternalReceivedMessage::Ciphertext(value) => {
f.write_str(&format!("Ciphertext({value:?})"))
}
ExternalReceivedMessage::GroupInfo(value) => {
f.write_str(&format!("GroupInfo({value:?})"))
}
ExternalReceivedMessage::KeyPackage(value) => {
f.write_str(&format!("KeyPackage({value:?})"))
}
ExternalReceivedMessage::Welcome => f.write_str("Welcome"),
ExternalReceivedMessage::CommitProcessor(_) => f.write_str("CommitProcessor"),
}
}
pub enum ExternalReceivedMessageOrProcessor<'a, C: ExternalClientConfig> {
ReceivedMessage(ExternalReceivedMessage),
/// A new commit can be processed to create a new group state.
CommitProcessor(ExternalCommitProcessor<'a, C>),
}

/// A handle to an observed group that can track plaintext control messages
Expand Down Expand Up @@ -214,7 +196,7 @@ impl<C: ExternalClientConfig + Clone> ExternalGroup<C> {
pub async fn process_incoming_message(
&mut self,
message: MlsMessage,
) -> Result<ExternalReceivedMessage<'_, C>, MlsError> {
) -> Result<ExternalReceivedMessageOrProcessor<'_, C>, MlsError> {
MessageProcessor::process_incoming_message(
self,
message,
Expand All @@ -228,7 +210,7 @@ impl<C: ExternalClientConfig + Clone> ExternalGroup<C> {
pub async fn process_incoming_message_oneshot(
&mut self,
message: MlsMessage,
) -> Result<ExternalReceivedMessage<'_, C>, MlsError> {
) -> Result<ExternalReceivedMessage, MlsError> {
let received_message = MessageProcessor::process_incoming_message(
self,
message,
Expand All @@ -238,10 +220,10 @@ impl<C: ExternalClientConfig + Clone> ExternalGroup<C> {
.await?;

match received_message {
ExternalReceivedMessage::CommitProcessor(p) => {
ExternalReceivedMessageOrProcessor::CommitProcessor(p) => {
p.process().await.map(ExternalReceivedMessage::Commit)
}
other => Ok(other),
ExternalReceivedMessageOrProcessor::ReceivedMessage(m) => Ok(m),
}
}

Expand Down Expand Up @@ -630,7 +612,7 @@ where
{
type MlsRules = C::MlsRules;
type IdentityProvider = C::IdentityProvider;
type OutputType = ExternalReceivedMessage<'a, C>;
type OutputType = ExternalReceivedMessageOrProcessor<'a, C>;
type CipherSuiteProvider = <C::CryptoProvider as CryptoProvider>::CipherSuiteProvider;

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
Expand All @@ -654,9 +636,11 @@ where
&mut self,
cipher_text: &PrivateMessage,
) -> Result<EventOrContent<Self::OutputType>, MlsError> {
Ok(EventOrContent::Event(ExternalReceivedMessage::Ciphertext(
cipher_text.content_type,
)))
Ok(EventOrContent::Event(
ExternalReceivedMessageOrProcessor::ReceivedMessage(
ExternalReceivedMessage::Ciphertext(cipher_text.content_type),
),
))
}

async fn update_key_schedule(
Expand Down Expand Up @@ -791,8 +775,24 @@ where
}
}

impl<'a, C: ExternalClientConfig> ExternalReceivedMessageOrProcessor<'a, C> {
pub fn into_received_message(self) -> Option<ExternalReceivedMessage> {
match self {
ExternalReceivedMessageOrProcessor::ReceivedMessage(m) => Some(m),
_ => None,
}
}

pub fn into_processor(self) -> Option<ExternalCommitProcessor<'a, C>> {
match self {
ExternalReceivedMessageOrProcessor::CommitProcessor(c) => Some(c),
_ => None,
}
}
}

impl<C: ExternalClientConfig> TryFrom<ApplicationMessageDescription>
for ExternalReceivedMessage<'_, C>
for ExternalReceivedMessageOrProcessor<'_, C>
{
type Error = MlsError;

Expand All @@ -802,38 +802,46 @@ impl<C: ExternalClientConfig> TryFrom<ApplicationMessageDescription>
}

impl<'a, C: ExternalClientConfig> From<InternalCommitProcessor<'a, ExternalGroup<C>>>
for ExternalReceivedMessage<'a, C>
for ExternalReceivedMessageOrProcessor<'a, C>
{
fn from(value: InternalCommitProcessor<'a, ExternalGroup<C>>) -> Self {
ExternalReceivedMessage::CommitProcessor(ExternalCommitProcessor(value))
ExternalReceivedMessageOrProcessor::CommitProcessor(ExternalCommitProcessor(value))
}
}

impl<C: ExternalClientConfig> From<CommitMessageDescription> for ExternalReceivedMessage<'_, C> {
impl<'a, C: ExternalClientConfig, T: Into<ExternalReceivedMessage>> From<T>
for ExternalReceivedMessageOrProcessor<'a, C>
{
fn from(value: T) -> Self {
Self::ReceivedMessage(value.into())
}
}

impl From<CommitMessageDescription> for ExternalReceivedMessage {
fn from(value: CommitMessageDescription) -> Self {
ExternalReceivedMessage::Commit(value)
}
}

impl<C: ExternalClientConfig> From<ProposalMessageDescription> for ExternalReceivedMessage<'_, C> {
impl From<ProposalMessageDescription> for ExternalReceivedMessage {
fn from(value: ProposalMessageDescription) -> Self {
ExternalReceivedMessage::Proposal(value)
}
}

impl<C: ExternalClientConfig> From<GroupInfo> for ExternalReceivedMessage<'_, C> {
impl From<GroupInfo> for ExternalReceivedMessage {
fn from(value: GroupInfo) -> Self {
ExternalReceivedMessage::GroupInfo(value)
}
}

impl<C: ExternalClientConfig> From<Welcome> for ExternalReceivedMessage<'_, C> {
impl From<Welcome> for ExternalReceivedMessage {
fn from(_: Welcome) -> Self {
ExternalReceivedMessage::Welcome
}
}

impl<C: ExternalClientConfig> From<KeyPackage> for ExternalReceivedMessage<'_, C> {
impl From<KeyPackage> for ExternalReceivedMessage {
fn from(value: KeyPackage) -> Self {
ExternalReceivedMessage::KeyPackage(value)
}
Expand All @@ -843,6 +851,15 @@ pub struct ExternalCommitProcessor<'a, C: ExternalClientConfig>(
InternalCommitProcessor<'a, ExternalGroup<C>>,
);

impl<'a, C: ExternalClientConfig> Debug for ExternalReceivedMessageOrProcessor<'a, C> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::ReceivedMessage(m) => f.write_str(&format!("ExternalReceivedMessage({m:?})")),
Self::CommitProcessor(_) => f.write_str("CommitProcessor"),
}
}
}

impl<C: ExternalClientConfig> ExternalCommitProcessor<'_, C> {
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn process(self) -> Result<CommitMessageDescription, MlsError> {
Expand Down
15 changes: 9 additions & 6 deletions mls-rs/src/group/commit/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use core::fmt::Debug;

use alloc::boxed::Box;
use alloc::vec::Vec;

Expand Down Expand Up @@ -359,7 +361,7 @@ mod tests {
use crate::{
client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
crypto::test_utils::TestCryptoProvider,
group::{ReceivedMessage, Sender},
group::Sender,
mls_rules::{CommitSource, ProposalInfo},
test_utils::get_test_groups,
};
Expand Down Expand Up @@ -388,11 +390,12 @@ mod tests {

let member_0 = groups[0].roster().member_with_index(0).unwrap();

let ReceivedMessage::CommitProcessor(processor) =
groups[1].process_incoming_message(commit).await.unwrap()
else {
panic!("expected commit processor")
};
let processor = groups[1]
.process_incoming_message(commit)
.await
.unwrap()
.into_processor()
.unwrap();

assert_eq!(
&processor.proposals().removals,
Expand Down
13 changes: 7 additions & 6 deletions mls-rs/src/group/interop_test_vectors/passive_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use rand::{seq::IteratorRandom, Rng, SeedableRng};
use crate::{
client_builder::{ClientBuilder, MlsConfig},
crypto::test_utils::TestCryptoProvider,
group::{ClientConfig, CommitBuilder, ExportedTree, ReceivedMessage},
group::{ClientConfig, CommitBuilder, ExportedTree},
identity::basic::BasicIdentityProvider,
mls_rules::CommitOptions,
test_utils::{
Expand Down Expand Up @@ -214,11 +214,12 @@ async fn interop_passive_client() {

let group_clone = group.clone();

let ReceivedMessage::CommitProcessor(mut processor) =
group.process_incoming_message(message).await.unwrap()
else {
panic!("expected commit")
};
let mut processor = group
.process_incoming_message(message)
.await
.unwrap()
.into_processor()
.unwrap();

processor = test_case
.external_psks
Expand Down
Loading

0 comments on commit ee6fddc

Please sign in to comment.