Skip to content

Commit

Permalink
Add JWT retryable exception type (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
artem-v authored Feb 5, 2025
1 parent c61af09 commit 2aa59d8
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 68 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package io.scalecube.security.environment;

import java.util.HashMap;
import java.util.Map;
import java.util.function.Supplier;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ExtensionContext.Namespace;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.jupiter.api.extension.ParameterResolutionException;
import org.junit.jupiter.api.extension.ParameterResolver;

public class IntegrationEnvironmentFixture
implements BeforeAllCallback, ExtensionContext.Store.CloseableResource, ParameterResolver {

private static final Map<Class<?>, Supplier<?>> PARAMETERS_TO_RESOLVE = new HashMap<>();

private static VaultEnvironment vaultEnvironment;

@Override
public void beforeAll(ExtensionContext context) {
context
.getRoot()
.getStore(Namespace.GLOBAL)
.getOrComputeIfAbsent(
this.getClass(),
key -> {
vaultEnvironment = VaultEnvironment.start();
return this;
});

PARAMETERS_TO_RESOLVE.put(VaultEnvironment.class, () -> vaultEnvironment);
}

@Override
public void close() {
if (vaultEnvironment != null) {
vaultEnvironment.close();
}
}

@Override
public boolean supportsParameter(
ParameterContext parameterContext, ExtensionContext extensionContext)
throws ParameterResolutionException {
Class<?> type = parameterContext.getParameter().getType();
return PARAMETERS_TO_RESOLVE.keySet().stream().anyMatch(type::isAssignableFrom);
}

@Override
public Object resolveParameter(
ParameterContext parameterContext, ExtensionContext extensionContext)
throws ParameterResolutionException {
Class<?> type = parameterContext.getParameter().getType();
return PARAMETERS_TO_RESOLVE.get(type).get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,13 @@ public static Throwable getRootCause(Throwable throwable) {
return throwable;
}

public String newServiceToken() {
String keyName = createIdentityKey(); // oidc/key
String roleName = createIdentityRole(keyName); // oidc/role
String clientToken = login(); // onboard entity with policy
return generateIdentityToken(clientToken, roleName);
}

@Override
public void close() {
vault.stop();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package io.scalecube.security.tokens.jwt;

import static io.scalecube.security.environment.VaultEnvironment.getRootCause;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.StringStartsWith.startsWith;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
Expand All @@ -9,34 +12,20 @@
import static org.mockito.Mockito.when;

import io.jsonwebtoken.Locator;
import io.scalecube.security.environment.IntegrationEnvironmentFixture;
import io.scalecube.security.environment.VaultEnvironment;
import java.security.Key;
import java.time.Duration;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;

@ExtendWith(IntegrationEnvironmentFixture.class)
public class JsonwebtokenResolverTests {

private static VaultEnvironment vaultEnvironment;

@BeforeAll
static void beforeAll() {
vaultEnvironment = VaultEnvironment.start();
}

@AfterAll
static void afterAll() {
if (vaultEnvironment != null) {
vaultEnvironment.close();
}
}

@Test
void testResolveTokenSuccessfully() throws Exception {
final var token = generateToken();
void testResolveTokenSuccessfully(VaultEnvironment vaultEnvironment) throws Exception {
final var token = vaultEnvironment.newServiceToken();

final var jwtToken =
new JsonwebtokenResolver(
Expand All @@ -50,13 +39,13 @@ void testResolveTokenSuccessfully() throws Exception {
.get(3, TimeUnit.SECONDS);

assertNotNull(jwtToken, "jwtToken");
Assertions.assertTrue(jwtToken.header().size() > 0, "jwtToken.header: " + jwtToken.header());
Assertions.assertTrue(jwtToken.payload().size() > 0, "jwtToken.payload: " + jwtToken.payload());
assertTrue(jwtToken.header().size() > 0, "jwtToken.header: " + jwtToken.header());
assertTrue(jwtToken.payload().size() > 0, "jwtToken.payload: " + jwtToken.payload());
}

@Test
void testJwksKeyLocatorThrowsError() {
final var token = generateToken();
void testJwksKeyLocatorThrowsError(VaultEnvironment vaultEnvironment) {
final var token = vaultEnvironment.newServiceToken();

Locator<Key> keyLocator = mock(Locator.class);
when(keyLocator.locate(any())).thenThrow(new RuntimeException("Cannot get key"));
Expand All @@ -66,16 +55,25 @@ void testJwksKeyLocatorThrowsError() {
fail("Expected exception");
} catch (Exception e) {
final var ex = getRootCause(e);
assertNotNull(ex);
assertNotNull(ex.getMessage());
assertTrue(ex.getMessage().startsWith("Cannot get key"), "Exception: " + ex);
assertThat(ex, instanceOf(RuntimeException.class));
assertThat(ex.getMessage(), startsWith("Cannot get key"));
}
}

private static String generateToken() {
String keyName = vaultEnvironment.createIdentityKey(); // oidc/key
String roleName = vaultEnvironment.createIdentityRole(keyName); // oidc/role
String clientToken = vaultEnvironment.login(); // onboard entity with policy
return vaultEnvironment.generateIdentityToken(clientToken, roleName);
@Test
void testJwksKeyLocatorThrowsRetryableError(VaultEnvironment vaultEnvironment) {
final var token = vaultEnvironment.newServiceToken();

Locator<Key> keyLocator = mock(Locator.class);
when(keyLocator.locate(any())).thenThrow(new JwtUnavailableException("JWKS timeout"));

try {
new JsonwebtokenResolver(keyLocator).resolve(token).get(3, TimeUnit.SECONDS);
fail("Expected exception");
} catch (Exception e) {
final var ex = getRootCause(e);
assertThat(ex, instanceOf(JwtUnavailableException.class));
assertThat(ex.getMessage(), startsWith("JWKS timeout"));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import static io.scalecube.security.environment.VaultEnvironment.getRootCause;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.StringStartsWith.startsWith;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.testcontainers.shaded.org.apache.commons.lang3.RandomStringUtils.randomAlphabetic;

import io.scalecube.security.environment.IntegrationEnvironmentFixture;
import io.scalecube.security.environment.VaultEnvironment;
import io.scalecube.security.tokens.jwt.JsonwebtokenResolver;
import io.scalecube.security.tokens.jwt.JwksKeyLocator;
Expand All @@ -17,29 +20,15 @@
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;

@ExtendWith(IntegrationEnvironmentFixture.class)
public class VaultServiceTokenTests {

private static VaultEnvironment vaultEnvironment;

@BeforeAll
static void beforeAll() {
vaultEnvironment = VaultEnvironment.start();
}

@AfterAll
static void afterAll() {
if (vaultEnvironment != null) {
vaultEnvironment.close();
}
}

@Test
void testGetServiceTokenUsingWrongCredentials() throws Exception {
void testGetServiceTokenUsingWrongCredentials(VaultEnvironment vaultEnvironment)
throws Exception {
final var serviceTokenSupplier =
new VaultServiceTokenSupplier.Builder()
.vaultAddress(vaultEnvironment.vaultAddr())
Expand All @@ -54,14 +43,12 @@ void testGetServiceTokenUsingWrongCredentials() throws Exception {
} catch (ExecutionException e) {
final var ex = getRootCause(e);
assertNotNull(ex);
assertNotNull(ex.getMessage());
assertTrue(
ex.getMessage().contains("Failed to get service token, status=403"), "Exception: " + ex);
assertThat(ex.getMessage(), startsWith("Failed to get service token, status=403"));
}
}

@Test
void testGetNonExistingServiceToken() throws Exception {
void testGetNonExistingServiceToken(VaultEnvironment vaultEnvironment) throws Exception {
final var nonExistingServiceRole = "non-existing-role-" + System.currentTimeMillis();

final var serviceTokenSupplier =
Expand All @@ -78,14 +65,12 @@ void testGetNonExistingServiceToken() throws Exception {
} catch (ExecutionException e) {
final var ex = getRootCause(e);
assertNotNull(ex);
assertNotNull(ex.getMessage());
assertTrue(
ex.getMessage().contains("Failed to get service token, status=400"), "Exception: " + ex);
assertThat(ex.getMessage(), startsWith("Failed to get service token, status=400"));
}
}

@Test
void testGetServiceTokenByWrongServiceRole() throws Exception {
void testGetServiceTokenByWrongServiceRole(VaultEnvironment vaultEnvironment) throws Exception {
final var now = System.currentTimeMillis();
final var serviceRole1 = "role1-" + now;
final var serviceRole2 = "role2-" + now;
Expand Down Expand Up @@ -122,14 +107,12 @@ void testGetServiceTokenByWrongServiceRole() throws Exception {
} catch (ExecutionException e) {
final var ex = getRootCause(e);
assertNotNull(ex);
assertNotNull(ex.getMessage());
assertTrue(
ex.getMessage().contains("Failed to get service token, status=400"), "Exception: " + ex);
assertThat(ex.getMessage(), startsWith("Failed to get service token, status=400"));
}
}

@Test
void testGetServiceTokenSuccessfully() throws Exception {
void testGetServiceTokenSuccessfully(VaultEnvironment vaultEnvironment) throws Exception {
final var now = System.currentTimeMillis();
final var serviceRole = "role-" + now;
final var tags = Map.of("type", "ops", "ns", "develop");
Expand Down Expand Up @@ -164,8 +147,8 @@ void testGetServiceTokenSuccessfully() throws Exception {
.get(3, TimeUnit.SECONDS);

assertNotNull(jwtToken, "jwtToken");
Assertions.assertTrue(jwtToken.header().size() > 0, "jwtToken.header: " + jwtToken.header());
Assertions.assertTrue(jwtToken.payload().size() > 0, "jwtToken.payload: " + jwtToken.payload());
assertTrue(jwtToken.header().size() > 0, "jwtToken.header: " + jwtToken.header());
assertTrue(jwtToken.payload().size() > 0, "jwtToken.payload: " + jwtToken.payload());
}

private static String toQualifiedName(String role, Map<String, String> tags) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.net.http.HttpResponse.BodyHandlers;
import java.net.http.HttpTimeoutException;
import java.security.Key;
import java.security.KeyFactory;
import java.security.PublicKey;
Expand Down Expand Up @@ -55,13 +56,11 @@ protected Key locate(JwsHeader header) {
kid -> {
final var key = findKeyById(computeKeyList(), kid);
if (key == null) {
throw new RuntimeException("Cannot find key by kid: " + kid);
throw new JwtUnavailableException("Cannot find key by kid: " + kid);
}
return new CachedKey(key, System.currentTimeMillis() + keyTtl);
})
.key();
} catch (Exception ex) {
throw new JwtTokenException(ex);
} finally {
tryCleanup();
}
Expand All @@ -77,8 +76,13 @@ private JwkInfoList computeKeyList() {
.send(
HttpRequest.newBuilder(jwksUri).GET().timeout(requestTimeout).build(),
BodyHandlers.ofInputStream());
} catch (Exception e) {
throw new RuntimeException("Failed to retrive jwk keys", e);
} catch (HttpTimeoutException e) {
throw new JwtUnavailableException("Failed to retrive jwk keys", e);
} catch (IOException e) {
throw new RuntimeException(e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}

final var statusCode = httpResponse.statusCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import java.util.StringJoiner;

/**
* Generic exception type for JWT token resolution errors. Used as part {@link JwtTokenResolver}
* mechanism, and responsible to abstract token resolution problems.
*/
public class JwtTokenException extends RuntimeException {

public JwtTokenException(String message) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.scalecube.security.tokens.jwt;

/**
* Special JWT exception type indicating transient error during token resolution. For example such
* transient errors are:
*
* <ul>
* <li>Key Rotation: JWKS endpoints often implement key rotation policies where keys are
* periodically changed for security reasons. If the JWT was issued with a "kid" that
* corresponds to a key that has since been rotated out, that key won't be available in the
* JWKS anymore.
* <li>Network or Server Issues: if the JWKS URI is temporarily down, inaccessible, or
* experiencing issues, cleint might not be able to retrieve the keys, or the list of keys
* might be incomplete or outdated.
* </ul>
*/
public class JwtUnavailableException extends JwtTokenException {

public JwtUnavailableException(String message) {
super(message);
}

public JwtUnavailableException(String message, Throwable cause) {
super(message, cause);
}
}

0 comments on commit 2aa59d8

Please sign in to comment.