diff --git a/cover.html b/cover.html deleted file mode 100644 index 316b917..0000000 --- a/cover.html +++ /dev/null @@ -1,268 +0,0 @@ - - - - - - - - -
- -
- not tracked - - not covered - covered - -
-
-
- -
package jose
-
-import "encoding/base64"
-
-// Encoder is satisfied if the type can marshal itself into a valid
-// structure for a JWS.
-type Encoder interface {
-        // Base64 implies T -> JSON -> RawURLEncodingBase64
-        Base64() ([]byte, error)
-}
-
-// Base64Decode decodes a base64-encoded byte slice.
-func Base64Decode(b []byte) ([]byte, error) {
-        buf := make([]byte, base64.RawURLEncoding.DecodedLen(len(b)))
-        n, err := base64.RawURLEncoding.Decode(buf, b)
-        return buf[:n], err
-}
-
-// Base64Encode encodes a byte slice.
-func Base64Encode(b []byte) []byte {
-        buf := make([]byte, base64.RawURLEncoding.EncodedLen(len(b)))
-        base64.RawURLEncoding.Encode(buf, b)
-        return buf
-}
-
-// EncodeEscape base64-encodes a byte slice but escapes it for JSON.
-// It'll return the format: `"base64"`
-func EncodeEscape(b []byte) []byte {
-        buf := make([]byte, base64.RawURLEncoding.EncodedLen(len(b))+2)
-        buf[0] = '"'
-        base64.RawURLEncoding.Encode(buf[1:], b)
-        buf[len(buf)-1] = '"'
-        return buf
-}
-
-// DecodeEscaped decodes a base64-encoded byte slice straight from a JSON
-// structure. It assumes it's in the format: `"base64"`, but can handle
-// cases where it's not.
-func DecodeEscaped(b []byte) ([]byte, error) {
-        if len(b) > 1 && b[0] == '"' && b[len(b)-1] == '"' {
-                b = b[1 : len(b)-1]
-        }
-        return Base64Decode(b)
-}
-
- - - -
- - - diff --git a/crypto/ecdsa.go b/crypto/ecdsa.go index c518c9d..3ef12ba 100644 --- a/crypto/ecdsa.go +++ b/crypto/ecdsa.go @@ -101,7 +101,9 @@ func (m *SigningMethodECDSA) sum(b []byte) []byte { } // Hasher implements the Hasher method from SigningMethod. -func (m *SigningMethodECDSA) Hasher() crypto.Hash { return m.Hash } +func (m *SigningMethodECDSA) Hasher() crypto.Hash { + return m.Hash +} // MarshalJSON is in case somebody decides to place SigningMethodECDSA // inside the Header, presumably because they (wrongly) decided it was a good diff --git a/crypto/none.go b/crypto/none.go index 2b27af4..db3d139 100644 --- a/crypto/none.go +++ b/crypto/none.go @@ -7,24 +7,28 @@ import ( "io" ) -func init() { crypto.RegisterHash(crypto.Hash(0), h) } +func init() { + crypto.RegisterHash(crypto.Hash(0), h) +} // h is passed to crypto.RegisterHash. -func h() hash.Hash { return &f{Writer: nil} } +func h() hash.Hash { + return &f{Writer: nil} +} type f struct{ io.Writer } // Sum helps implement the hash.Hash interface. -func (f *f) Sum(b []byte) []byte { return nil } +func (_ *f) Sum(b []byte) []byte { return nil } // Reset helps implement the hash.Hash interface. -func (f *f) Reset() {} +func (_ *f) Reset() {} // Size helps implement the hash.Hash interface. -func (f *f) Size() int { return -1 } +func (_ *f) Size() int { return -1 } // BlockSize helps implement the hash.Hash interface. -func (f *f) BlockSize() int { return -1 } +func (_ *f) BlockSize() int { return -1 } // Unsecured is the default "none" algorithm. var Unsecured = &SigningMethodNone{ @@ -40,20 +44,24 @@ type SigningMethodNone struct { } // Verify helps implement the SigningMethod interface. -func (m *SigningMethodNone) Verify(_ []byte, _ Signature, _ interface{}) error { +func (_ *SigningMethodNone) Verify(_ []byte, _ Signature, _ interface{}) error { return nil } // Sign helps implement the SigningMethod interface. -func (m *SigningMethodNone) Sign(_ []byte, _ interface{}) (Signature, error) { +func (_ *SigningMethodNone) Sign(_ []byte, _ interface{}) (Signature, error) { return nil, nil } // Alg helps implement the SigningMethod interface. -func (m *SigningMethodNone) Alg() string { return m.Name } +func (m *SigningMethodNone) Alg() string { + return m.Name +} // Hasher helps implement the SigningMethod interface. -func (m *SigningMethodNone) Hasher() crypto.Hash { return m.Hash } +func (m *SigningMethodNone) Hasher() crypto.Hash { + return m.Hash +} // MarshalJSON implements json.Marshaler. // See SigningMethodECDSA.MarshalJSON() for information. diff --git a/crypto/rsa_utils.go b/crypto/rsa_utils.go index 350c394..43aeff3 100644 --- a/crypto/rsa_utils.go +++ b/crypto/rsa_utils.go @@ -9,8 +9,9 @@ import ( // Errors specific to rsa_utils. var ( - ErrKeyMustBePEMEncoded = errors.New("Invalid Key: Key must be PEM encoded PKCS1 or PKCS8 private key") - ErrNotRSAPrivateKey = errors.New("Key is not a valid RSA private key") + ErrKeyMustBePEMEncoded = errors.New("invalid key: Key must be PEM encoded PKCS1 or PKCS8 private key") + ErrNotRSAPrivateKey = errors.New("key is not a valid RSA private key") + ErrNotRSAPublicKey = errors.New("key is not a valid RSA public key") ) // ParseRSAPrivateKeyFromPEM parses a PEM encoded PKCS1 or PKCS8 private key. @@ -62,7 +63,7 @@ func ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error) { var pkey *rsa.PublicKey var ok bool if pkey, ok = parsedKey.(*rsa.PublicKey); !ok { - return nil, ErrNotRSAPrivateKey + return nil, ErrNotRSAPublicKey } return pkey, nil diff --git a/header.go b/header.go index 4bf64c5..4499a76 100644 --- a/header.go +++ b/header.go @@ -32,7 +32,7 @@ func (h Header) Has(key string) bool { // MarshalJSON implements json.Marshaler for Header. func (h Header) MarshalJSON() ([]byte, error) { - if h == nil || len(h) == 0 { + if len(h) == 0 { return nil, nil } b, err := json.Marshal(map[string]interface{}(h)) @@ -52,23 +52,11 @@ func (h *Header) UnmarshalJSON(b []byte) error { if b == nil { return nil } - b, err := DecodeEscaped(b) if err != nil { return err } - - // Since json.Unmarshal calls UnmarshalJSON, - // calling json.Unmarshal on *p would be infinitely recursive - // A temp variable is needed because &map[string]interface{}(*p) is - // invalid Go. - - tmp := map[string]interface{}(*h) - if err = json.Unmarshal(b, &tmp); err != nil { - return err - } - *h = tmp - return nil + return json.Unmarshal(b, (*map[string]interface{})(h)) } // Protected Headers are base64-encoded after they're marshaled into @@ -131,4 +119,6 @@ func (p *Protected) UnmarshalJSON(b []byte) error { var ( _ json.Marshaler = (Protected)(nil) _ json.Unmarshaler = (*Protected)(nil) + _ json.Marshaler = (Header)(nil) + _ json.Unmarshaler = (*Header)(nil) ) diff --git a/jws/claims.go b/jws/claims.go index 068caa8..4cc616c 100644 --- a/jws/claims.go +++ b/jws/claims.go @@ -2,6 +2,7 @@ package jws import ( "encoding/json" + "time" "github.com/SermoDigital/jose" "github.com/SermoDigital/jose/jwt" @@ -84,19 +85,19 @@ func (c Claims) Audience() ([]string, bool) { // Expiration retrieves claim "exp" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.4 -func (c Claims) Expiration() (float64, bool) { +func (c Claims) Expiration() (time.Time, bool) { return jwt.Claims(c).Expiration() } // NotBefore retrieves claim "nbf" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.5 -func (c Claims) NotBefore() (float64, bool) { +func (c Claims) NotBefore() (time.Time, bool) { return jwt.Claims(c).NotBefore() } // IssuedAt retrieves claim "iat" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.6 -func (c Claims) IssuedAt() (float64, bool) { +func (c Claims) IssuedAt() (time.Time, bool) { return jwt.Claims(c).IssuedAt() } @@ -161,19 +162,19 @@ func (c Claims) SetAudience(audience ...string) { // SetExpiration sets claim "exp" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.4 -func (c Claims) SetExpiration(expiration float64) { +func (c Claims) SetExpiration(expiration time.Time) { jwt.Claims(c).SetExpiration(expiration) } // SetNotBefore sets claim "nbf" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.5 -func (c Claims) SetNotBefore(notBefore float64) { +func (c Claims) SetNotBefore(notBefore time.Time) { jwt.Claims(c).SetNotBefore(notBefore) } // SetIssuedAt sets claim "iat" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.6 -func (c Claims) SetIssuedAt(issuedAt float64) { +func (c Claims) SetIssuedAt(issuedAt time.Time) { jwt.Claims(c).SetIssuedAt(issuedAt) } diff --git a/jws/jws.go b/jws/jws.go index a0b929d..29ae67c 100644 --- a/jws/jws.go +++ b/jws/jws.go @@ -16,17 +16,21 @@ type JWS interface { Payload() interface{} // SetPayload sets the payload with the given value. - SetPayload(interface{}) + SetPayload(p interface{}) // Protected returns the JWS' Protected Header. + Protected() jose.Protected + + // ProtectedAt returns the JWS' Protected Header. // i represents the index of the Protected Header. - // Left empty, it defaults to 0. - Protected(...int) jose.Protected + ProtectedAt(i int) jose.Protected // Header returns the JWS' unprotected Header. - // i represents the index of the Protected Header. - // Left empty, it defaults to 0. - Header(...int) jose.Header + Header() jose.Header + + // HeaderAt returns the JWS' unprotected Header. + // i represents the index of the unprotected Header. + HeaderAt(i int) jose.Header // Verify validates the current JWS' signature as-is. Refer to // ValidateMulti for more information. @@ -74,29 +78,36 @@ type jws struct { } // Payload returns the jws' payload. -func (j *jws) Payload() interface{} { return j.payload.v } +func (j *jws) Payload() interface{} { + return j.payload.v +} // SetPayload sets the jws' raw, unexported payload. -func (j *jws) SetPayload(val interface{}) { j.payload.v = val } +func (j *jws) SetPayload(val interface{}) { + j.payload.v = val +} + +// Protected returns the JWS' Protected Header. +func (j *jws) Protected() jose.Protected { + return j.sb[0].protected +} // Protected returns the JWS' Protected Header. // i represents the index of the Protected Header. // Left empty, it defaults to 0. -func (j *jws) Protected(i ...int) jose.Protected { - if len(i) == 0 { - return j.sb[0].protected - } - return j.sb[i[0]].protected +func (j *jws) ProtectedAt(i int) jose.Protected { + return j.sb[i].protected } // Header returns the JWS' unprotected Header. -// i represents the index of the Protected Header. -// Left empty, it defaults to 0. -func (j *jws) Header(i ...int) jose.Header { - if len(i) == 0 { - return j.sb[0].unprotected - } - return j.sb[i[0]].unprotected +func (j *jws) Header() jose.Header { + return j.sb[0].unprotected +} + +// HeaderAt returns the JWS' unprotected Header. +// |i| is the index of the unprotected Header. +func (j *jws) HeaderAt(i int) jose.Header { + return j.sb[i].unprotected } // sigHead represents the 'signatures' member of the jws' "general" @@ -121,10 +132,7 @@ func (s *sigHead) unmarshal() error { if err := s.protected.UnmarshalJSON(s.Protected); err != nil { return err } - if err := s.unprotected.UnmarshalJSON(s.Unprotected); err != nil { - return err - } - return nil + return s.unprotected.UnmarshalJSON(s.Unprotected) } // New creates a JWS with the provided crypto.SigningMethods. @@ -155,7 +163,6 @@ func (s *sigHead) assignMethod(p jose.Protected) error { if sm == nil { return ErrNoAlgorithm } - s.method = sm return nil } @@ -236,10 +243,10 @@ func (g *generic) parseGeneral(u ...json.Unmarshaler) (JWS, error) { if err := g.Signatures[i].assignMethod(g.Signatures[i].protected); err != nil { return nil, err } - - g.clean = true } + g.clean = len(g.Signatures) != 0 + return &jws{ payload: &p, plcache: g.Payload, @@ -390,11 +397,12 @@ const ( Compact ) -var parseJumpTable = [^Format(0)]func([]byte, ...json.Unmarshaler) (JWS, error){ - Unknown: Parse, - Flat: ParseFlat, - General: ParseGeneral, - Compact: ParseCompact, +var parseJumpTable = [...]func([]byte, ...json.Unmarshaler) (JWS, error){ + Unknown: Parse, + Flat: ParseFlat, + General: ParseGeneral, + Compact: ParseCompact, + 1<<8 - 1: Parse, // Max uint8. } func init() { diff --git a/jws/jws_serialize.go b/jws/jws_serialize.go index 9cb53af..923fdc2 100644 --- a/jws/jws_serialize.go +++ b/jws/jws_serialize.go @@ -100,36 +100,28 @@ func (j *jws) sign(keys ...interface{}) error { } // cache marshals the payload, but only if it's changed since the last cache. -func (j *jws) cache() error { +func (j *jws) cache() (err error) { if !j.clean { - var err error j.plcache, err = j.payload.Base64() j.clean = err == nil - return err } - return nil + return err } // cache marshals the protected and unprotected headers, but only if // they've changed since their last cache. -func (s *sigHead) cache() error { +func (s *sigHead) cache() (err error) { if !s.clean { - var err error - s.Protected, err = s.protected.Base64() if err != nil { - goto err_return + return err } - s.Unprotected, err = s.unprotected.Base64() if err != nil { - goto err_return + return err } - - err_return: - s.clean = err == nil - return err } + s.clean = true return nil } diff --git a/jws/jws_validate.go b/jws/jws_validate.go index 1948880..d064113 100644 --- a/jws/jws_validate.go +++ b/jws/jws_validate.go @@ -27,40 +27,34 @@ func (j *jws) VerifyCallback(fn VerifyCallback, methods []crypto.SigningMethod, return j.VerifyMulti(keys, methods, o) } -// IsMultiError returns true if the given error is type MultiError. +// IsMultiError returns true if the given error is type *MultiError. func IsMultiError(err error) bool { - _, ok := err.(MultiError) + _, ok := err.(*MultiError) return ok } // MultiError is a slice of errors. type MultiError []error -func (m MultiError) sanityCheck() error { - if m == nil { - return nil - } - return m -} - // Errors implements the error interface. -func (m MultiError) Error() string { - s, n := "", 0 - for _, e := range m { - if e != nil { +func (m *MultiError) Error() string { + var s string + var n int + for _, err := range *m { + if err != nil { if n == 0 { - s = e.Error() + s = err.Error() } n++ } } switch n { case 0: - return "(0 errors)" + return "" case 1: return s case 2: - return s + " (and 1 other error)" + return s + " and 1 other error" } return fmt.Sprintf("%s (and %d other errors)", s, n-1) } @@ -101,7 +95,7 @@ func (j *jws) VerifyMulti(keys []interface{}, methods []crypto.SigningMethod, o var o2 SigningOpts if o == nil { - o = &SigningOpts{} + o = new(SigningOpts) } var m MultiError @@ -112,15 +106,20 @@ func (j *jws) VerifyMulti(keys []interface{}, methods []crypto.SigningMethod, o } else { o2.Inc() if o.Needs(i) { + o.ptr++ o2.Append(i) } } } - if err := o.Validate(&o2); err != nil { - return err + err := o.Validate(&o2) + if err != nil { + m = append(m, err) + } + if len(m) == 0 { + return nil } - return m.sanityCheck() + return &m } // SigningOpts is a struct which holds options for validating @@ -148,30 +147,24 @@ type SigningOpts struct { _ struct{} } -// Append appends x to s's Indices member. +// Append appends x to s' Indices member. func (s *SigningOpts) Append(x int) { s.Indices = append(s.Indices, x) } -// Needs returns true if x resides inside s's Indices member -// for the given index. If true, it increments s's internal -// index. It's used to match two SigningOpts Indices members. +// Needs returns true if x resides inside s' Indices member +// for the given index. It's used to match two SigningOpts Indices members. func (s *SigningOpts) Needs(x int) bool { - if s.ptr < len(s.Indices) && - s.Indices[s.ptr] == x { - s.ptr++ - return true - } - return false + return s.ptr < len(s.Indices) && s.Indices[s.ptr] == x } -// Inc increments s's Number member by one. +// Inc increments s' Number member by one. func (s *SigningOpts) Inc() { s.Number++ } // Validate returns any errors found while validating the -// provided SigningOpts. The receiver validates the parameter `have`. +// provided SigningOpts. The receiver validates |have|. // It'll return an error if the passed SigningOpts' Number member is less -// than s's or if the passed SigningOpts' Indices slice isn't equal to s's. +// than s' or if the passed SigningOpts' Indices slice isn't equal to s'. func (s *SigningOpts) Validate(have *SigningOpts) error { if have.Number < s.Number || (s.Indices != nil && @@ -182,10 +175,7 @@ func (s *SigningOpts) Validate(have *SigningOpts) error { } func eq(a, b []int) bool { - if a == nil && b == nil { - return true - } - if a == nil || b == nil || len(a) != len(b) { + if len(a) != len(b) { return false } for i := range a { diff --git a/jws/jwt.go b/jws/jwt.go index 67f18f7..29b75b8 100644 --- a/jws/jwt.go +++ b/jws/jwt.go @@ -10,7 +10,10 @@ import ( // NewJWT creates a new JWT with the given claims. func NewJWT(claims Claims, method crypto.SigningMethod) jwt.JWT { - j := New(claims, method).(*jws) + j, ok := New(claims, method).(*jws) + if !ok { + panic("jws.NewJWT: runtime panic: New(...).(*jws) != true") + } j.sb[0].protected.Set("typ", "JWT") j.isJWT = true return j @@ -64,7 +67,9 @@ func ParseJWT(encoded []byte) (jwt.JWT, error) { } // IsJWT returns true if the JWS is a JWT. -func (j *jws) IsJWT() bool { return j.isJWT } +func (j *jws) IsJWT() bool { + return j.isJWT +} func (j *jws) Validate(key interface{}, m crypto.SigningMethod, v ...*jwt.Validator) error { if j.isJWT { @@ -80,7 +85,7 @@ func (j *jws) Validate(key interface{}, m crypto.SigningMethod, v ...*jwt.Valida if err := v1.Validate(j); err != nil { return err } - return jwt.Claims(c).Validate(float64(time.Now().Unix()), v1.EXP, v1.NBF) + return jwt.Claims(c).Validate(time.Now(), v1.EXP, v1.NBF) } } return ErrIsNotJWT @@ -96,9 +101,8 @@ func Conv(fn func(Claims) error) jwt.ValidateFunc { } } -// NewValidator returns a pointer to a jwt.Validator structure containing -// the info to be used in the validation of a JWT. -func NewValidator(c Claims, exp, nbf float64, fn func(Claims) error) *jwt.Validator { +// NewValidator returns a jwt.Validator. +func NewValidator(c Claims, exp, nbf time.Duration, fn func(Claims) error) *jwt.Validator { return &jwt.Validator{ Expected: jwt.Claims(c), EXP: exp, diff --git a/jws/jwt_test.go b/jws/jwt_test.go index a8acb10..5ea4a9b 100644 --- a/jws/jwt_test.go +++ b/jws/jwt_test.go @@ -61,11 +61,18 @@ func TestJWTValidator(t *testing.T) { t.Error(err) } - d := float64(time.Now().Add(1 * time.Hour).Unix()) + d := time.Hour fn := func(c Claims) error { + + scopes, ok := c.Get("scopes").([]interface{}) + + if !ok { + return errors.New("Unexpected scopes type. Expected string") + } + if c.Get("name") != "Eric" && c.Get("admin") != true && - c.Get("scopes").([]string)[0] != "user.account.info" { + scopes[0] != "user.account.info" { return errors.New("invalid") } return nil diff --git a/jws/payload.go b/jws/payload.go index 34ba6d8..58bfd06 100644 --- a/jws/payload.go +++ b/jws/payload.go @@ -37,13 +37,11 @@ func (p *payload) UnmarshalJSON(b []byte) error { if err != nil { return err } - if p.u != nil { err := p.u.UnmarshalJSON(b2) p.v = p.u return err } - return json.Unmarshal(b2, &p.v) } diff --git a/jws/signing_methods.go b/jws/signing_methods.go index 1b6665f..525806f 100644 --- a/jws/signing_methods.go +++ b/jws/signing_methods.go @@ -7,7 +7,7 @@ import ( ) var ( - mu = &sync.RWMutex{} + mu sync.RWMutex signingMethods = map[string]crypto.SigningMethod{ crypto.SigningMethodES256.Alg(): crypto.SigningMethodES256, @@ -33,7 +33,8 @@ var ( // RegisterSigningMethod registers the crypto.SigningMethod in the global map. // This is typically done inside the caller's init function. func RegisterSigningMethod(sm crypto.SigningMethod) { - if GetSigningMethod(sm.Alg()) != nil { + alg := sm.Alg() + if GetSigningMethod(alg) != nil { panic("jose/jws: cannot duplicate signing methods") } @@ -42,7 +43,7 @@ func RegisterSigningMethod(sm crypto.SigningMethod) { } mu.Lock() - signingMethods[sm.Alg()] = sm + signingMethods[alg] = sm mu.Unlock() } @@ -54,8 +55,9 @@ func RemoveSigningMethod(sm crypto.SigningMethod) { } // GetSigningMethod retrieves a crypto.SigningMethod from the global map. -func GetSigningMethod(alg string) crypto.SigningMethod { +func GetSigningMethod(alg string) (method crypto.SigningMethod) { mu.RLock() - defer mu.RUnlock() - return signingMethods[alg] + method = signingMethods[alg] + mu.RUnlock() + return method } diff --git a/jwt/claims.go b/jwt/claims.go index cc135d8..ae41e7d 100644 --- a/jwt/claims.go +++ b/jwt/claims.go @@ -2,6 +2,8 @@ package jwt import ( "encoding/json" + "reflect" + "time" "github.com/SermoDigital/jose" ) @@ -12,25 +14,21 @@ type Claims map[string]interface{} // Validate validates the Claims per the claims found in // https://tools.ietf.org/html/rfc7519#section-4.1 -func (c Claims) Validate(now, expLeeway, nbfLeeway float64) error { +func (c Claims) Validate(now time.Time, expLeeway, nbfLeeway time.Duration) error { if exp, ok := c.Expiration(); ok { - if !within(exp, expLeeway, now) { + if now.After(exp.Add(expLeeway)) { return ErrTokenIsExpired } } if nbf, ok := c.NotBefore(); ok { - if !within(nbf, nbfLeeway, now) { + if !now.After(nbf.Add(-nbfLeeway)) { return ErrTokenNotYetValid } } return nil } -func within(val, delta, max float64) bool { - return val > max+delta || val > max-delta -} - // Get retrieves the value corresponding with key from the Claims. func (c Claims) Get(key string) interface{} { if c == nil { @@ -63,7 +61,7 @@ func (c Claims) MarshalJSON() ([]byte, error) { return json.Marshal(map[string]interface{}(c)) } -// Base64 implements the Encoder interface. +// Base64 implements the jose.Encoder interface. func (c Claims) Base64() ([]byte, error) { b, err := c.MarshalJSON() if err != nil { @@ -86,7 +84,7 @@ func (c *Claims) UnmarshalJSON(b []byte) error { // Since json.Unmarshal calls UnmarshalJSON, // calling json.Unmarshal on *p would be infinitely recursive // A temp variable is needed because &map[string]interface{}(*p) is - // invalid Go. + // invalid Go. (Address of unaddressable object and all that...) tmp := map[string]interface{}(*c) if err = json.Unmarshal(b, &tmp); err != nil { @@ -113,6 +111,8 @@ func (c Claims) Subject() (string, bool) { // Audience retrieves claim "aud" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.3 func (c Claims) Audience() ([]string, bool) { + // Audience claim must be stringy. That is, it may be one string + // or multiple strings but it should not be anything else. E.g. an int. switch t := c.Get("aud").(type) { case string: return []string{t}, true @@ -144,23 +144,20 @@ func stringify(a ...interface{}) ([]string, bool) { // Expiration retrieves claim "exp" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.4 -func (c Claims) Expiration() (float64, bool) { - v, ok := c.Get("exp").(float64) - return v, ok +func (c Claims) Expiration() (time.Time, bool) { + return c.GetTime("exp") } // NotBefore retrieves claim "nbf" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.5 -func (c Claims) NotBefore() (float64, bool) { - v, ok := c.Get("nbf").(float64) - return v, ok +func (c Claims) NotBefore() (time.Time, bool) { + return c.GetTime("nbf") } // IssuedAt retrieves claim "iat" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.6 -func (c Claims) IssuedAt() (float64, bool) { - v, ok := c.Get("iat").(float64) - return v, ok +func (c Claims) IssuedAt() (time.Time, bool) { + return c.GetTime("iat") } // JWTID retrieves claim "jti" per its type in @@ -215,20 +212,20 @@ func (c Claims) SetAudience(audience ...string) { // SetExpiration sets claim "exp" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.4 -func (c Claims) SetExpiration(expiration float64) { - c.Set("exp", expiration) +func (c Claims) SetExpiration(expiration time.Time) { + c.SetTime("exp", expiration) } // SetNotBefore sets claim "nbf" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.5 -func (c Claims) SetNotBefore(notBefore float64) { - c.Set("nbf", notBefore) +func (c Claims) SetNotBefore(notBefore time.Time) { + c.SetTime("nbf", notBefore) } // SetIssuedAt sets claim "iat" per its type in // https://tools.ietf.org/html/rfc7519#section-4.1.6 -func (c Claims) SetIssuedAt(issuedAt float64) { - c.Set("iat", issuedAt) +func (c Claims) SetIssuedAt(issuedAt time.Time) { + c.SetTime("iat", issuedAt) } // SetJWTID sets claim "jti" per its type in @@ -237,6 +234,41 @@ func (c Claims) SetJWTID(uniqueID string) { c.Set("jti", uniqueID) } +// zero pre-allocs the zero-time value +var zero = time.Time{} + +// GetTime returns a UNIX time for the given key. +// +// It converts an int, int32, int64, uint, uint32, uint64 or float64 value +// into a UNIX time (epoch seconds). float32 does not have sufficient +// precision to store a UNIX time. +// +// Numeric values parsed from JSON will always be stored as float64 since +// Claims is a map[string]interface{}. However, internally the values may be +// stored directly in the claims map as different types. +func (c Claims) GetTime(key string) (time.Time, bool) { + x := c.Get(key) + if x == nil { + return zero, false + } + v := reflect.ValueOf(x) + switch v.Kind() { + case reflect.Int, reflect.Int32, reflect.Int64: + return time.Unix(v.Int(), 0), true + case reflect.Uint, reflect.Uint32, reflect.Uint64: + return time.Unix(int64(v.Uint()), 0), true + case reflect.Float64: + return time.Unix(int64(v.Float()), 0), true + default: + return zero, false + } +} + +// SetTime stores a UNIX time for the given key. +func (c Claims) SetTime(key string, t time.Time) { + c.Set(key, t.Unix()) +} + var ( _ json.Marshaler = (Claims)(nil) _ json.Unmarshaler = (*Claims)(nil) diff --git a/jwt/claims_test.go b/jwt/claims_test.go index c5edd70..b653785 100644 --- a/jwt/claims_test.go +++ b/jwt/claims_test.go @@ -2,9 +2,11 @@ package jwt_test import ( "testing" + "time" "github.com/SermoDigital/jose/crypto" "github.com/SermoDigital/jose/jws" + "github.com/SermoDigital/jose/jwt" ) func TestMultipleAudienceBug_AfterMarshal(t *testing.T) { @@ -83,3 +85,115 @@ func TestSingleAudienceFix_AfterMarshal(t *testing.T) { t.Logf("aud Value: %s", aud) t.Logf("aud Type : %T", aud) } + +func TestValidate(t *testing.T) { + now := time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC) + before, after := now.Add(-time.Minute), now.Add(time.Minute) + leeway := 10 * time.Second + + exp := func(t time.Time) jwt.Claims { + return jwt.Claims{"exp": t.Unix()} + } + nbf := func(t time.Time) jwt.Claims { + return jwt.Claims{"nbf": t.Unix()} + } + + var tests = []struct { + desc string + c jwt.Claims + now time.Time + expLeeway time.Duration + nbfLeeway time.Duration + err error + }{ + // test for nbf < now <= exp + {desc: "exp == nil && nbf == nil", c: jwt.Claims{}, now: now, err: nil}, + + {desc: "now > exp", now: now, c: exp(before), err: jwt.ErrTokenIsExpired}, + {desc: "now = exp", now: now, c: exp(now), err: nil}, + {desc: "now < exp", now: now, c: exp(after), err: nil}, + + {desc: "nbf < now", c: nbf(before), now: now, err: nil}, + {desc: "nbf = now", c: nbf(now), now: now, err: jwt.ErrTokenNotYetValid}, + {desc: "nbf > now", c: nbf(after), now: now, err: jwt.ErrTokenNotYetValid}, + + // test for nbf-x < now <= exp+y + {desc: "now < exp+x", now: now.Add(leeway - time.Second), expLeeway: leeway, c: exp(now), err: nil}, + {desc: "now = exp+x", now: now.Add(leeway), expLeeway: leeway, c: exp(now), err: nil}, + {desc: "now > exp+x", now: now.Add(leeway + time.Second), expLeeway: leeway, c: exp(now), err: jwt.ErrTokenIsExpired}, + + {desc: "nbf-x > now", c: nbf(now), nbfLeeway: leeway, now: now.Add(-leeway + time.Second), err: nil}, + {desc: "nbf-x = now", c: nbf(now), nbfLeeway: leeway, now: now.Add(-leeway), err: jwt.ErrTokenNotYetValid}, + {desc: "nbf-x < now", c: nbf(now), nbfLeeway: leeway, now: now.Add(-leeway - time.Second), err: jwt.ErrTokenNotYetValid}, + } + + for i, tt := range tests { + if got, want := tt.c.Validate(tt.now, tt.expLeeway, tt.nbfLeeway), tt.err; got != want { + t.Errorf("%d - %q: got %v want %v", i, tt.desc, got, want) + } + } +} + +func TestGetAndSetTime(t *testing.T) { + now := time.Now() + nowUnix := now.Unix() + c := jwt.Claims{ + "int": int(nowUnix), + "int32": int32(nowUnix), + "int64": int64(nowUnix), + "uint": uint(nowUnix), + "uint32": uint32(nowUnix), + "uint64": uint64(nowUnix), + "float64": float64(nowUnix), + } + c.SetTime("setTime", now) + for k := range c { + v, ok := c.GetTime(k) + if got, want := v, time.Unix(nowUnix, 0); !ok || !got.Equal(want) { + t.Errorf("%s: got %v want %v", k, got, want) + } + } +} + +// TestTimeValuesThroughJSON verifies that the time values +// that are set via the Set{IssuedAt,NotBefore,Expiration}() +// methods can actually be parsed back +func TestTimeValuesThroughJSON(t *testing.T) { + now := time.Unix(time.Now().Unix(), 0) + + c := jws.Claims{} + c.SetIssuedAt(now) + c.SetNotBefore(now) + c.SetExpiration(now) + + // serialize to JWT + tok := jws.NewJWT(c, crypto.SigningMethodHS256) + b, err := tok.Serialize([]byte("key")) + if err != nil { + t.Fatal(err) + } + + // parse the JWT again + tok2, err := jws.ParseJWT(b) + if err != nil { + t.Fatal(err) + } + c2 := tok2.Claims() + + iat, ok1 := c2.IssuedAt() + nbf, ok2 := c2.NotBefore() + exp, ok3 := c2.Expiration() + if !ok1 || !ok2 || !ok3 { + t.Fatal("got false want true") + } + + if got, want := iat, now; !got.Equal(want) { + t.Errorf("%s: got %v want %v", "iat", got, want) + } + if got, want := nbf, now; !got.Equal(want) { + t.Errorf("%s: got %v want %v", "nbf", got, want) + } + if got, want := exp, now; !got.Equal(want) { + t.Errorf("%s: got %v want %v", "exp", got, want) + } +} diff --git a/jwt/eq.go b/jwt/eq.go index a7a37a9..3113269 100644 --- a/jwt/eq.go +++ b/jwt/eq.go @@ -1,52 +1,47 @@ package jwt -import "reflect" +func verifyPrincipals(pcpls, auds []string) bool { + // "Each principal intended to process the JWT MUST + // identify itself with a value in the audience claim." + // - https://tools.ietf.org/html/rfc7519#section-4.1.3 -// eq returns true if the two types are either strings -// or comparable slices. -func eq(a, b interface{}) bool { - t1 := reflect.TypeOf(a) - t2 := reflect.TypeOf(b) - - if t1.Kind() == t2.Kind() { - switch t1.Kind() { - case reflect.Slice: - return eqSlice(a, b) - case reflect.String: - return reflect.ValueOf(a).String() == - reflect.ValueOf(b).String() + found := -1 + for i, p := range pcpls { + for _, v := range auds { + if p == v { + found++ + break + } + } + if found != i { + return false } } - return false + return true } -// eqSlice returns true if the two interfaces are both slices -// and are equal. For example: https://play.golang.org/p/5VLMwNE3i- -func eqSlice(a, b interface{}) bool { - if a == nil || b == nil { - return false +// ValidAudience returns true iff: +// - a and b are strings and a == b +// - a is string, b is []string and a is in b +// - a is []string, b is []string and all of a is in b +// - a is []string, b is string and len(a) == 1 and a[0] == b +func ValidAudience(a, b interface{}) bool { + s1, ok := a.(string) + if ok { + if s2, ok := b.(string); ok { + return s1 == s2 + } + a2, ok := b.([]string) + return ok && verifyPrincipals([]string{s1}, a2) } - v1 := reflect.ValueOf(a) - v2 := reflect.ValueOf(b) - - if v1.Kind() != reflect.Slice || - v2.Kind() != reflect.Slice { + a1, ok := a.([]string) + if !ok { return false } - - if v1.Len() == v2.Len() && v1.Len() > 0 { - for i := 0; i < v1.Len() && i < v2.Len(); i++ { - k1 := v1.Index(i) - k2 := v2.Index(i) - if k1.Type().Comparable() && - k2.Type().Comparable() && - k1.CanInterface() && k2.CanInterface() && - k1.Interface() != k2.Interface() { - return false - } - } - return true + if a2, ok := b.([]string); ok { + return verifyPrincipals(a1, a2) } - return false + s2, ok := b.(string) + return ok && len(a1) == 1 && a1[0] == s2 } diff --git a/jwt/eq_test.go b/jwt/eq_test.go new file mode 100644 index 0000000..5f9d4fd --- /dev/null +++ b/jwt/eq_test.go @@ -0,0 +1,26 @@ +package jwt_test + +import ( + "testing" + + "github.com/SermoDigital/jose/jwt" +) + +func TestValidAudience(t *testing.T) { + tests := [...]struct { + a interface{} + b interface{} + v bool + }{ + 0: {"https://www.google.com", "https://www.google.com", true}, + 1: {[]string{"example.com", "google.com"}, []string{"example.com"}, false}, + 2: {500, 43, false}, + 3: {"google.com", "facebook.com", false}, + 4: {[]string{"example.com"}, []string{"example.com", "foo.com"}, true}, + } + for i, v := range tests { + if x := jwt.ValidAudience(v.a, v.b); x != v.v { + t.Fatalf("#%d: wanted %t, got %t", i, v.v, x) + } + } +} diff --git a/jwt/jwt.go b/jwt/jwt.go index bd84259..d29c43a 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -1,6 +1,10 @@ package jwt -import "github.com/SermoDigital/jose/crypto" +import ( + "time" + + "github.com/SermoDigital/jose/crypto" +) // JWT represents a JWT per RFC 7519. // It's described as an interface instead of a physical structure @@ -33,18 +37,12 @@ type ValidateFunc func(Claims) error // Validator represents some of the validation options. type Validator struct { - Expected Claims // If non-nil, these are required to match. - EXP float64 // EXPLeeway - NBF float64 // NBFLeeway - Fn ValidateFunc // See ValidateFunc for more information. - - _ struct{} -} + Expected Claims // If non-nil, these are required to match. + EXP time.Duration // EXPLeeway + NBF time.Duration // NBFLeeway + Fn ValidateFunc // See ValidateFunc for more information. -var defaultClaims = []string{ - "iss", "sub", "aud", - "exp", "nbf", "iat", - "jti", + _ struct{} // Require explicitly-named struct fields. } // Validate validates the JWT based on the expected claims in v. @@ -71,7 +69,8 @@ func (v *Validator) Validate(j JWT) error { } if aud, ok := v.Expected.Audience(); ok { - if aud2, _ := j.Claims().Audience(); !eq(aud, aud2){ + aud2, ok := j.Claims().Audience() + if !ok || !ValidAudience(aud, aud2) { return ErrInvalidAUDClaim } } @@ -111,21 +110,21 @@ func (v *Validator) SetAudience(aud string) { // SetExpiration sets the "exp" claim per // https://tools.ietf.org/html/rfc7519#section-4.1.4 -func (v *Validator) SetExpiration(exp float64) { +func (v *Validator) SetExpiration(exp time.Time) { v.expect() v.Expected.Set("exp", exp) } // SetNotBefore sets the "nbf" claim per // https://tools.ietf.org/html/rfc7519#section-4.1.5 -func (v *Validator) SetNotBefore(nbf float64) { +func (v *Validator) SetNotBefore(nbf time.Time) { v.expect() v.Expected.Set("nbf", nbf) } // SetIssuedAt sets the "iat" claim per // https://tools.ietf.org/html/rfc7519#section-4.1.6 -func (v *Validator) SetIssuedAt(iat float64) { +func (v *Validator) SetIssuedAt(iat time.Time) { v.expect() v.Expected.Set("iat", iat) }