Skip to content

Commit

Permalink
Change variances so that JwtValidator[Any] will work (#31)
Browse files Browse the repository at this point in the history
on any unregistered claims.
  • Loading branch information
nrktkt authored May 13, 2020
1 parent faf70d9 commit aab9f85
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 42 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,7 @@ fabric.properties
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser

.idea
.idea

.bloop
.metals
2 changes: 1 addition & 1 deletion build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import scalalib._
val devInfo = Developer("kag0", "Nathan Fischer", "https://github.com/kag0", Some("blackdoor"), Some("https://github.com/blackdoor"))

val `2.12` = "2.12.10"
val `2.13` = "2.13.0"
val `2.13` = "2.13.2"

trait BaseModule extends CrossScalaModule {
def scalacOptions = Seq("-Xfatal-warnings", "-feature", "-unchecked", "-deprecation")
Expand Down
2 changes: 1 addition & 1 deletion jose/src/black/door/jose/jws/Jws.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import scala.collection.immutable.Seq
import scala.concurrent.{ExecutionContext, Future}
import scala.util.Try

trait Jws[A] {
trait Jws[+A] {
def header: JwsHeader
def payload: A

Expand Down
2 changes: 1 addition & 1 deletion jose/src/black/door/jose/jwt/Claims.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package black.door.jose.jwt

import java.time.Instant

case class Claims[UnregisteredClaims](
case class Claims[+UnregisteredClaims](
iss: Option[String] = None,
sub: Option[String] = None,
aud: Option[String] = None,
Expand Down
81 changes: 48 additions & 33 deletions jose/src/black/door/jose/jwt/Jwt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,27 @@ import scala.collection.immutable.Seq
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, ExecutionContext, Future}

case class Jwt[UnregisteredClaims](header: JwsHeader, claims: Claims[UnregisteredClaims]) extends Jws[Claims[UnregisteredClaims]] {
case class Jwt[+UnregisteredClaims](header: JwsHeader, claims: Claims[UnregisteredClaims])
extends Jws[Claims[UnregisteredClaims]] {
def payload = claims
}

object Jwt {

@throws[KeyException]
def sign[PC](claims: Claims[PC], key: Jwk, algorithms: Seq[SignatureAlgorithm] = SignatureAlgorithms.all)
(implicit headerSerializer: Mapper[JwsHeader, Array[Byte]], payloadSerializer: Mapper[Claims[PC], Array[Byte]]) = {
val alg = key.alg.getOrElse(throw new KeyException("Jwk must have a defined alg to use Jwt.sign. Alternatively, create a Jwt with an explicit JwsHeader."))
def sign[PC](
claims: Claims[PC],
key: Jwk,
algorithms: Seq[SignatureAlgorithm] = SignatureAlgorithms.all
)(
implicit headerSerializer: Mapper[JwsHeader, Array[Byte]],
payloadSerializer: Mapper[Claims[PC], Array[Byte]]
) = {
val alg = key.alg.getOrElse(
throw new KeyException(
"Jwk must have a defined alg to use Jwt.sign. Alternatively, create a Jwt with an explicit JwsHeader."
)
)
Jwt(JwsHeader(alg, typ = Some("JWT"), kid = key.kid), claims).sign(key, algorithms)
}

Expand All @@ -42,45 +54,48 @@ object Jwt {
* @return
*/
def validate[C](
compact: String,
keyResolver: KeyResolver[Claims[C]],
jwtValidator: JwtValidator[C] = JwtValidator.empty,
fallbackJwtValidator: JwtValidator[C] = JwtValidator.defaultValidator(),
algorithms: Seq[SignatureAlgorithm] = SignatureAlgorithms.all
)
(
implicit payloadDeserializer: Mapper[Array[Byte], Claims[C]],
headerDeserializer: Mapper[Array[Byte], JwsHeader],
ec: ExecutionContext
): Future[Either[String, Jwt[C]]] = {
EitherT(Jws.validate[Claims[C]](compact, keyResolver, algorithms))
.flatMap { jws =>
val jwt = Jwt(jws.header, jws.payload)
OptionT(jwtValidator.orElse(fallbackJwtValidator).apply(jwt)).toLeft(jwt)
}.value
}
compact: String,
keyResolver: KeyResolver[Claims[C]],
jwtValidator: JwtValidator[C] = JwtValidator.empty,
fallbackJwtValidator: JwtValidator[C] = JwtValidator.defaultValidator(),
algorithms: Seq[SignatureAlgorithm] = SignatureAlgorithms.all
)(
implicit payloadDeserializer: Mapper[Array[Byte], Claims[C]],
headerDeserializer: Mapper[Array[Byte], JwsHeader],
ec: ExecutionContext
): Future[Either[String, Jwt[C]]] =
EitherT(Jws.validate[Claims[C]](compact, keyResolver, algorithms)).flatMap { jws =>
val jwt = Jwt(jws.header, jws.payload)
OptionT(jwtValidator.orElse(fallbackJwtValidator).apply(jwt)).toLeft(jwt)
}.value

private val sadSpasticLittleEc = ExecutionContext.fromExecutorService(Executors.newCachedThreadPool)
private val sadSpasticLittleEc =
ExecutionContext.fromExecutorService(Executors.newCachedThreadPool)

sealed trait TypedValidation[C] {
def compact: String

case class using(
keyResolver: KeyResolver[Claims[C]],
jwtValidator: JwtValidator[C] = JwtValidator.empty[C],
fallbackJwtValidator: JwtValidator[C] = JwtValidator.defaultValidator[C](),
algorithms: Seq[SignatureAlgorithm] = SignatureAlgorithms.all
)
(
implicit payloadDeserializer: Mapper[Array[Byte], Claims[C]],
headerDeserializer: Mapper[Array[Byte], JwsHeader]
) {
keyResolver: KeyResolver[Claims[C]],
jwtValidator: JwtValidator[C] = JwtValidator.empty[C],
fallbackJwtValidator: JwtValidator[C] = JwtValidator.defaultValidator(),
algorithms: Seq[SignatureAlgorithm] = SignatureAlgorithms.all
)(
implicit payloadDeserializer: Mapper[Array[Byte], Claims[C]],
headerDeserializer: Mapper[Array[Byte], JwsHeader]
) {

def now = Await.result(
validate(compact, keyResolver, jwtValidator, fallbackJwtValidator, algorithms)(payloadDeserializer, headerDeserializer, sadSpasticLittleEc),
validate(compact, keyResolver, jwtValidator, fallbackJwtValidator, algorithms)(
payloadDeserializer,
headerDeserializer,
sadSpasticLittleEc
),
Duration(1, TimeUnit.SECONDS)
)

def async(implicit ec: ExecutionContext) = validate(compact, keyResolver, jwtValidator, fallbackJwtValidator, algorithms)
def async(implicit ec: ExecutionContext) =
validate(compact, keyResolver, jwtValidator, fallbackJwtValidator, algorithms)
}
}

Expand Down
5 changes: 3 additions & 2 deletions jose/src/black/door/jose/jwt/JwtValidator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package black.door.jose.jwt
import java.time.{Clock, Instant}

import scala.concurrent.{ExecutionContext, Future}
import scala.language.implicitConversions

object JwtValidator {

Expand All @@ -15,9 +16,9 @@ object JwtValidator {
private def iatMessage(maybeIat: Option[Instant]) =
maybeIat.map(iat => s"It was issued at $iat.").getOrElse("")

def defaultValidator[C](clock: Clock = Clock.systemDefaultZone): JwtValidator[C] = {
def defaultValidator(clock: Clock = Clock.systemDefaultZone) = {
val now = Instant.now(clock)
JwtValidator.fromSync {
JwtValidator.fromSync[Any] {
case Jwt(_, claims) if claims.exp.exists(_.isBefore(now)) =>
s"Token expired at ${claims.exp.get}.${iatMessage(claims.iat)} It is now $now."
case Jwt(_, claims) if claims.nbf.exists(_.isAfter(now)) =>
Expand Down
3 changes: 1 addition & 2 deletions jose/src/black/door/jose/jwt/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@ package black.door.jose
import scala.concurrent.Future

package object jwt {

type JwtValidator[UnregisteredClaims] = Jwt[UnregisteredClaims] => Future[Option[String]]
type JwtValidator[-UnregisteredClaims] = Jwt[UnregisteredClaims] => Future[Option[String]]
}
2 changes: 1 addition & 1 deletion jose/src/black/door/jose/package.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package black.door

package object jose {
type Mapper[A, B] = A => Either[String, B]
type Mapper[-A, B] = A => Either[String, B]
}

0 comments on commit aab9f85

Please sign in to comment.