Skip to content

Commit

Permalink
[YS-70] feat: JWT 토큰에 사용자 권한(RoleType) 정보 포함 (#18)
Browse files Browse the repository at this point in the history
* refact: delete unused file

* refact: add roleName to Enum

* feat: add memberRole to jwt token

* test: add role validation test code

* test: move usecase's test code to application package

* style: delete unused import

* refact: refactor exception handling to be more granular for token validation

* style: rename JpaMemberRepository to MemberRepository for better clarity
  • Loading branch information
Ji-soo708 committed Jan 26, 2025
1 parent fd65cf9 commit f33a1df
Show file tree
Hide file tree
Showing 14 changed files with 126 additions and 72 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package com.dobby.backend.application.usecase

import com.dobby.backend.domain.gateway.MemberGateway
import com.dobby.backend.domain.gateway.TokenGateway

class GenerateTestToken(
private val tokenGateway: TokenGateway
private val tokenGateway: TokenGateway,
private val memberGateway: MemberGateway,
) : UseCase<GenerateTestToken.Input, GenerateTestToken.Output> {
data class Input(
val memberId: Long
Expand All @@ -16,9 +18,10 @@ class GenerateTestToken(

override fun execute(input: Input): Output {
val memberId = input.memberId
val member = memberGateway.getById(memberId)
return Output(
accessToken = tokenGateway.generateAccessToken(memberId),
refreshToken = tokenGateway.generateRefreshToken(memberId)
accessToken = tokenGateway.generateAccessToken(member),
refreshToken = tokenGateway.generateRefreshToken(member)
)
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package com.dobby.backend.application.usecase

import com.dobby.backend.domain.gateway.MemberGateway
import com.dobby.backend.domain.gateway.TokenGateway

class GenerateTokenWithRefreshToken(
private val tokenGateway: TokenGateway
private val tokenGateway: TokenGateway,
private val memberGateway: MemberGateway,
) : UseCase<GenerateTokenWithRefreshToken.Input, GenerateTokenWithRefreshToken.Output> {
data class Input(
val refreshToken: String,
Expand All @@ -17,9 +19,10 @@ class GenerateTokenWithRefreshToken(

override fun execute(input: Input): Output {
val memberId = tokenGateway.extractMemberIdFromRefreshToken(input.refreshToken).toLong()
val member = memberGateway.getById(memberId)
return Output(
accessToken = tokenGateway.generateAccessToken(memberId),
refreshToken = tokenGateway.generateRefreshToken(memberId),
accessToken = tokenGateway.generateAccessToken(member),
refreshToken = tokenGateway.generateRefreshToken(member),
memberId = memberId
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ class AuthenticationTokenNotFoundException : AuthenticationException(ErrorCode.T
class AuthenticationTokenNotValidException : AuthenticationException(ErrorCode.TOKEN_NOT_VALID)
class AuthenticationTokenExpiredException : AuthenticationException(ErrorCode.TOKEN_EXPIRED)
class InvalidTokenTypeException : AuthenticationException(ErrorCode.INVALID_TOKEN_TYPE)
class InvalidTokenValueException : AuthenticationException(ErrorCode.INVALID_TOKEN_VALUE)
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ enum class ErrorCode(
TOKEN_NOT_VALID("AU0002", "Authentication token is not valid.", HttpStatus.UNAUTHORIZED),
TOKEN_EXPIRED("AU0003", "Authentication token has expired.", HttpStatus.UNAUTHORIZED),
INVALID_TOKEN_TYPE("AU0004", "Invalid token type", HttpStatus.UNAUTHORIZED),
INVALID_TOKEN_VALUE("AU0005", "Invalid token value", HttpStatus.UNAUTHORIZED),

/**
* Authorization error codes
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.dobby.backend.domain.gateway

import com.dobby.backend.domain.model.Member

interface TokenGateway {
fun generateAccessToken(memberId: Long): String
fun generateRefreshToken(memberId: Long): String
fun generateAccessToken(member: Member): String
fun generateRefreshToken(member: Member): String
fun extractMemberIdFromRefreshToken(token: String): String
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package com.dobby.backend.infrastructure.database.entity.enum

enum class RoleType {
RESEARCHER, PARTICIPANT
enum class RoleType(
val roleName: String
) {
RESEARCHER("RESEARCHER"),
PARTICIPANT("PARTICIPANT")
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ package com.dobby.backend.infrastructure.gateway
import com.dobby.backend.domain.gateway.MemberGateway
import com.dobby.backend.domain.model.Member
import com.dobby.backend.infrastructure.database.entity.MemberEntity
import com.dobby.backend.infrastructure.database.repository.MemberJpaRepository
import com.dobby.backend.infrastructure.database.repository.MemberRepository
import org.springframework.stereotype.Component

@Component
class MemberGatewayImpl(
private val jpaMemberRepository: MemberJpaRepository,
private val memberRepository: MemberRepository,
) : MemberGateway {
override fun getById(memberId: Long): Member {
return jpaMemberRepository
return memberRepository
.getReferenceById(memberId)
.let(MemberEntity::toDomain)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
package com.dobby.backend.infrastructure.gateway

import com.dobby.backend.domain.gateway.TokenGateway
import com.dobby.backend.domain.model.Member
import com.dobby.backend.infrastructure.token.JwtTokenProvider
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken
import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.stereotype.Component

@Component
class TokenGatewayImpl(
private val tokenProvider: JwtTokenProvider,
) : TokenGateway {
override fun generateAccessToken(memberId: Long): String {
override fun generateAccessToken(member: Member): String {
val authorities = listOf(SimpleGrantedAuthority(member.role?.roleName))
val authentication = UsernamePasswordAuthenticationToken(
memberId,
member.memberId,
null,
authorities
)
return tokenProvider.generateAccessToken(authentication)
}

override fun generateRefreshToken(memberId: Long): String {
override fun generateRefreshToken(member: Member): String {
val authentication = UsernamePasswordAuthenticationToken(
memberId,
null,
member.memberId,
null
)
return tokenProvider.generateRefreshToken(authentication)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import io.jsonwebtoken.Jwts
import org.springframework.boot.context.properties.EnableConfigurationProperties
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken
import org.springframework.security.core.Authentication
import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.stereotype.Component
import java.util.*
import javax.crypto.SecretKey
Expand Down Expand Up @@ -45,12 +46,16 @@ class JwtTokenProvider(
authentication: Authentication,
expirationDate: Date
): String {
val authorities = authentication.authorities.joinToString(",") {
it.authority
}

return Jwts.builder()
.header().add(TOKEN_TYPE_HEADER_KEY, tokenType)
.and()
.claims()
.add(MEMBER_ID_CLAIM_KEY, authentication.name)
.add(AUTHORITIES_CLAIM_KEY, authorities)
.and()
.expiration(expirationDate)
.encryptWith(signKey, Jwts.ENC.A128CBC_HS256)
Expand All @@ -64,7 +69,12 @@ class JwtTokenProvider(
if (tokenType != ACCESS_TOKEN_TYPE_VALUE) throw InvalidTokenTypeException()

val memberId = claims.payload[MEMBER_ID_CLAIM_KEY] as? String ?: throw MemberNotFoundException()
return UsernamePasswordAuthenticationToken(memberId, accessToken, emptyList())
val authorities = claims.payload[AUTHORITIES_CLAIM_KEY]?.toString()
?.split(",")
?.map { SimpleGrantedAuthority(it) }
?: emptyList()

return UsernamePasswordAuthenticationToken(memberId, accessToken, authorities)
} catch (e: ExpiredJwtException) {
throw AuthenticationTokenExpiredException()
} catch (e: JwtException) {
Expand All @@ -73,11 +83,17 @@ class JwtTokenProvider(
}

fun getMemberIdFromRefreshToken(refreshToken: String): String {
val claims = jwtParser.parseEncryptedClaims(refreshToken)
val tokenType = claims.header[TOKEN_TYPE_HEADER_KEY] ?: throw RuntimeException()
if (tokenType != REFRESH_TOKEN_TYPE_VALUE) throw RuntimeException()
return try {
val claims = jwtParser.parseEncryptedClaims(refreshToken)
val tokenType = claims.header[TOKEN_TYPE_HEADER_KEY]
if (tokenType != REFRESH_TOKEN_TYPE_VALUE) {
throw InvalidTokenTypeException()
}

return claims.payload[MEMBER_ID_CLAIM_KEY] as? String ?: throw RuntimeException()
claims.payload[MEMBER_ID_CLAIM_KEY] as? String ?: throw InvalidTokenValueException()
} catch (e: Exception) {
throw InvalidTokenValueException()
}
}

private fun generateAccessTokenExpiration() = Date(System.currentTimeMillis() + tokenProperties.expiration.access * 1000)
Expand All @@ -86,6 +102,7 @@ class JwtTokenProvider(

companion object {
const val MEMBER_ID_CLAIM_KEY = "member_id"
const val AUTHORITIES_CLAIM_KEY = "authorities"
const val TOKEN_TYPE_HEADER_KEY = "token_type"
const val ACCESS_TOKEN_TYPE_VALUE = "access_token"
const val REFRESH_TOKEN_TYPE_VALUE = "refresh_token"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package com.dobby.backend.application.usecase

import com.dobby.backend.domain.gateway.MemberGateway
import io.kotest.core.spec.style.BehaviorSpec
import com.dobby.backend.domain.gateway.TokenGateway
import com.dobby.backend.domain.model.Member
import com.dobby.backend.infrastructure.database.entity.enum.MemberStatus
import com.dobby.backend.infrastructure.database.entity.enum.ProviderType
import com.dobby.backend.infrastructure.database.entity.enum.RoleType
import io.kotest.matchers.shouldBe
import io.mockk.every
import io.mockk.mockk
import java.time.LocalDate

class GenerateTestTokenTest: BehaviorSpec({
val tokenGateway = mockk<TokenGateway>()
val memberGateway = mockk<MemberGateway>()
val generateTestToken = GenerateTestToken(tokenGateway, memberGateway)

given("memberId가 주어졌을 때") {
val member = Member(memberId = 1, oauthEmail = "[email protected]", contactEmail = "[email protected]",
provider = ProviderType.NAVER, role = RoleType.PARTICIPANT, name = "dobby",
birthDate = LocalDate.of(2000, 7, 8), status = MemberStatus.ACTIVE)
val accessToken = "testAccessToken"
val refreshToken = "testRefreshToken"

every { tokenGateway.generateAccessToken(member) } returns accessToken
every { tokenGateway.generateRefreshToken(member) } returns refreshToken
every { memberGateway.getById(1) } returns member

`when`("execute가 호출되면") {
val input = GenerateTestToken.Input(member.memberId)
val result = generateTestToken.execute(input)

then("생성된 accessToken과 refreshToken이 반환되어야 한다") {
result.accessToken shouldBe accessToken
result.refreshToken shouldBe refreshToken
}
}
}
})
Original file line number Diff line number Diff line change
@@ -1,25 +1,34 @@
package com.dobby.backend.domain.usecase
package com.dobby.backend.application.usecase

import com.dobby.backend.application.usecase.GenerateTokenWithRefreshToken
import com.dobby.backend.domain.gateway.MemberGateway
import com.dobby.backend.domain.gateway.TokenGateway
import com.dobby.backend.domain.model.Member
import com.dobby.backend.infrastructure.database.entity.enum.MemberStatus
import com.dobby.backend.infrastructure.database.entity.enum.ProviderType
import com.dobby.backend.infrastructure.database.entity.enum.RoleType
import io.kotest.core.spec.style.BehaviorSpec
import io.kotest.matchers.shouldBe
import io.mockk.every
import io.mockk.mockk
import java.time.LocalDate

class GenerateTokenWithRefreshTokenTest : BehaviorSpec({
val tokenGateway = mockk<TokenGateway>()
val generateTokenWithRefreshToken = GenerateTokenWithRefreshToken(tokenGateway)
val memberGateway = mockk<MemberGateway>()
val generateTokenWithRefreshToken = GenerateTokenWithRefreshToken(tokenGateway, memberGateway)

given("유효한 리프레시 토큰이 주어졌을 때") {
val validRefreshToken = "validRefreshToken"
val memberId = 123L
val member = Member(memberId = 1, oauthEmail = "[email protected]", contactEmail = "[email protected]",
provider = ProviderType.NAVER, role = RoleType.PARTICIPANT, name = "dobby",
birthDate = LocalDate.of(2000, 7, 8), status = MemberStatus.ACTIVE)
val accessToken = "newAccessToken"
val newRefreshToken = "newRefreshToken"

every { tokenGateway.extractMemberIdFromRefreshToken(validRefreshToken) } returns memberId.toString()
every { tokenGateway.generateAccessToken(memberId) } returns accessToken
every { tokenGateway.generateRefreshToken(memberId) } returns newRefreshToken
every { tokenGateway.extractMemberIdFromRefreshToken(validRefreshToken) } returns member.memberId.toString()
every { tokenGateway.generateAccessToken(member) } returns accessToken
every { tokenGateway.generateRefreshToken(member) } returns newRefreshToken
every { memberGateway.getById(1) } returns member

`when`("execute가 호출되면") {
val input = GenerateTokenWithRefreshToken.Input(refreshToken = validRefreshToken)
Expand All @@ -28,7 +37,7 @@ class GenerateTokenWithRefreshTokenTest : BehaviorSpec({
then("accessToken과 refreshToken이 생성되고, memberId가 포함된다") {
result.accessToken shouldBe accessToken
result.refreshToken shouldBe newRefreshToken
result.memberId shouldBe memberId
result.memberId shouldBe member.memberId
}
}
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import io.kotest.matchers.shouldNotBe
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken
import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.test.context.ActiveProfiles
import java.time.LocalDate
import kotlin.test.assertFailsWith
Expand All @@ -27,7 +28,8 @@ class JwtTokenProviderTest : BehaviorSpec() {
val member = Member(memberId = 1, oauthEmail = "[email protected]", contactEmail = "[email protected]",
provider = ProviderType.NAVER, role = RoleType.PARTICIPANT, name = "dobby",
birthDate = LocalDate.of(2000, 7, 8), status = MemberStatus.ACTIVE)
val authentication = UsernamePasswordAuthenticationToken(member.memberId, null)
val authorities = listOf(SimpleGrantedAuthority(member.role?.name ?: "PARTICIPANT"))
val authentication = UsernamePasswordAuthenticationToken(member.memberId, null, authorities)

`when`("해당 인증 정보로 JWT 토큰을 생성하면") {
val jwtToken = jwtTokenProvider.generateAccessToken(authentication)
Expand All @@ -42,17 +44,24 @@ class JwtTokenProviderTest : BehaviorSpec() {
val member = Member(memberId = 1, oauthEmail = "[email protected]", contactEmail = "[email protected]",
provider = ProviderType.NAVER, role = RoleType.PARTICIPANT, name = "dobby",
birthDate = LocalDate.of(2000, 7, 8), status = MemberStatus.ACTIVE)
val authentication = UsernamePasswordAuthenticationToken(member.memberId, null)
val authorities = listOf(SimpleGrantedAuthority(member.role?.name ?: "PARTICIPANT"))
val authentication = UsernamePasswordAuthenticationToken(member.memberId, null, authorities)
val validToken = jwtTokenProvider.generateAccessToken(authentication)

`when`("해당 토큰을 파싱하면") {
val parsedAuthentication = jwtTokenProvider.parseAuthentication(validToken)
val extractedMemberId = parsedAuthentication.principal
val extractedAuthorities = parsedAuthentication.authorities

then("파싱된 멤버의 ID는 원래 멤버의 ID와 같아야 한다") {
extractedMemberId shouldNotBe null
extractedMemberId shouldBe member.memberId.toString()
}

then("파싱된 권한(role)은 원래 멤버의 역할과 같아야 한다") {
extractedAuthorities.size shouldBe 1
extractedAuthorities.first().authority shouldBe member.role?.name
}
}
}

Expand Down

0 comments on commit f33a1df

Please sign in to comment.