Skip to content

Commit

Permalink
Add support for OAuth login flow using authorization_grant with PKCE (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
rli authored Apr 12, 2024
1 parent 937546d commit cf855e6
Show file tree
Hide file tree
Showing 29 changed files with 2,342 additions and 155 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,17 @@ dependencies {
}

configurations {
all {
// IDE provides netty
exclude("io.netty")
}

// Make sure we exclude stuff we either A) ships with IDE, B) we don't use to cut down on size
runtimeClasspath {
exclude(group = "org.slf4j")
exclude(group = "org.jetbrains.kotlin")
exclude(group = "org.jetbrains.kotlinx")

}
}

Expand Down
2 changes: 1 addition & 1 deletion buildSrc/src/main/kotlin/toolkit-generate-sdks.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ sourceSets {
}

java {
setSrcDirs(listOf(sdkGenerator.srcDir()))
setSrcDirs(listOf(sdkGenerator.srcDir(), "src"))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ configurations {
}

all {
// IDE provides netty
exclude("io.netty")

if (name.startsWith("detekt")) {
return@all
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider
import software.aws.toolkits.core.utils.test.aString
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.sso.AccessToken
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
import software.aws.toolkits.jetbrains.services.amazonq.clients.AmazonQStreamingClient
import software.aws.toolkits.jetbrains.utils.rules.CodeInsightTestFixtureRule
Expand All @@ -47,7 +47,7 @@ open class AmazonQTestBase(
project = projectRule.project
toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project))

val accessToken = AccessToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())

val provider = mock<BearerTokenProvider> {
doReturn(accessToken).whenever(it).refresh()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider
import software.aws.toolkits.core.utils.test.aString
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.sso.AccessToken
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.clients.FeatureDevClient
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.clients.GenerateTaskAssistPlanResult
Expand Down Expand Up @@ -91,7 +91,7 @@ open class FeatureDevTestBase(
open fun setup() {
project = projectRule.project
toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project))
val accessToken = AccessToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
val provider = mock<BearerTokenProvider> {
doReturn(accessToken).whenever(it).refresh()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider
import software.aws.toolkits.core.utils.test.aString
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.sso.AccessToken
import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
import software.aws.toolkits.jetbrains.services.codemodernizer.client.GumbyClient
import software.aws.toolkits.jetbrains.services.codemodernizer.model.CodeModernizerArtifact
Expand Down Expand Up @@ -218,7 +218,7 @@ open class CodeWhispererCodeModernizerTestBase(
project = projectRule.project
toolkitConnectionManager = spy(ToolkitConnectionManager.getInstance(project))

val accessToken = AccessToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
val accessToken = DeviceAuthorizationGrantToken(aString(), aString(), aString(), aString(), Instant.MAX, Instant.now())
val provider = mock<BearerTokenProvider> {
doReturn(accessToken).whenever(it).refresh()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
</extensionPoints>

<extensions defaultExtensionNs="com.intellij">
<httpRequestHandler implementation="software.aws.toolkits.jetbrains.core.credentials.sso.pkce.ToolkitOAuthCallbackHandler"/>

<applicationService serviceInterface="software.aws.toolkits.jetbrains.settings.AwsSettings"
serviceImplementation="software.aws.toolkits.jetbrains.settings.DefaultAwsSettings"
testServiceImplementation="software.aws.toolkits.jetbrains.settings.MockAwsSettings"/>
Expand Down Expand Up @@ -52,6 +54,8 @@
<projectService serviceInterface="software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager"
serviceImplementation="software.aws.toolkits.jetbrains.core.credentials.DefaultToolkitConnectionManager"/>

<registryKey key="aws.dev.pkceAuth" description="True if new authorization requests should be using the PKCE grant flow"
defaultValue="false" restartRequired="false"/>
<registryKey key="aws.telemetry.endpoint" description="Endpoint to use for publishing AWS client-side telemetry"
defaultValue="https://client-telemetry.us-east-1.amazonaws.com" restartRequired="true"/>
<registryKey key="aws.telemetry.identityPool" description="Cognito identity pool to use for publishing AWS client-side telemetry"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

package software.aws.toolkits.jetbrains.core.credentials.sso

import com.fasterxml.jackson.annotation.JsonIgnore
import com.fasterxml.jackson.annotation.JsonInclude
import com.fasterxml.jackson.annotation.JsonSubTypes
import com.fasterxml.jackson.annotation.JsonTypeInfo
import com.intellij.collaboration.auth.credentials.Credentials
import software.amazon.awssdk.auth.token.credentials.SdkToken
import software.amazon.awssdk.services.sso.SsoClient
import software.amazon.awssdk.services.ssooidc.SsoOidcClient
Expand All @@ -15,29 +19,76 @@ import java.util.Optional
/**
* Access token returned from [SsoOidcClient.createToken] used to retrieve AWS Credentials from [SsoClient.getRoleCredentials].
*/
data class AccessToken(
val startUrl: String,
val region: String,
@JsonTypeInfo(use = JsonTypeInfo.Id.DEDUCTION)
@JsonSubTypes(value = [JsonSubTypes.Type(DeviceAuthorizationGrantToken::class), JsonSubTypes.Type(PKCEAuthorizationGrantToken::class) ])
sealed interface AccessToken : SdkToken, Credentials {
val region: String

@SensitiveField
val accessToken: String,
override val accessToken: String

@SensitiveField
@JsonInclude(JsonInclude.Include.NON_NULL)
val refreshToken: String? = null,
val expiresAt: Instant,
val createdAt: Instant = Instant.EPOCH
) : SdkToken {
@get:JsonInclude(JsonInclude.Include.NON_NULL)
val refreshToken: String?

val expiresAt: Instant
val createdAt: Instant

override fun token() = accessToken

override fun expirationTime() = Optional.of(expiresAt)

@get:JsonIgnore
val ssoUrl: String
}

data class DeviceAuthorizationGrantToken(
val startUrl: String,
override val region: String,
override val accessToken: String,
override val refreshToken: String? = null,
override val expiresAt: Instant,
override val createdAt: Instant = Instant.EPOCH
) : AccessToken {
override val ssoUrl: String
get() = startUrl

override fun toString() = redactedString(this)
}

data class PKCEAuthorizationGrantToken(
val issuerUrl: String,
override val region: String,
override val accessToken: String,
override val refreshToken: String,
override val expiresAt: Instant,
override val createdAt: Instant
) : AccessToken {
override val ssoUrl: String
get() = issuerUrl

override fun toString() = redactedString(this)
}

// we really don't need to differentitate since they refresh the same way, but to save some mental cycles,
// treat them as independent so we don't need to worry about intermingling the token/registration combos
@JsonTypeInfo(use = JsonTypeInfo.Id.DEDUCTION)
@JsonSubTypes(value = [JsonSubTypes.Type(DeviceGrantAccessTokenCacheKey::class), JsonSubTypes.Type(PKCEAccessTokenCacheKey::class) ])
sealed interface AccessTokenCacheKey {
val scopes: List<String>
}

// diverging from SDK/CLI impl here since they do: sha1sum(sessionName ?: startUrl)
// which isn't good enough for us
// only used in scoped case
data class AccessTokenCacheKey(
data class DeviceGrantAccessTokenCacheKey(
val connectionId: String,
val startUrl: String,
val scopes: List<String>
)
override val scopes: List<String>
) : AccessTokenCacheKey

data class PKCEAccessTokenCacheKey(
val issuerUrl: String,
val region: String,
override val scopes: List<String>
) : AccessTokenCacheKey
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import java.time.Instant
/**
* Returned by [SsoOidcClient.startDeviceAuthorization] that contains the required data to construct the user visible SSO login flow.
*/
@Deprecated("Device authorization grant flow is deprecated")
data class Authorization(
@SensitiveField
val deviceCode: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package software.aws.toolkits.jetbrains.core.credentials.sso

import com.fasterxml.jackson.annotation.JsonInclude
import com.fasterxml.jackson.annotation.JsonSubTypes
import com.fasterxml.jackson.annotation.JsonTypeInfo
import software.amazon.awssdk.services.ssooidc.SsoOidcClient
import software.aws.toolkits.core.utils.SensitiveField
import software.aws.toolkits.core.utils.redactedString
Expand All @@ -14,22 +16,60 @@ import java.time.Instant
*
* It should be persisted for reuse through many authentication requests.
*/
data class ClientRegistration(
@SensitiveField
val clientId: String,
@JsonTypeInfo(use = JsonTypeInfo.Id.DEDUCTION, defaultImpl = DeviceAuthorizationClientRegistration::class)
@JsonSubTypes(value = [JsonSubTypes.Type(DeviceAuthorizationClientRegistration::class), JsonSubTypes.Type(PKCEClientRegistration::class) ])
sealed interface ClientRegistration {
val clientId: String

@SensitiveField
val clientSecret: String,
val expiresAt: Instant,
@JsonInclude(JsonInclude.Include.NON_EMPTY)
val scopes: List<String> = emptyList()
) {
val clientSecret: String

val expiresAt: Instant

@get:JsonInclude(JsonInclude.Include.NON_EMPTY)
val scopes: List<String>
}

data class DeviceAuthorizationClientRegistration(
override val clientId: String,
override val clientSecret: String,
override val expiresAt: Instant,
override val scopes: List<String> = emptyList(),
) : ClientRegistration {
override fun toString(): String = redactedString(this)
}

data class PKCEClientRegistration(
override val clientId: String,
override val clientSecret: String,
override val expiresAt: Instant,
override val scopes: List<String>,
// fields below are implied from the key, but trying reverse the key is annoying
val issuerUrl: String,
val region: String,
val clientType: String,
val grantTypes: List<String>,
val redirectUris: List<String>,
) : ClientRegistration {
override fun toString(): String = redactedString(this)
}

sealed interface ClientRegistrationCacheKey

// only applicable in scoped registration path
// based on internal development branch @da780a4,L2574-2586
data class ClientRegistrationCacheKey(
data class DeviceAuthorizationClientRegistrationCacheKey(
val startUrl: String,
val scopes: List<String>,
val region: String,
)
) : ClientRegistrationCacheKey

data class PKCEClientRegistrationCacheKey(
val issuerUrl: String,
val region: String,
val scopes: List<String>,
// assume clientType, grantTypes, redirectUris are static, but throw them in just in case
val clientType: String,
val grantTypes: List<String>,
val redirectUris: List<String>
) : ClientRegistrationCacheKey
Loading

0 comments on commit cf855e6

Please sign in to comment.