Skip to content

Commit

Permalink
feat: fix refresh token reuse revocation
Browse files Browse the repository at this point in the history
  • Loading branch information
hf committed Nov 20, 2023
1 parent 8565d26 commit 494b921
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 76 deletions.
4 changes: 2 additions & 2 deletions internal/api/token_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h

if token.Revoked {
activeRefreshToken, terr := session.FindCurrentlyActiveRefreshToken(tx)
if terr != nil {
if terr != nil && !models.IsNotFoundError(terr) {
return internalServerError(terr.Error())
}

Expand Down Expand Up @@ -199,7 +199,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
}
}

return oauthError("invalid_grant", "Invalid Refresh Token: Already Used").WithInternalMessage("Possible abuse attempt: %v", token.ID)
return storage.NewCommitWithError(oauthError("invalid_grant", "Invalid Refresh Token: Already Used").WithInternalMessage("Possible abuse attempt: %v", token.ID))
}
}
}
Expand Down
174 changes: 100 additions & 74 deletions internal/api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,84 +362,110 @@ func (ts *TokenTestSuite) TestTokenRefreshTokenGrantFailure() {
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
}

func (ts *TokenTestSuite) TestTokenRefreshTokenRotation() {
u, err := models.NewUser("", "[email protected]", "password", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error creating test user model")
t := time.Now()
u.EmailConfirmedAt = &t
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving foo user")
func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() {
originalSecurity := ts.API.config.Security

first, err := models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{})
require.NoError(ts.T(), err)
second, err := models.GrantRefreshTokenSwap(&http.Request{}, ts.API.db, u, first)
require.NoError(ts.T(), err)
ts.API.config.Security.RefreshTokenRotationEnabled = true
ts.API.config.Security.RefreshTokenReuseInterval = 0

cases := []struct {
desc string
refreshTokenRotationEnabled bool
reuseInterval int
refreshToken string
expectedCode int
expectedBody map[string]interface{}
}{
{
desc: "Valid refresh within reuse interval",
refreshTokenRotationEnabled: true,
reuseInterval: 30,
refreshToken: second.Token,
expectedCode: http.StatusOK,
expectedBody: map[string]interface{}{
"refresh_token": "some-new-refresh-token",
},
},
{
desc: "Invalid refresh outside reuse interval",
refreshTokenRotationEnabled: true,
reuseInterval: 0,
refreshToken: first.Token,
expectedCode: http.StatusBadRequest,
expectedBody: map[string]interface{}{
"error": "invalid_grant",
"error_description": "Invalid Refresh Token: Already Used",
},
},
{
desc: "Invalid refresh, revoke third token",
refreshTokenRotationEnabled: true,
reuseInterval: 0,
refreshToken: first.Token,
expectedCode: http.StatusBadRequest,
expectedBody: map[string]interface{}{
"error": "invalid_grant",
"error_description": "Invalid Refresh Token: Already Used",
},
},
defer func() {
ts.API.config.Security = originalSecurity
}()

refreshTokens := []string{
ts.RefreshToken.Token,
}

for _, c := range cases {
ts.Run(c.desc, func() {
ts.Config.Security.RefreshTokenRotationEnabled = c.refreshTokenRotationEnabled
ts.Config.Security.RefreshTokenReuseInterval = c.reuseInterval
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": c.refreshToken,
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), c.expectedCode, w.Code)

data := make(map[string]interface{})
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
for k, v := range c.expectedBody {
if k == "refresh_token" {
require.NotEmpty(ts.T(), v, data[k])
} else {
require.Equal(ts.T(), v, data[k])
}
}
})
for i := 0; i < 3; i += 1 {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": refreshTokens[len(refreshTokens)-1],
}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)

assert.Equal(ts.T(), http.StatusOK, w.Code)

var response struct {
RefreshToken string `json:"refresh_token"`
}

require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response))

refreshTokens = append(refreshTokens, response.RefreshToken)
}

// ensure that the 4 refresh tokens are setup correctly
for i, refreshToken := range refreshTokens {
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false)
require.NoError(ts.T(), err)

if i == len(refreshTokens)-1 {
require.False(ts.T(), token.Revoked)
} else {
require.True(ts.T(), token.Revoked)
}
}

// try to reuse the first (earliest) refresh token which should trigger the family revocation logic
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": refreshTokens[0],
}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)

assert.Equal(ts.T(), http.StatusBadRequest, w.Code)

var response struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}

require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response))
require.Equal(ts.T(), response.Error, "invalid_grant")
require.Equal(ts.T(), response.ErrorDescription, "Invalid Refresh Token: Already Used")

// ensure that the refresh tokens are marked as revoked in the database
for _, refreshToken := range refreshTokens {
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false)
require.NoError(ts.T(), err)

require.True(ts.T(), token.Revoked)
}

// finally ensure that none of the refresh tokens can be reused any
// more, starting with the previously valid one
for i := len(refreshTokens) - 1; i >= 0; i -= 1 {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": refreshTokens[i],
}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)

assert.Equal(ts.T(), http.StatusBadRequest, w.Code, "For refresh token %d", i)

var response struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}

require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response))
require.Equal(ts.T(), response.Error, "invalid_grant", "For refresh token %d", i)
require.Equal(ts.T(), response.ErrorDescription, "Invalid Refresh Token: Already Used", "For refresh token %d", i)
}
}

Expand Down
4 changes: 4 additions & 0 deletions internal/models/refresh_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ func RevokeTokenFamily(tx *storage.Connection, token *RefreshToken) error {
update `+tablename+` r set revoked = true from token_family where token_family.id = r.id;`, token.Token).Exec()
}
if err != nil {
if errors.Cause(err) == sql.ErrNoRows || errors.Is(err, sql.ErrNoRows) {
return nil
}

return err
}
return nil
Expand Down
4 changes: 4 additions & 0 deletions internal/models/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,10 @@ func (s *Session) FindCurrentlyActiveRefreshToken(tx *storage.Connection) (*Refr
var activeRefreshToken RefreshToken

if err := tx.Q().Where("session_id = ? and revoked is false", s.ID).Order("id desc").First(&activeRefreshToken); err != nil {
if errors.Cause(err) == sql.ErrNoRows || errors.Is(err, sql.ErrNoRows) {
return nil, RefreshTokenNotFoundError{}
}

return nil, err
}

Expand Down

0 comments on commit 494b921

Please sign in to comment.