From aab9f85dfc15b5dc70173bdb9ae825be018fda90 Mon Sep 17 00:00:00 2001 From: Nathaniel Fischer Date: Wed, 13 May 2020 16:29:56 -0700 Subject: [PATCH] Change variances so that JwtValidator[Any] will work (#31) on any unregistered claims. --- .gitignore | 5 +- build.sc | 2 +- jose/src/black/door/jose/jws/Jws.scala | 2 +- jose/src/black/door/jose/jwt/Claims.scala | 2 +- jose/src/black/door/jose/jwt/Jwt.scala | 81 +++++++++++-------- .../black/door/jose/jwt/JwtValidator.scala | 5 +- jose/src/black/door/jose/jwt/package.scala | 3 +- jose/src/black/door/jose/package.scala | 2 +- 8 files changed, 60 insertions(+), 42 deletions(-) diff --git a/.gitignore b/.gitignore index 2a6c276..b14262d 100644 --- a/.gitignore +++ b/.gitignore @@ -96,4 +96,7 @@ fabric.properties # Android studio 3.1+ serialized cache file .idea/caches/build_file_checksums.ser -.idea \ No newline at end of file +.idea + +.bloop +.metals \ No newline at end of file diff --git a/build.sc b/build.sc index 8ab6ece..240983c 100644 --- a/build.sc +++ b/build.sc @@ -6,7 +6,7 @@ import scalalib._ val devInfo = Developer("kag0", "Nathan Fischer", "https://github.com/kag0", Some("blackdoor"), Some("https://github.com/blackdoor")) val `2.12` = "2.12.10" -val `2.13` = "2.13.0" +val `2.13` = "2.13.2" trait BaseModule extends CrossScalaModule { def scalacOptions = Seq("-Xfatal-warnings", "-feature", "-unchecked", "-deprecation") diff --git a/jose/src/black/door/jose/jws/Jws.scala b/jose/src/black/door/jose/jws/Jws.scala index 93e3ce0..b74eb06 100644 --- a/jose/src/black/door/jose/jws/Jws.scala +++ b/jose/src/black/door/jose/jws/Jws.scala @@ -13,7 +13,7 @@ import scala.collection.immutable.Seq import scala.concurrent.{ExecutionContext, Future} import scala.util.Try -trait Jws[A] { +trait Jws[+A] { def header: JwsHeader def payload: A diff --git a/jose/src/black/door/jose/jwt/Claims.scala b/jose/src/black/door/jose/jwt/Claims.scala index 0f88a93..1624d60 100644 --- a/jose/src/black/door/jose/jwt/Claims.scala +++ b/jose/src/black/door/jose/jwt/Claims.scala @@ -2,7 +2,7 @@ package black.door.jose.jwt import java.time.Instant -case class Claims[UnregisteredClaims]( +case class Claims[+UnregisteredClaims]( iss: Option[String] = None, sub: Option[String] = None, aud: Option[String] = None, diff --git a/jose/src/black/door/jose/jwt/Jwt.scala b/jose/src/black/door/jose/jwt/Jwt.scala index c14b3b9..0dc8211 100644 --- a/jose/src/black/door/jose/jwt/Jwt.scala +++ b/jose/src/black/door/jose/jwt/Jwt.scala @@ -15,15 +15,27 @@ import scala.collection.immutable.Seq import scala.concurrent.duration.Duration import scala.concurrent.{Await, ExecutionContext, Future} -case class Jwt[UnregisteredClaims](header: JwsHeader, claims: Claims[UnregisteredClaims]) extends Jws[Claims[UnregisteredClaims]] { +case class Jwt[+UnregisteredClaims](header: JwsHeader, claims: Claims[UnregisteredClaims]) + extends Jws[Claims[UnregisteredClaims]] { def payload = claims } object Jwt { + @throws[KeyException] - def sign[PC](claims: Claims[PC], key: Jwk, algorithms: Seq[SignatureAlgorithm] = SignatureAlgorithms.all) - (implicit headerSerializer: Mapper[JwsHeader, Array[Byte]], payloadSerializer: Mapper[Claims[PC], Array[Byte]]) = { - val alg = key.alg.getOrElse(throw new KeyException("Jwk must have a defined alg to use Jwt.sign. Alternatively, create a Jwt with an explicit JwsHeader.")) + def sign[PC]( + claims: Claims[PC], + key: Jwk, + algorithms: Seq[SignatureAlgorithm] = SignatureAlgorithms.all + )( + implicit headerSerializer: Mapper[JwsHeader, Array[Byte]], + payloadSerializer: Mapper[Claims[PC], Array[Byte]] + ) = { + val alg = key.alg.getOrElse( + throw new KeyException( + "Jwk must have a defined alg to use Jwt.sign. Alternatively, create a Jwt with an explicit JwsHeader." + ) + ) Jwt(JwsHeader(alg, typ = Some("JWT"), kid = key.kid), claims).sign(key, algorithms) } @@ -42,45 +54,48 @@ object Jwt { * @return */ def validate[C]( - compact: String, - keyResolver: KeyResolver[Claims[C]], - jwtValidator: JwtValidator[C] = JwtValidator.empty, - fallbackJwtValidator: JwtValidator[C] = JwtValidator.defaultValidator(), - algorithms: Seq[SignatureAlgorithm] = SignatureAlgorithms.all - ) - ( - implicit payloadDeserializer: Mapper[Array[Byte], Claims[C]], - headerDeserializer: Mapper[Array[Byte], JwsHeader], - ec: ExecutionContext - ): Future[Either[String, Jwt[C]]] = { - EitherT(Jws.validate[Claims[C]](compact, keyResolver, algorithms)) - .flatMap { jws => - val jwt = Jwt(jws.header, jws.payload) - OptionT(jwtValidator.orElse(fallbackJwtValidator).apply(jwt)).toLeft(jwt) - }.value - } + compact: String, + keyResolver: KeyResolver[Claims[C]], + jwtValidator: JwtValidator[C] = JwtValidator.empty, + fallbackJwtValidator: JwtValidator[C] = JwtValidator.defaultValidator(), + algorithms: Seq[SignatureAlgorithm] = SignatureAlgorithms.all + )( + implicit payloadDeserializer: Mapper[Array[Byte], Claims[C]], + headerDeserializer: Mapper[Array[Byte], JwsHeader], + ec: ExecutionContext + ): Future[Either[String, Jwt[C]]] = + EitherT(Jws.validate[Claims[C]](compact, keyResolver, algorithms)).flatMap { jws => + val jwt = Jwt(jws.header, jws.payload) + OptionT(jwtValidator.orElse(fallbackJwtValidator).apply(jwt)).toLeft(jwt) + }.value - private val sadSpasticLittleEc = ExecutionContext.fromExecutorService(Executors.newCachedThreadPool) + private val sadSpasticLittleEc = + ExecutionContext.fromExecutorService(Executors.newCachedThreadPool) sealed trait TypedValidation[C] { def compact: String case class using( - keyResolver: KeyResolver[Claims[C]], - jwtValidator: JwtValidator[C] = JwtValidator.empty[C], - fallbackJwtValidator: JwtValidator[C] = JwtValidator.defaultValidator[C](), - algorithms: Seq[SignatureAlgorithm] = SignatureAlgorithms.all - ) - ( - implicit payloadDeserializer: Mapper[Array[Byte], Claims[C]], - headerDeserializer: Mapper[Array[Byte], JwsHeader] - ) { + keyResolver: KeyResolver[Claims[C]], + jwtValidator: JwtValidator[C] = JwtValidator.empty[C], + fallbackJwtValidator: JwtValidator[C] = JwtValidator.defaultValidator(), + algorithms: Seq[SignatureAlgorithm] = SignatureAlgorithms.all + )( + implicit payloadDeserializer: Mapper[Array[Byte], Claims[C]], + headerDeserializer: Mapper[Array[Byte], JwsHeader] + ) { + def now = Await.result( - validate(compact, keyResolver, jwtValidator, fallbackJwtValidator, algorithms)(payloadDeserializer, headerDeserializer, sadSpasticLittleEc), + validate(compact, keyResolver, jwtValidator, fallbackJwtValidator, algorithms)( + payloadDeserializer, + headerDeserializer, + sadSpasticLittleEc + ), Duration(1, TimeUnit.SECONDS) ) - def async(implicit ec: ExecutionContext) = validate(compact, keyResolver, jwtValidator, fallbackJwtValidator, algorithms) + def async(implicit ec: ExecutionContext) = + validate(compact, keyResolver, jwtValidator, fallbackJwtValidator, algorithms) } } diff --git a/jose/src/black/door/jose/jwt/JwtValidator.scala b/jose/src/black/door/jose/jwt/JwtValidator.scala index 5ebbe3f..5eec190 100644 --- a/jose/src/black/door/jose/jwt/JwtValidator.scala +++ b/jose/src/black/door/jose/jwt/JwtValidator.scala @@ -3,6 +3,7 @@ package black.door.jose.jwt import java.time.{Clock, Instant} import scala.concurrent.{ExecutionContext, Future} +import scala.language.implicitConversions object JwtValidator { @@ -15,9 +16,9 @@ object JwtValidator { private def iatMessage(maybeIat: Option[Instant]) = maybeIat.map(iat => s"It was issued at $iat.").getOrElse("") - def defaultValidator[C](clock: Clock = Clock.systemDefaultZone): JwtValidator[C] = { + def defaultValidator(clock: Clock = Clock.systemDefaultZone) = { val now = Instant.now(clock) - JwtValidator.fromSync { + JwtValidator.fromSync[Any] { case Jwt(_, claims) if claims.exp.exists(_.isBefore(now)) => s"Token expired at ${claims.exp.get}.${iatMessage(claims.iat)} It is now $now." case Jwt(_, claims) if claims.nbf.exists(_.isAfter(now)) => diff --git a/jose/src/black/door/jose/jwt/package.scala b/jose/src/black/door/jose/jwt/package.scala index ebff51d..cfc44c2 100644 --- a/jose/src/black/door/jose/jwt/package.scala +++ b/jose/src/black/door/jose/jwt/package.scala @@ -3,6 +3,5 @@ package black.door.jose import scala.concurrent.Future package object jwt { - - type JwtValidator[UnregisteredClaims] = Jwt[UnregisteredClaims] => Future[Option[String]] + type JwtValidator[-UnregisteredClaims] = Jwt[UnregisteredClaims] => Future[Option[String]] } diff --git a/jose/src/black/door/jose/package.scala b/jose/src/black/door/jose/package.scala index 19fe294..ca72829 100644 --- a/jose/src/black/door/jose/package.scala +++ b/jose/src/black/door/jose/package.scala @@ -1,5 +1,5 @@ package black.door package object jose { - type Mapper[A, B] = A => Either[String, B] + type Mapper[-A, B] = A => Either[String, B] }