Skip to content

Commit

Permalink
feat: add inactivity-timeout to sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
hf committed Nov 6, 2023
1 parent 9a1f461 commit 79330c2
Show file tree
Hide file tree
Showing 13 changed files with 145 additions and 5 deletions.
2 changes: 2 additions & 0 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
var grantParams models.GrantParams
var err error

grantParams.FillGrantParams(r)

if providerType == "twitter" {
// future OAuth1.0 providers will use this method
oAuthResponseData, err := a.oAuth1Callback(ctx, r, providerType)
Expand Down
2 changes: 2 additions & 0 deletions internal/api/samlacs.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error {

var grantParams models.GrantParams

grantParams.FillGrantParams(r)

if !notAfter.IsZero() {
grantParams.SessionNotAfter = &notAfter
}
Expand Down
3 changes: 3 additions & 0 deletions internal/api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {

var user *models.User
var grantParams models.GrantParams

grantParams.FillGrantParams(r)

params.Aud = a.requestAud(ctx, r)

switch params.Provider {
Expand Down
10 changes: 9 additions & 1 deletion internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
var user *models.User
var grantParams models.GrantParams
var provider string

grantParams.FillGrantParams(r)

if params.Email != "" {
provider = "email"
if !config.External.Email.Enabled {
Expand Down Expand Up @@ -180,6 +183,12 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request)
db := a.db.WithContext(ctx)
var grantParams models.GrantParams

// There is a slight problem with this as it will pick-up the
// User-Agent and IP addresses from the server if used on the server
// side. Currently there's no mechanism to distinguish, but the server
// can be told to at least propagate the User-Agent header.
grantParams.FillGrantParams(r)

params := &PKCEGrantParams{}
body, err := getBodyBytes(r)
if err != nil {
Expand Down Expand Up @@ -245,7 +254,6 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request)
}

return sendJSON(w, http.StatusOK, token)

}

func generateAccessToken(tx *storage.Connection, user *models.User, sessionId *uuid.UUID, config *conf.JWTConfiguration) (string, int64, error) {
Expand Down
2 changes: 2 additions & 0 deletions internal/api/token_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R
var token *AccessTokenResponse
var grantParams models.GrantParams

grantParams.FillGrantParams(r)

if err := db.Transaction(func(tx *storage.Connection) error {
var user *models.User
var terr error
Expand Down
32 changes: 31 additions & 1 deletion internal/api/token_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/supabase/gotrue/internal/metering"
"github.com/supabase/gotrue/internal/models"
"github.com/supabase/gotrue/internal/storage"
"github.com/supabase/gotrue/internal/utilities"
)

const retryLoopDuration = 5.0
Expand Down Expand Up @@ -50,7 +51,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
for retry && time.Since(retryStart).Seconds() < retryLoopDuration {
retry = false

user, _, session, err := models.FindUserWithRefreshToken(db, params.RefreshToken, false)
user, token, session, err := models.FindUserWithRefreshToken(db, params.RefreshToken, false)
if err != nil {
if models.IsNotFoundError(err) {
return oauthError("invalid_grant", "Invalid Refresh Token: Refresh Token Not Found")
Expand Down Expand Up @@ -80,6 +81,14 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
if !notAfter.IsZero() && a.Now().After(notAfter) {
return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired")
}

if config.Sessions.InactivityTimeout != nil {
timesOutAt := session.LastRefreshedAt(&token.UpdatedAt).Add(*config.Sessions.InactivityTimeout)

if timesOutAt.Before(a.Now()) {
return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired (Inactivity)")
}
}
}

// Basic checks above passed, now we need to serialize access
Expand Down Expand Up @@ -172,6 +181,27 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
return internalServerError("error generating jwt token").WithInternalError(terr)
}

refreshedAt := a.Now()
session.RefreshedAt = &refreshedAt

userAgent := r.Header.Get("User-Agent")
if userAgent != "" {
session.UserAgent = &userAgent
} else {
session.UserAgent = nil
}

ipAddress := utilities.GetIPAddress(r)
if ipAddress != "" {
session.IP = &ipAddress
} else {
session.IP = nil
}

if terr := session.UpdateRefresh(tx); terr != nil {
return internalServerError("failed to update session information").WithInternalError(terr)
}

newTokenResponse = &AccessTokenResponse{
Token: tokenString,
TokenType: "bearer",
Expand Down
34 changes: 34 additions & 0 deletions internal/api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,40 @@ func (ts *TokenTestSuite) TestSessionTimebox() {
assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired", firstResult.ErrorDescription)
}

func (ts *TokenTestSuite) TestSessionInactivityTimeout() {
inactivityTimeout := 10 * time.Second

ts.API.config.Sessions.InactivityTimeout = &inactivityTimeout
ts.API.overrideTime = func() time.Time {
return time.Now().Add(inactivityTimeout).Add(time.Second)
}

defer func() {
ts.API.overrideTime = nil
}()

var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": ts.RefreshToken.Token,
}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)

var firstResult struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}

assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult))
assert.Equal(ts.T(), "invalid_grant", firstResult.Error)
assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired (Inactivity)", firstResult.ErrorDescription)
}

func (ts *TokenTestSuite) TestFailedToSaveRefreshTokenResultCase() {
var buffer bytes.Buffer

Expand Down
5 changes: 5 additions & 0 deletions internal/api/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa
token *AccessTokenResponse
authCode string
)

grantParams.FillGrantParams(r)

flowType := models.ImplicitFlow
var authenticationMethod models.AuthenticationMethod
if strings.HasPrefix(params.Token, PKCEPrefix) {
Expand Down Expand Up @@ -233,6 +236,8 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request, params *VerifyP
)
var isSingleConfirmationResponse = false

grantParams.FillGrantParams(r)

err := db.Transaction(func(tx *storage.Connection) error {
var terr error
aud := a.requestAud(ctx, r)
Expand Down
3 changes: 2 additions & 1 deletion internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ func (a *APIConfiguration) Validate() error {
}

type SessionsConfiguration struct {
Timebox *time.Duration `json:"timebox"`
Timebox *time.Duration `json:"timebox"`
InactivityTimeout *time.Duration `json:"inactivity_timeout,omitempty" split_words:"true"`
}

func (c *SessionsConfiguration) Validate() error {
Expand Down
18 changes: 17 additions & 1 deletion internal/models/refresh_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/pkg/errors"
"github.com/supabase/gotrue/internal/crypto"
"github.com/supabase/gotrue/internal/storage"
"github.com/supabase/gotrue/internal/utilities"
)

// RefreshToken is the database model for refresh tokens.
Expand Down Expand Up @@ -41,6 +42,14 @@ type GrantParams struct {
FactorID *uuid.UUID

SessionNotAfter *time.Time

UserAgent string
IP string
}

func (g *GrantParams) FillGrantParams(r *http.Request) {
g.UserAgent = r.Header.Get("User-Agent")
g.IP = utilities.GetIPAddress(r)
}

// GrantAuthenticatedUser creates a refresh token for the provided user.
Expand Down Expand Up @@ -110,7 +119,6 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok
if oldToken != nil {
token.Parent = storage.NullString(oldToken.Token)
token.SessionId = oldToken.SessionId

}

if token.SessionId == nil {
Expand All @@ -129,6 +137,14 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok
session.NotAfter = params.SessionNotAfter
}

if params.UserAgent != "" {
session.UserAgent = &params.UserAgent
}

if params.IP != "" {
session.IP = &params.IP
}

if err := tx.Create(session); err != nil {
return nil, errors.Wrap(err, "error creating new session")
}
Expand Down
28 changes: 28 additions & 0 deletions internal/models/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,41 @@ type Session struct {
FactorID *uuid.UUID `json:"factor_id" db:"factor_id"`
AMRClaims []AMRClaim `json:"amr,omitempty" has_many:"amr_claims"`
AAL *string `json:"aal" db:"aal"`

RefreshedAt *time.Time `json:"refreshed_at,omitempty" db:"refreshed_at"`
UserAgent *string `json:"user_agent,omitempty" db:"user_agent"`
IP *string `json:"ip,omitempty" db:"ip"`
}

func (Session) TableName() string {
tableName := "sessions"
return tableName
}

func (s *Session) LastRefreshedAt(refreshTokenTime *time.Time) time.Time {
refreshedAt := s.RefreshedAt

if refreshedAt == nil || refreshedAt.IsZero() {
if refreshTokenTime != nil {
rtt := *refreshTokenTime

if rtt.IsZero() {
return s.CreatedAt
} else if rtt.After(s.CreatedAt) {
return rtt
}
}

return s.CreatedAt
}

return *refreshedAt
}

func (s *Session) UpdateRefresh(tx *storage.Connection) error {
return tx.UpdateOnly(s, "refreshed_at", "user_agent", "ip")
}

func NewSession() (*Session, error) {
id := uuid.Must(uuid.NewV4())

Expand Down
7 changes: 6 additions & 1 deletion internal/utilities/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ func GetIPAddress(r *http.Request) string {

for _, ip := range ips {
if ip != "" {
return ip
parsed := net.ParseIP(ip)
if parsed == nil {
continue
}

return parsed.String()
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions migrations/20231027141322_add_session_refresh_columns.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
alter table if exists {{ index .Options "Namespace" }}.sessions
add column if not exists refreshed_at timestamp without time zone,
add column if not exists user_agent text,
add column if not exists ip inet;

0 comments on commit 79330c2

Please sign in to comment.