diff --git a/internal/api/token_refresh.go b/internal/api/token_refresh.go index de1d0bb834..2e91e2c90e 100644 --- a/internal/api/token_refresh.go +++ b/internal/api/token_refresh.go @@ -64,31 +64,18 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h } if session != nil { - var notAfter time.Time + result := session.CheckValidity(retryStart, &token.UpdatedAt, config.Sessions.Timebox, config.Sessions.InactivityTimeout) - if session.NotAfter != nil { - notAfter = *session.NotAfter - } - - if config.Sessions.Timebox != nil { - sessionEndsAt := session.CreatedAt.Add((*config.Sessions.Timebox).Abs()) + switch result { + case models.SessionValid: + // do nothing - if notAfter.IsZero() || notAfter.After(sessionEndsAt) { - notAfter = sessionEndsAt - } - } + case models.SessionTimedOut: + return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired (Inactivity)") - if !notAfter.IsZero() && a.Now().After(notAfter) { + default: return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired") } - - if config.Sessions.InactivityTimeout != nil { - timesOutAt := session.LastRefreshedAt(&token.UpdatedAt).Add(*config.Sessions.InactivityTimeout) - - if timesOutAt.Before(a.Now()) { - return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired (Inactivity)") - } - } } // Basic checks above passed, now we need to serialize access @@ -120,6 +107,60 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h return internalServerError(terr.Error()) } + if a.config.Sessions.SinglePerUser { + sessions, terr := models.FindAllSessionsForUser(tx, user.ID, true /* forUpdate */) + if models.IsNotFoundError(terr) { + // because forUpdate was set, and the + // previous check outside the + // transaction found a user and + // session, but now we're getting a + // IsNotFoundError, this means that the + // user is locked and we need to retry + // in a few milliseconds + retry = true + return terr + } else if terr != nil { + return internalServerError(terr.Error()) + } + + sessionTag := session.DetermineTag(config.Sessions.Tags) + + // go through all sessions of the user and + // check if the current session is the user's + // most recently refreshed valid session + for _, s := range sessions { + if s.ID == session.ID { + // current session, skip it + continue + } + + if s.CheckValidity(retryStart, nil, config.Sessions.Timebox, config.Sessions.InactivityTimeout) != models.SessionValid { + // session is not valid so it + // can't be regarded as active + // on the user + continue + } + + if s.DetermineTag(config.Sessions.Tags) != sessionTag { + // if tags are specified, + // ignore sessions with a + // mismatching tag + continue + } + + // since token is not the refresh token + // of s, we can't use it's UpdatedAt + // time to compare! + if s.LastRefreshedAt(nil).After(session.LastRefreshedAt(&token.UpdatedAt)) { + // session is not the most + // recently active one + return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired (Revoked by Newer Login)") + } + } + + // this session is the user's active session + } + // refresh token row and session are locked at this // point, cannot be concurrently refreshed diff --git a/internal/api/token_test.go b/internal/api/token_test.go index b07d8ad27a..dbb9b013d0 100644 --- a/internal/api/token_test.go +++ b/internal/api/token_test.go @@ -70,6 +70,7 @@ func (ts *TokenTestSuite) TestSessionTimebox() { defer func() { ts.API.overrideTime = nil + ts.API.config.Sessions.Timebox = nil }() var buffer bytes.Buffer @@ -176,6 +177,48 @@ func (ts *TokenTestSuite) TestFailedToSaveRefreshTokenResultCase() { assert.Equal(ts.T(), firstResult.RefreshToken, secondResult.RefreshToken) } +func (ts *TokenTestSuite) TestSingleSessionPerUserNoTags() { + ts.API.config.Sessions.SinglePerUser = true + defer func() { + ts.API.config.Sessions.SinglePerUser = false + }() + + firstRefreshToken := ts.RefreshToken + + // just in case to give some delay between first and second session creation + time.Sleep(10 * time.Millisecond) + + secondRefreshToken, err := models.GrantAuthenticatedUser(ts.API.db, ts.User, models.GrantParams{}) + + require.NoError(ts.T(), err) + + require.NotEqual(ts.T(), *firstRefreshToken.SessionId, *secondRefreshToken.SessionId) + require.Equal(ts.T(), firstRefreshToken.UserID, secondRefreshToken.UserID) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": firstRefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + assert.True(ts.T(), ts.API.config.Sessions.SinglePerUser) + + var firstResult struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + assert.Equal(ts.T(), "invalid_grant", firstResult.Error) + assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired (Revoked by Newer Login)", firstResult.ErrorDescription) +} + func (ts *TokenTestSuite) TestRateLimitTokenRefresh() { var buffer bytes.Buffer req := httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer) diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 47624e2637..2686ba253c 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -114,6 +114,9 @@ func (a *APIConfiguration) Validate() error { type SessionsConfiguration struct { Timebox *time.Duration `json:"timebox"` InactivityTimeout *time.Duration `json:"inactivity_timeout,omitempty" split_words:"true"` + + SinglePerUser bool `json:"single_per_user" split_words:"true"` + Tags []string `json:"tags,omitempty"` } func (c *SessionsConfiguration) Validate() error { diff --git a/internal/models/refresh_token.go b/internal/models/refresh_token.go index 1707a9bb2f..4235c4113b 100644 --- a/internal/models/refresh_token.go +++ b/internal/models/refresh_token.go @@ -42,6 +42,7 @@ type GrantParams struct { FactorID *uuid.UUID SessionNotAfter *time.Time + SessionTag *string UserAgent string IP string @@ -145,6 +146,10 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok session.IP = ¶ms.IP } + if params.SessionTag != nil && *params.SessionTag != "" { + session.Tag = params.SessionTag + } + if err := tx.Create(session); err != nil { return nil, errors.Wrap(err, "error creating new session") } diff --git a/internal/models/sessions.go b/internal/models/sessions.go index 94bb9cf810..0ec8a1c22c 100644 --- a/internal/models/sessions.go +++ b/internal/models/sessions.go @@ -73,6 +73,8 @@ type Session struct { RefreshedAt *time.Time `json:"refreshed_at,omitempty" db:"refreshed_at"` UserAgent *string `json:"user_agent,omitempty" db:"user_agent"` IP *string `json:"ip,omitempty" db:"ip"` + + Tag *string `json:"tag" db:"tag"` } func (Session) TableName() string { @@ -104,6 +106,54 @@ func (s *Session) UpdateOnlyRefreshInfo(tx *storage.Connection) error { return tx.UpdateOnly(s, "refreshed_at", "user_agent", "ip") } +type SessionValidityReason = int + +const ( + SessionValid SessionValidityReason = iota + SessionPastNotAfter = iota + SessionPastTimebox = iota + SessionTimedOut = iota +) + +func (s *Session) CheckValidity(now time.Time, refreshTokenTime *time.Time, timebox, inactivityTimeout *time.Duration) SessionValidityReason { + if s.NotAfter != nil && now.After(*s.NotAfter) { + return SessionPastNotAfter + } + + if timebox != nil && *timebox != 0 && now.After(s.CreatedAt.Add(*timebox)) { + return SessionPastTimebox + } + + if inactivityTimeout != nil && *inactivityTimeout != 0 && now.After(s.LastRefreshedAt(refreshTokenTime).Add(*inactivityTimeout)) { + return SessionTimedOut + } + + return SessionValid +} + +func (s *Session) DetermineTag(tags []string) string { + if len(tags) == 0 { + return "" + } + + if s.Tag == nil { + return tags[0] + } + + tag := *s.Tag + if tag == "" { + return tags[0] + } + + for _, t := range tags { + if t == tag { + return tag + } + } + + return tags[0] +} + func NewSession() (*Session, error) { id := uuid.Must(uuid.NewV4()) @@ -168,6 +218,35 @@ func FindSessionsByFactorID(tx *storage.Connection, factorID uuid.UUID) ([]*Sess return sessions, nil } +// FindAllSessionsForUser finds all of the sessions for a user. If forUpdate is +// set, it will first lock on the user row which can be used to prevent issues +// with concurrency. If the lock is acquired, it will return a +// UserNotFoundError and the operation should be retried. If there are no +// sessions for the user, a nil result is returned without an error. +func FindAllSessionsForUser(tx *storage.Connection, userId uuid.UUID, forUpdate bool) ([]*Session, error) { + if forUpdate { + user := &User{} + if err := tx.RawQuery(fmt.Sprintf("SELECT id FROM %q WHERE id = ? LIMIT 1 FOR UPDATE SKIP LOCKED;", user.TableName()), userId).First(user); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, UserNotFoundError{} + } + + return nil, err + } + } + + var sessions []*Session + if err := tx.Where("user_id = ?", userId).All(&sessions); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil + } + + return nil, err + } + + return sessions, nil +} + func updateFactorAssociatedSessions(tx *storage.Connection, userID, factorID uuid.UUID, aal string) error { return tx.RawQuery("UPDATE "+(&pop.Model{Value: Session{}}).TableName()+" set aal = ?, factor_id = ? WHERE user_id = ? AND factor_id = ?", aal, nil, userID, factorID).Exec() } diff --git a/migrations/20231114161723_add_sessions_tag.up.sql b/migrations/20231114161723_add_sessions_tag.up.sql new file mode 100644 index 0000000000..7acf1bb9dd --- /dev/null +++ b/migrations/20231114161723_add_sessions_tag.up.sql @@ -0,0 +1,2 @@ +alter table if exists {{ index .Options "Namespace" }}.sessions + add column if not exists tag text;