Skip to content

Commit

Permalink
make mls commit lock per-client
Browse files Browse the repository at this point in the history
  • Loading branch information
insipx committed Feb 19, 2025
1 parent 555d80d commit 4f20dc3
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 40 deletions.
1 change: 0 additions & 1 deletion xmtp_api/src/mls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ where

id_cursor = Some(paging_info.id_cursor);
}

Ok(out)
}

Expand Down
8 changes: 7 additions & 1 deletion xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ use xmtp_proto::xmtp::mls::api::v1::{welcome_message, GroupMessage, WelcomeMessa
#[cfg(any(test, feature = "test-utils"))]
use crate::groups::device_sync::WorkerHandle;

use crate::groups::group_mutable_metadata::MessageDisappearingSettings;
use crate::groups::{ConversationListItem, DMMetadataOptions};
use crate::{groups::group_mutable_metadata::MessageDisappearingSettings, GroupCommitLock};
use crate::{
groups::{
device_sync::preference_sync::UserPreferenceUpdate, group_metadata::DmMembers,
Expand Down Expand Up @@ -174,6 +174,7 @@ pub struct XmtpMlsLocalContext {
/// XMTP Local Storage
store: EncryptedMessageStore,
pub(crate) mutexes: MutexRegistry,
pub(crate) mls_commit_lock: std::sync::Arc<GroupCommitLock>,
}

impl XmtpMlsLocalContext {
Expand Down Expand Up @@ -213,6 +214,10 @@ impl XmtpMlsLocalContext {
) -> Result<Vec<u8>, IdentityError> {
self.identity.sign_with_public_context(text)
}

pub fn mls_commit_lock(&self) -> &Arc<GroupCommitLock> {
&self.mls_commit_lock
}
}

impl<ApiClient, V> Client<ApiClient, V>
Expand All @@ -238,6 +243,7 @@ where
identity,
store,
mutexes: MutexRegistry::new(),
mls_commit_lock: Arc::new(GroupCommitLock::new()),
});
let (tx, _) = broadcast::channel(32);

Expand Down
10 changes: 0 additions & 10 deletions xmtp_mls/src/groups/mls_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,11 +493,6 @@ where
intent.kind,
message_epoch
);
#[cfg(test)]
{
let mut w = crate::PROCESSED.lock();
w.push((*cursor, intent.clone()));
}

if let Some((staged_commit, validated_commit)) = commit {
tracing::info!(
Expand Down Expand Up @@ -1280,11 +1275,6 @@ where
intent.id,
intent.kind
);
#[cfg(test)]
{
let mut w = crate::PUBLISHED.lock();
w.push(intent.clone());
}
if has_staged_commit {
tracing::info!("Commit sent. Stopping further publishes for this round");
return Ok(());
Expand Down
11 changes: 8 additions & 3 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ use crate::storage::{
refresh_state::EntityKind,
NotFound, ProviderTransactions, StorageError,
};
use crate::GroupCommitLock;
use crate::{
client::{ClientError, XmtpMlsLocalContext},
configuration::{
Expand All @@ -67,7 +68,7 @@ use crate::{
},
subscriptions::{LocalEventError, LocalEvents},
utils::id::calculate_message_id,
Store, MLS_COMMIT_LOCK,
Store,
};
use device_sync::preference_sync::UserPreferenceUpdate;
use intents::SendMessageIntentData;
Expand Down Expand Up @@ -281,6 +282,7 @@ pub struct MlsGroup<C> {
pub group_id: Vec<u8>,
pub created_at_ns: i64,
pub client: Arc<C>,
mls_commit_lock: Arc<GroupCommitLock>,
mutex: Arc<Mutex<()>>,
}

Expand Down Expand Up @@ -309,6 +311,7 @@ impl<C> Clone for MlsGroup<C> {
created_at_ns: self.created_at_ns,
client: self.client.clone(),
mutex: self.mutex.clone(),
mls_commit_lock: self.mls_commit_lock.clone(),
}
}
}
Expand Down Expand Up @@ -415,11 +418,13 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
created_at_ns: i64,
) -> Self {
let mut mutexes = client.context().mutexes.clone();
let context = client.context();
Self {
group_id: group_id.clone(),
created_at_ns,
mutex: mutexes.get_mutex(group_id),
client,
mls_commit_lock: Arc::clone(context.mls_commit_lock()),
}
}

Expand Down Expand Up @@ -447,7 +452,7 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
let group_id = self.group_id.clone();

// Acquire the lock synchronously using blocking_lock
let _lock = MLS_COMMIT_LOCK.get_lock_sync(group_id.clone());
let _lock = self.mls_commit_lock.get_lock_sync(group_id.clone());
// Load the MLS group
let mls_group =
OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id))
Expand Down Expand Up @@ -478,7 +483,7 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
hex::encode(&group_id)
);
// Acquire the lock asynchronously
let _lock = MLS_COMMIT_LOCK.get_lock_async(group_id.clone()).await;
let _lock = self.mls_commit_lock.get_lock_async(group_id.clone()).await;
tracing::info!("LOADING GROUP");

// Load the MLS group
Expand Down
10 changes: 1 addition & 9 deletions xmtp_mls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,6 @@ use tokio::sync::Mutex as TokioMutex;
pub use xmtp_id::InboxOwner;
pub use xmtp_proto::api_client::trait_impls::*;

#[cfg(test)]
pub static PUBLISHED: once_cell::sync::Lazy<Mutex<Vec<storage::group_intent::StoredGroupIntent>>> =
once_cell::sync::Lazy::new(|| Mutex::new(Vec::new()));
#[cfg(test)]
pub static PROCESSED: once_cell::sync::Lazy<
Mutex<Vec<(u64, storage::group_intent::StoredGroupIntent)>>,
> = once_cell::sync::Lazy::new(|| Mutex::new(Vec::new()));

/// A manager for group-specific semaphores
#[derive(Debug)]
pub struct GroupCommitLock {
Expand Down Expand Up @@ -93,7 +85,7 @@ pub struct MlsGroupGuard {
}

// Static instance of `GroupCommitLock`
pub static MLS_COMMIT_LOCK: LazyLock<GroupCommitLock> = LazyLock::new(GroupCommitLock::new);
// pub static MLS_COMMIT_LOCK: LazyLock<GroupCommitLock> = LazyLock::new(GroupCommitLock::new);

/// Inserts a model to the underlying data store, erroring if it already exists
pub trait Store<StorageConnection> {
Expand Down
16 changes: 0 additions & 16 deletions xmtp_mls/src/subscriptions/stream_all.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,22 +377,6 @@ mod tests {
})
.await;

tracing::info!("Total Messages: {}", messages.len());
tracing::info!("--------------------------");
tracing::info!("PUBLISHED");
tracing::info!("--------------------------");
let published = crate::PUBLISHED.lock();
let processed = crate::PROCESSED.lock();
for i in published.iter() {
tracing::info!("{:?}", i);
}
tracing::info!("--------------------------");
tracing::info!("PROCESSED");
tracing::info!("--------------------------");

for (cursor, i) in processed.iter() {
tracing::info!("cursor = {}, Intent={:?}", cursor, i);
}
assert_eq!(messages.len(), 6);
}

Expand Down

0 comments on commit 4f20dc3

Please sign in to comment.