Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: mls client init [WPB-15022] #3178

Merged
merged 5 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ interface MLSFailure : CoreFailure {
data object StaleProposal : MLSFailure
data object StaleCommit : MLSFailure
data object InternalErrors : MLSFailure
data object Disabled : MLSFailure

data class Generic(internal val exception: Exception) : MLSFailure {
val rootCause: Throwable get() = exception
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import com.wire.kalium.cryptography.coreCryptoCentral
import com.wire.kalium.logger.KaliumLogLevel
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.E2EIFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.featureConfig.FeatureConfigRepository
Expand Down Expand Up @@ -130,6 +131,10 @@ class MLSClientProviderImpl(
}

override suspend fun getOrFetchMLSConfig(): Either<CoreFailure, SupportedCipherSuite> {
if (!userConfigRepository.isMLSEnabled().getOrElse(true)) {
kaliumLogger.w("$TAG: Cannot fetch MLS config, MLS is disabled.")
return MLSFailure.Disabled.left()
}
return userConfigRepository.getSupportedCipherSuite().flatMapLeft<CoreFailure, SupportedCipherSuite> {
featureConfigRepository.getFeatureConfigs().map {
it.mlsModel.supportedCipherSuite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,16 +432,19 @@ internal class ConversationDataSource internal constructor(
): Either<CoreFailure, Boolean> = wrapStorageRequest {
val isNewConversation = conversationDAO.getConversationById(conversation.id.toDao()) == null
if (isNewConversation) {
conversationDAO.insertConversation(
conversationMapper.fromApiModelToDaoModel(
conversation,
mlsGroupState = conversation.groupId?.let { mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent) },
selfTeamIdProvider().getOrNull(),
val mlsGroupState = conversation.groupId?.let { mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent) }
if (shouldPersistMLSConversation(mlsGroupState)) {
conversationDAO.insertConversation(
conversationMapper.fromApiModelToDaoModel(
conversation,
mlsGroupState = mlsGroupState?.getOrNull(),
selfTeamIdProvider().getOrNull(),
)
)
)
memberDAO.insertMembersWithQualifiedId(
memberMapper.fromApiModelToDaoModel(conversation.members), idMapper.fromApiToDao(conversation.id)
)
memberDAO.insertMembersWithQualifiedId(
memberMapper.fromApiModelToDaoModel(conversation.members), idMapper.fromApiToDao(conversation.id)
)
}
}
isNewConversation
}
Expand All @@ -453,17 +456,19 @@ internal class ConversationDataSource internal constructor(
invalidateMembers: Boolean
) = wrapStorageRequest {
val conversationEntities = conversations
.map { conversationResponse ->
conversationMapper.fromApiModelToDaoModel(
conversationResponse,
mlsGroupState = conversationResponse.groupId?.let {
mlsGroupState(
idMapper.fromGroupIDEntity(it),
originatedFromEvent
)
},
selfTeamIdProvider().getOrNull(),
)
.mapNotNull { conversationResponse ->
val mlsGroupState = conversationResponse.groupId?.let {
mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent)
}
if (shouldPersistMLSConversation(mlsGroupState)) {
conversationMapper.fromApiModelToDaoModel(
conversationResponse,
mlsGroupState = mlsGroupState?.getOrNull(),
selfTeamIdProvider().getOrNull(),
)
} else {
null
}
}
conversationDAO.insertConversations(conversationEntities)
conversations.forEach { conversationsResponse ->
Expand All @@ -483,10 +488,11 @@ internal class ConversationDataSource internal constructor(
}
}

private suspend fun mlsGroupState(groupId: GroupID, originatedFromEvent: Boolean = false): ConversationEntity.GroupState =
hasEstablishedMLSGroup(groupId).fold({
throw IllegalStateException(it.toString()) // TODO find a more fitting exception?
}, { exists ->
private suspend fun mlsGroupState(
groupId: GroupID,
originatedFromEvent: Boolean = false
): Either<CoreFailure, ConversationEntity.GroupState> = hasEstablishedMLSGroup(groupId)
.map { exists ->
if (exists) {
ConversationEntity.GroupState.ESTABLISHED
} else {
Expand All @@ -496,7 +502,7 @@ internal class ConversationDataSource internal constructor(
ConversationEntity.GroupState.PENDING_JOIN
}
}
})
}

private suspend fun hasEstablishedMLSGroup(groupID: GroupID): Either<CoreFailure, Boolean> =
mlsClientProvider.getMLSClient()
Expand All @@ -506,6 +512,10 @@ internal class ConversationDataSource internal constructor(
}
}

// if group state is not null and is left, then we don't want to persist the MLS conversation
private fun shouldPersistMLSConversation(groupState: Either<CoreFailure, ConversationEntity.GroupState>?): Boolean =
groupState?.fold({ true }, { false }) != true

@DelicateKaliumApi("This function does not get values from cache")
override suspend fun getProteusSelfConversationId(): Either<StorageFailure, ConversationId> =
wrapStorageRequest { conversationDAO.getSelfConversationId(ConversationEntity.Protocol.PROTEUS) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -671,17 +671,23 @@ internal class MLSConversationDataSource(
})

override suspend fun getClientIdentity(clientId: ClientId) =
wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) }.flatMap {
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) }
.flatMap { conversationClientInfo ->
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {

mlsClient.getDeviceIdentities(
it.mlsGroupId,
listOf(CryptoQualifiedClientId(it.clientId, it.userId.toModel().toCrypto()))
).firstOrNull()
mlsClient.getDeviceIdentities(
conversationClientInfo.mlsGroupId,
listOf(
CryptoQualifiedClientId(
conversationClientInfo.clientId,
conversationClientInfo.userId.toModel().toCrypto()
)
)
).firstOrNull()
}
}
}
}

override suspend fun getUserIdentity(userId: UserId) =
wrapStorageRequest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1771,7 +1771,8 @@ class UserSessionScope internal constructor(
cachedClientIdClearer,
updateSupportedProtocolsAndResolveOneOnOnes,
registerMLSClientUseCase,
syncFeatureConfigsUseCase
syncFeatureConfigsUseCase,
userConfigRepository
)
}
val conversations: ConversationScope by lazy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package com.wire.kalium.logic.feature.client

import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.configuration.notification.NotificationTokenRepository
import com.wire.kalium.logic.data.auth.verification.SecondFactorVerificationRepository
import com.wire.kalium.logic.data.client.ClientRepository
Expand Down Expand Up @@ -71,7 +72,8 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor(
private val cachedClientIdClearer: CachedClientIdClearer,
private val updateSupportedProtocolsAndResolveOneOnOnes: UpdateSupportedProtocolsAndResolveOneOnOnesUseCase,
private val registerMLSClientUseCase: RegisterMLSClientUseCase,
private val syncFeatureConfigsUseCase: SyncFeatureConfigsUseCase
private val syncFeatureConfigsUseCase: SyncFeatureConfigsUseCase,
private val userConfigRepository: UserConfigRepository
) {
@OptIn(DelicateKaliumApi::class)
val register: RegisterClientUseCase
Expand Down Expand Up @@ -102,7 +104,7 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor(
val deregisterNativePushToken: DeregisterTokenUseCase
get() = DeregisterTokenUseCaseImpl(clientRepository, notificationTokenRepository)
val mlsKeyPackageCountUseCase: MLSKeyPackageCountUseCase
get() = MLSKeyPackageCountUseCaseImpl(keyPackageRepository, clientIdProvider, keyPackageLimitsProvider)
get() = MLSKeyPackageCountUseCaseImpl(keyPackageRepository, clientIdProvider, keyPackageLimitsProvider, userConfigRepository)
val restartSlowSyncProcessForRecoveryUseCase: RestartSlowSyncProcessForRecoveryUseCase
get() = RestartSlowSyncProcessForRecoveryUseCaseImpl(slowSyncRepository)
val refillKeyPackages: RefillKeyPackagesUseCase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package com.wire.kalium.logic.feature.client
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.featureFlags.FeatureSupport
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.getOrElse
import com.wire.kalium.logic.functional.isRight
import com.wire.kalium.util.DelicateKaliumApi

Expand All @@ -45,8 +45,8 @@ internal class IsAllowedToRegisterMLSClientUseCaseImpl(
) : IsAllowedToRegisterMLSClientUseCase {

override suspend operator fun invoke(): Boolean {
return featureSupport.isMLSSupported &&
mlsPublicKeysRepository.getKeys().isRight() &&
userConfigRepository.isMLSEnabled().fold({ false }, { isEnabled -> isEnabled })
return featureSupport.isMLSSupported
&& userConfigRepository.isMLSEnabled().getOrElse(false)
&& mlsPublicKeysRepository.getKeys().isRight()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ package com.wire.kalium.logic.feature.keypackage

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.getOrElse

/**
* This use case will return the current number of key packages.
Expand All @@ -37,6 +39,7 @@ internal class MLSKeyPackageCountUseCaseImpl(
private val keyPackageRepository: KeyPackageRepository,
private val currentClientIdProvider: CurrentClientIdProvider,
private val keyPackageLimitsProvider: KeyPackageLimitsProvider,
private val userConfigRepository: UserConfigRepository
) : MLSKeyPackageCountUseCase {
override suspend operator fun invoke(fromAPI: Boolean): MLSKeyPackageCountResult =
when (fromAPI) {
Expand All @@ -47,10 +50,15 @@ internal class MLSKeyPackageCountUseCaseImpl(
private suspend fun validKeyPackagesCountFromAPI() = currentClientIdProvider().fold({
MLSKeyPackageCountResult.Failure.FetchClientIdFailure(it)
}, { selfClient ->
keyPackageRepository.getAvailableKeyPackageCount(selfClient).fold(
{
MLSKeyPackageCountResult.Failure.NetworkCallFailure(it)
}, { MLSKeyPackageCountResult.Success(selfClient, it.count, keyPackageLimitsProvider.needsRefill(it.count)) })
if (userConfigRepository.isMLSEnabled().getOrElse(false)) {
keyPackageRepository.getAvailableKeyPackageCount(selfClient)
.fold(
{ MLSKeyPackageCountResult.Failure.NetworkCallFailure(it) },
{ MLSKeyPackageCountResult.Success(selfClient, it.count, keyPackageLimitsProvider.needsRefill(it.count)) }
)
} else {
MLSKeyPackageCountResult.Failure.NotEnabled
}
})

private suspend fun validKeyPackagesCountFromMLSClient() =
Expand All @@ -70,6 +78,7 @@ sealed class MLSKeyPackageCountResult {
sealed class Failure : MLSKeyPackageCountResult() {
class NetworkCallFailure(val networkFailure: NetworkFailure) : Failure()
class FetchClientIdFailure(val genericFailure: CoreFailure) : Failure()
data object NotEnabled : Failure()
data class Generic(val genericFailure: CoreFailure) : Failure()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ internal object MLSMessageFailureHandler {
is MLSFailure.StaleCommit -> MLSMessageFailureResolution.Ignore
is MLSFailure.MessageEpochTooOld -> MLSMessageFailureResolution.Ignore
is MLSFailure.InternalErrors -> MLSMessageFailureResolution.Ignore
is MLSFailure.Disabled -> MLSMessageFailureResolution.Ignore
else -> MLSMessageFailureResolution.InformUser
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package com.wire.kalium.logic.data.client

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.featureConfig.FeatureConfigTest
import com.wire.kalium.logic.data.featureConfig.MLSModel
Expand All @@ -32,12 +33,15 @@ import com.wire.kalium.logic.util.arrangement.repository.FeatureConfigRepository
import com.wire.kalium.logic.util.arrangement.repository.FeatureConfigRepositoryArrangementImpl
import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangement
import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangementImpl
import com.wire.kalium.logic.util.shouldFail
import com.wire.kalium.logic.util.shouldSucceed
import com.wire.kalium.persistence.dbPassphrase.PassphraseStorage
import io.ktor.util.reflect.instanceOf
import io.mockative.Mock
import io.mockative.coVerify
import io.mockative.mock
import io.mockative.once
import io.mockative.verify
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import kotlin.test.Test
Expand All @@ -63,12 +67,16 @@ class MLSClientProviderTest {
val (arrangement, mlsClientProvider) = Arrangement().arrange {
withGetSupportedCipherSuitesReturning(StorageFailure.DataNotFound.left())
withGetFeatureConfigsReturning(FeatureConfigTest.newModel(mlsModel = expected).right())
withGetMLSEnabledReturning(true.right())
}

mlsClientProvider.getOrFetchMLSConfig().shouldSucceed {
assertEquals(expected.supportedCipherSuite, it)
}

verify { arrangement.userConfigRepository.isMLSEnabled() }
.wasInvoked(exactly = once)

coVerify { arrangement.userConfigRepository.getSupportedCipherSuite() }
.wasInvoked(exactly = once)

Expand All @@ -88,12 +96,17 @@ class MLSClientProviderTest {

val (arrangement, mlsClientProvider) = Arrangement().arrange {
withGetSupportedCipherSuitesReturning(expected.right())
withGetMLSEnabledReturning(true.right())
withGetFeatureConfigsReturning(FeatureConfigTest.newModel().right())
}

mlsClientProvider.getOrFetchMLSConfig().shouldSucceed {
assertEquals(expected, it)
}

verify { arrangement.userConfigRepository.isMLSEnabled() }
.wasInvoked(exactly = once)

coVerify {
arrangement.userConfigRepository.getSupportedCipherSuite()
}.wasInvoked(exactly = once)
Expand All @@ -103,6 +116,37 @@ class MLSClientProviderTest {
}.wasNotInvoked()
}

@Test
fun givenMLSDisabledWhenGetOrFetchMLSConfigIsCalledThenDoNotCallGetSupportedCipherSuiteOrGetFeatureConfigs() = runTest {
// given
val (arrangement, mlsClientProvider) = Arrangement().arrange {
withGetMLSEnabledReturning(false.right())
withGetSupportedCipherSuitesReturning(
SupportedCipherSuite(
supported = listOf(
CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256,
CipherSuite.MLS_256_DHKEMP384_AES256GCM_SHA384_P384
),
default = CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256
).right()
)
}

// when
val result = mlsClientProvider.getOrFetchMLSConfig()

// then
result.shouldFail {
it.instanceOf(CoreFailure.Unknown::class)
}

coVerify { arrangement.userConfigRepository.getSupportedCipherSuite() }
.wasNotInvoked()

coVerify { arrangement.featureConfigRepository.getFeatureConfigs() }
.wasNotInvoked()
}

private class Arrangement : UserConfigRepositoryArrangement by UserConfigRepositoryArrangementImpl(),
FeatureConfigRepositoryArrangement by FeatureConfigRepositoryArrangementImpl() {

Expand Down
Loading
Loading