diff --git a/cover.html b/cover.html
deleted file mode 100644
index 316b917..0000000
--- a/cover.html
+++ /dev/null
@@ -1,268 +0,0 @@
-
-
-
-
-
-
-
package jose
-
-import "encoding/base64"
-
-// Encoder is satisfied if the type can marshal itself into a valid
-// structure for a JWS.
-type Encoder interface {
- // Base64 implies T -> JSON -> RawURLEncodingBase64
- Base64() ([]byte, error)
-}
-
-// Base64Decode decodes a base64-encoded byte slice.
-func Base64Decode(b []byte) ([]byte, error) {
- buf := make([]byte, base64.RawURLEncoding.DecodedLen(len(b)))
- n, err := base64.RawURLEncoding.Decode(buf, b)
- return buf[:n], err
-}
-
-// Base64Encode encodes a byte slice.
-func Base64Encode(b []byte) []byte {
- buf := make([]byte, base64.RawURLEncoding.EncodedLen(len(b)))
- base64.RawURLEncoding.Encode(buf, b)
- return buf
-}
-
-// EncodeEscape base64-encodes a byte slice but escapes it for JSON.
-// It'll return the format: `"base64"`
-func EncodeEscape(b []byte) []byte {
- buf := make([]byte, base64.RawURLEncoding.EncodedLen(len(b))+2)
- buf[0] = '"'
- base64.RawURLEncoding.Encode(buf[1:], b)
- buf[len(buf)-1] = '"'
- return buf
-}
-
-// DecodeEscaped decodes a base64-encoded byte slice straight from a JSON
-// structure. It assumes it's in the format: `"base64"`, but can handle
-// cases where it's not.
-func DecodeEscaped(b []byte) ([]byte, error) {
- if len(b) > 1 && b[0] == '"' && b[len(b)-1] == '"' {
- b = b[1 : len(b)-1]
- }
- return Base64Decode(b)
-}
-
-
-
package jose
-
-import "encoding/json"
-
-// Header implements a JOSE Header with the addition of some helper
-// methods, similar to net/url.Values.
-type Header map[string]interface{}
-
-// Get retrieves the value corresponding with key from the Header.
-func (h Header) Get(key string) interface{} {
- if h == nil {
- return nil
- }
- return h[key]
-}
-
-// Set sets Claims[key] = val. It'll overwrite without warning.
-func (h Header) Set(key string, val interface{}) {
- h[key] = val
-}
-
-// Del removes the value that corresponds with key from the Header.
-func (h Header) Del(key string) {
- delete(h, key)
-}
-
-// Has returns true if a value for the given key exists inside the Header.
-func (h Header) Has(key string) bool {
- _, ok := h[key]
- return ok
-}
-
-// MarshalJSON implements json.Marshaler for Header.
-func (h Header) MarshalJSON() ([]byte, error) {
- if h == nil || len(h) == 0 {
- return nil, nil
- }
- b, err := json.Marshal(map[string]interface{}(h))
- if err != nil {
- return nil, err
- }
- return EncodeEscape(b), nil
-}
-
-// Base64 implements the Encoder interface.
-func (h Header) Base64() ([]byte, error) {
- return h.MarshalJSON()
-}
-
-// UnmarshalJSON implements json.Unmarshaler for Header.
-func (h *Header) UnmarshalJSON(b []byte) error {
- if b == nil {
- return nil
- }
-
- b, err := DecodeEscaped(b)
- if err != nil {
- return err
- }
-
- // Since json.Unmarshal calls UnmarshalJSON,
- // calling json.Unmarshal on *p would be infinitely recursive
- // A temp variable is needed because &map[string]interface{}(*p) is
- // invalid Go.
-
- tmp := map[string]interface{}(*h)
- if err = json.Unmarshal(b, &tmp); err != nil {
- return err
- }
- *h = Header(tmp)
- return nil
-}
-
-// Protected Headers are base64-encoded after they're marshaled into
-// JSON.
-type Protected Header
-
-// Get retrieves the value corresponding with key from the Protected Header.
-func (p Protected) Get(key string) interface{} {
- if p == nil {
- return nil
- }
- return p[key]
-}
-
-// Set sets Protected[key] = val. It'll overwrite without warning.
-func (p Protected) Set(key string, val interface{}) {
- p[key] = val
-}
-
-// Del removes the value that corresponds with key from the Protected Header.
-func (p Protected) Del(key string) {
- delete(p, key)
-}
-
-// Has returns true if a value for the given key exists inside the Protected
-// Header.
-func (p Protected) Has(key string) bool {
- _, ok := p[key]
- return ok
-}
-
-// MarshalJSON implements json.Marshaler for Protected.
-func (p Protected) MarshalJSON() ([]byte, error) {
- b, err := json.Marshal(map[string]interface{}(p))
- if err != nil {
- return nil, err
- }
- return EncodeEscape(b), nil
-}
-
-// Base64 implements the Encoder interface.
-func (p Protected) Base64() ([]byte, error) {
- b, err := json.Marshal(map[string]interface{}(p))
- if err != nil {
- return nil, err
- }
- return Base64Encode(b), nil
-}
-
-// UnmarshalJSON implements json.Unmarshaler for Protected.
-func (p *Protected) UnmarshalJSON(b []byte) error {
- var h Header
- h.UnmarshalJSON(b)
- *p = Protected(h)
- return nil
-}
-
-var (
- _ json.Marshaler = (Protected)(nil)
- _ json.Unmarshaler = (*Protected)(nil)
-)
-
-
-
-
-
-
diff --git a/crypto/ecdsa.go b/crypto/ecdsa.go
index c518c9d..3ef12ba 100644
--- a/crypto/ecdsa.go
+++ b/crypto/ecdsa.go
@@ -101,7 +101,9 @@ func (m *SigningMethodECDSA) sum(b []byte) []byte {
}
// Hasher implements the Hasher method from SigningMethod.
-func (m *SigningMethodECDSA) Hasher() crypto.Hash { return m.Hash }
+func (m *SigningMethodECDSA) Hasher() crypto.Hash {
+ return m.Hash
+}
// MarshalJSON is in case somebody decides to place SigningMethodECDSA
// inside the Header, presumably because they (wrongly) decided it was a good
diff --git a/crypto/none.go b/crypto/none.go
index 2b27af4..db3d139 100644
--- a/crypto/none.go
+++ b/crypto/none.go
@@ -7,24 +7,28 @@ import (
"io"
)
-func init() { crypto.RegisterHash(crypto.Hash(0), h) }
+func init() {
+ crypto.RegisterHash(crypto.Hash(0), h)
+}
// h is passed to crypto.RegisterHash.
-func h() hash.Hash { return &f{Writer: nil} }
+func h() hash.Hash {
+ return &f{Writer: nil}
+}
type f struct{ io.Writer }
// Sum helps implement the hash.Hash interface.
-func (f *f) Sum(b []byte) []byte { return nil }
+func (_ *f) Sum(b []byte) []byte { return nil }
// Reset helps implement the hash.Hash interface.
-func (f *f) Reset() {}
+func (_ *f) Reset() {}
// Size helps implement the hash.Hash interface.
-func (f *f) Size() int { return -1 }
+func (_ *f) Size() int { return -1 }
// BlockSize helps implement the hash.Hash interface.
-func (f *f) BlockSize() int { return -1 }
+func (_ *f) BlockSize() int { return -1 }
// Unsecured is the default "none" algorithm.
var Unsecured = &SigningMethodNone{
@@ -40,20 +44,24 @@ type SigningMethodNone struct {
}
// Verify helps implement the SigningMethod interface.
-func (m *SigningMethodNone) Verify(_ []byte, _ Signature, _ interface{}) error {
+func (_ *SigningMethodNone) Verify(_ []byte, _ Signature, _ interface{}) error {
return nil
}
// Sign helps implement the SigningMethod interface.
-func (m *SigningMethodNone) Sign(_ []byte, _ interface{}) (Signature, error) {
+func (_ *SigningMethodNone) Sign(_ []byte, _ interface{}) (Signature, error) {
return nil, nil
}
// Alg helps implement the SigningMethod interface.
-func (m *SigningMethodNone) Alg() string { return m.Name }
+func (m *SigningMethodNone) Alg() string {
+ return m.Name
+}
// Hasher helps implement the SigningMethod interface.
-func (m *SigningMethodNone) Hasher() crypto.Hash { return m.Hash }
+func (m *SigningMethodNone) Hasher() crypto.Hash {
+ return m.Hash
+}
// MarshalJSON implements json.Marshaler.
// See SigningMethodECDSA.MarshalJSON() for information.
diff --git a/crypto/rsa_utils.go b/crypto/rsa_utils.go
index 350c394..43aeff3 100644
--- a/crypto/rsa_utils.go
+++ b/crypto/rsa_utils.go
@@ -9,8 +9,9 @@ import (
// Errors specific to rsa_utils.
var (
- ErrKeyMustBePEMEncoded = errors.New("Invalid Key: Key must be PEM encoded PKCS1 or PKCS8 private key")
- ErrNotRSAPrivateKey = errors.New("Key is not a valid RSA private key")
+ ErrKeyMustBePEMEncoded = errors.New("invalid key: Key must be PEM encoded PKCS1 or PKCS8 private key")
+ ErrNotRSAPrivateKey = errors.New("key is not a valid RSA private key")
+ ErrNotRSAPublicKey = errors.New("key is not a valid RSA public key")
)
// ParseRSAPrivateKeyFromPEM parses a PEM encoded PKCS1 or PKCS8 private key.
@@ -62,7 +63,7 @@ func ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error) {
var pkey *rsa.PublicKey
var ok bool
if pkey, ok = parsedKey.(*rsa.PublicKey); !ok {
- return nil, ErrNotRSAPrivateKey
+ return nil, ErrNotRSAPublicKey
}
return pkey, nil
diff --git a/header.go b/header.go
index 4bf64c5..4499a76 100644
--- a/header.go
+++ b/header.go
@@ -32,7 +32,7 @@ func (h Header) Has(key string) bool {
// MarshalJSON implements json.Marshaler for Header.
func (h Header) MarshalJSON() ([]byte, error) {
- if h == nil || len(h) == 0 {
+ if len(h) == 0 {
return nil, nil
}
b, err := json.Marshal(map[string]interface{}(h))
@@ -52,23 +52,11 @@ func (h *Header) UnmarshalJSON(b []byte) error {
if b == nil {
return nil
}
-
b, err := DecodeEscaped(b)
if err != nil {
return err
}
-
- // Since json.Unmarshal calls UnmarshalJSON,
- // calling json.Unmarshal on *p would be infinitely recursive
- // A temp variable is needed because &map[string]interface{}(*p) is
- // invalid Go.
-
- tmp := map[string]interface{}(*h)
- if err = json.Unmarshal(b, &tmp); err != nil {
- return err
- }
- *h = tmp
- return nil
+ return json.Unmarshal(b, (*map[string]interface{})(h))
}
// Protected Headers are base64-encoded after they're marshaled into
@@ -131,4 +119,6 @@ func (p *Protected) UnmarshalJSON(b []byte) error {
var (
_ json.Marshaler = (Protected)(nil)
_ json.Unmarshaler = (*Protected)(nil)
+ _ json.Marshaler = (Header)(nil)
+ _ json.Unmarshaler = (*Header)(nil)
)
diff --git a/jws/claims.go b/jws/claims.go
index 068caa8..4cc616c 100644
--- a/jws/claims.go
+++ b/jws/claims.go
@@ -2,6 +2,7 @@ package jws
import (
"encoding/json"
+ "time"
"github.com/SermoDigital/jose"
"github.com/SermoDigital/jose/jwt"
@@ -84,19 +85,19 @@ func (c Claims) Audience() ([]string, bool) {
// Expiration retrieves claim "exp" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.4
-func (c Claims) Expiration() (float64, bool) {
+func (c Claims) Expiration() (time.Time, bool) {
return jwt.Claims(c).Expiration()
}
// NotBefore retrieves claim "nbf" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.5
-func (c Claims) NotBefore() (float64, bool) {
+func (c Claims) NotBefore() (time.Time, bool) {
return jwt.Claims(c).NotBefore()
}
// IssuedAt retrieves claim "iat" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.6
-func (c Claims) IssuedAt() (float64, bool) {
+func (c Claims) IssuedAt() (time.Time, bool) {
return jwt.Claims(c).IssuedAt()
}
@@ -161,19 +162,19 @@ func (c Claims) SetAudience(audience ...string) {
// SetExpiration sets claim "exp" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.4
-func (c Claims) SetExpiration(expiration float64) {
+func (c Claims) SetExpiration(expiration time.Time) {
jwt.Claims(c).SetExpiration(expiration)
}
// SetNotBefore sets claim "nbf" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.5
-func (c Claims) SetNotBefore(notBefore float64) {
+func (c Claims) SetNotBefore(notBefore time.Time) {
jwt.Claims(c).SetNotBefore(notBefore)
}
// SetIssuedAt sets claim "iat" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.6
-func (c Claims) SetIssuedAt(issuedAt float64) {
+func (c Claims) SetIssuedAt(issuedAt time.Time) {
jwt.Claims(c).SetIssuedAt(issuedAt)
}
diff --git a/jws/jws.go b/jws/jws.go
index a0b929d..29ae67c 100644
--- a/jws/jws.go
+++ b/jws/jws.go
@@ -16,17 +16,21 @@ type JWS interface {
Payload() interface{}
// SetPayload sets the payload with the given value.
- SetPayload(interface{})
+ SetPayload(p interface{})
// Protected returns the JWS' Protected Header.
+ Protected() jose.Protected
+
+ // ProtectedAt returns the JWS' Protected Header.
// i represents the index of the Protected Header.
- // Left empty, it defaults to 0.
- Protected(...int) jose.Protected
+ ProtectedAt(i int) jose.Protected
// Header returns the JWS' unprotected Header.
- // i represents the index of the Protected Header.
- // Left empty, it defaults to 0.
- Header(...int) jose.Header
+ Header() jose.Header
+
+ // HeaderAt returns the JWS' unprotected Header.
+ // i represents the index of the unprotected Header.
+ HeaderAt(i int) jose.Header
// Verify validates the current JWS' signature as-is. Refer to
// ValidateMulti for more information.
@@ -74,29 +78,36 @@ type jws struct {
}
// Payload returns the jws' payload.
-func (j *jws) Payload() interface{} { return j.payload.v }
+func (j *jws) Payload() interface{} {
+ return j.payload.v
+}
// SetPayload sets the jws' raw, unexported payload.
-func (j *jws) SetPayload(val interface{}) { j.payload.v = val }
+func (j *jws) SetPayload(val interface{}) {
+ j.payload.v = val
+}
+
+// Protected returns the JWS' Protected Header.
+func (j *jws) Protected() jose.Protected {
+ return j.sb[0].protected
+}
// Protected returns the JWS' Protected Header.
// i represents the index of the Protected Header.
// Left empty, it defaults to 0.
-func (j *jws) Protected(i ...int) jose.Protected {
- if len(i) == 0 {
- return j.sb[0].protected
- }
- return j.sb[i[0]].protected
+func (j *jws) ProtectedAt(i int) jose.Protected {
+ return j.sb[i].protected
}
// Header returns the JWS' unprotected Header.
-// i represents the index of the Protected Header.
-// Left empty, it defaults to 0.
-func (j *jws) Header(i ...int) jose.Header {
- if len(i) == 0 {
- return j.sb[0].unprotected
- }
- return j.sb[i[0]].unprotected
+func (j *jws) Header() jose.Header {
+ return j.sb[0].unprotected
+}
+
+// HeaderAt returns the JWS' unprotected Header.
+// |i| is the index of the unprotected Header.
+func (j *jws) HeaderAt(i int) jose.Header {
+ return j.sb[i].unprotected
}
// sigHead represents the 'signatures' member of the jws' "general"
@@ -121,10 +132,7 @@ func (s *sigHead) unmarshal() error {
if err := s.protected.UnmarshalJSON(s.Protected); err != nil {
return err
}
- if err := s.unprotected.UnmarshalJSON(s.Unprotected); err != nil {
- return err
- }
- return nil
+ return s.unprotected.UnmarshalJSON(s.Unprotected)
}
// New creates a JWS with the provided crypto.SigningMethods.
@@ -155,7 +163,6 @@ func (s *sigHead) assignMethod(p jose.Protected) error {
if sm == nil {
return ErrNoAlgorithm
}
-
s.method = sm
return nil
}
@@ -236,10 +243,10 @@ func (g *generic) parseGeneral(u ...json.Unmarshaler) (JWS, error) {
if err := g.Signatures[i].assignMethod(g.Signatures[i].protected); err != nil {
return nil, err
}
-
- g.clean = true
}
+ g.clean = len(g.Signatures) != 0
+
return &jws{
payload: &p,
plcache: g.Payload,
@@ -390,11 +397,12 @@ const (
Compact
)
-var parseJumpTable = [^Format(0)]func([]byte, ...json.Unmarshaler) (JWS, error){
- Unknown: Parse,
- Flat: ParseFlat,
- General: ParseGeneral,
- Compact: ParseCompact,
+var parseJumpTable = [...]func([]byte, ...json.Unmarshaler) (JWS, error){
+ Unknown: Parse,
+ Flat: ParseFlat,
+ General: ParseGeneral,
+ Compact: ParseCompact,
+ 1<<8 - 1: Parse, // Max uint8.
}
func init() {
diff --git a/jws/jws_serialize.go b/jws/jws_serialize.go
index 9cb53af..923fdc2 100644
--- a/jws/jws_serialize.go
+++ b/jws/jws_serialize.go
@@ -100,36 +100,28 @@ func (j *jws) sign(keys ...interface{}) error {
}
// cache marshals the payload, but only if it's changed since the last cache.
-func (j *jws) cache() error {
+func (j *jws) cache() (err error) {
if !j.clean {
- var err error
j.plcache, err = j.payload.Base64()
j.clean = err == nil
- return err
}
- return nil
+ return err
}
// cache marshals the protected and unprotected headers, but only if
// they've changed since their last cache.
-func (s *sigHead) cache() error {
+func (s *sigHead) cache() (err error) {
if !s.clean {
- var err error
-
s.Protected, err = s.protected.Base64()
if err != nil {
- goto err_return
+ return err
}
-
s.Unprotected, err = s.unprotected.Base64()
if err != nil {
- goto err_return
+ return err
}
-
- err_return:
- s.clean = err == nil
- return err
}
+ s.clean = true
return nil
}
diff --git a/jws/jws_validate.go b/jws/jws_validate.go
index 1948880..d064113 100644
--- a/jws/jws_validate.go
+++ b/jws/jws_validate.go
@@ -27,40 +27,34 @@ func (j *jws) VerifyCallback(fn VerifyCallback, methods []crypto.SigningMethod,
return j.VerifyMulti(keys, methods, o)
}
-// IsMultiError returns true if the given error is type MultiError.
+// IsMultiError returns true if the given error is type *MultiError.
func IsMultiError(err error) bool {
- _, ok := err.(MultiError)
+ _, ok := err.(*MultiError)
return ok
}
// MultiError is a slice of errors.
type MultiError []error
-func (m MultiError) sanityCheck() error {
- if m == nil {
- return nil
- }
- return m
-}
-
// Errors implements the error interface.
-func (m MultiError) Error() string {
- s, n := "", 0
- for _, e := range m {
- if e != nil {
+func (m *MultiError) Error() string {
+ var s string
+ var n int
+ for _, err := range *m {
+ if err != nil {
if n == 0 {
- s = e.Error()
+ s = err.Error()
}
n++
}
}
switch n {
case 0:
- return "(0 errors)"
+ return ""
case 1:
return s
case 2:
- return s + " (and 1 other error)"
+ return s + " and 1 other error"
}
return fmt.Sprintf("%s (and %d other errors)", s, n-1)
}
@@ -101,7 +95,7 @@ func (j *jws) VerifyMulti(keys []interface{}, methods []crypto.SigningMethod, o
var o2 SigningOpts
if o == nil {
- o = &SigningOpts{}
+ o = new(SigningOpts)
}
var m MultiError
@@ -112,15 +106,20 @@ func (j *jws) VerifyMulti(keys []interface{}, methods []crypto.SigningMethod, o
} else {
o2.Inc()
if o.Needs(i) {
+ o.ptr++
o2.Append(i)
}
}
}
- if err := o.Validate(&o2); err != nil {
- return err
+ err := o.Validate(&o2)
+ if err != nil {
+ m = append(m, err)
+ }
+ if len(m) == 0 {
+ return nil
}
- return m.sanityCheck()
+ return &m
}
// SigningOpts is a struct which holds options for validating
@@ -148,30 +147,24 @@ type SigningOpts struct {
_ struct{}
}
-// Append appends x to s's Indices member.
+// Append appends x to s' Indices member.
func (s *SigningOpts) Append(x int) {
s.Indices = append(s.Indices, x)
}
-// Needs returns true if x resides inside s's Indices member
-// for the given index. If true, it increments s's internal
-// index. It's used to match two SigningOpts Indices members.
+// Needs returns true if x resides inside s' Indices member
+// for the given index. It's used to match two SigningOpts Indices members.
func (s *SigningOpts) Needs(x int) bool {
- if s.ptr < len(s.Indices) &&
- s.Indices[s.ptr] == x {
- s.ptr++
- return true
- }
- return false
+ return s.ptr < len(s.Indices) && s.Indices[s.ptr] == x
}
-// Inc increments s's Number member by one.
+// Inc increments s' Number member by one.
func (s *SigningOpts) Inc() { s.Number++ }
// Validate returns any errors found while validating the
-// provided SigningOpts. The receiver validates the parameter `have`.
+// provided SigningOpts. The receiver validates |have|.
// It'll return an error if the passed SigningOpts' Number member is less
-// than s's or if the passed SigningOpts' Indices slice isn't equal to s's.
+// than s' or if the passed SigningOpts' Indices slice isn't equal to s'.
func (s *SigningOpts) Validate(have *SigningOpts) error {
if have.Number < s.Number ||
(s.Indices != nil &&
@@ -182,10 +175,7 @@ func (s *SigningOpts) Validate(have *SigningOpts) error {
}
func eq(a, b []int) bool {
- if a == nil && b == nil {
- return true
- }
- if a == nil || b == nil || len(a) != len(b) {
+ if len(a) != len(b) {
return false
}
for i := range a {
diff --git a/jws/jwt.go b/jws/jwt.go
index 67f18f7..29b75b8 100644
--- a/jws/jwt.go
+++ b/jws/jwt.go
@@ -10,7 +10,10 @@ import (
// NewJWT creates a new JWT with the given claims.
func NewJWT(claims Claims, method crypto.SigningMethod) jwt.JWT {
- j := New(claims, method).(*jws)
+ j, ok := New(claims, method).(*jws)
+ if !ok {
+ panic("jws.NewJWT: runtime panic: New(...).(*jws) != true")
+ }
j.sb[0].protected.Set("typ", "JWT")
j.isJWT = true
return j
@@ -64,7 +67,9 @@ func ParseJWT(encoded []byte) (jwt.JWT, error) {
}
// IsJWT returns true if the JWS is a JWT.
-func (j *jws) IsJWT() bool { return j.isJWT }
+func (j *jws) IsJWT() bool {
+ return j.isJWT
+}
func (j *jws) Validate(key interface{}, m crypto.SigningMethod, v ...*jwt.Validator) error {
if j.isJWT {
@@ -80,7 +85,7 @@ func (j *jws) Validate(key interface{}, m crypto.SigningMethod, v ...*jwt.Valida
if err := v1.Validate(j); err != nil {
return err
}
- return jwt.Claims(c).Validate(float64(time.Now().Unix()), v1.EXP, v1.NBF)
+ return jwt.Claims(c).Validate(time.Now(), v1.EXP, v1.NBF)
}
}
return ErrIsNotJWT
@@ -96,9 +101,8 @@ func Conv(fn func(Claims) error) jwt.ValidateFunc {
}
}
-// NewValidator returns a pointer to a jwt.Validator structure containing
-// the info to be used in the validation of a JWT.
-func NewValidator(c Claims, exp, nbf float64, fn func(Claims) error) *jwt.Validator {
+// NewValidator returns a jwt.Validator.
+func NewValidator(c Claims, exp, nbf time.Duration, fn func(Claims) error) *jwt.Validator {
return &jwt.Validator{
Expected: jwt.Claims(c),
EXP: exp,
diff --git a/jws/jwt_test.go b/jws/jwt_test.go
index a8acb10..5ea4a9b 100644
--- a/jws/jwt_test.go
+++ b/jws/jwt_test.go
@@ -61,11 +61,18 @@ func TestJWTValidator(t *testing.T) {
t.Error(err)
}
- d := float64(time.Now().Add(1 * time.Hour).Unix())
+ d := time.Hour
fn := func(c Claims) error {
+
+ scopes, ok := c.Get("scopes").([]interface{})
+
+ if !ok {
+ return errors.New("Unexpected scopes type. Expected string")
+ }
+
if c.Get("name") != "Eric" &&
c.Get("admin") != true &&
- c.Get("scopes").([]string)[0] != "user.account.info" {
+ scopes[0] != "user.account.info" {
return errors.New("invalid")
}
return nil
diff --git a/jws/payload.go b/jws/payload.go
index 34ba6d8..58bfd06 100644
--- a/jws/payload.go
+++ b/jws/payload.go
@@ -37,13 +37,11 @@ func (p *payload) UnmarshalJSON(b []byte) error {
if err != nil {
return err
}
-
if p.u != nil {
err := p.u.UnmarshalJSON(b2)
p.v = p.u
return err
}
-
return json.Unmarshal(b2, &p.v)
}
diff --git a/jws/signing_methods.go b/jws/signing_methods.go
index 1b6665f..525806f 100644
--- a/jws/signing_methods.go
+++ b/jws/signing_methods.go
@@ -7,7 +7,7 @@ import (
)
var (
- mu = &sync.RWMutex{}
+ mu sync.RWMutex
signingMethods = map[string]crypto.SigningMethod{
crypto.SigningMethodES256.Alg(): crypto.SigningMethodES256,
@@ -33,7 +33,8 @@ var (
// RegisterSigningMethod registers the crypto.SigningMethod in the global map.
// This is typically done inside the caller's init function.
func RegisterSigningMethod(sm crypto.SigningMethod) {
- if GetSigningMethod(sm.Alg()) != nil {
+ alg := sm.Alg()
+ if GetSigningMethod(alg) != nil {
panic("jose/jws: cannot duplicate signing methods")
}
@@ -42,7 +43,7 @@ func RegisterSigningMethod(sm crypto.SigningMethod) {
}
mu.Lock()
- signingMethods[sm.Alg()] = sm
+ signingMethods[alg] = sm
mu.Unlock()
}
@@ -54,8 +55,9 @@ func RemoveSigningMethod(sm crypto.SigningMethod) {
}
// GetSigningMethod retrieves a crypto.SigningMethod from the global map.
-func GetSigningMethod(alg string) crypto.SigningMethod {
+func GetSigningMethod(alg string) (method crypto.SigningMethod) {
mu.RLock()
- defer mu.RUnlock()
- return signingMethods[alg]
+ method = signingMethods[alg]
+ mu.RUnlock()
+ return method
}
diff --git a/jwt/claims.go b/jwt/claims.go
index cc135d8..ae41e7d 100644
--- a/jwt/claims.go
+++ b/jwt/claims.go
@@ -2,6 +2,8 @@ package jwt
import (
"encoding/json"
+ "reflect"
+ "time"
"github.com/SermoDigital/jose"
)
@@ -12,25 +14,21 @@ type Claims map[string]interface{}
// Validate validates the Claims per the claims found in
// https://tools.ietf.org/html/rfc7519#section-4.1
-func (c Claims) Validate(now, expLeeway, nbfLeeway float64) error {
+func (c Claims) Validate(now time.Time, expLeeway, nbfLeeway time.Duration) error {
if exp, ok := c.Expiration(); ok {
- if !within(exp, expLeeway, now) {
+ if now.After(exp.Add(expLeeway)) {
return ErrTokenIsExpired
}
}
if nbf, ok := c.NotBefore(); ok {
- if !within(nbf, nbfLeeway, now) {
+ if !now.After(nbf.Add(-nbfLeeway)) {
return ErrTokenNotYetValid
}
}
return nil
}
-func within(val, delta, max float64) bool {
- return val > max+delta || val > max-delta
-}
-
// Get retrieves the value corresponding with key from the Claims.
func (c Claims) Get(key string) interface{} {
if c == nil {
@@ -63,7 +61,7 @@ func (c Claims) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]interface{}(c))
}
-// Base64 implements the Encoder interface.
+// Base64 implements the jose.Encoder interface.
func (c Claims) Base64() ([]byte, error) {
b, err := c.MarshalJSON()
if err != nil {
@@ -86,7 +84,7 @@ func (c *Claims) UnmarshalJSON(b []byte) error {
// Since json.Unmarshal calls UnmarshalJSON,
// calling json.Unmarshal on *p would be infinitely recursive
// A temp variable is needed because &map[string]interface{}(*p) is
- // invalid Go.
+ // invalid Go. (Address of unaddressable object and all that...)
tmp := map[string]interface{}(*c)
if err = json.Unmarshal(b, &tmp); err != nil {
@@ -113,6 +111,8 @@ func (c Claims) Subject() (string, bool) {
// Audience retrieves claim "aud" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.3
func (c Claims) Audience() ([]string, bool) {
+ // Audience claim must be stringy. That is, it may be one string
+ // or multiple strings but it should not be anything else. E.g. an int.
switch t := c.Get("aud").(type) {
case string:
return []string{t}, true
@@ -144,23 +144,20 @@ func stringify(a ...interface{}) ([]string, bool) {
// Expiration retrieves claim "exp" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.4
-func (c Claims) Expiration() (float64, bool) {
- v, ok := c.Get("exp").(float64)
- return v, ok
+func (c Claims) Expiration() (time.Time, bool) {
+ return c.GetTime("exp")
}
// NotBefore retrieves claim "nbf" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.5
-func (c Claims) NotBefore() (float64, bool) {
- v, ok := c.Get("nbf").(float64)
- return v, ok
+func (c Claims) NotBefore() (time.Time, bool) {
+ return c.GetTime("nbf")
}
// IssuedAt retrieves claim "iat" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.6
-func (c Claims) IssuedAt() (float64, bool) {
- v, ok := c.Get("iat").(float64)
- return v, ok
+func (c Claims) IssuedAt() (time.Time, bool) {
+ return c.GetTime("iat")
}
// JWTID retrieves claim "jti" per its type in
@@ -215,20 +212,20 @@ func (c Claims) SetAudience(audience ...string) {
// SetExpiration sets claim "exp" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.4
-func (c Claims) SetExpiration(expiration float64) {
- c.Set("exp", expiration)
+func (c Claims) SetExpiration(expiration time.Time) {
+ c.SetTime("exp", expiration)
}
// SetNotBefore sets claim "nbf" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.5
-func (c Claims) SetNotBefore(notBefore float64) {
- c.Set("nbf", notBefore)
+func (c Claims) SetNotBefore(notBefore time.Time) {
+ c.SetTime("nbf", notBefore)
}
// SetIssuedAt sets claim "iat" per its type in
// https://tools.ietf.org/html/rfc7519#section-4.1.6
-func (c Claims) SetIssuedAt(issuedAt float64) {
- c.Set("iat", issuedAt)
+func (c Claims) SetIssuedAt(issuedAt time.Time) {
+ c.SetTime("iat", issuedAt)
}
// SetJWTID sets claim "jti" per its type in
@@ -237,6 +234,41 @@ func (c Claims) SetJWTID(uniqueID string) {
c.Set("jti", uniqueID)
}
+// zero pre-allocs the zero-time value
+var zero = time.Time{}
+
+// GetTime returns a UNIX time for the given key.
+//
+// It converts an int, int32, int64, uint, uint32, uint64 or float64 value
+// into a UNIX time (epoch seconds). float32 does not have sufficient
+// precision to store a UNIX time.
+//
+// Numeric values parsed from JSON will always be stored as float64 since
+// Claims is a map[string]interface{}. However, internally the values may be
+// stored directly in the claims map as different types.
+func (c Claims) GetTime(key string) (time.Time, bool) {
+ x := c.Get(key)
+ if x == nil {
+ return zero, false
+ }
+ v := reflect.ValueOf(x)
+ switch v.Kind() {
+ case reflect.Int, reflect.Int32, reflect.Int64:
+ return time.Unix(v.Int(), 0), true
+ case reflect.Uint, reflect.Uint32, reflect.Uint64:
+ return time.Unix(int64(v.Uint()), 0), true
+ case reflect.Float64:
+ return time.Unix(int64(v.Float()), 0), true
+ default:
+ return zero, false
+ }
+}
+
+// SetTime stores a UNIX time for the given key.
+func (c Claims) SetTime(key string, t time.Time) {
+ c.Set(key, t.Unix())
+}
+
var (
_ json.Marshaler = (Claims)(nil)
_ json.Unmarshaler = (*Claims)(nil)
diff --git a/jwt/claims_test.go b/jwt/claims_test.go
index c5edd70..b653785 100644
--- a/jwt/claims_test.go
+++ b/jwt/claims_test.go
@@ -2,9 +2,11 @@ package jwt_test
import (
"testing"
+ "time"
"github.com/SermoDigital/jose/crypto"
"github.com/SermoDigital/jose/jws"
+ "github.com/SermoDigital/jose/jwt"
)
func TestMultipleAudienceBug_AfterMarshal(t *testing.T) {
@@ -83,3 +85,115 @@ func TestSingleAudienceFix_AfterMarshal(t *testing.T) {
t.Logf("aud Value: %s", aud)
t.Logf("aud Type : %T", aud)
}
+
+func TestValidate(t *testing.T) {
+ now := time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)
+ before, after := now.Add(-time.Minute), now.Add(time.Minute)
+ leeway := 10 * time.Second
+
+ exp := func(t time.Time) jwt.Claims {
+ return jwt.Claims{"exp": t.Unix()}
+ }
+ nbf := func(t time.Time) jwt.Claims {
+ return jwt.Claims{"nbf": t.Unix()}
+ }
+
+ var tests = []struct {
+ desc string
+ c jwt.Claims
+ now time.Time
+ expLeeway time.Duration
+ nbfLeeway time.Duration
+ err error
+ }{
+ // test for nbf < now <= exp
+ {desc: "exp == nil && nbf == nil", c: jwt.Claims{}, now: now, err: nil},
+
+ {desc: "now > exp", now: now, c: exp(before), err: jwt.ErrTokenIsExpired},
+ {desc: "now = exp", now: now, c: exp(now), err: nil},
+ {desc: "now < exp", now: now, c: exp(after), err: nil},
+
+ {desc: "nbf < now", c: nbf(before), now: now, err: nil},
+ {desc: "nbf = now", c: nbf(now), now: now, err: jwt.ErrTokenNotYetValid},
+ {desc: "nbf > now", c: nbf(after), now: now, err: jwt.ErrTokenNotYetValid},
+
+ // test for nbf-x < now <= exp+y
+ {desc: "now < exp+x", now: now.Add(leeway - time.Second), expLeeway: leeway, c: exp(now), err: nil},
+ {desc: "now = exp+x", now: now.Add(leeway), expLeeway: leeway, c: exp(now), err: nil},
+ {desc: "now > exp+x", now: now.Add(leeway + time.Second), expLeeway: leeway, c: exp(now), err: jwt.ErrTokenIsExpired},
+
+ {desc: "nbf-x > now", c: nbf(now), nbfLeeway: leeway, now: now.Add(-leeway + time.Second), err: nil},
+ {desc: "nbf-x = now", c: nbf(now), nbfLeeway: leeway, now: now.Add(-leeway), err: jwt.ErrTokenNotYetValid},
+ {desc: "nbf-x < now", c: nbf(now), nbfLeeway: leeway, now: now.Add(-leeway - time.Second), err: jwt.ErrTokenNotYetValid},
+ }
+
+ for i, tt := range tests {
+ if got, want := tt.c.Validate(tt.now, tt.expLeeway, tt.nbfLeeway), tt.err; got != want {
+ t.Errorf("%d - %q: got %v want %v", i, tt.desc, got, want)
+ }
+ }
+}
+
+func TestGetAndSetTime(t *testing.T) {
+ now := time.Now()
+ nowUnix := now.Unix()
+ c := jwt.Claims{
+ "int": int(nowUnix),
+ "int32": int32(nowUnix),
+ "int64": int64(nowUnix),
+ "uint": uint(nowUnix),
+ "uint32": uint32(nowUnix),
+ "uint64": uint64(nowUnix),
+ "float64": float64(nowUnix),
+ }
+ c.SetTime("setTime", now)
+ for k := range c {
+ v, ok := c.GetTime(k)
+ if got, want := v, time.Unix(nowUnix, 0); !ok || !got.Equal(want) {
+ t.Errorf("%s: got %v want %v", k, got, want)
+ }
+ }
+}
+
+// TestTimeValuesThroughJSON verifies that the time values
+// that are set via the Set{IssuedAt,NotBefore,Expiration}()
+// methods can actually be parsed back
+func TestTimeValuesThroughJSON(t *testing.T) {
+ now := time.Unix(time.Now().Unix(), 0)
+
+ c := jws.Claims{}
+ c.SetIssuedAt(now)
+ c.SetNotBefore(now)
+ c.SetExpiration(now)
+
+ // serialize to JWT
+ tok := jws.NewJWT(c, crypto.SigningMethodHS256)
+ b, err := tok.Serialize([]byte("key"))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // parse the JWT again
+ tok2, err := jws.ParseJWT(b)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c2 := tok2.Claims()
+
+ iat, ok1 := c2.IssuedAt()
+ nbf, ok2 := c2.NotBefore()
+ exp, ok3 := c2.Expiration()
+ if !ok1 || !ok2 || !ok3 {
+ t.Fatal("got false want true")
+ }
+
+ if got, want := iat, now; !got.Equal(want) {
+ t.Errorf("%s: got %v want %v", "iat", got, want)
+ }
+ if got, want := nbf, now; !got.Equal(want) {
+ t.Errorf("%s: got %v want %v", "nbf", got, want)
+ }
+ if got, want := exp, now; !got.Equal(want) {
+ t.Errorf("%s: got %v want %v", "exp", got, want)
+ }
+}
diff --git a/jwt/eq.go b/jwt/eq.go
index a7a37a9..3113269 100644
--- a/jwt/eq.go
+++ b/jwt/eq.go
@@ -1,52 +1,47 @@
package jwt
-import "reflect"
+func verifyPrincipals(pcpls, auds []string) bool {
+ // "Each principal intended to process the JWT MUST
+ // identify itself with a value in the audience claim."
+ // - https://tools.ietf.org/html/rfc7519#section-4.1.3
-// eq returns true if the two types are either strings
-// or comparable slices.
-func eq(a, b interface{}) bool {
- t1 := reflect.TypeOf(a)
- t2 := reflect.TypeOf(b)
-
- if t1.Kind() == t2.Kind() {
- switch t1.Kind() {
- case reflect.Slice:
- return eqSlice(a, b)
- case reflect.String:
- return reflect.ValueOf(a).String() ==
- reflect.ValueOf(b).String()
+ found := -1
+ for i, p := range pcpls {
+ for _, v := range auds {
+ if p == v {
+ found++
+ break
+ }
+ }
+ if found != i {
+ return false
}
}
- return false
+ return true
}
-// eqSlice returns true if the two interfaces are both slices
-// and are equal. For example: https://play.golang.org/p/5VLMwNE3i-
-func eqSlice(a, b interface{}) bool {
- if a == nil || b == nil {
- return false
+// ValidAudience returns true iff:
+// - a and b are strings and a == b
+// - a is string, b is []string and a is in b
+// - a is []string, b is []string and all of a is in b
+// - a is []string, b is string and len(a) == 1 and a[0] == b
+func ValidAudience(a, b interface{}) bool {
+ s1, ok := a.(string)
+ if ok {
+ if s2, ok := b.(string); ok {
+ return s1 == s2
+ }
+ a2, ok := b.([]string)
+ return ok && verifyPrincipals([]string{s1}, a2)
}
- v1 := reflect.ValueOf(a)
- v2 := reflect.ValueOf(b)
-
- if v1.Kind() != reflect.Slice ||
- v2.Kind() != reflect.Slice {
+ a1, ok := a.([]string)
+ if !ok {
return false
}
-
- if v1.Len() == v2.Len() && v1.Len() > 0 {
- for i := 0; i < v1.Len() && i < v2.Len(); i++ {
- k1 := v1.Index(i)
- k2 := v2.Index(i)
- if k1.Type().Comparable() &&
- k2.Type().Comparable() &&
- k1.CanInterface() && k2.CanInterface() &&
- k1.Interface() != k2.Interface() {
- return false
- }
- }
- return true
+ if a2, ok := b.([]string); ok {
+ return verifyPrincipals(a1, a2)
}
- return false
+ s2, ok := b.(string)
+ return ok && len(a1) == 1 && a1[0] == s2
}
diff --git a/jwt/eq_test.go b/jwt/eq_test.go
new file mode 100644
index 0000000..5f9d4fd
--- /dev/null
+++ b/jwt/eq_test.go
@@ -0,0 +1,26 @@
+package jwt_test
+
+import (
+ "testing"
+
+ "github.com/SermoDigital/jose/jwt"
+)
+
+func TestValidAudience(t *testing.T) {
+ tests := [...]struct {
+ a interface{}
+ b interface{}
+ v bool
+ }{
+ 0: {"https://www.google.com", "https://www.google.com", true},
+ 1: {[]string{"example.com", "google.com"}, []string{"example.com"}, false},
+ 2: {500, 43, false},
+ 3: {"google.com", "facebook.com", false},
+ 4: {[]string{"example.com"}, []string{"example.com", "foo.com"}, true},
+ }
+ for i, v := range tests {
+ if x := jwt.ValidAudience(v.a, v.b); x != v.v {
+ t.Fatalf("#%d: wanted %t, got %t", i, v.v, x)
+ }
+ }
+}
diff --git a/jwt/jwt.go b/jwt/jwt.go
index bd84259..d29c43a 100644
--- a/jwt/jwt.go
+++ b/jwt/jwt.go
@@ -1,6 +1,10 @@
package jwt
-import "github.com/SermoDigital/jose/crypto"
+import (
+ "time"
+
+ "github.com/SermoDigital/jose/crypto"
+)
// JWT represents a JWT per RFC 7519.
// It's described as an interface instead of a physical structure
@@ -33,18 +37,12 @@ type ValidateFunc func(Claims) error
// Validator represents some of the validation options.
type Validator struct {
- Expected Claims // If non-nil, these are required to match.
- EXP float64 // EXPLeeway
- NBF float64 // NBFLeeway
- Fn ValidateFunc // See ValidateFunc for more information.
-
- _ struct{}
-}
+ Expected Claims // If non-nil, these are required to match.
+ EXP time.Duration // EXPLeeway
+ NBF time.Duration // NBFLeeway
+ Fn ValidateFunc // See ValidateFunc for more information.
-var defaultClaims = []string{
- "iss", "sub", "aud",
- "exp", "nbf", "iat",
- "jti",
+ _ struct{} // Require explicitly-named struct fields.
}
// Validate validates the JWT based on the expected claims in v.
@@ -71,7 +69,8 @@ func (v *Validator) Validate(j JWT) error {
}
if aud, ok := v.Expected.Audience(); ok {
- if aud2, _ := j.Claims().Audience(); !eq(aud, aud2){
+ aud2, ok := j.Claims().Audience()
+ if !ok || !ValidAudience(aud, aud2) {
return ErrInvalidAUDClaim
}
}
@@ -111,21 +110,21 @@ func (v *Validator) SetAudience(aud string) {
// SetExpiration sets the "exp" claim per
// https://tools.ietf.org/html/rfc7519#section-4.1.4
-func (v *Validator) SetExpiration(exp float64) {
+func (v *Validator) SetExpiration(exp time.Time) {
v.expect()
v.Expected.Set("exp", exp)
}
// SetNotBefore sets the "nbf" claim per
// https://tools.ietf.org/html/rfc7519#section-4.1.5
-func (v *Validator) SetNotBefore(nbf float64) {
+func (v *Validator) SetNotBefore(nbf time.Time) {
v.expect()
v.Expected.Set("nbf", nbf)
}
// SetIssuedAt sets the "iat" claim per
// https://tools.ietf.org/html/rfc7519#section-4.1.6
-func (v *Validator) SetIssuedAt(iat float64) {
+func (v *Validator) SetIssuedAt(iat time.Time) {
v.expect()
v.Expected.Set("iat", iat)
}