Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: MLS groups are not usable if I have a proteus client registered to my account but have MLS on current client [WPB-15192] #3197

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.wire.kalium.logger.obfuscateDomain
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.conversation.MemberMapper
import com.wire.kalium.logic.data.conversation.Recipient
import com.wire.kalium.logic.data.conversation.mls.NameAndHandle
Expand Down Expand Up @@ -167,6 +168,7 @@ interface UserRepository {
suspend fun getNameAndHandle(userId: UserId): Either<StorageFailure, NameAndHandle>
suspend fun migrateUserToTeam(teamName: String): Either<CoreFailure, CreateUserTeam>
suspend fun updateTeamId(userId: UserId, teamId: TeamId): Either<StorageFailure, Unit>
suspend fun isClientMlsCapable(userId: UserId, clientId: ClientId): Either<StorageFailure, Boolean>
}

@Suppress("LongParameterList", "TooManyFunctions")
Expand Down Expand Up @@ -668,6 +670,10 @@ internal class UserDataSource internal constructor(
userDAO.updateTeamId(userId.toDao(), teamId.value)
}

override suspend fun isClientMlsCapable(userId: UserId, clientId: ClientId): Either<StorageFailure, Boolean> = wrapStorageRequest {
clientDAO.isMLSCapable(userId.toDao(), clientId.value)
}

companion object {
internal const val SELF_USER_ID_KEY = "selfUserID"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,12 @@ class ConversationScope internal constructor(
get() = ObserveIsSelfUserMemberUseCaseImpl(conversationRepository, selfUserId)

val observeConversationInteractionAvailabilityUseCase: ObserveConversationInteractionAvailabilityUseCase
get() = ObserveConversationInteractionAvailabilityUseCase(conversationRepository, userRepository)
get() = ObserveConversationInteractionAvailabilityUseCase(
conversationRepository,
selfUserId = selfUserId,
selfClientIdProvider = currentClientIdProvider,
userRepository = userRepository
)

val deleteTeamConversation: DeleteTeamConversationUseCase
get() = DeleteTeamConversationUseCaseImpl(selfTeamIdProvider, teamRepository, conversationRepository)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@ import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.conversation.InteractionAvailability
import com.wire.kalium.logic.data.conversation.interactionAvailability
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.data.message.MessageContent
import com.wire.kalium.logic.data.user.SelfUser
import com.wire.kalium.logic.data.user.SupportedProtocol
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.getOrElse
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.util.KaliumDispatcher
import com.wire.kalium.util.KaliumDispatcherImpl
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.combine
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.withContext

Expand All @@ -48,6 +51,8 @@ import kotlinx.coroutines.withContext
class ObserveConversationInteractionAvailabilityUseCase internal constructor(
private val conversationRepository: ConversationRepository,
private val userRepository: UserRepository,
private val selfUserId: UserId,
private val selfClientIdProvider: CurrentClientIdProvider,
private val dispatcher: KaliumDispatcher = KaliumDispatcherImpl,
) {

Expand All @@ -56,13 +61,21 @@ class ObserveConversationInteractionAvailabilityUseCase internal constructor(
* @return an [IsInteractionAvailableResult] containing Success or Failure cases
*/
suspend operator fun invoke(conversationId: ConversationId): Flow<IsInteractionAvailableResult> = withContext(dispatcher.io) {
conversationRepository.observeConversationDetailsById(conversationId).combine(
userRepository.observeSelfUser()
) { conversation, selfUser ->
conversation to selfUser
}.map { (eitherConversation, selfUser) ->

val isSelfClientMlsCapable = selfClientIdProvider().flatMap {
userRepository.isClientMlsCapable(selfUserId, it)
}.getOrElse {
return@withContext flow { IsInteractionAvailableResult.Failure(it) }
}

kaliumLogger.withTextTag("ObserveConversationInteractionAvailabilityUseCase").d("isSelfClientMlsCapable $isSelfClientMlsCapable")

conversationRepository.observeConversationDetailsById(conversationId).map { eitherConversation ->
eitherConversation.fold({ failure -> IsInteractionAvailableResult.Failure(failure) }, { conversationDetails ->
val isProtocolSupported = doesUserSupportConversationProtocol(conversationDetails, selfUser)
val isProtocolSupported = doesUserSupportConversationProtocol(
conversationDetails = conversationDetails,
isSelfClientMlsCapable = isSelfClientMlsCapable
)
if (!isProtocolSupported) { // short-circuit to Unsupported Protocol if it's the case
return@fold IsInteractionAvailableResult.Success(InteractionAvailability.UNSUPPORTED_PROTOCOL)
}
Expand All @@ -74,19 +87,12 @@ class ObserveConversationInteractionAvailabilityUseCase internal constructor(

private fun doesUserSupportConversationProtocol(
conversationDetails: ConversationDetails,
selfUser: SelfUser
): Boolean {
val protocolInfo = conversationDetails.conversation.protocol
val acceptableProtocols = when (protocolInfo) {
is Conversation.ProtocolInfo.MLS -> setOf(SupportedProtocol.MLS)
// Messages in mixed conversations are sent through Proteus
is Conversation.ProtocolInfo.Mixed -> setOf(SupportedProtocol.PROTEUS)
Conversation.ProtocolInfo.Proteus -> setOf(SupportedProtocol.PROTEUS)
}
val isProtocolSupported = selfUser.supportedProtocols?.any { supported ->
acceptableProtocols.contains(supported)
} ?: false
return isProtocolSupported
isSelfClientMlsCapable: Boolean
): Boolean = when (conversationDetails.conversation.protocol) {
is Conversation.ProtocolInfo.MLS -> isSelfClientMlsCapable
// Messages in mixed conversations are sent through Proteus
is Conversation.ProtocolInfo.Mixed,
Conversation.ProtocolInfo.Proteus -> true
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,21 @@ package com.wire.kalium.logic.feature.conversation

import app.cash.turbine.test
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.InteractionAvailability
import com.wire.kalium.logic.data.user.ConnectionState
import com.wire.kalium.logic.data.user.SupportedProtocol
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.framework.TestConversation
import com.wire.kalium.logic.framework.TestConversationDetails
import com.wire.kalium.logic.framework.TestUser
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.right
import com.wire.kalium.logic.test_util.TestKaliumDispatcher
import com.wire.kalium.logic.test_util.testKaliumDispatcher
import com.wire.kalium.logic.util.arrangement.provider.CurrentClientIdProviderArrangement
import com.wire.kalium.logic.util.arrangement.provider.CurrentClientIdProviderArrangementImpl
import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement
import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl
import com.wire.kalium.logic.util.arrangement.repository.UserRepositoryArrangement
Expand All @@ -41,6 +46,7 @@ import io.mockative.once
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.test.runTest
import kotlin.test.Ignore
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertIs
Expand All @@ -52,6 +58,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
val conversationId = TestConversation.ID

val (arrangement, observeConversationInteractionAvailability) = arrange {
withIsClientMlsCapable(false.right())
dispatcher = testKaliumDispatcher
withSelfUserBeingMemberOfConversation(isMember = true)
}
Expand All @@ -76,6 +83,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
val (arrangement, observeConversationInteractionAvailability) = arrange {
dispatcher = testKaliumDispatcher
withSelfUserBeingMemberOfConversation(isMember = false)
withIsClientMlsCapable(false.right())
}

observeConversationInteractionAvailability(conversationId).test {
Expand All @@ -96,6 +104,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
val conversationId = TestConversation.ID

val (arrangement, observeConversationInteractionAvailability) = arrange {
withIsClientMlsCapable(false.right())
dispatcher = testKaliumDispatcher
withGroupConversationError()
}
Expand All @@ -118,6 +127,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
val conversationId = TestConversation.ID

val (arrangement, observeConversationInteractionAvailability) = arrange {
withIsClientMlsCapable(false.right())
dispatcher = testKaliumDispatcher
withBlockedUserConversation()
}
Expand All @@ -132,14 +142,14 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {

awaitComplete()
}

}

@Test
fun givenOtherUserIsDeleted_whenInvokingInteractionForConversation_thenInteractionShouldBeDisabled() = runTest {
val conversationId = TestConversation.ID

val (arrangement, observeConversationInteractionAvailability) = arrange {
withIsClientMlsCapable(false.right())
dispatcher = testKaliumDispatcher
withDeletedUserConversation()
}
Expand All @@ -156,11 +166,12 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
}
}

@Ignore // is this really a case that a client does not support Proteus
@Test
fun givenProteusConversationAndUserSupportsOnlyMLS_whenObserving_thenShouldReturnUnsupportedProtocol() = runTest {
testProtocolSupport(
conversationProtocolInfo = Conversation.ProtocolInfo.Proteus,
userSupportedProtocols = setOf(SupportedProtocol.MLS),
isMlsCapable = true.right(),
expectedResult = InteractionAvailability.UNSUPPORTED_PROTOCOL
)
}
Expand All @@ -169,7 +180,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
fun givenMLSConversationAndUserSupportsOnlyMLS_whenObserving_thenShouldReturnUnsupportedProtocol() = runTest {
testProtocolSupport(
conversationProtocolInfo = TestConversation.MLS_PROTOCOL_INFO,
userSupportedProtocols = setOf(SupportedProtocol.PROTEUS),
isMlsCapable = false.right(),
expectedResult = InteractionAvailability.UNSUPPORTED_PROTOCOL
)
}
Expand All @@ -178,7 +189,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
fun givenMixedConversationAndUserSupportsOnlyMLS_whenObserving_thenShouldReturnUnsupportedProtocol() = runTest {
testProtocolSupport(
conversationProtocolInfo = TestConversation.MIXED_PROTOCOL_INFO,
userSupportedProtocols = setOf(SupportedProtocol.PROTEUS),
isMlsCapable = false.right(),
expectedResult = InteractionAvailability.ENABLED
)
}
Expand All @@ -187,7 +198,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
fun givenMixedConversationAndUserSupportsProteus_whenObserving_thenShouldReturnEnabled() = runTest {
testProtocolSupport(
conversationProtocolInfo = TestConversation.MIXED_PROTOCOL_INFO,
userSupportedProtocols = setOf(SupportedProtocol.PROTEUS),
isMlsCapable = false.right(),
expectedResult = InteractionAvailability.ENABLED
)
}
Expand All @@ -196,35 +207,35 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
fun givenMLSConversationAndUserSupportsMLS_whenObserving_thenShouldReturnEnabled() = runTest {
testProtocolSupport(
conversationProtocolInfo = TestConversation.MLS_PROTOCOL_INFO,
userSupportedProtocols = setOf(SupportedProtocol.MLS),
expectedResult = InteractionAvailability.ENABLED
expectedResult = InteractionAvailability.ENABLED,
isMlsCapable = true.right()
)
}

@Test
fun givenProteusConversationAndUserSupportsProteus_whenObserving_thenShouldReturnEnabled() = runTest {
testProtocolSupport(
conversationProtocolInfo = TestConversation.PROTEUS_PROTOCOL_INFO,
userSupportedProtocols = setOf(SupportedProtocol.PROTEUS),
expectedResult = InteractionAvailability.ENABLED
expectedResult = InteractionAvailability.ENABLED,
isMlsCapable = false.right()
)
}

private suspend fun CoroutineScope.testProtocolSupport(
conversationProtocolInfo: Conversation.ProtocolInfo,
userSupportedProtocols: Set<SupportedProtocol>,
isMlsCapable: Either<StorageFailure, Boolean>,
expectedResult: InteractionAvailability
) {
val convId = TestConversationDetails.CONVERSATION_GROUP.conversation.id
val (_, observeConversationInteractionAvailabilityUseCase) = arrange {
withIsClientMlsCapable(isMlsCapable)
dispatcher = testKaliumDispatcher
val proteusGroupDetails = TestConversationDetails.CONVERSATION_GROUP.copy(
conversation = TestConversationDetails.CONVERSATION_GROUP.conversation.copy(
protocol = conversationProtocolInfo
)
)
withObserveConversationDetailsByIdReturning(Either.Right(proteusGroupDetails))
withObservingSelfUserReturning(flowOf(TestUser.SELF.copy(supportedProtocols = userSupportedProtocols)))
}

observeConversationInteractionAvailabilityUseCase(convId).test {
Expand All @@ -241,6 +252,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
val (_, observeConversationInteractionAvailability) = arrange {
dispatcher = testKaliumDispatcher
withLegalHoldOneOnOneConversation(Conversation.LegalHoldStatus.ENABLED)
withIsClientMlsCapable(false.right())
}
observeConversationInteractionAvailability(conversationId).test {
val interactionResult = awaitItem()
Expand All @@ -253,6 +265,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
fun givenConversationLegalHoldIsDegraded_whenInvokingInteractionForConversation_thenInteractionShouldBeLegalHold() = runTest {
val conversationId = TestConversation.ID
val (_, observeConversationInteractionAvailability) = arrange {
withIsClientMlsCapable(false.right())
dispatcher = testKaliumDispatcher
withLegalHoldOneOnOneConversation(Conversation.LegalHoldStatus.DEGRADED)
}
Expand All @@ -266,10 +279,12 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
private class Arrangement(
private val configure: suspend Arrangement.() -> Unit
) : UserRepositoryArrangement by UserRepositoryArrangementImpl(),
ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl() {
ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl(),
CurrentClientIdProviderArrangement by CurrentClientIdProviderArrangementImpl() {

var dispatcher: KaliumDispatcher = TestKaliumDispatcher

val selfUser = UserId("self_value", "self_domain")
suspend fun withSelfUserBeingMemberOfConversation(isMember: Boolean) = apply {
withObserveConversationDetailsByIdReturning(
Either.Right(TestConversationDetails.CONVERSATION_GROUP.copy(isSelfUserMember = isMember))
Expand Down Expand Up @@ -315,17 +330,15 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
}

suspend fun arrange(): Pair<Arrangement, ObserveConversationInteractionAvailabilityUseCase> = run {
withObservingSelfUserReturning(
flowOf(
TestUser.SELF.copy(supportedProtocols = setOf(SupportedProtocol.MLS, SupportedProtocol.PROTEUS))
)
)
withCurrentClientIdSuccess(ClientId("client_id"))
configure()
this@Arrangement to ObserveConversationInteractionAvailabilityUseCase(
conversationRepository = conversationRepository,
userRepository = userRepository,
dispatcher = dispatcher
)
dispatcher = dispatcher,
selfUserId = selfUser,
selfClientIdProvider = currentClientIdProvider
)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.wire.kalium.logic.util.arrangement.repository

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.conversation.mls.NameAndHandle
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.QualifiedID
Expand Down Expand Up @@ -95,6 +96,12 @@ internal interface UserRepositoryArrangement {
)

suspend fun withNameAndHandle(result: Either<StorageFailure, NameAndHandle>, userId: Matcher<UserId> = AnyMatcher(valueOf()))

suspend fun withIsClientMlsCapable(
result: Either<StorageFailure, Boolean>,
userId: Matcher<UserId> = AnyMatcher(valueOf()),
clientId: Matcher<ClientId> = AnyMatcher(valueOf())
)
}

@Suppress("INAPPLICABLE_JVM_NAME")
Expand Down Expand Up @@ -233,4 +240,13 @@ internal open class UserRepositoryArrangementImpl : UserRepositoryArrangement {
override suspend fun withNameAndHandle(result: Either<StorageFailure, NameAndHandle>, userId: Matcher<UserId>) {
coEvery { userRepository.getNameAndHandle(matches { userId.matches(it) }) }.returns(result)
}

override suspend fun withIsClientMlsCapable(result: Either<StorageFailure, Boolean>, userId: Matcher<UserId>, clientId: Matcher<ClientId>) {
coEvery {
userRepository.isClientMlsCapable(
userId = matches { userId.matches(it) },
clientId = matches { clientId.matches(it) }
)
}.returns(result)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ SELECT * FROM Client WHERE user_id = :user_id AND id = :client_id;
deleteClientsOfUserExcept:
DELETE FROM Client WHERE user_id = :user_id AND id NOT IN :exception_ids;

isClientMLSCapable:
SELECT is_mls_capable FROM Client WHERE user_id = :user_id AND id = :client_id;

tryMarkAsInvalid:
UPDATE OR IGNORE Client SET is_valid = 0 WHERE user_id = :user_id AND id IN :clientId_List;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,5 @@ interface ClientDAO {
): Map<QualifiedIDEntity, List<Client>>

suspend fun selectAllClients(): Map<QualifiedIDEntity, List<Client>>
suspend fun isMLSCapable(userId: QualifiedIDEntity, clientId: String): Boolean?
}
Loading
Loading