Skip to content

Commit

Permalink
feat: add inactivity-timeout to sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
hf committed Oct 31, 2023
1 parent 2385212 commit 72ca691
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 2 deletions.
10 changes: 10 additions & 0 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ type API struct {
db *storage.Connection
config *conf.GlobalConfiguration
version string

overrideTime func() time.Time
}

func (a *API) Now() time.Time {
if a.overrideTime != nil {
return a.overrideTime()
}

return time.Now()
}

// NewAPI instantiates a new REST API
Expand Down
2 changes: 2 additions & 0 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
var grantParams models.GrantParams
var err error

grantParams.FromRequest(r)

if providerType == "twitter" {
// future OAuth1.0 providers will use this method
oAuthResponseData, err := a.oAuth1Callback(ctx, r, providerType)
Expand Down
2 changes: 2 additions & 0 deletions internal/api/samlacs.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error {

var grantParams models.GrantParams

grantParams.FromRequest(r)

if !notAfter.IsZero() {
grantParams.SessionNotAfter = &notAfter
}
Expand Down
3 changes: 3 additions & 0 deletions internal/api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {

var user *models.User
var grantParams models.GrantParams

grantParams.FromRequest(r)

params.Aud = a.requestAud(ctx, r)

switch params.Provider {
Expand Down
10 changes: 9 additions & 1 deletion internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
var user *models.User
var grantParams models.GrantParams
var provider string

grantParams.FromRequest(r)

if params.Email != "" {
provider = "email"
if !config.External.Email.Enabled {
Expand Down Expand Up @@ -180,6 +183,12 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request)
db := a.db.WithContext(ctx)
var grantParams models.GrantParams

// There is a slight problem with this as it will pick-up the
// User-Agent and IP addresses from the server if used on the server
// side. Currently there's no mechanism to distinguish, but the server
// can be told to at least propagate the User-Agent header.
grantParams.FromRequest(r)

params := &PKCEGrantParams{}
body, err := getBodyBytes(r)
if err != nil {
Expand Down Expand Up @@ -245,7 +254,6 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request)
}

return sendJSON(w, http.StatusOK, token)

}

func generateAccessToken(tx *storage.Connection, user *models.User, sessionId *uuid.UUID, config *conf.JWTConfiguration) (string, int64, error) {
Expand Down
2 changes: 2 additions & 0 deletions internal/api/token_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R
var token *AccessTokenResponse
var grantParams models.GrantParams

grantParams.FromRequest(r)

if err := db.Transaction(func(tx *storage.Connection) error {
var user *models.User
var terr error
Expand Down
28 changes: 28 additions & 0 deletions internal/api/token_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/supabase/gotrue/internal/metering"
"github.com/supabase/gotrue/internal/models"
"github.com/supabase/gotrue/internal/storage"
"github.com/supabase/gotrue/internal/utilities"
)

const retryLoopDuration = 5.0
Expand Down Expand Up @@ -72,6 +73,12 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
if !notAfter.IsZero() && time.Now().UTC().After(notAfter) {
return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired")
}

if config.Sessions.InactivityTimeout != nil {
if a.Now().Sub(session.LastRefreshedAt()) >= *config.Sessions.InactivityTimeout {
return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired (Inactivity)")
}
}
}

// Basic checks above passed, now we need to serialize access
Expand Down Expand Up @@ -164,6 +171,27 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
return internalServerError("error generating jwt token").WithInternalError(terr)
}

refreshedAt := a.Now()
session.RefreshedAt = &refreshedAt

userAgent := r.Header.Get("User-Agent")
if userAgent != "" {
session.UserAgent = &userAgent
} else {
session.UserAgent = nil
}

ipAddress := utilities.GetIPAddress(r)
if ipAddress != "" {
session.IP = &ipAddress
} else {
session.IP = nil
}

if terr := session.UpdateRefresh(tx); terr != nil {
return internalServerError("failed to update session information").WithInternalError(terr)
}

newTokenResponse = &AccessTokenResponse{
Token: tokenString,
TokenType: "bearer",
Expand Down
5 changes: 5 additions & 0 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ func (a *APIConfiguration) Validate() error {
return nil
}

type SessionsConfiguration struct {
InactivityTimeout *time.Duration `json:"inactivity_timeout,omitempty" split_words:"true"`
}

// GlobalConfiguration holds all the configuration that applies to all instances.
type GlobalConfiguration struct {
API APIConfiguration
Expand Down Expand Up @@ -139,6 +143,7 @@ type GlobalConfiguration struct {
DisableSignup bool `json:"disable_signup" split_words:"true"`
Webhook WebhookConfig `json:"webhook" split_words:"true"`
Security SecurityConfiguration `json:"security"`
Sessions SessionsConfiguration `json:"sessions"`
MFA MFAConfiguration `json:"MFA"`
Cookie struct {
Key string `json:"key"`
Expand Down
18 changes: 17 additions & 1 deletion internal/models/refresh_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/pkg/errors"
"github.com/supabase/gotrue/internal/crypto"
"github.com/supabase/gotrue/internal/storage"
"github.com/supabase/gotrue/internal/utilities"
)

// RefreshToken is the database model for refresh tokens.
Expand Down Expand Up @@ -41,6 +42,14 @@ type GrantParams struct {
FactorID *uuid.UUID

SessionNotAfter *time.Time

UserAgent string
IP string
}

func (g *GrantParams) FromRequest(r *http.Request) {
g.UserAgent = r.Header.Get("User-Agent")
g.IP = utilities.GetIPAddress(r)
}

// GrantAuthenticatedUser creates a refresh token for the provided user.
Expand Down Expand Up @@ -110,7 +119,6 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok
if oldToken != nil {
token.Parent = storage.NullString(oldToken.Token)
token.SessionId = oldToken.SessionId

}

if token.SessionId == nil {
Expand All @@ -129,6 +137,14 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok
session.NotAfter = params.SessionNotAfter
}

if params.UserAgent != "" {
session.UserAgent = &params.UserAgent
}

if params.IP != "" {
session.IP = &params.IP
}

if err := tx.Create(session); err != nil {
return nil, errors.Wrap(err, "error creating new session")
}
Expand Down
18 changes: 18 additions & 0 deletions internal/models/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,31 @@ type Session struct {
FactorID *uuid.UUID `json:"factor_id" db:"factor_id"`
AMRClaims []AMRClaim `json:"amr,omitempty" has_many:"amr_claims"`
AAL *string `json:"aal" db:"aal"`

RefreshedAt *time.Time `json:"refreshed_at,omitempty" db:"refreshed_at"`
UserAgent *string `json:"user_agent,omitempty" db:"user_agent"`
IP *string `json:"ip,omitempty" db:"ip"`
}

func (Session) TableName() string {
tableName := "sessions"
return tableName
}

func (s *Session) LastRefreshedAt() time.Time {
refreshedAt := s.RefreshedAt

if refreshedAt == nil || refreshedAt.IsZero() {
return s.CreatedAt
}

return *refreshedAt
}

func (s *Session) UpdateRefresh(tx *storage.Connection) error {
return tx.UpdateOnly(s, "refreshed_at", "user_agent", "ip")
}

func NewSession() (*Session, error) {
id := uuid.Must(uuid.NewV4())

Expand Down
4 changes: 4 additions & 0 deletions migrations/20231027141322_add_session_refresh_columns.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
alter table if exists {{ index .Options "Namespace" }}.sessions
add column if not exists refreshed_at timestamp without time zone,
add column if not exists user_agent text,
add column if not exists ip inet;

0 comments on commit 72ca691

Please sign in to comment.