From 4d1dfd368532770b317f419f6729bb69aa9fbdd2 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 17 Oct 2023 16:39:42 +0200 Subject: [PATCH] refactor: unify access token refreshing logic [WPB-5038] (#2142) --- .../logic/data/auth/login/LoginRepository.kt | 12 +- .../register/RegisterAccountRepository.kt | 12 +- .../logic/data/session/SessionMapper.kt | 24 +-- .../logic/data/session/SessionRepository.kt | 10 +- .../logic/data/session/token/AccessToken.kt | 39 ++++ .../session/token/AccessTokenRepository.kt | 90 ++++++++ .../token/AccessTokenRepositoryFactory.kt | 39 ++++ .../kalium/logic/feature/UserSessionScope.kt | 37 +++- .../auth/AddAuthenticatedUserUseCase.kt | 16 +- .../kalium/logic/feature/auth/AuthSession.kt | 26 ++- .../kalium/logic/feature/auth/LoginUseCase.kt | 2 +- .../auth/sso/GetSSOLoginSessionUseCase.kt | 10 +- .../register/RegisterAccountUseCase.kt | 4 +- .../session/UpgradeCurrentSessionUseCase.kt | 21 +- .../session/token/AccessTokenRefresher.kt | 64 ++++++ .../token/AccessTokenRefresherFactory.kt | 47 +++++ .../logic/network/SessionManagerImpl.kt | 67 +++--- .../register/RegisterAccountRepositoryTest.kt | 23 +- .../logic/data/session/SessionMapperTest.kt | 16 +- .../auth/AddAuthenticatedUserUseCaseTest.kt | 30 ++- .../logic/feature/auth/LoginUseCaseTest.kt | 6 +- .../register/RegisterAccountUseCaseTest.kt | 4 +- .../session/token/AccessTokenRefresherTest.kt | 141 +++++++++++++ .../logic/network/SessionManagerTest.kt | 199 ++++++++++++++++++ .../kalium/monkeys/conversation/Monkey.kt | 2 +- .../api/base/authenticated/AccessTokenApi.kt | 5 +- .../kalium/network/api/base/model/Tokens.kt | 8 + .../AuthenticatedNetworkContainer.kt | 21 +- .../session/FailureToRefreshTokenException.kt | 30 +++ .../kalium/network/session/SessionManager.kt | 17 +- .../kotlin/com/wire/kalium/api/ApiTest.kt | 8 +- .../wire/kalium/api/TestSessionManagerV0.kt | 17 +- ...st.kt => SessionManagerIntegrationTest.kt} | 107 ++++++++-- .../kmmSettings/GlobalPrefProvider.kt | 3 +- .../kmmSettings/GlobalPrefProvider.kt | 3 +- .../persistence/client/AuthTokenStorage.kt | 28 ++- .../client/AuthTokenStorageTest.kt | 3 +- .../kmmSettings/GlobalPrefProvider.kt | 3 +- .../kotlin/action/LoginActions.kt | 4 +- 39 files changed, 1001 insertions(+), 197 deletions(-) create mode 100644 logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/token/AccessToken.kt create mode 100644 logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/token/AccessTokenRepository.kt create mode 100644 logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/token/AccessTokenRepositoryFactory.kt create mode 100644 logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/session/token/AccessTokenRefresher.kt create mode 100644 logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/session/token/AccessTokenRefresherFactory.kt create mode 100644 logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/session/token/AccessTokenRefresherTest.kt create mode 100644 logic/src/commonTest/kotlin/com/wire/kalium/logic/network/SessionManagerTest.kt create mode 100644 network/src/commonMain/kotlin/com/wire/kalium/network/session/FailureToRefreshTokenException.kt rename network/src/commonTest/kotlin/com/wire/kalium/api/common/{SessionManagerTest.kt => SessionManagerIntegrationTest.kt} (66%) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/auth/login/LoginRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/auth/login/LoginRepository.kt index 3141584e2eb..1a86ee247e3 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/auth/login/LoginRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/auth/login/LoginRepository.kt @@ -23,7 +23,7 @@ import com.wire.kalium.logic.data.id.IdMapper import com.wire.kalium.logic.data.session.SessionMapper import com.wire.kalium.logic.data.user.SsoId import com.wire.kalium.logic.di.MapperProvider -import com.wire.kalium.logic.feature.auth.AuthTokens +import com.wire.kalium.logic.feature.auth.AccountTokens import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.map import com.wire.kalium.logic.wrapApiRequest @@ -36,14 +36,14 @@ internal interface LoginRepository { label: String?, shouldPersistClient: Boolean, secondFactorVerificationCode: String? = null, - ): Either> + ): Either> suspend fun loginWithHandle( handle: String, password: String, label: String?, shouldPersistClient: Boolean - ): Either> + ): Either> } internal class LoginRepositoryImpl internal constructor( @@ -58,7 +58,7 @@ internal class LoginRepositoryImpl internal constructor( label: String?, shouldPersistClient: Boolean, secondFactorVerificationCode: String?, - ): Either> = + ): Either> = login( LoginApi.LoginParam.LoginWithEmail(email, password, label, secondFactorVerificationCode), shouldPersistClient @@ -69,7 +69,7 @@ internal class LoginRepositoryImpl internal constructor( password: String, label: String?, shouldPersistClient: Boolean, - ): Either> = + ): Either> = login( LoginApi.LoginParam.LoginWithHandle(handle, password, label), shouldPersistClient @@ -78,7 +78,7 @@ internal class LoginRepositoryImpl internal constructor( private suspend fun login( loginParam: LoginApi.LoginParam, persistClient: Boolean - ): Either> = wrapApiRequest { + ): Either> = wrapApiRequest { loginApi.login(param = loginParam, persist = persistClient) }.map { Pair(sessionMapper.fromSessionDTO(it.first), idMapper.toSsoId(it.second.ssoID)) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/register/RegisterAccountRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/register/RegisterAccountRepository.kt index 014b3e7223c..07fe1266217 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/register/RegisterAccountRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/register/RegisterAccountRepository.kt @@ -23,7 +23,7 @@ import com.wire.kalium.logic.data.id.IdMapper import com.wire.kalium.logic.data.session.SessionMapper import com.wire.kalium.logic.data.user.SsoId import com.wire.kalium.logic.di.MapperProvider -import com.wire.kalium.logic.feature.auth.AuthTokens +import com.wire.kalium.logic.feature.auth.AccountTokens import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.map import com.wire.kalium.logic.wrapApiRequest @@ -42,7 +42,7 @@ internal interface RegisterAccountRepository { name: String, password: String, cookieLabel: String? - ): Either> + ): Either> @Suppress("LongParameterList") suspend fun registerTeamWithEmail( @@ -53,7 +53,7 @@ internal interface RegisterAccountRepository { teamName: String, teamIcon: String, cookieLabel: String? - ): Either> + ): Either> } internal class RegisterAccountDataSource internal constructor( @@ -76,7 +76,7 @@ internal class RegisterAccountDataSource internal constructor( name: String, password: String, cookieLabel: String? - ): Either> = + ): Either> = register( RegisterApi.RegisterParam.PersonalAccount( email = email, @@ -95,7 +95,7 @@ internal class RegisterAccountDataSource internal constructor( teamName: String, teamIcon: String, cookieLabel: String? - ): Either> = + ): Either> = register( RegisterApi.RegisterParam.TeamAccount( email = email, @@ -115,7 +115,7 @@ internal class RegisterAccountDataSource internal constructor( private suspend fun activateUser(param: RegisterApi.ActivationParam): Either = wrapApiRequest { registerApi.activate(param) } - private suspend fun register(param: RegisterApi.RegisterParam): Either> = + private suspend fun register(param: RegisterApi.RegisterParam): Either> = wrapApiRequest { registerApi.register(param) }.map { Pair(idMapper.toSsoId(it.first.ssoID), sessionMapper.fromSessionDTO(it.second)) } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionMapper.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionMapper.kt index 7ceae8098c2..3a73941fd10 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionMapper.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionMapper.kt @@ -26,7 +26,7 @@ import com.wire.kalium.logic.data.id.toModel import com.wire.kalium.logic.data.logout.LogoutReason import com.wire.kalium.logic.data.user.SsoId import com.wire.kalium.logic.feature.auth.AccountInfo -import com.wire.kalium.logic.feature.auth.AuthTokens +import com.wire.kalium.logic.feature.auth.AccountTokens import com.wire.kalium.logic.feature.auth.PersistentWebSocketStatus import com.wire.kalium.network.api.base.model.ProxyCredentialsDTO import com.wire.kalium.network.api.base.model.SessionDTO @@ -39,13 +39,13 @@ import com.wire.kalium.persistence.model.LogoutReason as LogoutReasonEntity @Suppress("TooManyFunctions") interface SessionMapper { - fun toSessionDTO(authSession: AuthTokens): SessionDTO + fun toSessionDTO(authSession: AccountTokens): SessionDTO fun fromEntityToSessionDTO(authTokenEntity: AuthTokenEntity): SessionDTO - fun fromSessionDTO(sessionDTO: SessionDTO): AuthTokens + fun fromSessionDTO(sessionDTO: SessionDTO): AccountTokens fun fromAccountInfoEntity(accountInfoEntity: AccountInfoEntity): AccountInfo fun toLogoutReasonEntity(reason: LogoutReason): LogoutReasonEntity fun toSsoIdEntity(ssoId: SsoId?): SsoIdEntity? - fun toAuthTokensEntity(authSession: AuthTokens): AuthTokenEntity + fun toAuthTokensEntity(authSession: AccountTokens): AuthTokenEntity fun fromSsoIdEntity(ssoIdEntity: SsoIdEntity?): SsoId? fun toLogoutReason(reason: LogoutReasonEntity): LogoutReason fun fromEntityToProxyCredentialsDTO(proxyCredentialsEntity: ProxyCredentialsEntity): ProxyCredentialsDTO @@ -62,12 +62,12 @@ internal class SessionMapperImpl( private val idMapper: IdMapper ) : SessionMapper { - override fun toSessionDTO(authSession: AuthTokens): SessionDTO = with(authSession) { + override fun toSessionDTO(authSession: AccountTokens): SessionDTO = with(authSession) { SessionDTO( userId = userId.toApi(), tokenType = tokenType, - accessToken = accessToken, - refreshToken = refreshToken, + accessToken = accessToken.value, + refreshToken = refreshToken.value, cookieLabel = cookieLabel ) } @@ -82,8 +82,8 @@ internal class SessionMapperImpl( ) } - override fun fromSessionDTO(sessionDTO: SessionDTO): AuthTokens = with(sessionDTO) { - AuthTokens( + override fun fromSessionDTO(sessionDTO: SessionDTO): AccountTokens = with(sessionDTO) { + AccountTokens( userId = userId.toModel(), accessToken = accessToken, refreshToken = refreshToken, @@ -112,11 +112,11 @@ internal class SessionMapperImpl( override fun toSsoIdEntity(ssoId: SsoId?): SsoIdEntity? = ssoId?.let { SsoIdEntity(scimExternalId = it.scimExternalId, subject = it.subject, tenant = it.tenant) } - override fun toAuthTokensEntity(authSession: AuthTokens): AuthTokenEntity = with(authSession) { + override fun toAuthTokensEntity(authSession: AccountTokens): AuthTokenEntity = with(authSession) { AuthTokenEntity( userId = userId.toDao(), - accessToken = accessToken, - refreshToken = refreshToken, + accessToken = accessToken.value, + refreshToken = refreshToken.value, tokenType = tokenType, cookieLabel = cookieLabel ) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionRepository.kt index 602de05a06a..d15fd1e8b38 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionRepository.kt @@ -31,7 +31,7 @@ import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.di.MapperProvider import com.wire.kalium.logic.feature.auth.Account import com.wire.kalium.logic.feature.auth.AccountInfo -import com.wire.kalium.logic.feature.auth.AuthTokens +import com.wire.kalium.logic.feature.auth.AccountTokens import com.wire.kalium.logic.feature.auth.PersistentWebSocketStatus import com.wire.kalium.logic.featureFlags.KaliumConfigs import com.wire.kalium.logic.functional.Either @@ -54,7 +54,7 @@ interface SessionRepository { suspend fun storeSession( serverConfigId: String, ssoId: SsoId?, - authTokens: AuthTokens, + accountTokens: AccountTokens, proxyCredentials: ProxyCredentials? ): Either @@ -93,12 +93,12 @@ internal class SessionDataSource( override suspend fun storeSession( serverConfigId: String, ssoId: SsoId?, - authTokens: AuthTokens, + accountTokens: AccountTokens, proxyCredentials: ProxyCredentials? ): Either = wrapStorageRequest { accountsDAO.insertOrReplace( - authTokens.userId.toDao(), + accountTokens.userId.toDao(), sessionMapper.toSsoIdEntity(ssoId), serverConfigId, isPersistentWebSocketEnabled = kaliumConfigs.isWebSocketEnabledByDefault @@ -106,7 +106,7 @@ internal class SessionDataSource( }.flatMap { wrapStorageRequest { authTokenStorage.addOrReplace( - sessionMapper.toAuthTokensEntity(authTokens), + sessionMapper.toAuthTokensEntity(accountTokens), proxyCredentials?.let { sessionMapper.fromModelToProxyCredentialsEntity(it) } ) } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/token/AccessToken.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/token/AccessToken.kt new file mode 100644 index 00000000000..fd13035757b --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/token/AccessToken.kt @@ -0,0 +1,39 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.data.session.token + +import kotlin.jvm.JvmInline + +internal data class AccessTokenRefreshResult( + val accessToken: AccessToken, + val refreshToken: RefreshToken +) + +/** + * Represents an access token, which is used for authentication and authorization purposes. + * + * @property value The value of the access token. + * @property tokenType The type of the access token. _e.g._ "Bearer" + */ +data class AccessToken( + val value: String, + val tokenType: String +) + +@JvmInline +value class RefreshToken(val value: String) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/token/AccessTokenRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/token/AccessTokenRepository.kt new file mode 100644 index 00000000000..2fe7f28d789 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/token/AccessTokenRepository.kt @@ -0,0 +1,90 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.data.session.token + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.NetworkFailure +import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.data.id.toDao +import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.map +import com.wire.kalium.logic.wrapApiRequest +import com.wire.kalium.logic.wrapStorageRequest +import com.wire.kalium.network.api.base.authenticated.AccessTokenApi +import com.wire.kalium.persistence.client.AuthTokenStorage + +internal interface AccessTokenRepository { + /** + * Retrieves a new access token using the provided refresh token and client ID. + * + * If provided, the new token will be associated with this client ID. + * If the client is remotely removed by the user, the tokens will be invalidated. + * Future refreshes will keep the previously associated client ID. + * _i.e._ after the first refresh, the client ID doesn't need to be provided anymore. + * + * @param refreshToken The refresh token to use for obtaining a new access token. + * @param clientId The optional client ID. + * @return Either a [CoreFailure] or the new access token. + */ + suspend fun getNewAccessToken( + refreshToken: String, + clientId: String? = null + ): Either + + /** + * Persists the access token and refresh token in the repository. + * + * @param accessToken The access token to persist. + * @param refreshToken The refresh token to persist. + * @return Either a [CoreFailure] if the operation fails, or [Unit] if the tokens are successfully persisted. + */ + suspend fun persistTokens( + accessToken: AccessToken, + refreshToken: RefreshToken + ): Either +} + +internal class AccessTokenRepositoryImpl( + private val userId: UserId, + private val accessTokenApi: AccessTokenApi, + private val authTokenStorage: AuthTokenStorage, +) : AccessTokenRepository { + override suspend fun getNewAccessToken( + refreshToken: String, + clientId: String? + ): Either = wrapApiRequest { + accessTokenApi.getToken(refreshToken, clientId) + }.map { (accessTokenDTO, newRefreshToken) -> + val token = AccessToken(accessTokenDTO.value, accessTokenDTO.tokenType) + val resolvedRefreshToken = newRefreshToken?.value ?: refreshToken + AccessTokenRefreshResult(token, RefreshToken(resolvedRefreshToken)) + } + + override suspend fun persistTokens( + accessToken: AccessToken, + refreshToken: RefreshToken + ): Either = wrapStorageRequest { + authTokenStorage.updateToken( + userId.toDao(), + accessToken.value, + accessToken.tokenType, + refreshToken.value + ) + }.map { } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/token/AccessTokenRepositoryFactory.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/token/AccessTokenRepositoryFactory.kt new file mode 100644 index 00000000000..ccf39b0d273 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/token/AccessTokenRepositoryFactory.kt @@ -0,0 +1,39 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.data.session.token + +import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.network.api.base.authenticated.AccessTokenApi +import com.wire.kalium.persistence.client.AuthTokenStorage + +/** + * Interface for creating an [AccessTokenRepository] instance. + * Allows intaking a dynamic [AccessTokenApi] for its construction. + */ +internal interface AccessTokenRepositoryFactory { + fun create(tokenApi: AccessTokenApi): AccessTokenRepository +} + +internal class AccessTokenRepositoryFactoryImpl( + private val userId: UserId, + private val tokenStorage: AuthTokenStorage +) : AccessTokenRepositoryFactory { + override fun create(tokenApi: AccessTokenApi): AccessTokenRepository { + return AccessTokenRepositoryImpl(userId, tokenApi, tokenStorage) + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index e63bb2fd2f1..fc7c239841a 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -113,6 +113,8 @@ import com.wire.kalium.logic.data.publicuser.UserSearchApiWrapper import com.wire.kalium.logic.data.publicuser.UserSearchApiWrapperImpl import com.wire.kalium.logic.data.service.ServiceDataSource import com.wire.kalium.logic.data.service.ServiceRepository +import com.wire.kalium.logic.data.session.token.AccessTokenRepository +import com.wire.kalium.logic.data.session.token.AccessTokenRepositoryImpl import com.wire.kalium.logic.data.sync.InMemoryIncrementalSyncRepository import com.wire.kalium.logic.data.sync.IncrementalSyncRepository import com.wire.kalium.logic.data.sync.SlowSyncRepository @@ -244,6 +246,10 @@ import com.wire.kalium.logic.feature.service.ServiceScope import com.wire.kalium.logic.feature.session.GetProxyCredentialsUseCase import com.wire.kalium.logic.feature.session.GetProxyCredentialsUseCaseImpl import com.wire.kalium.logic.feature.session.UpgradeCurrentSessionUseCaseImpl +import com.wire.kalium.logic.feature.session.token.AccessTokenRefresher +import com.wire.kalium.logic.feature.session.token.AccessTokenRefresherFactory +import com.wire.kalium.logic.feature.session.token.AccessTokenRefresherFactoryImpl +import com.wire.kalium.logic.feature.session.token.AccessTokenRefresherImpl import com.wire.kalium.logic.feature.team.SyncSelfTeamUseCase import com.wire.kalium.logic.feature.team.SyncSelfTeamUseCaseImpl import com.wire.kalium.logic.feature.team.TeamScope @@ -484,8 +490,28 @@ class UserSessionScope internal constructor( private val selfTeamId = SelfTeamIdProvider { teamId() } + private val accessTokenRepository: AccessTokenRepository + get() = AccessTokenRepositoryImpl( + userId = userId, + accessTokenApi = authenticatedNetworkContainer.accessTokenApi, + authTokenStorage = globalPreferences.authTokenStorage + ) + + private val accessTokenRefresherFactory: AccessTokenRefresherFactory + get() = AccessTokenRefresherFactoryImpl( + userId = userId, + tokenStorage = globalPreferences.authTokenStorage + ) + + private val accessTokenRefresher: AccessTokenRefresher + get() = AccessTokenRefresherImpl( + userId = userId, + repository = accessTokenRepository + ) + private val sessionManager: SessionManager = SessionManagerImpl( sessionRepository = globalScope.sessionRepository, + accessTokenRefresherFactory = accessTokenRefresherFactory, userId = userId, tokenStorage = globalPreferences.authTokenStorage, logout = { logoutReason -> logout(logoutReason) } @@ -997,12 +1023,11 @@ class UserSessionScope internal constructor( } private val upgradeCurrentSessionUseCase - get() = - UpgradeCurrentSessionUseCaseImpl( - authenticatedNetworkContainer, - authenticatedNetworkContainer.accessTokenApi, - sessionManager - ) + get() = UpgradeCurrentSessionUseCaseImpl( + authenticatedNetworkContainer, + accessTokenRefresher, + sessionManager + ) @Suppress("MagicNumber") private val apiMigrations = listOf( diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/AddAuthenticatedUserUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/AddAuthenticatedUserUseCase.kt index f8c120518ea..ef23072c578 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/AddAuthenticatedUserUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/AddAuthenticatedUserUseCase.kt @@ -46,7 +46,7 @@ class AddAuthenticatedUserUseCase internal constructor( suspend operator fun invoke( serverConfigId: String, ssoId: SsoId?, - authTokens: AuthTokens, + authTokens: AccountTokens, proxyCredentials: ProxyCredentials?, replace: Boolean = false ): Result = sessionRepository.doesValidSessionExist(authTokens.userId).fold( @@ -63,28 +63,28 @@ class AddAuthenticatedUserUseCase internal constructor( private suspend fun storeUser( serverConfigId: String, ssoId: SsoId?, - authTokens: AuthTokens, + accountTokens: AccountTokens, proxyCredentials: ProxyCredentials? ): Result = - sessionRepository.storeSession(serverConfigId, ssoId, authTokens, proxyCredentials) + sessionRepository.storeSession(serverConfigId, ssoId, accountTokens, proxyCredentials) .onSuccess { - sessionRepository.updateCurrentSession(authTokens.userId) + sessionRepository.updateCurrentSession(accountTokens.userId) }.fold( { Result.Failure.Generic(it) }, - { Result.Success(authTokens.userId) } + { Result.Success(accountTokens.userId) } ) // In case of the new session have a different server configurations the new session should not be added private suspend fun onUserExist( newServerConfigId: String, ssoId: SsoId?, - newAuthTokens: AuthTokens, + newAccountTokens: AccountTokens, proxyCredentials: ProxyCredentials?, replace: Boolean ): Result = when (replace) { true -> { - sessionRepository.fullAccountInfo(newAuthTokens.userId).fold( + sessionRepository.fullAccountInfo(newAccountTokens.userId).fold( { Result.Failure.Generic(it) }, { oldSession -> val newServerConfig = @@ -93,7 +93,7 @@ class AddAuthenticatedUserUseCase internal constructor( storeUser( serverConfigId = newServerConfigId, ssoId = ssoId, - authTokens = newAuthTokens, + accountTokens = newAccountTokens, proxyCredentials = proxyCredentials ) } else Result.Failure.UserAlreadyExists diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/AuthSession.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/AuthSession.kt index cf59af4bf65..c502971ddb8 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/AuthSession.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/AuthSession.kt @@ -20,6 +20,8 @@ package com.wire.kalium.logic.feature.auth import com.wire.kalium.logic.configuration.server.ServerConfig import com.wire.kalium.logic.data.logout.LogoutReason +import com.wire.kalium.logic.data.session.token.AccessToken +import com.wire.kalium.logic.data.session.token.RefreshToken import com.wire.kalium.logic.data.user.SsoId import com.wire.kalium.logic.data.user.UserId import kotlin.contracts.ExperimentalContracts @@ -55,10 +57,24 @@ data class Account( val ssoId: SsoId? ) -data class AuthTokens( +/** + * Holds information about the user ID, and the associated user id. + */ +data class AccountTokens( val userId: UserId, - val accessToken: String, - val refreshToken: String, - val tokenType: String, + val accessToken: AccessToken, + val refreshToken: RefreshToken, val cookieLabel: String? -) +) { + constructor( + userId: UserId, + accessToken: String, + refreshToken: String, + tokenType: String, + cookieLabel: String? + ) : this(userId, AccessToken(accessToken, tokenType), RefreshToken(refreshToken), cookieLabel) + + val tokenType: String + get() = accessToken.tokenType + +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/LoginUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/LoginUseCase.kt index 2188c65ab89..4fe55e73032 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/LoginUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/LoginUseCase.kt @@ -37,7 +37,7 @@ import com.wire.kalium.network.exceptions.isInvalidCredentials sealed class AuthenticationResult { data class Success( - val authData: AuthTokens, + val authData: AccountTokens, val ssoID: SsoId?, val serverConfigId: String, val proxyCredentials: ProxyCredentials? diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/sso/GetSSOLoginSessionUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/sso/GetSSOLoginSessionUseCase.kt index b9714763fa2..d0488f0d772 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/sso/GetSSOLoginSessionUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/sso/GetSSOLoginSessionUseCase.kt @@ -26,13 +26,17 @@ import com.wire.kalium.logic.data.id.IdMapper import com.wire.kalium.logic.data.session.SessionMapper import com.wire.kalium.logic.data.user.SsoId import com.wire.kalium.logic.di.MapperProvider -import com.wire.kalium.logic.feature.auth.AuthTokens +import com.wire.kalium.logic.feature.auth.AccountTokens import com.wire.kalium.logic.functional.fold import com.wire.kalium.network.exceptions.KaliumException import io.ktor.http.HttpStatusCode sealed class SSOLoginSessionResult { - data class Success(val authTokens: AuthTokens, val ssoId: SsoId?, val proxyCredentials: ProxyCredentials?) : SSOLoginSessionResult() + data class Success( + val accountTokens: AccountTokens, + val ssoId: SsoId?, + val proxyCredentials: ProxyCredentials? + ) : SSOLoginSessionResult() sealed class Failure : SSOLoginSessionResult() { object InvalidCookie : Failure() @@ -67,7 +71,7 @@ internal class GetSSOLoginSessionUseCaseImpl( SSOLoginSessionResult.Failure.Generic(it) }, { SSOLoginSessionResult.Success( - authTokens = sessionMapper.fromSessionDTO(it.sessionDTO), + accountTokens = sessionMapper.fromSessionDTO(it.sessionDTO), ssoId = idMapper.toSsoId(it.userDTO.ssoID), proxyCredentials = proxyCredentials ) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/register/RegisterAccountUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/register/RegisterAccountUseCase.kt index 7226178974a..bceb8a7096e 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/register/RegisterAccountUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/register/RegisterAccountUseCase.kt @@ -25,7 +25,7 @@ import com.wire.kalium.logic.configuration.server.ServerConfig import com.wire.kalium.logic.data.auth.login.ProxyCredentials import com.wire.kalium.logic.data.register.RegisterAccountRepository import com.wire.kalium.logic.data.user.SsoId -import com.wire.kalium.logic.feature.auth.AuthTokens +import com.wire.kalium.logic.feature.auth.AccountTokens import com.wire.kalium.logic.functional.fold import com.wire.kalium.logic.functional.map import com.wire.kalium.network.exceptions.KaliumException @@ -138,7 +138,7 @@ class RegisterAccountUseCase internal constructor( sealed class RegisterResult { data class Success( - val authData: AuthTokens, + val authData: AccountTokens, val ssoID: SsoId?, val serverConfigId: String, val proxyCredentials: ProxyCredentials? diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/session/UpgradeCurrentSessionUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/session/UpgradeCurrentSessionUseCase.kt index bf3c4d942ad..c16c88730d7 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/session/UpgradeCurrentSessionUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/session/UpgradeCurrentSessionUseCase.kt @@ -20,12 +20,11 @@ package com.wire.kalium.logic.feature.session import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.data.conversation.ClientId +import com.wire.kalium.logic.feature.session.token.AccessTokenRefresher import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.map -import com.wire.kalium.logic.wrapApiRequest import com.wire.kalium.logic.wrapStorageRequest -import com.wire.kalium.network.api.base.authenticated.AccessTokenApi import com.wire.kalium.network.networkContainer.AuthenticatedNetworkContainer import com.wire.kalium.network.session.SessionManager @@ -36,20 +35,16 @@ interface UpgradeCurrentSessionUseCase { suspend operator fun invoke(clientId: ClientId): Either } -class UpgradeCurrentSessionUseCaseImpl( +internal class UpgradeCurrentSessionUseCaseImpl( private val authenticatedNetworkContainer: AuthenticatedNetworkContainer, - private val accessTokenApi: AccessTokenApi, + private val accessTokenRefresher: AccessTokenRefresher, private val sessionManager: SessionManager ) : UpgradeCurrentSessionUseCase { override suspend operator fun invoke(clientId: ClientId): Either = wrapStorageRequest { sessionManager.session()?.refreshToken } - .flatMap { refreshToken -> - wrapApiRequest { - accessTokenApi.getToken(refreshToken, clientId.value) - }.flatMap { - wrapStorageRequest { sessionManager.updateLoginSession(it.first, it.second) } - }.map { - authenticatedNetworkContainer.clearCachedToken() - } + .flatMap { currentRefreshToken -> + accessTokenRefresher.refreshTokenAndPersistSession(currentRefreshToken, clientId.value) + }.map { + authenticatedNetworkContainer.clearCachedToken() } - } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/session/token/AccessTokenRefresher.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/session/token/AccessTokenRefresher.kt new file mode 100644 index 00000000000..14257f5f408 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/session/token/AccessTokenRefresher.kt @@ -0,0 +1,64 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.session.token + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.session.token.AccessTokenRepository +import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.feature.auth.AccountTokens +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.flatMap +import com.wire.kalium.logic.functional.map + +internal interface AccessTokenRefresher { + /** + * Refreshes the access token using the provided refresh token and persists the session in the repository. + * + * @param currentRefreshToken The refresh token to use for obtaining a new access token. + * @param clientId The optional client ID associated with the new token. + * @return Either a [CoreFailure] if the operation fails, or the [AccountTokens] with the new access token and refresh token. + */ + suspend fun refreshTokenAndPersistSession( + currentRefreshToken: String, + clientId: String? = null, + ): Either +} + +internal class AccessTokenRefresherImpl( + private val userId: UserId, + private val repository: AccessTokenRepository, +) : AccessTokenRefresher { + override suspend fun refreshTokenAndPersistSession( + currentRefreshToken: String, + clientId: String? + ): Either { + return repository.getNewAccessToken( + refreshToken = currentRefreshToken, + clientId = clientId + ).flatMap { result -> + repository.persistTokens(result.accessToken, result.refreshToken).map { + AccountTokens( + userId = userId, + accessToken = result.accessToken, + refreshToken = result.refreshToken, + cookieLabel = null + ) + } + } + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/session/token/AccessTokenRefresherFactory.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/session/token/AccessTokenRefresherFactory.kt new file mode 100644 index 00000000000..82fe32dee35 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/session/token/AccessTokenRefresherFactory.kt @@ -0,0 +1,47 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.session.token + +import com.wire.kalium.logic.data.session.token.AccessTokenRepositoryImpl +import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.network.api.base.authenticated.AccessTokenApi +import com.wire.kalium.persistence.client.AuthTokenStorage + +/** + * Represents a factory for creating instances of [AccessTokenRefresher]. + * Allows taking a dynamic [AccessTokenApi] for its construction. + */ +internal interface AccessTokenRefresherFactory { + fun create(accessTokenApi: AccessTokenApi): AccessTokenRefresher +} + +internal class AccessTokenRefresherFactoryImpl( + private val userId: UserId, + private val tokenStorage: AuthTokenStorage +) : AccessTokenRefresherFactory { + override fun create(accessTokenApi: AccessTokenApi): AccessTokenRefresher { + return AccessTokenRefresherImpl( + userId = userId, + repository = AccessTokenRepositoryImpl( + userId = userId, + accessTokenApi = accessTokenApi, + authTokenStorage = tokenStorage + ) + ) + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/network/SessionManagerImpl.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/network/SessionManagerImpl.kt index 06e4c367df5..63695c34e60 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/network/SessionManagerImpl.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/network/SessionManagerImpl.kt @@ -22,26 +22,26 @@ import com.wire.kalium.logger.obfuscateId import com.wire.kalium.logic.NetworkFailure import com.wire.kalium.logic.configuration.server.ServerConfigMapper import com.wire.kalium.logic.data.id.QualifiedID +import com.wire.kalium.logic.data.id.toApi import com.wire.kalium.logic.data.id.toDao import com.wire.kalium.logic.data.logout.LogoutReason import com.wire.kalium.logic.data.session.SessionMapper import com.wire.kalium.logic.data.session.SessionRepository import com.wire.kalium.logic.di.MapperProvider +import com.wire.kalium.logic.feature.session.token.AccessTokenRefresherFactory import com.wire.kalium.logic.functional.fold import com.wire.kalium.logic.functional.map import com.wire.kalium.logic.functional.nullableFold import com.wire.kalium.logic.functional.onFailure import com.wire.kalium.logic.functional.onSuccess import com.wire.kalium.logic.kaliumLogger -import com.wire.kalium.logic.wrapApiRequest import com.wire.kalium.logic.wrapStorageRequest import com.wire.kalium.network.api.base.authenticated.AccessTokenApi -import com.wire.kalium.network.api.base.model.AccessTokenDTO import com.wire.kalium.network.api.base.model.ProxyCredentialsDTO -import com.wire.kalium.network.api.base.model.RefreshTokenDTO import com.wire.kalium.network.api.base.model.SessionDTO import com.wire.kalium.network.exceptions.KaliumException import com.wire.kalium.network.exceptions.isUnknownClient +import com.wire.kalium.network.session.FailureToRefreshTokenException import com.wire.kalium.network.session.SessionManager import com.wire.kalium.network.tools.ServerConfigDTO import com.wire.kalium.persistence.client.AuthTokenStorage @@ -55,6 +55,7 @@ import kotlin.coroutines.CoroutineContext @Suppress("LongParameterList") class SessionManagerImpl internal constructor( private val sessionRepository: SessionRepository, + private val accessTokenRefresherFactory: AccessTokenRefresherFactory, private val userId: QualifiedID, private val tokenStorage: AuthTokenStorage, private val logout: suspend (LogoutReason) -> Unit, @@ -88,42 +89,38 @@ class SessionManagerImpl internal constructor( .onSuccess { serverConfig = it } .fold({ error("use serverConfig is missing or an error while reading local storage") }, { it }) serverConfig!! - }!! - - override suspend fun updateLoginSession(newAccessTokenDTO: AccessTokenDTO, newRefreshTokenDTO: RefreshTokenDTO?): SessionDTO? = - wrapStorageRequest { - tokenStorage.updateToken( - userId = userId.toDao(), - accessToken = newAccessTokenDTO.value, - tokenType = newAccessTokenDTO.tokenType, - refreshToken = newRefreshTokenDTO?.value - ) - }.map { - sessionMapper.fromEntityToSessionDTO(it) - }.onSuccess { - session = it - }.nullableFold({ - null - }, { - it - }) + } - override suspend fun updateToken(accessTokenApi: AccessTokenApi, oldAccessToken: String, oldRefreshToken: String): SessionDTO? { + override suspend fun updateToken( + accessTokenApi: AccessTokenApi, + oldAccessToken: String, + oldRefreshToken: String + ): SessionDTO { + val refresher = accessTokenRefresherFactory.create(accessTokenApi) return withContext(coroutineContext) { - wrapApiRequest { accessTokenApi.getToken(oldRefreshToken) }.nullableFold({ - when (it) { - is NetworkFailure.NoNetworkConnection -> null - is NetworkFailure.ProxyError -> null - is NetworkFailure.FederatedBackendFailure -> null - is NetworkFailure.FeatureNotSupported -> null - is NetworkFailure.ServerMiscommunication -> { - onServerMissCommunication(it) - null - } + refresher.refreshTokenAndPersistSession(oldRefreshToken).onFailure { + if (it is NetworkFailure.ServerMiscommunication) { + onServerMissCommunication(it) } + }.map { refreshResult -> + SessionDTO( + userId = userId.toApi(), + tokenType = refreshResult.accessToken.tokenType, + accessToken = refreshResult.accessToken.value, + refreshToken = refreshResult.refreshToken.value, + cookieLabel = refreshResult.cookieLabel + ) + }.fold({ + val message = "Failure during auth token refresh. " + + "A network request is failing because of this. " + + "Future requests should reattempt to refresh the token. Failure='$it'" + kaliumLogger.w(message) + throw FailureToRefreshTokenException(message) }, { - updateLoginSession(it.first, it.second) - }) + it + }).also { + session = it + } } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/register/RegisterAccountRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/register/RegisterAccountRepositoryTest.kt index 68f4e71c1c8..766cc79007c 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/register/RegisterAccountRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/register/RegisterAccountRepositoryTest.kt @@ -23,12 +23,11 @@ import com.wire.kalium.logic.data.id.IdMapper import com.wire.kalium.logic.data.session.SessionMapper import com.wire.kalium.logic.data.user.SsoId import com.wire.kalium.logic.data.user.UserId -import com.wire.kalium.logic.feature.auth.AuthTokens +import com.wire.kalium.logic.feature.auth.AccountTokens import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.test_util.TestNetworkException import com.wire.kalium.network.api.base.model.SelfUserDTO import com.wire.kalium.network.api.base.model.SessionDTO -import com.wire.kalium.network.api.base.model.UserDTO import com.wire.kalium.network.api.base.unauthenticated.register.RegisterApi import com.wire.kalium.network.utils.NetworkResponse import io.mockative.Mock @@ -138,8 +137,8 @@ class RegisterAccountRepositoryTest { this?.let { SsoId(scimExternalId = it.scimExternalId, subject = it.subject, tenant = it.tenant) } } val cookieLabel = "COOKIE_LABEL" - val authTokens = with(SESSION) { - AuthTokens( + val accountTokens = with(SESSION) { + AccountTokens( userId = UserId(userId.value, userId.domain), accessToken = accessToken, refreshToken = refreshToken, @@ -147,7 +146,7 @@ class RegisterAccountRepositoryTest { cookieLabel = cookieLabel ) } - val expected = Pair(ssoId, authTokens) + val expected = Pair(ssoId, accountTokens) given(registerApi).coroutine { register( @@ -161,7 +160,7 @@ class RegisterAccountRepositoryTest { ) }.then { NetworkResponse.Success(Pair(TEST_USER, SESSION), mapOf(), 200) } given(idMapper).invocation { toSsoId(TEST_USER.ssoID) }.then { ssoId } - given(sessionMapper).invocation { fromSessionDTO(SESSION) }.then { authTokens } + given(sessionMapper).invocation { fromSessionDTO(SESSION) }.then { accountTokens } val actual = registerAccountRepository.registerPersonalAccountWithEmail( email = email, @@ -171,7 +170,7 @@ class RegisterAccountRepositoryTest { cookieLabel = cookieLabel ) - assertIs>>(actual) + assertIs>>(actual) assertEquals(expected, actual.value) verify(registerApi).coroutine { @@ -203,9 +202,9 @@ class RegisterAccountRepositoryTest { val ssoId = with(TEST_USER.ssoID) { this?.let { SsoId(scimExternalId = it.scimExternalId, subject = it.subject, tenant = it.tenant) } } - val authTokens = + val accountTokens = with(SESSION) { - AuthTokens( + AccountTokens( userId = UserId(userId.value, userId.domain), accessToken = accessToken, refreshToken = refreshToken, @@ -213,7 +212,7 @@ class RegisterAccountRepositoryTest { cookieLabel = cookieLabel ) } - val expected = Pair(ssoId, authTokens) + val expected = Pair(ssoId, accountTokens) given(registerApi).coroutine { register( @@ -231,7 +230,7 @@ class RegisterAccountRepositoryTest { given(idMapper).invocation { toSsoId(TEST_USER.ssoID) }.then { ssoId } given(sessionMapper) .invocation { fromSessionDTO(SESSION) } - .then { authTokens } + .then { accountTokens } val actual = registerAccountRepository.registerTeamWithEmail( email = email, @@ -243,7 +242,7 @@ class RegisterAccountRepositoryTest { cookieLabel = cookieLabel ) - assertIs>>(actual) + assertIs>>(actual) assertEquals(expected, actual.value) verify(registerApi).coroutine { diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/session/SessionMapperTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/session/SessionMapperTest.kt index a06ede49d0c..170294c5205 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/session/SessionMapperTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/session/SessionMapperTest.kt @@ -21,7 +21,7 @@ package com.wire.kalium.logic.data.session import com.wire.kalium.logic.data.id.IdMapper import com.wire.kalium.logic.data.user.SsoId import com.wire.kalium.logic.data.user.UserId -import com.wire.kalium.logic.feature.auth.AuthTokens +import com.wire.kalium.logic.feature.auth.AccountTokens import com.wire.kalium.network.api.base.model.SessionDTO import com.wire.kalium.persistence.client.AuthTokenEntity import com.wire.kalium.persistence.dao.UserIDEntity @@ -49,15 +49,15 @@ class SessionMapperTest { @Test fun givenAnAuthTokens_whenMappingToSessionCredentials_thenValuesAreMappedCorrectly() { - val authSession: AuthTokens = TEST_AUTH_TOKENS + val authSession: AccountTokens = TEST_AUTH_TOKENS val acuteValue: SessionDTO = with(authSession) { SessionDTO( UserIdDTO(userId.value, userId.domain), tokenType, - accessToken, - refreshToken, + accessToken.value, + refreshToken.value, cookieLabel ) } @@ -68,7 +68,7 @@ class SessionMapperTest { @Test fun givenAnAuthTokens_whenMappingToPersistenceAuthTokens_thenValuesAreMappedCorrectly() { - val authSession: AuthTokens = TEST_AUTH_TOKENS + val authSession: AccountTokens = TEST_AUTH_TOKENS given(idMapper).invocation { idMapper.toSsoIdEntity(TEST_SSO_ID) }.then { TEST_SSO_ID_ENTITY } @@ -76,8 +76,8 @@ class SessionMapperTest { AuthTokenEntity( userId = UserIDEntity(userId.value, userId.domain), tokenType = tokenType, - accessToken = accessToken, - refreshToken = refreshToken, + accessToken = accessToken.value, + refreshToken = refreshToken.value, cookieLabel = cookieLabel ) } @@ -89,7 +89,7 @@ class SessionMapperTest { private companion object { val userId = UserId("user_id", "user.domain.io") - val TEST_AUTH_TOKENS = AuthTokens( + val TEST_AUTH_TOKENS = AccountTokens( userId = userId, tokenType = "Bearer", accessToken = "access_token", diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/auth/AddAuthenticatedUserUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/auth/AddAuthenticatedUserUseCaseTest.kt index bf33a73760a..1cc41898f65 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/auth/AddAuthenticatedUserUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/auth/AddAuthenticatedUserUseCaseTest.kt @@ -23,6 +23,8 @@ import com.wire.kalium.logic.configuration.server.ServerConfig import com.wire.kalium.logic.configuration.server.ServerConfigRepository import com.wire.kalium.logic.data.auth.login.ProxyCredentials import com.wire.kalium.logic.data.session.SessionRepository +import com.wire.kalium.logic.data.session.token.AccessToken +import com.wire.kalium.logic.data.session.token.RefreshToken import com.wire.kalium.logic.data.user.SsoId import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.functional.Either @@ -33,12 +35,10 @@ import io.mockative.given import io.mockative.mock import io.mockative.once import io.mockative.verify -import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.runTest import kotlin.test.Test import kotlin.test.assertIs -@OptIn(ExperimentalCoroutinesApi::class) class AddAuthenticatedUserUseCaseTest { @Test @@ -100,10 +100,16 @@ class AddAuthenticatedUserUseCaseTest { @Test fun givenUserWithAlreadyStoredSession_whenInvokedWithReplaceAndServerConfigAreTheSame_thenSuccessReturned() = runTest { - val oldSession = TEST_AUTH_TOKENS.copy(accessToken = "oldAccessToken", refreshToken = "oldRefreshToken") + val oldSession = TEST_AUTH_TOKENS.copy( + accessToken = AccessToken("oldAccessToken", TEST_AUTH_TOKENS.tokenType), + refreshToken = RefreshToken("oldRefreshToken") + ) val oldSessionFullInfo = Account(AccountInfo.Valid(oldSession.userId), TEST_SERVER_CONFIG, TEST_SSO_ID) - val newSession = TEST_AUTH_TOKENS.copy(accessToken = "newAccessToken", refreshToken = "newRefreshToken") + val newSession = TEST_AUTH_TOKENS.copy( + accessToken = AccessToken("newAccessToken", TEST_AUTH_TOKENS.tokenType), + refreshToken = RefreshToken("newRefreshToken") + ) val proxyCredentials = PROXY_CREDENTIALS @@ -146,10 +152,16 @@ class AddAuthenticatedUserUseCaseTest { @Test fun givenUserWithAlreadyStoredSessionWithDifferentServerConfig_whenInvokedWithReplace_thenUserAlreadyExistsReturned() = runTest { - val oldSession = TEST_AUTH_TOKENS.copy(accessToken = "oldAccessToken", refreshToken = "oldRefreshToken") + val oldSession = TEST_AUTH_TOKENS.copy( + accessToken = AccessToken("oldAccessToken", TEST_AUTH_TOKENS.tokenType), + refreshToken = RefreshToken("oldRefreshToken") + ) val oldSessionServer = newServerConfig(id = 11) - val newSession = TEST_AUTH_TOKENS.copy(accessToken = "newAccessToken", refreshToken = "newRefreshToken") + val newSession = TEST_AUTH_TOKENS.copy( + accessToken = AccessToken("newAccessToken", TEST_AUTH_TOKENS.tokenType), + refreshToken = RefreshToken("newRefreshToken") + ) val newSessionServer = newServerConfig(id = 22) val proxyCredentials = PROXY_CREDENTIALS @@ -191,7 +203,7 @@ class AddAuthenticatedUserUseCaseTest { private companion object { val TEST_USERID = UserId("user_id", "domain.de") val TEST_SERVER_CONFIG: ServerConfig = newServerConfig(1) - val TEST_AUTH_TOKENS = AuthTokens( + val TEST_AUTH_TOKENS = AccountTokens( TEST_USERID, "access-token", "refresh-token", @@ -246,11 +258,11 @@ class AddAuthenticatedUserUseCaseTest { suspend fun withStoreSessionResult( serverConfigId: String, ssoId: SsoId?, - authTokens: AuthTokens, + accountTokens: AccountTokens, proxyCredentials: ProxyCredentials?, result: Either ) = apply { - given(sessionRepository).coroutine { storeSession(serverConfigId, ssoId, authTokens, proxyCredentials) }.then { result } + given(sessionRepository).coroutine { storeSession(serverConfigId, ssoId, accountTokens, proxyCredentials) }.then { result } } suspend fun withUpdateCurrentSessionResult( diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/auth/LoginUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/auth/LoginUseCaseTest.kt index 406fffe38ba..1cedeb6cd84 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/auth/LoginUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/auth/LoginUseCaseTest.kt @@ -499,14 +499,14 @@ class LoginUseCaseTest { .thenReturn(handleValidationResult) } - fun withLoginUsingEmailResulting(result: Either>) = apply { + fun withLoginUsingEmailResulting(result: Either>) = apply { given(loginRepository) .suspendFunction(loginRepository::loginWithEmail) .whenInvokedWith(any(), any(), any(), any(), anything()) .thenReturn(result) } - fun withLoginUsingHandleResulting(result: Either>) = apply { + fun withLoginUsingHandleResulting(result: Either>) = apply { given(loginRepository) .suspendFunction(loginRepository::loginWithHandle) .whenInvokedWith(any(), any(), any(), any()) @@ -534,7 +534,7 @@ class LoginUseCaseTest { // TODO: Remove random value from tests val TEST_PERSIST_CLIENT = Random.nextBoolean() val TEST_SERVER_CONFIG: ServerConfig = newTestServer(1) - val TEST_AUTH_TOKENS = AuthTokens( + val TEST_AUTH_TOKENS = AccountTokens( userId = UserId("user_id", "domain.de"), accessToken = "access_token", refreshToken = "refresh_token", diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/register/RegisterAccountUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/register/RegisterAccountUseCaseTest.kt index 284d378ccea..a971645a6e9 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/register/RegisterAccountUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/register/RegisterAccountUseCaseTest.kt @@ -27,7 +27,7 @@ import com.wire.kalium.logic.data.user.SelfUser import com.wire.kalium.logic.data.user.SsoId import com.wire.kalium.logic.data.user.UserAvailabilityStatus import com.wire.kalium.logic.data.user.UserId -import com.wire.kalium.logic.feature.auth.AuthTokens +import com.wire.kalium.logic.feature.auth.AccountTokens import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.test_util.TestNetworkException import com.wire.kalium.logic.util.stubs.newServerConfig @@ -217,7 +217,7 @@ class RegisterAccountUseCaseTest { availabilityStatus = UserAvailabilityStatus.NONE, supportedProtocols = null ) - val TEST_AUTH_TOKENS = AuthTokens( + val TEST_AUTH_TOKENS = AccountTokens( accessToken = "access_token", refreshToken = "refresh_token", tokenType = "token_type", diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/session/token/AccessTokenRefresherTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/session/token/AccessTokenRefresherTest.kt new file mode 100644 index 00000000000..1d41bb6de83 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/session/token/AccessTokenRefresherTest.kt @@ -0,0 +1,141 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.session.token + +import com.wire.kalium.logic.NetworkFailure +import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.data.session.token.AccessToken +import com.wire.kalium.logic.data.session.token.AccessTokenRefreshResult +import com.wire.kalium.logic.data.session.token.AccessTokenRepository +import com.wire.kalium.logic.data.session.token.RefreshToken +import com.wire.kalium.logic.framework.TestUser +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.util.shouldFail +import com.wire.kalium.logic.util.shouldSucceed +import io.mockative.Mock +import io.mockative.anything +import io.mockative.eq +import io.mockative.given +import io.mockative.mock +import io.mockative.once +import io.mockative.verify +import kotlinx.coroutines.test.TestResult +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals + +class AccessTokenRefresherTest { + + @Test + fun givenRefreshFails_whenRefreshing_thenShouldPropagateFailureAndNotPersist(): TestResult = runTest { + val failure = NetworkFailure.NoNetworkConnection(null) + val (arrangement, accessTokenRefresher) = arrange { + withRefreshTokenReturning(Either.Left(failure)) + } + + accessTokenRefresher.refreshTokenAndPersistSession("egal") + .shouldFail { + assertEquals(failure, it) + } + verify(arrangement.repository) + .suspendFunction(arrangement.repository::persistTokens) + .with(anything()) + .wasNotInvoked() + } + + @Test + fun givenPersistFails_whenRefreshing_thenShouldPropagateFailure(): TestResult = runTest { + val failure = StorageFailure.DataNotFound + val (_, accessTokenRefresher) = arrange { + withRefreshTokenReturning( + Either.Right( + AccessTokenRefreshResult( + AccessToken("access", "refresh"), + RefreshToken("hey") + ) + ) + ) + withPersistReturning(Either.Left(failure)) + } + + accessTokenRefresher.refreshTokenAndPersistSession("egal") + .shouldFail { + assertEquals(failure, it) + } + } + + @Test + fun givenSuccessfulRefresh_whenRefreshing_thenShouldPersistResultCorrectly(): TestResult = runTest { + val (arrangement, accessTokenRefresher) = arrange { + withRefreshTokenReturning(Either.Right(TEST_REFRESH_RESULT)) + withPersistReturning(Either.Right(Unit)) + } + + accessTokenRefresher.refreshTokenAndPersistSession("egal") + verify(arrangement.repository) + .suspendFunction(arrangement.repository::persistTokens) + .with(eq(TEST_REFRESH_RESULT.accessToken), eq(TEST_REFRESH_RESULT.refreshToken)) + .wasInvoked(exactly = once) + } + + @Test + fun givenEverythingSucceeds_whenRefreshing_thenShouldPropagateSuccess(): TestResult = runTest { + val (_, accessTokenRefresher) = arrange { + withRefreshTokenReturning(Either.Right(TEST_REFRESH_RESULT)) + withPersistReturning(Either.Right(Unit)) + } + + accessTokenRefresher.refreshTokenAndPersistSession("egal").shouldSucceed { } + } + + private class Arrangement(private val configure: Arrangement.() -> Unit) { + + val userId = TestUser.USER_ID + + @Mock + val repository = mock(AccessTokenRepository::class) + + fun arrange(): Pair = run { + configure() + this@Arrangement to AccessTokenRefresherImpl(userId, repository) + } + + fun withRefreshTokenReturning(result: Either) { + given(repository) + .suspendFunction(repository::getNewAccessToken) + .whenInvokedWith(anything(), anything()) + .thenReturn(result) + } + + fun withPersistReturning(result: Either) { + given(repository) + .suspendFunction(repository::persistTokens) + .whenInvokedWith(anything(), anything()) + .thenReturn(result) + } + } + + private companion object { + fun arrange(configure: Arrangement.() -> Unit) = Arrangement(configure).arrange() + + val TEST_REFRESH_RESULT = AccessTokenRefreshResult( + AccessToken("access", "refresh"), + RefreshToken("hey") + ) + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/network/SessionManagerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/network/SessionManagerTest.kt new file mode 100644 index 00000000000..50e95a03f86 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/network/SessionManagerTest.kt @@ -0,0 +1,199 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.network + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.NetworkFailure +import com.wire.kalium.logic.configuration.server.ServerConfigMapper +import com.wire.kalium.logic.data.logout.LogoutReason +import com.wire.kalium.logic.data.session.SessionRepository +import com.wire.kalium.logic.di.MapperProvider +import com.wire.kalium.logic.feature.auth.AccountTokens +import com.wire.kalium.logic.feature.session.token.AccessTokenRefresher +import com.wire.kalium.logic.feature.session.token.AccessTokenRefresherFactory +import com.wire.kalium.logic.framework.TestUser +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.network.api.base.authenticated.AccessTokenApi +import com.wire.kalium.network.api.base.model.QualifiedID +import com.wire.kalium.network.api.base.model.SessionDTO +import com.wire.kalium.network.session.FailureToRefreshTokenException +import com.wire.kalium.network.session.SessionManager +import com.wire.kalium.persistence.client.AuthTokenEntity +import com.wire.kalium.persistence.client.AuthTokenStorage +import com.wire.kalium.persistence.dao.UserIDEntity +import io.mockative.Mock +import io.mockative.any +import io.mockative.anything +import io.mockative.given +import io.mockative.mock +import kotlinx.coroutines.test.runTest +import kotlin.coroutines.EmptyCoroutineContext +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotNull + +class SessionManagerTest { + + @Test + fun givenFailureOnRefresh_whenRefreshingToken_thenShouldThrowException() = runTest { + val failure = NetworkFailure.NoNetworkConnection(null) + val (arrangement, sessionManager) = arrange { + withTokenRefresherResult(Either.Left(failure)) + } + + assertFailsWith { + sessionManager.updateToken(arrangement.accessTokenApi, "egal", "egal") + } + } + + @Test + fun givenInitialSession_whenFetchingSession_thenSessionShouldBeReturnedProperly() = runTest { + val expectedData = AuthTokenEntity( + userId = UserIDEntity("potato", "potahto"), + accessToken = "aToken", + refreshToken = "rToken", + tokenType = "tType", + cookieLabel = null + ) + val (_, sessionManager) = arrange { + withCurrentTokenResult(expectedData) + } + + val result = sessionManager.session() + assertNotNull(result) + assertEquals(expectedData.userId.value, result.userId.value) + assertEquals(expectedData.userId.domain, result.userId.domain) + assertEquals(expectedData.accessToken, result.accessToken) + assertEquals(expectedData.refreshToken, result.refreshToken) + assertEquals(expectedData.tokenType, result.tokenType) + assertEquals(expectedData.cookieLabel, result.cookieLabel) + } + + @Test + fun givenInitialSessionIsUpdated_whenFetchingSession_thenSessionShouldBeUpdatedProperly() = runTest { + val expectedData = TEST_ACCOUNT_TOKENS + val (arrangement, sessionManager) = arrange { + withCurrentTokenResult( + AuthTokenEntity( + userId = UserIDEntity("potato", "potahto"), + accessToken = "aToken", + refreshToken = "rToken", + tokenType = "tType", + cookieLabel = "cLabel" + ) + ) + withTokenRefresherResult(Either.Right(expectedData)) + } + + sessionManager.session() + sessionManager.updateToken(arrangement.accessTokenApi, "egal", "egal") + val result = sessionManager.session() + assertNotNull(result) + assertEquals(expectedData.userId.value, result.userId.value) + assertEquals(expectedData.userId.domain, result.userId.domain) + assertEquals(expectedData.accessToken.value, result.accessToken) + assertEquals(expectedData.refreshToken.value, result.refreshToken) + assertEquals(expectedData.tokenType, result.tokenType) + assertEquals(expectedData.cookieLabel, result.cookieLabel) + } + + @Test + fun givenTokenWasUpdated_whenGettingSession_thenItShouldBeUpdatedAsWell() = runTest { + val (arrangement, sessionManager) = arrange { + withTokenRefresherResult(Either.Right(TEST_ACCOUNT_TOKENS)) + } + + sessionManager.updateToken(arrangement.accessTokenApi, "egal", "egal") + + assertEquals(TEST_SESSION_DTO, sessionManager.session()) + } + + private class Arrangement(private val configure: Arrangement.() -> Unit) { + + @Mock + private val sessionRepository = mock(SessionRepository::class) + + // Unused, but necessary when updating tokens + @Mock + val accessTokenApi = mock(AccessTokenApi::class) + + @Mock + private val accessTokenRefresher = mock(AccessTokenRefresher::class) + private val accessTokenRefresherFactory = object : AccessTokenRefresherFactory { + override fun create(accessTokenApi: AccessTokenApi): AccessTokenRefresher { + return accessTokenRefresher + } + } + private val userId = TestUser.USER_ID + + @Mock + private val tokenStorage = mock(AuthTokenStorage::class) + + private val logout = { _: LogoutReason -> } + + @Mock + private val serverConfigMapper = mock(ServerConfigMapper::class) + + private val sessionMapper = MapperProvider.sessionMapper() + + fun arrange(): Pair = run { + configure() + this@Arrangement to SessionManagerImpl( + sessionRepository = sessionRepository, + accessTokenRefresherFactory = accessTokenRefresherFactory, + userId = userId, + tokenStorage = tokenStorage, + logout = logout, + serverConfigMapper = serverConfigMapper, + sessionMapper = sessionMapper, + coroutineContext = EmptyCoroutineContext + ) + } + + fun withTokenRefresherResult(result: Either) = apply { + given(accessTokenRefresher).suspendFunction(accessTokenRefresher::refreshTokenAndPersistSession) + .whenInvokedWith(anything(), anything()).thenReturn(result) + } + + fun withCurrentTokenResult(result: AuthTokenEntity) = apply { + given(tokenStorage) + .function(tokenStorage::getToken) + .whenInvokedWith(any()) + .thenReturn(result) + } + } + + private companion object { + fun arrange(configure: Arrangement.() -> Unit) = Arrangement(configure).arrange() + val TEST_ACCOUNT_TOKENS = AccountTokens( + userId = TestUser.USER_ID, + accessToken = "access-token", + refreshToken = "refresh-token", + tokenType = "type", + cookieLabel = "cookie-label" + ) + val TEST_SESSION_DTO = SessionDTO( + userId = QualifiedID(TEST_ACCOUNT_TOKENS.userId.value, TEST_ACCOUNT_TOKENS.userId.domain), + tokenType = TEST_ACCOUNT_TOKENS.accessToken.tokenType, + accessToken = TEST_ACCOUNT_TOKENS.accessToken.value, + refreshToken = TEST_ACCOUNT_TOKENS.refreshToken.value, + cookieLabel = TEST_ACCOUNT_TOKENS.cookieLabel + ) + } +} diff --git a/monkeys/src/main/kotlin/com/wire/kalium/monkeys/conversation/Monkey.kt b/monkeys/src/main/kotlin/com/wire/kalium/monkeys/conversation/Monkey.kt index cca0705c531..1ed14db8057 100644 --- a/monkeys/src/main/kotlin/com/wire/kalium/monkeys/conversation/Monkey.kt +++ b/monkeys/src/main/kotlin/com/wire/kalium/monkeys/conversation/Monkey.kt @@ -82,7 +82,7 @@ class Monkey(val user: UserData) { val storeResult = addAuthenticatedAccount( serverConfigId = loginResult.serverConfigId, ssoId = loginResult.ssoID, - authTokens = loginResult.authData, + accountTokens = loginResult.authData, proxyCredentials = loginResult.proxyCredentials, replace = true ) diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/AccessTokenApi.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/AccessTokenApi.kt index 8a699c2f371..35addc23cee 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/AccessTokenApi.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/AccessTokenApi.kt @@ -23,5 +23,8 @@ import com.wire.kalium.network.api.base.model.RefreshTokenDTO import com.wire.kalium.network.utils.NetworkResponse interface AccessTokenApi { - suspend fun getToken(refreshToken: String, clientId: String? = null): NetworkResponse> + suspend fun getToken( + refreshToken: String, + clientId: String? = null + ): NetworkResponse> } diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/model/Tokens.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/model/Tokens.kt index e249532dd07..8584fb9dafd 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/model/Tokens.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/model/Tokens.kt @@ -22,6 +22,14 @@ import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlin.jvm.JvmInline +/** + * Represents an access token received from an authentication server. + * + * @property userId The ID of the user associated with the access token. + * @property value The access token value. + * @property expiresIn The duration in seconds until the token expires. + * @property tokenType The type of the token. _e.g._ "Bearer" + */ @Serializable data class AccessTokenDTO( @SerialName("user") val userId: NonQualifiedUserId, diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/networkContainer/AuthenticatedNetworkContainer.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/networkContainer/AuthenticatedNetworkContainer.kt index b39e6a06500..57e497a4fab 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/networkContainer/AuthenticatedNetworkContainer.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/networkContainer/AuthenticatedNetworkContainer.kt @@ -48,6 +48,7 @@ import com.wire.kalium.network.api.v2.authenticated.networkContainer.Authenticat import com.wire.kalium.network.api.v3.authenticated.networkContainer.AuthenticatedNetworkContainerV3 import com.wire.kalium.network.api.v4.authenticated.networkContainer.AuthenticatedNetworkContainerV4 import com.wire.kalium.network.api.v5.authenticated.networkContainer.AuthenticatedNetworkContainerV5 +import com.wire.kalium.network.kaliumLogger import com.wire.kalium.network.session.CertificatePinning import com.wire.kalium.network.session.SessionManager import com.wire.kalium.network.tools.ServerConfigDTO @@ -202,14 +203,24 @@ internal class AuthenticatedHttpClientProviderImpl( private val loadToken: suspend () -> BearerTokens? = { val session = sessionManager.session() ?: error("missing user session") - BearerTokens(accessToken = session.accessToken, refreshToken = session.refreshToken) + BearerTokens(accessToken = "Invalid", refreshToken = session.refreshToken) } - private val refreshToken: suspend RefreshTokensParams.() -> BearerTokens? = { - val newSession = sessionManager.updateToken(accessTokenApi(client), oldTokens!!.accessToken, oldTokens!!.refreshToken) - newSession?.let { - BearerTokens(accessToken = it.accessToken, refreshToken = it.refreshToken) + private val refreshToken: suspend RefreshTokensParams.() -> BearerTokens = { + val areOldTokensNull = oldTokens == null + kaliumLogger.i("Auth tokens are being refreshed") + if (areOldTokensNull) { + kaliumLogger.e("Old Auth tokens are null! Someone call the doctor! This should never happen") } + val newSession = sessionManager.updateToken( + accessTokenApi = accessTokenApi(client), + oldAccessToken = oldTokens!!.accessToken, + oldRefreshToken = oldTokens!!.refreshToken + ) + BearerTokens( + accessToken = newSession.accessToken, + refreshToken = newSession.refreshToken + ) } private val bearerAuthProvider: BearerAuthProvider = BearerAuthProvider(refreshToken, loadToken, { true }, null) diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/session/FailureToRefreshTokenException.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/session/FailureToRefreshTokenException.kt new file mode 100644 index 00000000000..0a366237ab5 --- /dev/null +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/session/FailureToRefreshTokenException.kt @@ -0,0 +1,30 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.network.session + +/** + * Exception thrown when a token refresh fails. + * Could be caused by anything, including network errors, invalid credentials, etc. + * + * @param message The detail message. + * @param cause The cause of the exception. + */ +class FailureToRefreshTokenException( + message: String, + cause: Throwable? = null +) : RuntimeException(message, cause) diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/session/SessionManager.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/session/SessionManager.kt index 708d397da6f..c90b0e3d30a 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/session/SessionManager.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/session/SessionManager.kt @@ -19,9 +19,7 @@ package com.wire.kalium.network.session import com.wire.kalium.network.api.base.authenticated.AccessTokenApi -import com.wire.kalium.network.api.base.model.AccessTokenDTO import com.wire.kalium.network.api.base.model.ProxyCredentialsDTO -import com.wire.kalium.network.api.base.model.RefreshTokenDTO import com.wire.kalium.network.api.base.model.SessionDTO import com.wire.kalium.network.tools.ServerConfigDTO import io.ktor.client.HttpClient @@ -43,8 +41,19 @@ import kotlin.coroutines.CoroutineContext interface SessionManager { suspend fun session(): SessionDTO? fun serverConfig(): ServerConfigDTO - suspend fun updateToken(accessTokenApi: AccessTokenApi, oldAccessToken: String, oldRefreshToken: String): SessionDTO? - suspend fun updateLoginSession(newAccessTokenDTO: AccessTokenDTO, newRefreshTokenDTO: RefreshTokenDTO?): SessionDTO? + + /** + * Updates the access token and (possibly) the refresh token for the session. + * + * In case of failure to refresh the access token, an exception can be thrown. + * + * @param accessTokenApi The AccessTokenApi interface used to retrieve the new access token. + * @param oldAccessToken The old access token to be replaced. + * @param oldRefreshToken The old refresh token to be replaced. + * @return The updated SessionDTO object. + * @see FailureToRefreshTokenException + */ + suspend fun updateToken(accessTokenApi: AccessTokenApi, oldAccessToken: String, oldRefreshToken: String): SessionDTO fun proxyCredentials(): ProxyCredentialsDTO? } diff --git a/network/src/commonTest/kotlin/com/wire/kalium/api/ApiTest.kt b/network/src/commonTest/kotlin/com/wire/kalium/api/ApiTest.kt index f102edf2827..32678b0e534 100644 --- a/network/src/commonTest/kotlin/com/wire/kalium/api/ApiTest.kt +++ b/network/src/commonTest/kotlin/com/wire/kalium/api/ApiTest.kt @@ -65,8 +65,12 @@ internal abstract class ApiTest { private val refreshToken: suspend RefreshTokensParams.() -> BearerTokens? get() = { - val newSession = TEST_SESSION_MANAGER.updateToken(AccessTokenApiV0(client), oldTokens!!.accessToken, oldTokens!!.refreshToken) - newSession?.let { + val newSession = TEST_SESSION_MANAGER.updateToken( + accessTokenApi = AccessTokenApiV0(client), + oldAccessToken = oldTokens!!.accessToken, + oldRefreshToken = oldTokens!!.refreshToken + ) + newSession.let { BearerTokens(accessToken = it.accessToken, refreshToken = it.refreshToken) } } diff --git a/network/src/commonTest/kotlin/com/wire/kalium/api/TestSessionManagerV0.kt b/network/src/commonTest/kotlin/com/wire/kalium/api/TestSessionManagerV0.kt index 6b092cf90dc..15b20e5970a 100644 --- a/network/src/commonTest/kotlin/com/wire/kalium/api/TestSessionManagerV0.kt +++ b/network/src/commonTest/kotlin/com/wire/kalium/api/TestSessionManagerV0.kt @@ -20,9 +20,7 @@ package com.wire.kalium.api import com.wire.kalium.api.json.model.testCredentials import com.wire.kalium.network.api.base.authenticated.AccessTokenApi -import com.wire.kalium.network.api.base.model.AccessTokenDTO import com.wire.kalium.network.api.base.model.ProxyCredentialsDTO -import com.wire.kalium.network.api.base.model.RefreshTokenDTO import com.wire.kalium.network.api.base.model.SessionDTO import com.wire.kalium.network.session.SessionManager import com.wire.kalium.network.tools.ServerConfigDTO @@ -33,19 +31,14 @@ class TestSessionManagerV0 : SessionManager { override suspend fun session(): SessionDTO = session override fun serverConfig(): ServerConfigDTO = serverConfig - override suspend fun updateToken(accessTokenApi: AccessTokenApi, oldAccessToken: String, oldRefreshToken: String): SessionDTO? { + override suspend fun updateToken( + accessTokenApi: AccessTokenApi, + oldAccessToken: String, + oldRefreshToken: String + ): SessionDTO { TODO("Not yet implemented") } - override suspend fun updateLoginSession(newAccessTokenDTO: AccessTokenDTO, newRefreshTokenDTO: RefreshTokenDTO?) = - SessionDTO( - session.userId, - newAccessTokenDTO.tokenType, - newAccessTokenDTO.value, - newRefreshTokenDTO?.value ?: session.refreshToken, - session.cookieLabel - ) - override fun proxyCredentials(): ProxyCredentialsDTO? = ProxyCredentialsDTO("username", "password") diff --git a/network/src/commonTest/kotlin/com/wire/kalium/api/common/SessionManagerTest.kt b/network/src/commonTest/kotlin/com/wire/kalium/api/common/SessionManagerIntegrationTest.kt similarity index 66% rename from network/src/commonTest/kotlin/com/wire/kalium/api/common/SessionManagerTest.kt rename to network/src/commonTest/kotlin/com/wire/kalium/api/common/SessionManagerIntegrationTest.kt index 93e3d53ced7..6c02c647c54 100644 --- a/network/src/commonTest/kotlin/com/wire/kalium/api/common/SessionManagerTest.kt +++ b/network/src/commonTest/kotlin/com/wire/kalium/api/common/SessionManagerIntegrationTest.kt @@ -23,17 +23,17 @@ import com.wire.kalium.api.TestNetworkStateObserver.Companion.DEFAULT_TEST_NETWO import com.wire.kalium.api.json.model.testCredentials import com.wire.kalium.network.AuthenticatedNetworkClient import com.wire.kalium.network.api.base.authenticated.AccessTokenApi -import com.wire.kalium.network.api.base.model.AccessTokenDTO import com.wire.kalium.network.api.base.model.ProxyCredentialsDTO -import com.wire.kalium.network.api.base.model.RefreshTokenDTO import com.wire.kalium.network.api.base.model.SessionDTO import com.wire.kalium.network.api.v0.authenticated.AccessTokenApiV0 import com.wire.kalium.network.api.v0.authenticated.AssetApiV0 +import com.wire.kalium.network.exceptions.KaliumException import com.wire.kalium.network.kaliumLogger import com.wire.kalium.network.networkContainer.KaliumUserAgentProvider import com.wire.kalium.network.session.SessionManager import com.wire.kalium.network.session.installAuth import com.wire.kalium.network.tools.ServerConfigDTO +import com.wire.kalium.network.utils.NetworkResponse import io.ktor.client.HttpClient import io.ktor.client.engine.mock.MockEngine import io.ktor.client.engine.mock.respondError @@ -44,30 +44,39 @@ import io.ktor.client.plugins.auth.providers.RefreshTokensParams import io.ktor.client.request.get import io.ktor.http.HttpHeaders import io.ktor.http.HttpStatusCode -import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.runTest import okio.FileSystem import okio.Path.Companion.toPath import okio.fakefilesystem.FakeFileSystem import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertIs import kotlin.test.assertNull +import kotlin.test.assertTrue -@OptIn(ExperimentalCoroutinesApi::class) -class SessionManagerTest { +/** + * Tests how our [SessionManager] integrates with the internals of Ktor. + * For example, making sure that when we throw an exception during token refresh, + * Ktor will catch it, and we will be able to return a [NetworkResponse.Error] with the exception. + */ +class SessionManagerIntegrationTest { @Test fun givenClientWithAuth_whenServerReturns401_thenShouldTryAgainWithNewToken() = runTest { val sessionManager = createFakeSessionManager() val loadToken: suspend () -> BearerTokens? = { - val session = sessionManager.session() ?: error("missing user session") + val session = sessionManager.session() BearerTokens(accessToken = session.accessToken, refreshToken = session.refreshToken) } val refreshToken: suspend RefreshTokensParams.() -> BearerTokens? = { - val newSession = sessionManager.updateToken(AccessTokenApiV0(client), oldTokens!!.accessToken, oldTokens!!.refreshToken) - newSession?.let { + val newSession = sessionManager.updateToken( + accessTokenApi = AccessTokenApiV0(client), + oldAccessToken = oldTokens!!.accessToken, + oldRefreshToken = oldTokens!!.refreshToken + ) + newSession.let { BearerTokens(accessToken = it.accessToken, refreshToken = it.refreshToken) } } @@ -76,7 +85,7 @@ class SessionManagerTest { var callCount = 0 var didFail = false - val mockEngine = MockEngine() { + val mockEngine = MockEngine { callCount++ // Fail only the first time, so the test can // proceed when sessionManager is called again @@ -104,20 +113,24 @@ class SessionManagerTest { val sessionManager = createFakeSessionManager() val loadToken: suspend () -> BearerTokens? = { - val session = sessionManager.session() ?: error("missing user session") + val session = sessionManager.session() BearerTokens(accessToken = session.accessToken, refreshToken = session.refreshToken) } val refreshToken: suspend RefreshTokensParams.() -> BearerTokens? = { - val newSession = sessionManager.updateToken(AccessTokenApiV0(client), oldTokens!!.accessToken, oldTokens!!.refreshToken) - newSession?.let { + val newSession = sessionManager.updateToken( + accessTokenApi = AccessTokenApiV0(client), + oldAccessToken = oldTokens!!.accessToken, + oldRefreshToken = oldTokens!!.refreshToken + ) + newSession.let { BearerTokens(accessToken = it.accessToken, refreshToken = it.refreshToken) } } val bearerAuthProvider = BearerAuthProvider(refreshToken, loadToken, { true }, null) - val mockEngine = MockEngine() { + val mockEngine = MockEngine { respondOk() } @@ -139,13 +152,17 @@ class SessionManagerTest { val sessionManager = createFakeSessionManager() val loadToken: suspend () -> BearerTokens? = { - val session = sessionManager.session() ?: error("missing user session") + val session = sessionManager.session() BearerTokens(accessToken = session.accessToken, refreshToken = session.refreshToken) } val refreshToken: suspend RefreshTokensParams.() -> BearerTokens? = { - val newSession = sessionManager.updateToken(AccessTokenApiV0(client), oldTokens!!.accessToken, oldTokens!!.refreshToken) - newSession?.let { + val newSession = sessionManager.updateToken( + AccessTokenApiV0(client), + oldTokens!!.accessToken, + oldTokens!!.refreshToken + ) + newSession.let { BearerTokens(accessToken = it.accessToken, refreshToken = it.refreshToken) } } @@ -154,7 +171,7 @@ class SessionManagerTest { var callCount = 0 var didFail = false - val mockEngine = MockEngine() { + val mockEngine = MockEngine { callCount++ // Fail only the first time, so the test can // proceed when sessionManager is called again @@ -169,7 +186,12 @@ class SessionManagerTest { } val client = AuthenticatedNetworkClient( - DEFAULT_TEST_NETWORK_STATE_OBSERVER, mockEngine, sessionManager.serverConfig(), bearerAuthProvider, kaliumLogger, false + DEFAULT_TEST_NETWORK_STATE_OBSERVER, + mockEngine, + sessionManager.serverConfig(), + bearerAuthProvider, + kaliumLogger, + false ) val assetApi = AssetApiV0(client) val kaliumFileSystem: FileSystem = FakeFileSystem() @@ -180,6 +202,50 @@ class SessionManagerTest { assertEquals(2, callCount) } + @Test + fun givenRefreshTokenThrows_whenServerSignalTokenRefreshIsNeeded_thenShouldReturnFailure() = runTest { + KaliumUserAgentProvider.setUserAgent("KaliumTest") + val sessionManager = createFakeSessionManager() + + val loadToken: suspend () -> BearerTokens? = { + val session = sessionManager.session() + BearerTokens(accessToken = session.accessToken, refreshToken = session.refreshToken) + } + + val expectedCause = Exception("Refresh token failed") + var isThrowing = false + val refreshToken: suspend RefreshTokensParams.() -> BearerTokens? = { + isThrowing = true + throw expectedCause + } + + val bearerAuthProvider = BearerAuthProvider(refreshToken, loadToken, { true }, null) + + val mockEngine = MockEngine { + respondError(status = HttpStatusCode.Unauthorized) + } + + val client = AuthenticatedNetworkClient( + DEFAULT_TEST_NETWORK_STATE_OBSERVER, + mockEngine, + sessionManager.serverConfig(), + bearerAuthProvider, + kaliumLogger, + false + ) + val assetApi = AssetApiV0(client) + val kaliumFileSystem: FileSystem = FakeFileSystem() + val tempPath = "some-dummy-path".toPath() + val tempOutputSink = kaliumFileSystem.sink(tempPath) + + val result = assetApi.downloadAsset("asset_id", "asset_domain", null, tempFileSink = tempOutputSink) + assertIs(result) + val exception = result.kException + assertIs(exception) + assertEquals(expectedCause.message, exception.cause.message) + assertTrue(isThrowing) + } + private companion object { const val UPDATED_ACCESS_TOKEN = "new access token" } @@ -191,10 +257,7 @@ class SessionManagerTest { accessTokenApi: AccessTokenApi, oldAccessToken: String, oldRefreshToken: String - ): SessionDTO? = testCredentials.copy(accessToken = UPDATED_ACCESS_TOKEN) - - override suspend fun updateLoginSession(newAccessTokeDTO: AccessTokenDTO, newRefreshTokenDTO: RefreshTokenDTO?): SessionDTO? = - testCredentials + ): SessionDTO = testCredentials.copy(accessToken = UPDATED_ACCESS_TOKEN) override fun proxyCredentials(): ProxyCredentialsDTO? = ProxyCredentialsDTO("username", "password") } diff --git a/persistence/src/androidMain/kotlin/com/wire/kalium/persistence/kmmSettings/GlobalPrefProvider.kt b/persistence/src/androidMain/kotlin/com/wire/kalium/persistence/kmmSettings/GlobalPrefProvider.kt index a63a3fa6be1..428f7f91b82 100644 --- a/persistence/src/androidMain/kotlin/com/wire/kalium/persistence/kmmSettings/GlobalPrefProvider.kt +++ b/persistence/src/androidMain/kotlin/com/wire/kalium/persistence/kmmSettings/GlobalPrefProvider.kt @@ -20,6 +20,7 @@ package com.wire.kalium.persistence.kmmSettings import android.content.Context import com.wire.kalium.persistence.client.AuthTokenStorage +import com.wire.kalium.persistence.client.AuthTokenStorageImpl import com.wire.kalium.persistence.client.TokenStorage import com.wire.kalium.persistence.client.TokenStorageImpl import com.wire.kalium.persistence.dbPassphrase.PassphraseStorage @@ -32,7 +33,7 @@ actual class GlobalPrefProvider(context: Context, shouldEncryptData: Boolean = t ) actual val authTokenStorage: AuthTokenStorage - get() = AuthTokenStorage(encryptedSettingsHolder) + get() = AuthTokenStorageImpl(encryptedSettingsHolder) actual val passphraseStorage: PassphraseStorage get() = PassphraseStorageImpl(encryptedSettingsHolder) actual val tokenStorage: TokenStorage diff --git a/persistence/src/appleMain/kotlin/com/wire/kalium/persistence/kmmSettings/GlobalPrefProvider.kt b/persistence/src/appleMain/kotlin/com/wire/kalium/persistence/kmmSettings/GlobalPrefProvider.kt index 88d765b8921..023a88beae5 100644 --- a/persistence/src/appleMain/kotlin/com/wire/kalium/persistence/kmmSettings/GlobalPrefProvider.kt +++ b/persistence/src/appleMain/kotlin/com/wire/kalium/persistence/kmmSettings/GlobalPrefProvider.kt @@ -19,6 +19,7 @@ package com.wire.kalium.persistence.kmmSettings import com.wire.kalium.persistence.client.AuthTokenStorage +import com.wire.kalium.persistence.client.AuthTokenStorageImpl import com.wire.kalium.persistence.client.TokenStorage import com.wire.kalium.persistence.client.TokenStorageImpl import com.wire.kalium.persistence.dbPassphrase.PassphraseStorage @@ -37,7 +38,7 @@ actual class GlobalPrefProvider( ) actual val authTokenStorage: AuthTokenStorage - get() = AuthTokenStorage(kaliumPref) + get() = AuthTokenStorageImpl(kaliumPref) actual val passphraseStorage: PassphraseStorage get() = PassphraseStorageImpl(kaliumPref) actual val tokenStorage: TokenStorage diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/client/AuthTokenStorage.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/client/AuthTokenStorage.kt index cd4eb6fdce7..ebebbe1c946 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/client/AuthTokenStorage.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/client/AuthTokenStorage.kt @@ -38,10 +38,24 @@ data class ProxyCredentialsEntity( @SerialName("password") val password: String, ) -class AuthTokenStorage internal constructor( +interface AuthTokenStorage { + fun addOrReplace(authTokenEntity: AuthTokenEntity, proxyCredentialsEntity: ProxyCredentialsEntity?) + fun updateToken( + userId: UserIDEntity, + accessToken: String, + tokenType: String, + refreshToken: String?, + ): AuthTokenEntity + + fun getToken(userId: UserIDEntity): AuthTokenEntity? + fun deleteToken(userId: UserIDEntity) + fun proxyCredentials(userId: UserIDEntity): ProxyCredentialsEntity? +} + +internal class AuthTokenStorageImpl( private val kaliumPreferences: KaliumPreferences -) { - fun addOrReplace(authTokenEntity: AuthTokenEntity, proxyCredentialsEntity: ProxyCredentialsEntity?) { +) : AuthTokenStorage { + override fun addOrReplace(authTokenEntity: AuthTokenEntity, proxyCredentialsEntity: ProxyCredentialsEntity?) { kaliumPreferences.putSerializable( tokenKey(authTokenEntity.userId), authTokenEntity, @@ -53,7 +67,7 @@ class AuthTokenStorage internal constructor( } } - fun updateToken( + override fun updateToken( userId: UserIDEntity, accessToken: String, tokenType: String, @@ -74,18 +88,18 @@ class AuthTokenStorage internal constructor( } // TODO: make suspendable - fun getToken(userId: UserIDEntity): AuthTokenEntity? = + override fun getToken(userId: UserIDEntity): AuthTokenEntity? = kaliumPreferences.getSerializable( tokenKey(userId), AuthTokenEntity.serializer() ) - fun deleteToken(userId: UserIDEntity) { + override fun deleteToken(userId: UserIDEntity) { kaliumPreferences.remove(tokenKey(userId)) kaliumPreferences.remove(proxyCredentialsKey(userId)) } - fun proxyCredentials(userId: UserIDEntity): ProxyCredentialsEntity? = + override fun proxyCredentials(userId: UserIDEntity): ProxyCredentialsEntity? = kaliumPreferences.getSerializable(proxyCredentialsKey(userId), ProxyCredentialsEntity.serializer()) private fun tokenKey(userId: UserIDEntity): String { diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/client/AuthTokenStorageTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/client/AuthTokenStorageTest.kt index ffe1b699eaf..b3bde143db5 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/client/AuthTokenStorageTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/client/AuthTokenStorageTest.kt @@ -37,7 +37,7 @@ class AuthTokenStorageTest { @BeforeTest fun setup() { mockSettings.clear() - authTokenStorage = AuthTokenStorage(kaliumPreferences) + authTokenStorage = AuthTokenStorageImpl(kaliumPreferences) } @Test @@ -68,7 +68,6 @@ class AuthTokenStorageTest { assertFails { authTokenStorage.updateToken(expected.userId, "access_token", "token_type", "refresh_token") - } } diff --git a/persistence/src/jvmMain/kotlin/com/wire/kalium/persistence/kmmSettings/GlobalPrefProvider.kt b/persistence/src/jvmMain/kotlin/com/wire/kalium/persistence/kmmSettings/GlobalPrefProvider.kt index 88d765b8921..023a88beae5 100644 --- a/persistence/src/jvmMain/kotlin/com/wire/kalium/persistence/kmmSettings/GlobalPrefProvider.kt +++ b/persistence/src/jvmMain/kotlin/com/wire/kalium/persistence/kmmSettings/GlobalPrefProvider.kt @@ -19,6 +19,7 @@ package com.wire.kalium.persistence.kmmSettings import com.wire.kalium.persistence.client.AuthTokenStorage +import com.wire.kalium.persistence.client.AuthTokenStorageImpl import com.wire.kalium.persistence.client.TokenStorage import com.wire.kalium.persistence.client.TokenStorageImpl import com.wire.kalium.persistence.dbPassphrase.PassphraseStorage @@ -37,7 +38,7 @@ actual class GlobalPrefProvider( ) actual val authTokenStorage: AuthTokenStorage - get() = AuthTokenStorage(kaliumPref) + get() = AuthTokenStorageImpl(kaliumPref) actual val passphraseStorage: PassphraseStorage get() = PassphraseStorageImpl(kaliumPref) actual val tokenStorage: TokenStorage diff --git a/tango-tests/src/integrationTest/kotlin/action/LoginActions.kt b/tango-tests/src/integrationTest/kotlin/action/LoginActions.kt index f7073f8443e..1bffb799c83 100644 --- a/tango-tests/src/integrationTest/kotlin/action/LoginActions.kt +++ b/tango-tests/src/integrationTest/kotlin/action/LoginActions.kt @@ -19,7 +19,7 @@ package action import com.wire.kalium.logic.CoreLogic import com.wire.kalium.logic.feature.auth.AddAuthenticatedUserUseCase -import com.wire.kalium.logic.feature.auth.AuthTokens +import com.wire.kalium.logic.feature.auth.AccountTokens import com.wire.kalium.logic.feature.auth.AuthenticationResult import com.wire.kalium.logic.feature.auth.AuthenticationScope import com.wire.kalium.network.api.base.model.SelfUserDTO @@ -37,7 +37,7 @@ object LoginActions { password: String, coreLogic: CoreLogic, authScope: AuthenticationScope, - ): AuthTokens { + ): AccountTokens { val loginResult = authScope.login(email, password, true) if (loginResult !is AuthenticationResult.Success) { error("User creds didn't work ($email, $password)")