-
Notifications
You must be signed in to change notification settings - Fork 399
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: fix refresh token reuse revocation
- Loading branch information
Showing
4 changed files
with
110 additions
and
76 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -362,84 +362,110 @@ func (ts *TokenTestSuite) TestTokenRefreshTokenGrantFailure() { | |
assert.Equal(ts.T(), http.StatusBadRequest, w.Code) | ||
} | ||
|
||
func (ts *TokenTestSuite) TestTokenRefreshTokenRotation() { | ||
u, err := models.NewUser("", "[email protected]", "password", ts.Config.JWT.Aud, nil) | ||
require.NoError(ts.T(), err, "Error creating test user model") | ||
t := time.Now() | ||
u.EmailConfirmedAt = &t | ||
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving foo user") | ||
func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() { | ||
originalSecurity := ts.API.config.Security | ||
|
||
first, err := models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{}) | ||
require.NoError(ts.T(), err) | ||
second, err := models.GrantRefreshTokenSwap(&http.Request{}, ts.API.db, u, first) | ||
require.NoError(ts.T(), err) | ||
ts.API.config.Security.RefreshTokenRotationEnabled = true | ||
ts.API.config.Security.RefreshTokenReuseInterval = 0 | ||
|
||
cases := []struct { | ||
desc string | ||
refreshTokenRotationEnabled bool | ||
reuseInterval int | ||
refreshToken string | ||
expectedCode int | ||
expectedBody map[string]interface{} | ||
}{ | ||
{ | ||
desc: "Valid refresh within reuse interval", | ||
refreshTokenRotationEnabled: true, | ||
reuseInterval: 30, | ||
refreshToken: second.Token, | ||
expectedCode: http.StatusOK, | ||
expectedBody: map[string]interface{}{ | ||
"refresh_token": "some-new-refresh-token", | ||
}, | ||
}, | ||
{ | ||
desc: "Invalid refresh outside reuse interval", | ||
refreshTokenRotationEnabled: true, | ||
reuseInterval: 0, | ||
refreshToken: first.Token, | ||
expectedCode: http.StatusBadRequest, | ||
expectedBody: map[string]interface{}{ | ||
"error": "invalid_grant", | ||
"error_description": "Invalid Refresh Token: Already Used", | ||
}, | ||
}, | ||
{ | ||
desc: "Invalid refresh, revoke third token", | ||
refreshTokenRotationEnabled: true, | ||
reuseInterval: 0, | ||
refreshToken: first.Token, | ||
expectedCode: http.StatusBadRequest, | ||
expectedBody: map[string]interface{}{ | ||
"error": "invalid_grant", | ||
"error_description": "Invalid Refresh Token: Already Used", | ||
}, | ||
}, | ||
defer func() { | ||
ts.API.config.Security = originalSecurity | ||
}() | ||
|
||
refreshTokens := []string{ | ||
ts.RefreshToken.Token, | ||
} | ||
|
||
for _, c := range cases { | ||
ts.Run(c.desc, func() { | ||
ts.Config.Security.RefreshTokenRotationEnabled = c.refreshTokenRotationEnabled | ||
ts.Config.Security.RefreshTokenReuseInterval = c.reuseInterval | ||
var buffer bytes.Buffer | ||
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ | ||
"refresh_token": c.refreshToken, | ||
})) | ||
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) | ||
req.Header.Set("Content-Type", "application/json") | ||
w := httptest.NewRecorder() | ||
ts.API.handler.ServeHTTP(w, req) | ||
assert.Equal(ts.T(), c.expectedCode, w.Code) | ||
|
||
data := make(map[string]interface{}) | ||
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) | ||
for k, v := range c.expectedBody { | ||
if k == "refresh_token" { | ||
require.NotEmpty(ts.T(), v, data[k]) | ||
} else { | ||
require.Equal(ts.T(), v, data[k]) | ||
} | ||
} | ||
}) | ||
for i := 0; i < 3; i += 1 { | ||
var buffer bytes.Buffer | ||
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ | ||
"refresh_token": refreshTokens[len(refreshTokens)-1], | ||
})) | ||
|
||
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) | ||
req.Header.Set("Content-Type", "application/json") | ||
|
||
w := httptest.NewRecorder() | ||
ts.API.handler.ServeHTTP(w, req) | ||
|
||
assert.Equal(ts.T(), http.StatusOK, w.Code) | ||
|
||
var response struct { | ||
RefreshToken string `json:"refresh_token"` | ||
} | ||
|
||
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response)) | ||
|
||
refreshTokens = append(refreshTokens, response.RefreshToken) | ||
} | ||
|
||
// ensure that the 4 refresh tokens are setup correctly | ||
for i, refreshToken := range refreshTokens { | ||
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false) | ||
require.NoError(ts.T(), err) | ||
|
||
if i == len(refreshTokens)-1 { | ||
require.False(ts.T(), token.Revoked) | ||
} else { | ||
require.True(ts.T(), token.Revoked) | ||
} | ||
} | ||
|
||
// try to reuse the first (earliest) refresh token which should trigger the family revocation logic | ||
var buffer bytes.Buffer | ||
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ | ||
"refresh_token": refreshTokens[0], | ||
})) | ||
|
||
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) | ||
req.Header.Set("Content-Type", "application/json") | ||
|
||
w := httptest.NewRecorder() | ||
ts.API.handler.ServeHTTP(w, req) | ||
|
||
assert.Equal(ts.T(), http.StatusBadRequest, w.Code) | ||
|
||
var response struct { | ||
Error string `json:"error"` | ||
ErrorDescription string `json:"error_description"` | ||
} | ||
|
||
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response)) | ||
require.Equal(ts.T(), response.Error, "invalid_grant") | ||
require.Equal(ts.T(), response.ErrorDescription, "Invalid Refresh Token: Already Used") | ||
|
||
// ensure that the refresh tokens are marked as revoked in the database | ||
for _, refreshToken := range refreshTokens { | ||
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false) | ||
require.NoError(ts.T(), err) | ||
|
||
require.True(ts.T(), token.Revoked) | ||
} | ||
|
||
// finally ensure that none of the refresh tokens can be reused any | ||
// more, starting with the previously valid one | ||
for i := len(refreshTokens) - 1; i >= 0; i -= 1 { | ||
var buffer bytes.Buffer | ||
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ | ||
"refresh_token": refreshTokens[i], | ||
})) | ||
|
||
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) | ||
req.Header.Set("Content-Type", "application/json") | ||
|
||
w := httptest.NewRecorder() | ||
ts.API.handler.ServeHTTP(w, req) | ||
|
||
assert.Equal(ts.T(), http.StatusBadRequest, w.Code, "For refresh token %d", i) | ||
|
||
var response struct { | ||
Error string `json:"error"` | ||
ErrorDescription string `json:"error_description"` | ||
} | ||
|
||
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response)) | ||
require.Equal(ts.T(), response.Error, "invalid_grant", "For refresh token %d", i) | ||
require.Equal(ts.T(), response.ErrorDescription, "Invalid Refresh Token: Already Used", "For refresh token %d", i) | ||
} | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters