Skip to content

Commit

Permalink
Fix simple_scenario_sync and remove async for now
Browse files Browse the repository at this point in the history
  • Loading branch information
tomleavy committed Feb 26, 2024
1 parent cf32ac9 commit 65a090f
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 34 deletions.
18 changes: 9 additions & 9 deletions mls-rs-uniffi/src/config/group_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use mls_rs_core::mls_rs_codec::{MlsDecode, MlsEncode};

use super::FFICallbackError;

#[derive(Clone, Debug, uniffi::Object)]
#[derive(Clone, Debug, uniffi::Record)]
pub struct GroupState {
pub id: Vec<u8>,
pub data: Vec<u8>,
Expand All @@ -16,7 +16,7 @@ impl mls_rs_core::group::GroupState for GroupState {
}
}

#[derive(Clone, Debug, uniffi::Object)]
#[derive(Clone, Debug, uniffi::Record)]
pub struct EpochRecord {
pub id: u64,
pub data: Vec<u8>,
Expand All @@ -41,9 +41,9 @@ pub trait GroupStateStorage: Send + Sync + Debug {

async fn write(
&self,
state: Arc<GroupState>,
epoch_inserts: Vec<Arc<EpochRecord>>,
epoch_updates: Vec<Arc<EpochRecord>>,
state: GroupState,
epoch_inserts: Vec<EpochRecord>,
epoch_updates: Vec<EpochRecord>,
) -> Result<(), FFICallbackError>;

async fn max_epoch_id(&self, group_id: Vec<u8>) -> Result<Option<u64>, FFICallbackError>;
Expand Down Expand Up @@ -99,16 +99,16 @@ impl mls_rs_core::group::GroupStateStorage for GroupStateStorageWrapper {
ST: mls_rs_core::group::GroupState + MlsEncode + MlsDecode + Send + Sync,
ET: mls_rs_core::group::EpochRecord + MlsEncode + MlsDecode + Send + Sync,
{
let state = Arc::new(GroupState {
let state = GroupState {
id: state.id(),
data: state.mls_encode_to_vec()?,
});
};

let epoch_to_record = |v: ET| -> Result<_, Self::Error> {
Ok(Arc::new(EpochRecord {
Ok(EpochRecord {
id: v.id(),
data: v.mls_encode_to_vec()?,
}))
})
};

let inserts = epoch_inserts
Expand Down
24 changes: 22 additions & 2 deletions mls-rs-uniffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ fn arc_unwrap_or_clone<T: Clone>(arc: Arc<T>) -> T {
#[uniffi(flat_error)]
#[non_exhaustive]
pub enum Error {
#[error("A mls-rs error occurred")]
#[error("A mls-rs error occurred: {inner}")]
MlsError {
#[from]
inner: mls_rs::error::MlsError,
},
#[error("An unknown error occurred")]
#[error("An unknown error occurred: {inner}")]
AnyError {
#[from]
inner: mls_rs::error::AnyError,
Expand Down Expand Up @@ -329,6 +329,19 @@ impl Client {
group_info_extensions,
})
}

/// Load an existing group.
///
/// See [`mls_rs::Client::load_group`] for details.
pub async fn load_group(&self, group_id: Vec<u8>) -> Result<Group, Error> {
self.inner
.load_group(&group_id)
.await
.map(|g| Group {
inner: Arc::new(Mutex::new(g)),
})
.map_err(Into::into)
}
}

#[derive(Clone, Debug, uniffi::Object)]
Expand Down Expand Up @@ -423,6 +436,13 @@ async fn signing_identity_to_identifier(
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[uniffi::export]
impl Group {
/// Write the current state of the group to storage defined by
/// [`ClientConfig::group_state_storage`]
pub async fn write_to_storage(&self) -> Result<(), Error> {
let mut group = self.inner().await;
group.write_to_storage().await.map_err(Into::into)
}

/// Perform a commit of received proposals (or an empty commit).
///
/// TODO: ensure `path_required` is always set in
Expand Down
19 changes: 0 additions & 19 deletions mls-rs-uniffi/test_bindings/simple_scenario_async.py

This file was deleted.

75 changes: 71 additions & 4 deletions mls-rs-uniffi/test_bindings/simple_scenario_sync.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,74 @@
from mls_rs_uniffi import CipherSuite, generate_signature_keypair, Client
from mls_rs_uniffi import CipherSuite, generate_signature_keypair, Client, GroupStateStorage, ClientConfig

class EpochData:
def __init__(self, id: "int", data: "bytes"):
self.id = id
self.data = data

class GroupStateData:
def __init__(self, state: "bytes"):
self.state = state
self.epoch_data = []

class PythonGroupStateStorage(GroupStateStorage):
def __init__(self):
self.groups = {}

def state(self, group_id: "bytes"):
group = self.groups.get(group_id.hex())

if group == None:
return None

group.state

def epoch(self, group_id: "bytes",epoch_id: "int"):
group = self.groups[group_id.hex()]

if group == None:
return None

for epoch in group.epoch_data:
if epoch.id == epoch_id:
return epoch

return None

def write(self, state: "GroupState",epoch_inserts: "typing.List[EpochRecord]",epoch_updates: "typing.List[EpochRecord]"):
if self.groups.get(state.id.hex()) == None:
self.groups[state.id.hex()] = GroupStateData(state.data)

group = self.groups[state.id.hex()]

for insert in epoch_inserts:
group.epoch_data.append(insert)

for update in epoch_updates:
for i in range(len(group.epoch_data)):
if group.epoch_data[i].id == update.id:
group.epoch_data[i] = update

def max_epoch_id(self, group_id: "bytes"):
group = self.groups.get(group_id.hex())

if group == None:
return None

last = group.epoch_data.last()

if last == None:
return None

return last.id

group_state_storage = PythonGroupStateStorage()
client_config = ClientConfig(group_state_storage)

key = generate_signature_keypair(CipherSuite.CURVE25519_AES128)
alice = Client(b'alice', key)
alice = Client(b'alice', key, client_config)

key = generate_signature_keypair(CipherSuite.CURVE25519_AES128)
bob = Client(b'bob', key)
bob = Client(b'bob', key, client_config)

alice = alice.create_group(None)
kp = bob.generate_key_package_message()
Expand All @@ -15,4 +79,7 @@
msg = alice.encrypt_application_message(b'hello, bob')
output = bob.process_incoming_message(msg)

assert output.data == b'hello, bob'
alice.write_to_storage()

assert output.data == b'hello, bob'
assert len(group_state_storage.groups) == 1

0 comments on commit 65a090f

Please sign in to comment.