Skip to content

Commit

Permalink
feat: add cleanup for session timebox and inactivity timeout (supabas…
Browse files Browse the repository at this point in the history
…e#1298)

Builds on top of supabase#1288.

Adds cleanup for timeboxed sessions and sessions that have expired due
to inactivity timeout.

It achieves backward compatibility with sessions that have `null` in
`refreshed_at` by looking at the `updated_at` column of the refresh
tokens table. This approach is the one that puts the least strain on the
database, having considered backfilling (very expensive at least
`O(nlogn)` over the whole refresh tokens table).
  • Loading branch information
hf authored and LashaJini committed Nov 13, 2024
1 parent 8862647 commit a674018
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 54 deletions.
10 changes: 9 additions & 1 deletion internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/supabase/gotrue/internal/conf"
"github.com/supabase/gotrue/internal/mailer"
"github.com/supabase/gotrue/internal/models"
"github.com/supabase/gotrue/internal/observability"
"github.com/supabase/gotrue/internal/storage"
)
Expand Down Expand Up @@ -93,7 +94,14 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
r.Use(recoverer)

if globalConfig.DB.CleanupEnabled {
r.UseBypass(api.databaseCleanup)
cleanup := &models.Cleanup{
SessionTimebox: globalConfig.Sessions.Timebox,
SessionInactivityTimeout: globalConfig.Sessions.InactivityTimeout,
}

cleanup.Setup()

r.UseBypass(api.databaseCleanup(cleanup))
}

r.Get("/health", api.HealthCheck)
Expand Down
38 changes: 20 additions & 18 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,26 +233,28 @@ func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (cont
return ctx, nil
}

func (a *API) databaseCleanup(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)

switch r.Method {
case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete:
// continue
switch r.Method {
case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete:
// continue

default:
return
}
default:
return
}

db := a.db.WithContext(r.Context())
log := observability.GetLogEntry(r)
db := a.db.WithContext(r.Context())
log := observability.GetLogEntry(r)

affectedRows, err := models.Cleanup(db)
if err != nil {
log.WithError(err).WithField("affected_rows", affectedRows).Warn("database cleanup failed")
} else if affectedRows > 0 {
log.WithField("affected_rows", affectedRows).Debug("cleaned up expired or stale rows")
}
})
affectedRows, err := cleanup.Clean(db)
if err != nil {
log.WithError(err).WithField("affected_rows", affectedRows).Warn("database cleanup failed")
} else if affectedRows > 0 {
log.WithField("affected_rows", affectedRows).Debug("cleaned up expired or stale rows")
}
})
}
}
58 changes: 40 additions & 18 deletions internal/models/cleanup.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package models
import (
"fmt"
"sync/atomic"
"time"

"github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
Expand All @@ -14,19 +15,22 @@ import (
"github.com/supabase/gotrue/internal/storage"
)

// cleanupNext holds an atomically incrementing value that determines which of
// the cleanupStatements will be run next.
var cleanupNext uint32
type Cleanup struct {
SessionTimebox *time.Duration
SessionInactivityTimeout *time.Duration

// cleanupStatements holds all of the possible cleanup raw SQL. Only one at a
// time is executed using cleanupNext % len(cleanupStatements).
var CleanupStatements []string
cleanupStatements []string

// cleanupAffectedRows tracks an OpenTelemetry metric on the total number of
// cleaned up rows.
var cleanupAffectedRows otelasyncint64instrument.Counter
// cleanupNext holds an atomically incrementing value that determines which of
// the cleanupStatements will be run next.
cleanupNext uint32

func init() {
// cleanupAffectedRows tracks an OpenTelemetry metric on the total number of
// cleaned up rows.
cleanupAffectedRows otelasyncint64instrument.Counter
}

func (c *Cleanup) Setup() {
tableRefreshTokens := RefreshToken{}.TableName()
tableSessions := Session{}.TableName()
tableRelayStates := SAMLRelayState{}.TableName()
Expand All @@ -37,7 +41,7 @@ func init() {
// as this makes sure that only rows that are not being used in another
// transaction are deleted. These deletes are thus very quick and
// efficient, as they don't wait on other transactions.
CleanupStatements = append(CleanupStatements,
c.cleanupStatements = append(c.cleanupStatements,
fmt.Sprintf("delete from %q where id in (select id from %q where revoked is true and updated_at < now() - interval '24 hours' limit 100 for update skip locked);", tableRefreshTokens, tableRefreshTokens),
fmt.Sprintf("update %q set revoked = true, updated_at = now() where id in (select %q.id from %q join %q on %q.session_id = %q.id where %q.not_after < now() - interval '24 hours' and %q.revoked is false limit 100 for update skip locked);", tableRefreshTokens, tableRefreshTokens, tableRefreshTokens, tableSessions, tableRefreshTokens, tableSessions, tableSessions, tableRefreshTokens),
// sessions are deleted after 72 hours to allow refresh tokens
Expand All @@ -49,14 +53,32 @@ func init() {
fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' limit 100 for update skip locked);", tableMFAChallenges, tableMFAChallenges),
)

var err error
cleanupAffectedRows, err = metricglobal.Meter("gotrue").AsyncInt64().Counter(
if c.SessionTimebox != nil {
timeboxSeconds := int((*c.SessionTimebox).Seconds())

c.cleanupStatements = append(c.cleanupStatements, fmt.Sprintf("delete from %q where id in (select id from %q where created_at + interval '%d seconds' < now() - interval '24 hours' limit 100 for update skip locked);", tableSessions, tableSessions, timeboxSeconds))
}

if c.SessionInactivityTimeout != nil {
inactivitySeconds := int((*c.SessionTimebox).Seconds())

// delete sessions with a refreshed_at column
c.cleanupStatements = append(c.cleanupStatements, fmt.Sprintf("delete from %q where id in (select id from %q where refreshed_at is not null and refreshed_at + interval '%d seconds' < now() - interval '24 hours' limit 100 for update skip locked);", tableSessions, tableSessions, inactivitySeconds))

// delete sessions without a refreshed_at column by looking for
// unrevoked refresh_tokens
c.cleanupStatements = append(c.cleanupStatements, fmt.Sprintf("delete from %q where id in (select %q.id as id from %q, %q where %q.session_id = %q.id and %q.refreshed_at is null and %q.revoked is false and %q.updated_at + interval '%d seconds' < now() - interval '24 hours' limit 100 for update skip locked)", tableSessions, tableSessions, tableSessions, tableRefreshTokens, tableRefreshTokens, tableSessions, tableSessions, tableRefreshTokens, tableRefreshTokens, inactivitySeconds))
}

cleanupAffectedRows, err := metricglobal.Meter("gotrue").AsyncInt64().Counter(
"gotrue_cleanup_affected_rows",
metricinstrument.WithDescription("Number of affected rows from cleaning up stale entities"),
)
if err != nil {
logrus.WithError(err).Error("unable to get gotrue.gotrue_cleanup_rows counter metric")
}

c.cleanupAffectedRows = cleanupAffectedRows
}

// Cleanup removes stale entities in the database. You can call it on each
Expand All @@ -65,16 +87,16 @@ func init() {
// not affect performance of other database jobs. Note that calling this does
// not clean up the whole database, but does a small piecemeal clean up each
// time when called.
func Cleanup(db *storage.Connection) (int, error) {
func (c *Cleanup) Clean(db *storage.Connection) (int, error) {
ctx, span := observability.Tracer("gotrue").Start(db.Context(), "database-cleanup")
defer span.End()

affectedRows := 0
defer span.SetAttributes(attribute.Int64("gotrue.cleanup.affected_rows", int64(affectedRows)))

if err := db.WithContext(ctx).Transaction(func(tx *storage.Connection) error {
nextIndex := atomic.AddUint32(&cleanupNext, 1) % uint32(len(CleanupStatements))
statement := CleanupStatements[nextIndex]
nextIndex := atomic.AddUint32(&c.cleanupNext, 1) % uint32(len(c.cleanupStatements))
statement := c.cleanupStatements[nextIndex]

count, terr := tx.RawQuery(statement).ExecWithCount()
if terr != nil {
Expand All @@ -88,8 +110,8 @@ func Cleanup(db *storage.Connection) (int, error) {
return affectedRows, err
}

if cleanupAffectedRows != nil {
cleanupAffectedRows.Observe(ctx, int64(affectedRows))
if c.cleanupAffectedRows != nil {
c.cleanupAffectedRows.Observe(ctx, int64(affectedRows))
}

return affectedRows, nil
Expand Down
29 changes: 12 additions & 17 deletions internal/models/cleanup_test.go
Original file line number Diff line number Diff line change
@@ -1,38 +1,33 @@
package models

import (
"fmt"
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/supabase/gotrue/internal/conf"
"github.com/supabase/gotrue/internal/storage/test"
)

func TestCleanupSQL(t *testing.T) {
func TestCleanup(t *testing.T) {
globalConfig, err := conf.LoadGlobal(modelsTestConfig)
require.NoError(t, err)
conn, err := test.SetupDBConnection(globalConfig)
require.NoError(t, err)

for _, statement := range CleanupStatements {
_, err := conn.RawQuery(statement).ExecWithCount()
require.NoError(t, err, statement)
sessionTimebox := 10 * time.Second
sessionInactivityTimeout := 5 * time.Second

cleanup := &Cleanup{
SessionTimebox: &sessionTimebox,
SessionInactivityTimeout: &sessionInactivityTimeout,
}
}

func TestCleanup(t *testing.T) {
globalConfig, err := conf.LoadGlobal(modelsTestConfig)
require.NoError(t, err)
conn, err := test.SetupDBConnection(globalConfig)
require.NoError(t, err)
cleanup.Setup()

for _, statement := range CleanupStatements {
_, err := Cleanup(conn)
if err != nil {
fmt.Printf("%v %t\n", err, err)
}
require.NoError(t, err, statement)
for i := 0; i < 100; i += 1 {
_, err := cleanup.Clean(conn)
require.NoError(t, err)
}
}

0 comments on commit a674018

Please sign in to comment.