diff --git a/internal/api/external.go b/internal/api/external.go index 948b0f03cf..8b1a89ac6c 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -156,6 +156,8 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re var grantParams models.GrantParams var err error + grantParams.FillGrantParams(r) + if providerType == "twitter" { // future OAuth1.0 providers will use this method oAuthResponseData, err := a.oAuth1Callback(ctx, r, providerType) diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go index f443339018..fa561798d7 100644 --- a/internal/api/samlacs.go +++ b/internal/api/samlacs.go @@ -262,6 +262,8 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { var grantParams models.GrantParams + grantParams.FillGrantParams(r) + if !notAfter.IsZero() { grantParams.SessionNotAfter = ¬After } diff --git a/internal/api/signup.go b/internal/api/signup.go index efe30139a3..62097302da 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -108,6 +108,9 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { var user *models.User var grantParams models.GrantParams + + grantParams.FillGrantParams(r) + params.Aud = a.requestAud(ctx, r) switch params.Provider { diff --git a/internal/api/token.go b/internal/api/token.go index afbf6057e6..30f82c37ab 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -113,6 +113,9 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri var user *models.User var grantParams models.GrantParams var provider string + + grantParams.FillGrantParams(r) + if params.Email != "" { provider = "email" if !config.External.Email.Enabled { @@ -180,6 +183,12 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) db := a.db.WithContext(ctx) var grantParams models.GrantParams + // There is a slight problem with this as it will pick-up the + // User-Agent and IP addresses from the server if used on the server + // side. Currently there's no mechanism to distinguish, but the server + // can be told to at least propagate the User-Agent header. + grantParams.FillGrantParams(r) + params := &PKCEGrantParams{} body, err := getBodyBytes(r) if err != nil { @@ -245,7 +254,6 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) } return sendJSON(w, http.StatusOK, token) - } func generateAccessToken(tx *storage.Connection, user *models.User, sessionId *uuid.UUID, config *conf.JWTConfiguration) (string, int64, error) { diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index 3c8553f468..dd00ea99bd 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -193,6 +193,8 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R var token *AccessTokenResponse var grantParams models.GrantParams + grantParams.FillGrantParams(r) + if err := db.Transaction(func(tx *storage.Connection) error { var user *models.User var terr error diff --git a/internal/api/token_refresh.go b/internal/api/token_refresh.go index 26094efa18..1cd7bbe0a8 100644 --- a/internal/api/token_refresh.go +++ b/internal/api/token_refresh.go @@ -10,6 +10,7 @@ import ( "github.com/supabase/gotrue/internal/metering" "github.com/supabase/gotrue/internal/models" "github.com/supabase/gotrue/internal/storage" + "github.com/supabase/gotrue/internal/utilities" ) const retryLoopDuration = 5.0 @@ -50,7 +51,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h for retry && time.Since(retryStart).Seconds() < retryLoopDuration { retry = false - user, _, session, err := models.FindUserWithRefreshToken(db, params.RefreshToken, false) + user, token, session, err := models.FindUserWithRefreshToken(db, params.RefreshToken, false) if err != nil { if models.IsNotFoundError(err) { return oauthError("invalid_grant", "Invalid Refresh Token: Refresh Token Not Found") @@ -80,6 +81,14 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h if !notAfter.IsZero() && a.Now().After(notAfter) { return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired") } + + if config.Sessions.InactivityTimeout != nil { + timesOutAt := session.LastRefreshedAt(&token.UpdatedAt).Add(*config.Sessions.InactivityTimeout) + + if timesOutAt.Before(a.Now()) { + return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired (Inactivity)") + } + } } // Basic checks above passed, now we need to serialize access @@ -172,6 +181,27 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h return internalServerError("error generating jwt token").WithInternalError(terr) } + refreshedAt := a.Now() + session.RefreshedAt = &refreshedAt + + userAgent := r.Header.Get("User-Agent") + if userAgent != "" { + session.UserAgent = &userAgent + } else { + session.UserAgent = nil + } + + ipAddress := utilities.GetIPAddress(r) + if ipAddress != "" { + session.IP = &ipAddress + } else { + session.IP = nil + } + + if terr := session.UpdateRefresh(tx); terr != nil { + return internalServerError("failed to update session information").WithInternalError(terr) + } + newTokenResponse = &AccessTokenResponse{ Token: tokenString, TokenType: "bearer", diff --git a/internal/api/token_test.go b/internal/api/token_test.go index 654920bb09..5909602297 100644 --- a/internal/api/token_test.go +++ b/internal/api/token_test.go @@ -94,6 +94,40 @@ func (ts *TokenTestSuite) TestSessionTimebox() { assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired", firstResult.ErrorDescription) } +func (ts *TokenTestSuite) TestSessionInactivityTimeout() { + inactivityTimeout := 10 * time.Second + + ts.API.config.Sessions.InactivityTimeout = &inactivityTimeout + ts.API.overrideTime = func() time.Time { + return time.Now().Add(inactivityTimeout).Add(time.Second) + } + + defer func() { + ts.API.overrideTime = nil + }() + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var firstResult struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + assert.Equal(ts.T(), "invalid_grant", firstResult.Error) + assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired (Inactivity)", firstResult.ErrorDescription) +} + func (ts *TokenTestSuite) TestFailedToSaveRefreshTokenResultCase() { var buffer bytes.Buffer diff --git a/internal/api/verify.go b/internal/api/verify.go index 30c20618f5..dc5a5e0d4b 100644 --- a/internal/api/verify.go +++ b/internal/api/verify.go @@ -133,6 +133,9 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa token *AccessTokenResponse authCode string ) + + grantParams.FillGrantParams(r) + flowType := models.ImplicitFlow var authenticationMethod models.AuthenticationMethod if strings.HasPrefix(params.Token, PKCEPrefix) { @@ -228,6 +231,8 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request, params *VerifyP ) var isSingleConfirmationResponse = false + grantParams.FillGrantParams(r) + err := db.Transaction(func(tx *storage.Connection) error { var terr error aud := a.requestAud(ctx, r) diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 098ca89839..628f5b6ff2 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -112,7 +112,8 @@ func (a *APIConfiguration) Validate() error { } type SessionsConfiguration struct { - Timebox *time.Duration `json:"timebox"` + Timebox *time.Duration `json:"timebox"` + InactivityTimeout *time.Duration `json:"inactivity_timeout,omitempty" split_words:"true"` } func (c *SessionsConfiguration) Validate() error { diff --git a/internal/models/refresh_token.go b/internal/models/refresh_token.go index 3a15b0b0ea..1707a9bb2f 100644 --- a/internal/models/refresh_token.go +++ b/internal/models/refresh_token.go @@ -10,6 +10,7 @@ import ( "github.com/pkg/errors" "github.com/supabase/gotrue/internal/crypto" "github.com/supabase/gotrue/internal/storage" + "github.com/supabase/gotrue/internal/utilities" ) // RefreshToken is the database model for refresh tokens. @@ -41,6 +42,14 @@ type GrantParams struct { FactorID *uuid.UUID SessionNotAfter *time.Time + + UserAgent string + IP string +} + +func (g *GrantParams) FillGrantParams(r *http.Request) { + g.UserAgent = r.Header.Get("User-Agent") + g.IP = utilities.GetIPAddress(r) } // GrantAuthenticatedUser creates a refresh token for the provided user. @@ -110,7 +119,6 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok if oldToken != nil { token.Parent = storage.NullString(oldToken.Token) token.SessionId = oldToken.SessionId - } if token.SessionId == nil { @@ -129,6 +137,14 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok session.NotAfter = params.SessionNotAfter } + if params.UserAgent != "" { + session.UserAgent = ¶ms.UserAgent + } + + if params.IP != "" { + session.IP = ¶ms.IP + } + if err := tx.Create(session); err != nil { return nil, errors.Wrap(err, "error creating new session") } diff --git a/internal/models/sessions.go b/internal/models/sessions.go index 6007562c41..4090e1d99a 100644 --- a/internal/models/sessions.go +++ b/internal/models/sessions.go @@ -69,6 +69,10 @@ type Session struct { FactorID *uuid.UUID `json:"factor_id" db:"factor_id"` AMRClaims []AMRClaim `json:"amr,omitempty" has_many:"amr_claims"` AAL *string `json:"aal" db:"aal"` + + RefreshedAt *time.Time `json:"refreshed_at,omitempty" db:"refreshed_at"` + UserAgent *string `json:"user_agent,omitempty" db:"user_agent"` + IP *string `json:"ip,omitempty" db:"ip"` } func (Session) TableName() string { @@ -76,6 +80,30 @@ func (Session) TableName() string { return tableName } +func (s *Session) LastRefreshedAt(refreshTokenTime *time.Time) time.Time { + refreshedAt := s.RefreshedAt + + if refreshedAt == nil || refreshedAt.IsZero() { + if refreshTokenTime != nil { + rtt := *refreshTokenTime + + if rtt.IsZero() { + return s.CreatedAt + } else if rtt.After(s.CreatedAt) { + return rtt + } + } + + return s.CreatedAt + } + + return *refreshedAt +} + +func (s *Session) UpdateRefresh(tx *storage.Connection) error { + return tx.UpdateOnly(s, "refreshed_at", "user_agent", "ip") +} + func NewSession() (*Session, error) { id := uuid.Must(uuid.NewV4()) diff --git a/internal/utilities/request.go b/internal/utilities/request.go index 0313798e53..3c419fa6e3 100644 --- a/internal/utilities/request.go +++ b/internal/utilities/request.go @@ -24,7 +24,12 @@ func GetIPAddress(r *http.Request) string { for _, ip := range ips { if ip != "" { - return ip + parsed := net.ParseIP(ip) + if parsed == nil { + continue + } + + return parsed.String() } } } diff --git a/migrations/20231027141322_add_session_refresh_columns.up.sql b/migrations/20231027141322_add_session_refresh_columns.up.sql new file mode 100644 index 0000000000..79efba9bcf --- /dev/null +++ b/migrations/20231027141322_add_session_refresh_columns.up.sql @@ -0,0 +1,4 @@ +alter table if exists {{ index .Options "Namespace" }}.sessions + add column if not exists refreshed_at timestamp without time zone, + add column if not exists user_agent text, + add column if not exists ip inet;