Skip to content

Commit

Permalink
Rework based on Cow to avoid unnecessary cloning of the tree
Browse files Browse the repository at this point in the history
  • Loading branch information
tomleavy committed Dec 26, 2023
1 parent 0ca5d20 commit f5c2bae
Show file tree
Hide file tree
Showing 17 changed files with 86 additions and 41 deletions.
2 changes: 1 addition & 1 deletion mls-rs-codec/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "mls-rs-codec"
version = "0.5.0"
version = "0.5.1"
edition = "2021"
description = "TLS codec and MLS specific encoding used by mls-rs"
homepage = "https://github.com/awslabs/mls-rs"
Expand Down
31 changes: 31 additions & 0 deletions mls-rs-codec/src/cow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use alloc::borrow::Cow;

use crate::{Error, MlsDecode, MlsEncode, MlsSize};

impl<'a, T> MlsSize for Cow<'a, T>
where
T: MlsSize + ToOwned,
{
fn mls_encoded_len(&self) -> usize {
self.as_ref().mls_encoded_len()
}
}

impl<'a, T> MlsEncode for Cow<'a, T>
where
T: MlsEncode + ToOwned,
{
#[inline]
fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), Error> {
self.as_ref().mls_encode(writer)
}
}

impl<'a, T> MlsDecode for Cow<'a, T>
where
T: MlsDecode + ToOwned<Owned = T>,
{
fn mls_decode(reader: &mut &[u8]) -> Result<Self, Error> {
T::mls_decode(reader).map(Cow::Owned)
}
}
1 change: 1 addition & 0 deletions mls-rs-codec/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod byte_vec;

pub mod iter;

mod cow;
mod map;
mod option;
mod stdint;
Expand Down
4 changes: 2 additions & 2 deletions mls-rs/examples/basic_server_usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl BasicServer {
fn create_group(group_info: &[u8], tree: &[u8]) -> Result<Self, MlsError> {
let server = make_server();
let group_info = MlsMessage::from_bytes(group_info)?;
let tree = ExportedTree::from_bytes(tree).unwrap();
let tree = ExportedTree::from_bytes(tree)?;

let group = server.observe_group(group_info, Some(tree))?;

Expand Down Expand Up @@ -157,7 +157,7 @@ fn main() -> Result<(), MlsError> {

// Server starts observing Alice's group
let group_info = alice_group.group_info_message(true)?.to_bytes()?;
let tree = alice_group.export_tree().to_bytes().unwrap();
let tree = alice_group.export_tree()?;

let mut server = BasicServer::create_group(&group_info, &tree)?;

Expand Down
8 changes: 4 additions & 4 deletions mls-rs/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ where
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn join_group(
&self,
tree_data: Option<ExportedTree>,
tree_data: Option<ExportedTree<'_>>,
welcome_message: MlsMessage,
) -> Result<(Group<C>, NewMemberInfo), MlsError> {
Group::join(
Expand Down Expand Up @@ -624,7 +624,7 @@ where
pub async fn external_add_proposal(
&self,
group_info: MlsMessage,
tree_data: Option<crate::group::ExportedTree>,
tree_data: Option<crate::group::ExportedTree<'_>>,
authenticated_data: Vec<u8>,
) -> Result<MlsMessage, MlsError> {
let protocol_version = group_info.version;
Expand Down Expand Up @@ -902,7 +902,7 @@ mod tests {
let mut builder = new_client
.external_commit_builder()
.unwrap()
.with_tree_data(alice_group.group.export_tree());
.with_tree_data(alice_group.group.export_tree().into_owned());

if do_remove {
builder = builder.with_removal(1);
Expand Down Expand Up @@ -1008,7 +1008,7 @@ mod tests {
let (_, external_commit) = carol
.external_commit_builder()
.unwrap()
.with_tree_data(bob_group.group.export_tree())
.with_tree_data(bob_group.group.export_tree().into_owned())
.build(group_info_msg)
.await
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions mls-rs/src/extension/built_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ impl MlsCodecExtension for ApplicationIdExt {
)]
#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
pub struct RatchetTreeExt {
pub tree_data: ExportedTree,
pub tree_data: ExportedTree<'static>,
}

#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
impl RatchetTreeExt {
/// Required custom extension types.
#[cfg(feature = "ffi")]
pub fn tree_data(&self) -> &ExportedTree {
pub fn tree_data(&self) -> &ExportedTree<'static> {
&self.tree_data
}
}
Expand Down
2 changes: 1 addition & 1 deletion mls-rs/src/external_client/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ impl<C: ExternalClientConfig + Clone> ExternalGroup<C> {
pub fn export_tree(&self) -> Result<Vec<u8>, MlsError> {
self.group_state()
.public_tree
.export_node_data()
.nodes
.mls_encode_to_vec()
.map_err(Into::into)
}
Expand Down
8 changes: 4 additions & 4 deletions mls-rs/src/group/commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ pub struct CommitOutput {
/// Ratchet tree that can be sent out of band if
/// `ratchet_tree_extension` is not used according to
/// [`MlsRules::encryption_options`].
pub ratchet_tree: Option<ExportedTree>,
pub ratchet_tree: Option<ExportedTree<'static>>,
}

#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
Expand All @@ -110,7 +110,7 @@ impl CommitOutput {
/// `ratchet_tree_extension` is not used according to
/// [`MlsRules::encryption_options`].
#[cfg(feature = "ffi")]
pub fn ratchet_tree(&self) -> Option<&ExportedTree> {
pub fn ratchet_tree(&self) -> Option<&ExportedTree<'static>> {
self.ratchet_tree.as_ref()
}
}
Expand Down Expand Up @@ -560,7 +560,7 @@ where

if commit_options.ratchet_tree_extension {
let ratchet_tree_ext = RatchetTreeExt {
tree_data: ExportedTree::new(provisional_state.public_tree.export_node_data()),
tree_data: ExportedTree::new(provisional_state.public_tree.nodes.clone()),
};

extensions.set_from(ratchet_tree_ext)?;
Expand Down Expand Up @@ -675,7 +675,7 @@ where
self.pending_commit = Some(pending_commit);

let ratchet_tree = (!commit_options.ratchet_tree_extension)
.then(|| ExportedTree::new(provisional_state.public_tree.export_node_data()));
.then_some(ExportedTree::new(provisional_state.public_tree.nodes));

if let Some(signer) = new_signer {
self.signer = signer;
Expand Down
28 changes: 22 additions & 6 deletions mls-rs/src/group/exported_tree.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use alloc::vec::Vec;
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use alloc::{borrow::Cow, vec::Vec};
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};

use crate::{client::MlsError, tree_kem::node::NodeVec};
Expand All @@ -8,12 +12,16 @@ use crate::{client::MlsError, tree_kem::node::NodeVec};
safer_ffi_gen::ffi_type(clone, opaque)
)]
#[derive(Debug, MlsSize, MlsEncode, MlsDecode, PartialEq, Clone)]
pub struct ExportedTree(pub(crate) NodeVec);
pub struct ExportedTree<'a>(pub(crate) Cow<'a, NodeVec>);

#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
impl ExportedTree {
impl<'a> ExportedTree<'a> {
pub(crate) fn new(node_data: NodeVec) -> Self {
Self(node_data)
Self(Cow::Owned(node_data))
}

pub(crate) fn new_borrowed(node_data: &'a NodeVec) -> Self {
Self(Cow::Borrowed(node_data))
}

pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
Expand All @@ -23,10 +31,18 @@ impl ExportedTree {
pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
self.mls_encode_to_vec().map_err(Into::into)
}

pub fn byte_size(&self) -> usize {
self.mls_encoded_len()
}

pub fn into_owned(self) -> ExportedTree<'static> {
ExportedTree(Cow::Owned(self.0.into_owned()))
}
}

impl From<ExportedTree> for NodeVec {
impl From<ExportedTree<'_>> for NodeVec {
fn from(value: ExportedTree) -> Self {
value.0
value.0.into_owned()
}
}
4 changes: 2 additions & 2 deletions mls-rs/src/group/external_commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub struct ExternalCommitBuilder<C: ClientConfig> {
signer: SignatureSecretKey,
signing_identity: SigningIdentity,
config: C,
tree_data: Option<ExportedTree>,
tree_data: Option<ExportedTree<'static>>,
to_remove: Option<u32>,
#[cfg(feature = "psk")]
external_psks: Vec<ExternalPskId>,
Expand Down Expand Up @@ -70,7 +70,7 @@ impl<C: ClientConfig> ExternalCommitBuilder<C> {
#[must_use]
/// Use external tree data if the GroupInfo message does not contain a
/// [`RatchetTreeExt`](crate::extension::built_in::RatchetTreeExt)
pub fn with_tree_data(self, tree_data: ExportedTree) -> Self {
pub fn with_tree_data(self, tree_data: ExportedTree<'static>) -> Self {
Self {
tree_data: Some(tree_data),
..self
Expand Down
2 changes: 1 addition & 1 deletion mls-rs/src/group/interop_test_vectors/passive_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ pub async fn add_random_members<C: MlsConfig>(
tc.epochs.push(epoch)
};

let tree_data = groups[committer].export_tree();
let tree_data = groups[committer].export_tree().into_owned();

for client in &clients {
let commit = commit_output.welcome_messages[0].clone();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async fn tree_modifications_interop() {

let tree_after = apply_proposal(proposal, test_case.proposal_sender, &tree_before).await;

let tree_after = tree_after.export_node_data().mls_encode_to_vec().unwrap();
let tree_after = tree_after.nodes.mls_encode_to_vec().unwrap();

assert_eq!(tree_after, test_case.tree_after);
}
Expand Down
8 changes: 4 additions & 4 deletions mls-rs/src/group/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1438,8 +1438,8 @@ where
///
/// This function is used to provide the current group tree to new members
/// when the `ratchet_tree_extension` is not used according to [`MlsRules::commit_options`].
pub fn export_tree(&self) -> ExportedTree {
ExportedTree::new(self.current_epoch_tree().export_node_data())
pub fn export_tree(&self) -> ExportedTree<'_> {
ExportedTree::new_borrowed(&self.current_epoch_tree().nodes)
}

/// Current version of the MLS protocol in use by this group.
Expand Down Expand Up @@ -2637,7 +2637,7 @@ mod tests {
let (bob_group, commit) = bob
.external_commit_builder()
.unwrap()
.with_tree_data(alice_group.group.export_tree())
.with_tree_data(alice_group.group.export_tree().into_owned())
.build(
alice_group
.group
Expand Down Expand Up @@ -2683,7 +2683,7 @@ mod tests {
let (_, commit) = bob
.external_commit_builder()
.unwrap()
.with_tree_data(alice_group.group.export_tree())
.with_tree_data(alice_group.group.export_tree().into_owned())
.build(
alice_group
.group
Expand Down
2 changes: 1 addition & 1 deletion mls-rs/src/test_utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ pub async fn get_test_groups<C: CryptoProvider + Clone>(

creator_group.apply_pending_commit().await.unwrap();

let tree_data = creator_group.export_tree();
let tree_data = creator_group.export_tree().into_owned();

let mut groups = vec![creator_group];

Expand Down
2 changes: 1 addition & 1 deletion mls-rs/src/tree_kem/interop_test_vectors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl ValidationTestCase {

Self {
cipher_suite: cs.cipher_suite().into(),
tree: tree.export_node_data().mls_encode_to_vec().unwrap(),
tree: tree.nodes.mls_encode_to_vec().unwrap(),
tree_hashes: tree
.tree_hashes
.current
Expand Down
17 changes: 7 additions & 10 deletions mls-rs/src/tree_kem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,6 @@ impl TreeKemPublic {
Ok(None)
}

pub(crate) fn export_node_data(&self) -> NodeVec {
self.nodes.clone()
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn derive<I: IdentityProvider>(
leaf_node: LeafNode,
Expand Down Expand Up @@ -1008,12 +1004,13 @@ mod tests {
.await
.unwrap();

let exported = test_tree.public.export_node_data();

let imported =
TreeKemPublic::import_node_data(exported, &BasicIdentityProvider, &Default::default())
.await
.unwrap();
let imported = TreeKemPublic::import_node_data(
test_tree.public.nodes.clone(),
&BasicIdentityProvider,
&Default::default(),
)
.await
.unwrap();

assert_eq!(test_tree.public.nodes, imported.nodes);

Expand Down
2 changes: 1 addition & 1 deletion mls-rs/src/tree_kem/tree_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ mod tests {

test_cases.push(TestCase {
cipher_suite: cipher_suite.into(),
tree_data: tree.export_node_data().mls_encode_to_vec().unwrap(),
tree_data: tree.nodes.mls_encode_to_vec().unwrap(),
tree_hash: tree
.tree_hash(&test_cipher_suite_provider(cipher_suite))
.await
Expand Down

0 comments on commit f5c2bae

Please sign in to comment.