From ff2a93c668afb8c57daf252d260eafd41171683b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 6 Jan 2025 16:25:09 +0000 Subject: [PATCH] fix: MLS groups are not usable if I have a proteus client registered to my account but have MLS on current client [WPB-15192] (#3197) (#3204) * fix: MLS groups are not usable if I have a proteus client registered to my account but have MLS on current client [WPB-15192] * detekt * detekt * fix: pr comments --------- Co-authored-by: Mohamad Jaara Co-authored-by: yamilmedina --- .../kalium/logic/data/user/UserRepository.kt | 6 ++ .../feature/conversation/ConversationScope.kt | 7 ++- ...versationInteractionAvailabilityUseCase.kt | 50 ++++++++------- ...ationInteractionAvailabilityUseCaseTest.kt | 51 +++++++++------ .../repository/UserRepositoryArrangement.kt | 16 +++++ .../com/wire/kalium/persistence/Clients.sq | 3 + .../persistence/dao/client/ClientDAO.kt | 1 + .../persistence/dao/client/ClientDAOImpl.kt | 5 ++ .../persistence/dao/client/ClientDAOTest.kt | 63 +++++++++++++++++++ 9 files changed, 160 insertions(+), 42 deletions(-) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/user/UserRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/user/UserRepository.kt index 984af58ca65..b283f3158b6 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/user/UserRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/user/UserRepository.kt @@ -23,6 +23,7 @@ import com.wire.kalium.logger.obfuscateDomain import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.NetworkFailure import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.conversation.MemberMapper import com.wire.kalium.logic.data.conversation.Recipient import com.wire.kalium.logic.data.conversation.mls.NameAndHandle @@ -167,6 +168,7 @@ interface UserRepository { suspend fun getNameAndHandle(userId: UserId): Either suspend fun migrateUserToTeam(teamName: String): Either suspend fun updateTeamId(userId: UserId, teamId: TeamId): Either + suspend fun isClientMlsCapable(userId: UserId, clientId: ClientId): Either } @Suppress("LongParameterList", "TooManyFunctions") @@ -668,6 +670,10 @@ internal class UserDataSource internal constructor( userDAO.updateTeamId(userId.toDao(), teamId.value) } + override suspend fun isClientMlsCapable(userId: UserId, clientId: ClientId): Either = wrapStorageRequest { + clientDAO.isMLSCapable(userId.toDao(), clientId.value) + } + companion object { internal const val SELF_USER_ID_KEY = "selfUserID" diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/ConversationScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/ConversationScope.kt index 988f49480e3..7f73c1c6cda 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/ConversationScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/ConversationScope.kt @@ -164,7 +164,12 @@ class ConversationScope internal constructor( get() = ObserveIsSelfUserMemberUseCaseImpl(conversationRepository, selfUserId) val observeConversationInteractionAvailabilityUseCase: ObserveConversationInteractionAvailabilityUseCase - get() = ObserveConversationInteractionAvailabilityUseCase(conversationRepository, userRepository) + get() = ObserveConversationInteractionAvailabilityUseCase( + conversationRepository, + selfUserId = selfUserId, + selfClientIdProvider = currentClientIdProvider, + userRepository = userRepository + ) val deleteTeamConversation: DeleteTeamConversationUseCase get() = DeleteTeamConversationUseCaseImpl(selfTeamIdProvider, teamRepository, conversationRepository) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/ObserveConversationInteractionAvailabilityUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/ObserveConversationInteractionAvailabilityUseCase.kt index 9dab4f5fa45..8d25d6c82d2 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/ObserveConversationInteractionAvailabilityUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/ObserveConversationInteractionAvailabilityUseCase.kt @@ -25,15 +25,18 @@ import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.conversation.InteractionAvailability import com.wire.kalium.logic.data.conversation.interactionAvailability import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.id.CurrentClientIdProvider import com.wire.kalium.logic.data.message.MessageContent -import com.wire.kalium.logic.data.user.SelfUser -import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.fold +import com.wire.kalium.logic.functional.getOrElse +import com.wire.kalium.logic.kaliumLogger import com.wire.kalium.util.KaliumDispatcher import com.wire.kalium.util.KaliumDispatcherImpl import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.combine +import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.withContext @@ -48,6 +51,8 @@ import kotlinx.coroutines.withContext class ObserveConversationInteractionAvailabilityUseCase internal constructor( private val conversationRepository: ConversationRepository, private val userRepository: UserRepository, + private val selfUserId: UserId, + private val selfClientIdProvider: CurrentClientIdProvider, private val dispatcher: KaliumDispatcher = KaliumDispatcherImpl, ) { @@ -56,13 +61,21 @@ class ObserveConversationInteractionAvailabilityUseCase internal constructor( * @return an [IsInteractionAvailableResult] containing Success or Failure cases */ suspend operator fun invoke(conversationId: ConversationId): Flow = withContext(dispatcher.io) { - conversationRepository.observeConversationDetailsById(conversationId).combine( - userRepository.observeSelfUser() - ) { conversation, selfUser -> - conversation to selfUser - }.map { (eitherConversation, selfUser) -> + + val isSelfClientMlsCapable = selfClientIdProvider().flatMap { + userRepository.isClientMlsCapable(selfUserId, it) + }.getOrElse { + return@withContext flow { IsInteractionAvailableResult.Failure(it) } + } + + kaliumLogger.withTextTag("ObserveConversationInteractionAvailabilityUseCase").d("isSelfClientMlsCapable $isSelfClientMlsCapable") + + conversationRepository.observeConversationDetailsById(conversationId).map { eitherConversation -> eitherConversation.fold({ failure -> IsInteractionAvailableResult.Failure(failure) }, { conversationDetails -> - val isProtocolSupported = doesUserSupportConversationProtocol(conversationDetails, selfUser) + val isProtocolSupported = doesUserSupportConversationProtocol( + conversationDetails = conversationDetails, + isSelfClientMlsCapable = isSelfClientMlsCapable + ) if (!isProtocolSupported) { // short-circuit to Unsupported Protocol if it's the case return@fold IsInteractionAvailableResult.Success(InteractionAvailability.UNSUPPORTED_PROTOCOL) } @@ -74,19 +87,12 @@ class ObserveConversationInteractionAvailabilityUseCase internal constructor( private fun doesUserSupportConversationProtocol( conversationDetails: ConversationDetails, - selfUser: SelfUser - ): Boolean { - val protocolInfo = conversationDetails.conversation.protocol - val acceptableProtocols = when (protocolInfo) { - is Conversation.ProtocolInfo.MLS -> setOf(SupportedProtocol.MLS) - // Messages in mixed conversations are sent through Proteus - is Conversation.ProtocolInfo.Mixed -> setOf(SupportedProtocol.PROTEUS) - Conversation.ProtocolInfo.Proteus -> setOf(SupportedProtocol.PROTEUS) - } - val isProtocolSupported = selfUser.supportedProtocols?.any { supported -> - acceptableProtocols.contains(supported) - } ?: false - return isProtocolSupported + isSelfClientMlsCapable: Boolean + ): Boolean = when (conversationDetails.conversation.protocol) { + is Conversation.ProtocolInfo.MLS -> isSelfClientMlsCapable + // Messages in mixed conversations are sent through Proteus + is Conversation.ProtocolInfo.Mixed, + Conversation.ProtocolInfo.Proteus -> true } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/ObserveConversationInteractionAvailabilityUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/ObserveConversationInteractionAvailabilityUseCaseTest.kt index cc721631f91..2feef65b5cd 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/ObserveConversationInteractionAvailabilityUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/ObserveConversationInteractionAvailabilityUseCaseTest.kt @@ -20,16 +20,21 @@ package com.wire.kalium.logic.feature.conversation import app.cash.turbine.test import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.conversation.Conversation import com.wire.kalium.logic.data.conversation.InteractionAvailability import com.wire.kalium.logic.data.user.ConnectionState import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.framework.TestConversation import com.wire.kalium.logic.framework.TestConversationDetails import com.wire.kalium.logic.framework.TestUser import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.right import com.wire.kalium.logic.test_util.TestKaliumDispatcher import com.wire.kalium.logic.test_util.testKaliumDispatcher +import com.wire.kalium.logic.util.arrangement.provider.CurrentClientIdProviderArrangement +import com.wire.kalium.logic.util.arrangement.provider.CurrentClientIdProviderArrangementImpl import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl import com.wire.kalium.logic.util.arrangement.repository.UserRepositoryArrangement @@ -41,6 +46,7 @@ import io.mockative.once import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.test.runTest +import kotlin.test.Ignore import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertIs @@ -52,6 +58,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { val conversationId = TestConversation.ID val (arrangement, observeConversationInteractionAvailability) = arrange { + withIsClientMlsCapable(false.right()) dispatcher = testKaliumDispatcher withSelfUserBeingMemberOfConversation(isMember = true) } @@ -76,6 +83,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { val (arrangement, observeConversationInteractionAvailability) = arrange { dispatcher = testKaliumDispatcher withSelfUserBeingMemberOfConversation(isMember = false) + withIsClientMlsCapable(false.right()) } observeConversationInteractionAvailability(conversationId).test { @@ -96,6 +104,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { val conversationId = TestConversation.ID val (arrangement, observeConversationInteractionAvailability) = arrange { + withIsClientMlsCapable(false.right()) dispatcher = testKaliumDispatcher withGroupConversationError() } @@ -118,6 +127,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { val conversationId = TestConversation.ID val (arrangement, observeConversationInteractionAvailability) = arrange { + withIsClientMlsCapable(false.right()) dispatcher = testKaliumDispatcher withBlockedUserConversation() } @@ -132,7 +142,6 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { awaitComplete() } - } @Test @@ -140,6 +149,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { val conversationId = TestConversation.ID val (arrangement, observeConversationInteractionAvailability) = arrange { + withIsClientMlsCapable(false.right()) dispatcher = testKaliumDispatcher withDeletedUserConversation() } @@ -156,11 +166,12 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { } } + @Ignore // is this really a case that a client does not support Proteus @Test fun givenProteusConversationAndUserSupportsOnlyMLS_whenObserving_thenShouldReturnUnsupportedProtocol() = runTest { testProtocolSupport( conversationProtocolInfo = Conversation.ProtocolInfo.Proteus, - userSupportedProtocols = setOf(SupportedProtocol.MLS), + isMlsCapable = true.right(), expectedResult = InteractionAvailability.UNSUPPORTED_PROTOCOL ) } @@ -169,7 +180,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { fun givenMLSConversationAndUserSupportsOnlyMLS_whenObserving_thenShouldReturnUnsupportedProtocol() = runTest { testProtocolSupport( conversationProtocolInfo = TestConversation.MLS_PROTOCOL_INFO, - userSupportedProtocols = setOf(SupportedProtocol.PROTEUS), + isMlsCapable = false.right(), expectedResult = InteractionAvailability.UNSUPPORTED_PROTOCOL ) } @@ -178,7 +189,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { fun givenMixedConversationAndUserSupportsOnlyMLS_whenObserving_thenShouldReturnUnsupportedProtocol() = runTest { testProtocolSupport( conversationProtocolInfo = TestConversation.MIXED_PROTOCOL_INFO, - userSupportedProtocols = setOf(SupportedProtocol.PROTEUS), + isMlsCapable = false.right(), expectedResult = InteractionAvailability.ENABLED ) } @@ -187,7 +198,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { fun givenMixedConversationAndUserSupportsProteus_whenObserving_thenShouldReturnEnabled() = runTest { testProtocolSupport( conversationProtocolInfo = TestConversation.MIXED_PROTOCOL_INFO, - userSupportedProtocols = setOf(SupportedProtocol.PROTEUS), + isMlsCapable = false.right(), expectedResult = InteractionAvailability.ENABLED ) } @@ -196,8 +207,8 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { fun givenMLSConversationAndUserSupportsMLS_whenObserving_thenShouldReturnEnabled() = runTest { testProtocolSupport( conversationProtocolInfo = TestConversation.MLS_PROTOCOL_INFO, - userSupportedProtocols = setOf(SupportedProtocol.MLS), - expectedResult = InteractionAvailability.ENABLED + expectedResult = InteractionAvailability.ENABLED, + isMlsCapable = true.right() ) } @@ -205,18 +216,19 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { fun givenProteusConversationAndUserSupportsProteus_whenObserving_thenShouldReturnEnabled() = runTest { testProtocolSupport( conversationProtocolInfo = TestConversation.PROTEUS_PROTOCOL_INFO, - userSupportedProtocols = setOf(SupportedProtocol.PROTEUS), - expectedResult = InteractionAvailability.ENABLED + expectedResult = InteractionAvailability.ENABLED, + isMlsCapable = false.right() ) } private suspend fun CoroutineScope.testProtocolSupport( conversationProtocolInfo: Conversation.ProtocolInfo, - userSupportedProtocols: Set, + isMlsCapable: Either, expectedResult: InteractionAvailability ) { val convId = TestConversationDetails.CONVERSATION_GROUP.conversation.id val (_, observeConversationInteractionAvailabilityUseCase) = arrange { + withIsClientMlsCapable(isMlsCapable) dispatcher = testKaliumDispatcher val proteusGroupDetails = TestConversationDetails.CONVERSATION_GROUP.copy( conversation = TestConversationDetails.CONVERSATION_GROUP.conversation.copy( @@ -224,7 +236,6 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { ) ) withObserveConversationDetailsByIdReturning(Either.Right(proteusGroupDetails)) - withObservingSelfUserReturning(flowOf(TestUser.SELF.copy(supportedProtocols = userSupportedProtocols))) } observeConversationInteractionAvailabilityUseCase(convId).test { @@ -241,6 +252,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { val (_, observeConversationInteractionAvailability) = arrange { dispatcher = testKaliumDispatcher withLegalHoldOneOnOneConversation(Conversation.LegalHoldStatus.ENABLED) + withIsClientMlsCapable(false.right()) } observeConversationInteractionAvailability(conversationId).test { val interactionResult = awaitItem() @@ -253,6 +265,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { fun givenConversationLegalHoldIsDegraded_whenInvokingInteractionForConversation_thenInteractionShouldBeLegalHold() = runTest { val conversationId = TestConversation.ID val (_, observeConversationInteractionAvailability) = arrange { + withIsClientMlsCapable(false.right()) dispatcher = testKaliumDispatcher withLegalHoldOneOnOneConversation(Conversation.LegalHoldStatus.DEGRADED) } @@ -266,10 +279,12 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { private class Arrangement( private val configure: suspend Arrangement.() -> Unit ) : UserRepositoryArrangement by UserRepositoryArrangementImpl(), - ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl() { + ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl(), + CurrentClientIdProviderArrangement by CurrentClientIdProviderArrangementImpl() { var dispatcher: KaliumDispatcher = TestKaliumDispatcher + val selfUser = UserId("self_value", "self_domain") suspend fun withSelfUserBeingMemberOfConversation(isMember: Boolean) = apply { withObserveConversationDetailsByIdReturning( Either.Right(TestConversationDetails.CONVERSATION_GROUP.copy(isSelfUserMember = isMember)) @@ -315,17 +330,15 @@ class ObserveConversationInteractionAvailabilityUseCaseTest { } suspend fun arrange(): Pair = run { - withObservingSelfUserReturning( - flowOf( - TestUser.SELF.copy(supportedProtocols = setOf(SupportedProtocol.MLS, SupportedProtocol.PROTEUS)) - ) - ) + withCurrentClientIdSuccess(ClientId("client_id")) configure() this@Arrangement to ObserveConversationInteractionAvailabilityUseCase( conversationRepository = conversationRepository, userRepository = userRepository, - dispatcher = dispatcher - ) + dispatcher = dispatcher, + selfUserId = selfUser, + selfClientIdProvider = currentClientIdProvider + ) } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserRepositoryArrangement.kt index e106b3290a6..6f1ce40c292 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserRepositoryArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserRepositoryArrangement.kt @@ -19,6 +19,7 @@ package com.wire.kalium.logic.util.arrangement.repository import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.conversation.mls.NameAndHandle import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.id.QualifiedID @@ -95,6 +96,12 @@ internal interface UserRepositoryArrangement { ) suspend fun withNameAndHandle(result: Either, userId: Matcher = AnyMatcher(valueOf())) + + suspend fun withIsClientMlsCapable( + result: Either, + userId: Matcher = AnyMatcher(valueOf()), + clientId: Matcher = AnyMatcher(valueOf()) + ) } @Suppress("INAPPLICABLE_JVM_NAME") @@ -233,4 +240,13 @@ internal open class UserRepositoryArrangementImpl : UserRepositoryArrangement { override suspend fun withNameAndHandle(result: Either, userId: Matcher) { coEvery { userRepository.getNameAndHandle(matches { userId.matches(it) }) }.returns(result) } + + override suspend fun withIsClientMlsCapable(result: Either, userId: Matcher, clientId: Matcher) { + coEvery { + userRepository.isClientMlsCapable( + userId = matches { userId.matches(it) }, + clientId = matches { clientId.matches(it) } + ) + }.returns(result) + } } diff --git a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Clients.sq b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Clients.sq index 709b3613d56..03bbdd7fbb5 100644 --- a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Clients.sq +++ b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Clients.sq @@ -74,6 +74,9 @@ SELECT * FROM Client WHERE user_id = :user_id AND id = :client_id; deleteClientsOfUserExcept: DELETE FROM Client WHERE user_id = :user_id AND id NOT IN :exception_ids; +isClientMLSCapable: +SELECT is_mls_capable FROM Client WHERE user_id = :user_id AND id = :client_id; + tryMarkAsInvalid: UPDATE OR IGNORE Client SET is_valid = 0 WHERE user_id = :user_id AND id IN :clientId_List; diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/client/ClientDAO.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/client/ClientDAO.kt index 4b6d1a5e17b..7b5a9d87d18 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/client/ClientDAO.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/client/ClientDAO.kt @@ -97,4 +97,5 @@ interface ClientDAO { ): Map> suspend fun selectAllClients(): Map> + suspend fun isMLSCapable(userId: QualifiedIDEntity, clientId: String): Boolean? } diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/client/ClientDAOImpl.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/client/ClientDAOImpl.kt index b30b2a8ebec..79935606535 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/client/ClientDAOImpl.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/client/ClientDAOImpl.kt @@ -155,6 +155,11 @@ internal class ClientDAOImpl internal constructor( .executeAsList() .groupBy { it.userId } + override suspend fun isMLSCapable(userId: QualifiedIDEntity, clientId: String): Boolean? = withContext(queriesContext) { + clientsQueries.isClientMLSCapable(userId, clientId) + .executeAsOneOrNull() + } + override suspend fun getClientsOfUserByQualifiedIDFlow(qualifiedID: QualifiedIDEntity): Flow> = clientsQueries.selectAllClientsByUserId(qualifiedID, mapper::fromClient) .asFlow() diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/client/ClientDAOTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/client/ClientDAOTest.kt index 7d352251f6e..4fa472f3ee1 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/client/ClientDAOTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/client/ClientDAOTest.kt @@ -459,6 +459,69 @@ class ClientDAOTest : BaseDatabaseTest() { } } + @Test + fun givenClientIsNotMlsCapable_whenCallingIsMlsCapable_thenReturnFalse() = runTest { + val user = user + val client: InsertClientParam = insertedClient.copy(isMLSCapable = false) + userDAO.upsertUser(user) + clientDAO.insertClient(client) + assertFalse { clientDAO.isMLSCapable(userId, clientId = client.id)!! } + } + + @Test + fun givenClientIsMlsCapable_whenCallingIsMlsCapable_thenReturnTrue() = runTest { + val user = user + val client: InsertClientParam = insertedClient.copy(isMLSCapable = true) + userDAO.upsertUser(user) + clientDAO.insertClient(client) + assertTrue { clientDAO.isMLSCapable(userId, clientId = client.id)!! } + } + + @Test + fun givenNotFound_whenCallingIsMlsCapableForUser_thenReturnNull() = runTest { + val user = user + userDAO.upsertUser(user) + assertNull(clientDAO.isMLSCapable(userId, clientId = client.id)) + } + + @Test + fun givenPersistedClient_whenUpsertingTheSameExactClient_thenItShouldIgnoreAndNotNotifyOtherQueries() = runTest { + // Given + userDAO.upsertUser(user) + clientDAO.insertClient(insertedClient) + + clientDAO.observeClient(user.id, insertedClient.id).test { + val initialValue = awaitItem() + assertEquals(insertedClient.toClient(), initialValue) + + // When + clientDAO.insertClient(insertedClient) // the same exact client is being saved again + + // Then + expectNoEvents() // other query should not be notified + } + } + + @Test + fun givenPersistedClient_whenUpsertingUpdatedClient_thenItShouldBeSavedAndOtherQueriesShouldBeUpdated() = runTest { + // Given + userDAO.upsertUser(user) + clientDAO.insertClient(insertedClient) + val updatedInsertedClient = insertedClient.copy(label = "new_label") + + clientDAO.observeClient(user.id, insertedClient.id).test { + val initialValue = awaitItem() + assertEquals(insertedClient.toClient(), initialValue) + + // When + clientDAO.insertClient(updatedInsertedClient) // updated client is being saved that should replace the old one + + // Then + val updatedValue = awaitItem() // other query should be notified + assertEquals(updatedInsertedClient.toClient(), updatedValue) + } + } + private companion object { val userId = QualifiedIDEntity("test", "domain") val user = newUserEntity(userId)