Skip to content

Commit

Permalink
allow multiple audiences
Browse files Browse the repository at this point in the history
  • Loading branch information
gruberlu committed Dec 17, 2024
1 parent bc8bdca commit 48c6492
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 11 deletions.
8 changes: 4 additions & 4 deletions parser_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ func WithExpirationRequired() ParserOption {
}
}

// WithAudience configures the validator to require the specified audience in
// the `aud` claim. Validation will fail if the audience is not listed in the
// token or the `aud` claim is missing.
// WithAudience configures the validator to require ONE of the specified
// audiences to be present in the `aud` claim. Validation will fail if none of
// the audiences is listed in the token or the `aud` claim is missing.
//
// NOTE: While the `aud` claim is OPTIONAL in a JWT, the handling of it is
// application-specific. Since this validation API is helping developers in
// writing secure application, we decided to REQUIRE the existence of the claim,
// if an audience is expected.
func WithAudience(aud string) ParserOption {
func WithAudience(aud ...string) ParserOption {
return func(p *Parser) {
p.validator.expectedAud = aud
}
Expand Down
14 changes: 8 additions & 6 deletions validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ type Validator struct {

// expectedAud contains the audience this token expects. Supplying an empty
// string will disable aud checking.
expectedAud string
expectedAud []string

// expectedIss contains the issuer this token expects. Supplying an empty
// string will disable iss checking.
Expand Down Expand Up @@ -120,7 +120,7 @@ func (v *Validator) Validate(claims Claims) error {
}

// If we have an expected audience, we also require the audience claim
if v.expectedAud != "" {
if len(v.expectedAud) > 0 {
if err = v.verifyAudience(claims, v.expectedAud, true); err != nil {
errs = append(errs, err)
}
Expand Down Expand Up @@ -226,7 +226,7 @@ func (v *Validator) verifyNotBefore(claims Claims, cmp time.Time, required bool)
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) error {
func (v *Validator) verifyAudience(claims Claims, cmp []string, required bool) error {
aud, err := claims.GetAudience()
if err != nil {
return err
Expand All @@ -241,10 +241,12 @@ func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) err

var stringClaims string
for _, a := range aud {
if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 {
result = true
for _, c := range cmp {
if subtle.ConstantTimeCompare([]byte(a), []byte(c)) != 0 {
result = true
}
stringClaims = stringClaims + a
}
stringClaims = stringClaims + a
}

// case where "" is sent in one or many aud claims
Expand Down
49 changes: 48 additions & 1 deletion validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func Test_Validator_Validate(t *testing.T) {
leeway time.Duration
timeFunc func() time.Time
verifyIat bool
expectedAud string
expectedAud []string
expectedIss string
expectedSub string
}
Expand Down Expand Up @@ -259,3 +259,50 @@ func Test_Validator_verifyIssuedAt(t *testing.T) {
})
}
}

func Test_Validator_verifyAudience(t *testing.T) {
type fields struct {
expectedAud []string
}
type args struct {
claims Claims
cmp []string
required bool
}
tests := []struct {
name string
fields fields
args args
wantErr error
}{
{
name: "single value in aud claim",
fields: fields{expectedAud: []string{"me", "you"}},
args: args{claims: MapClaims{"aud": "me"}, cmp: []string{"me"}},
wantErr: nil,
},
{
name: "multiple values in aud claim",
fields: fields{expectedAud: []string{"me"}},
args: args{claims: MapClaims{"aud": []string{"me", "you"}}, cmp: []string{"me"}},
wantErr: nil,
},
{
name: "claims with invalid audience",
fields: fields{expectedAud: []string{"me"}},
args: args{claims: MapClaims{"aud": "you"}, cmp: []string{"me"}},
wantErr: ErrTokenInvalidAudience,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &Validator{
expectedAud: tt.fields.expectedAud,
}
err := v.verifyAudience(tt.args.claims, tt.args.cmp, tt.args.required)
if (err != nil) && !errors.Is(err, tt.wantErr) {
t.Errorf("validator.verifyAudience() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

0 comments on commit 48c6492

Please sign in to comment.