diff --git a/internal/api/api.go b/internal/api/api.go index e3d7f133c3..767daa28a6 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -14,6 +14,7 @@ import ( "github.com/sirupsen/logrus" "github.com/supabase/gotrue/internal/conf" "github.com/supabase/gotrue/internal/mailer" + "github.com/supabase/gotrue/internal/models" "github.com/supabase/gotrue/internal/observability" "github.com/supabase/gotrue/internal/storage" ) @@ -93,7 +94,14 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati r.Use(recoverer) if globalConfig.DB.CleanupEnabled { - r.UseBypass(api.databaseCleanup) + cleanup := &models.Cleanup{ + SessionTimebox: globalConfig.Sessions.Timebox, + SessionInactivityTimeout: globalConfig.Sessions.InactivityTimeout, + } + + cleanup.Setup() + + r.UseBypass(api.databaseCleanup(cleanup)) } r.Get("/health", api.HealthCheck) diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 2230939522..0ccb51a5a7 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -233,26 +233,28 @@ func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (cont return ctx, nil } -func (a *API) databaseCleanup(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - next.ServeHTTP(w, r) +func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) - switch r.Method { - case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: - // continue + switch r.Method { + case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: + // continue - default: - return - } + default: + return + } - db := a.db.WithContext(r.Context()) - log := observability.GetLogEntry(r) + db := a.db.WithContext(r.Context()) + log := observability.GetLogEntry(r) - affectedRows, err := models.Cleanup(db) - if err != nil { - log.WithError(err).WithField("affected_rows", affectedRows).Warn("database cleanup failed") - } else if affectedRows > 0 { - log.WithField("affected_rows", affectedRows).Debug("cleaned up expired or stale rows") - } - }) + affectedRows, err := cleanup.Clean(db) + if err != nil { + log.WithError(err).WithField("affected_rows", affectedRows).Warn("database cleanup failed") + } else if affectedRows > 0 { + log.WithField("affected_rows", affectedRows).Debug("cleaned up expired or stale rows") + } + }) + } } diff --git a/internal/models/cleanup.go b/internal/models/cleanup.go index eddaf090f2..b1e686afcb 100644 --- a/internal/models/cleanup.go +++ b/internal/models/cleanup.go @@ -3,6 +3,7 @@ package models import ( "fmt" "sync/atomic" + "time" "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" @@ -14,19 +15,22 @@ import ( "github.com/supabase/gotrue/internal/storage" ) -// cleanupNext holds an atomically incrementing value that determines which of -// the cleanupStatements will be run next. -var cleanupNext uint32 +type Cleanup struct { + SessionTimebox *time.Duration + SessionInactivityTimeout *time.Duration -// cleanupStatements holds all of the possible cleanup raw SQL. Only one at a -// time is executed using cleanupNext % len(cleanupStatements). -var CleanupStatements []string + cleanupStatements []string -// cleanupAffectedRows tracks an OpenTelemetry metric on the total number of -// cleaned up rows. -var cleanupAffectedRows otelasyncint64instrument.Counter + // cleanupNext holds an atomically incrementing value that determines which of + // the cleanupStatements will be run next. + cleanupNext uint32 -func init() { + // cleanupAffectedRows tracks an OpenTelemetry metric on the total number of + // cleaned up rows. + cleanupAffectedRows otelasyncint64instrument.Counter +} + +func (c *Cleanup) Setup() { tableRefreshTokens := RefreshToken{}.TableName() tableSessions := Session{}.TableName() tableRelayStates := SAMLRelayState{}.TableName() @@ -37,7 +41,7 @@ func init() { // as this makes sure that only rows that are not being used in another // transaction are deleted. These deletes are thus very quick and // efficient, as they don't wait on other transactions. - CleanupStatements = append(CleanupStatements, + c.cleanupStatements = append(c.cleanupStatements, fmt.Sprintf("delete from %q where id in (select id from %q where revoked is true and updated_at < now() - interval '24 hours' limit 100 for update skip locked);", tableRefreshTokens, tableRefreshTokens), fmt.Sprintf("update %q set revoked = true, updated_at = now() where id in (select %q.id from %q join %q on %q.session_id = %q.id where %q.not_after < now() - interval '24 hours' and %q.revoked is false limit 100 for update skip locked);", tableRefreshTokens, tableRefreshTokens, tableRefreshTokens, tableSessions, tableRefreshTokens, tableSessions, tableSessions, tableRefreshTokens), // sessions are deleted after 72 hours to allow refresh tokens @@ -49,14 +53,32 @@ func init() { fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' limit 100 for update skip locked);", tableMFAChallenges, tableMFAChallenges), ) - var err error - cleanupAffectedRows, err = metricglobal.Meter("gotrue").AsyncInt64().Counter( + if c.SessionTimebox != nil { + timeboxSeconds := int((*c.SessionTimebox).Seconds()) + + c.cleanupStatements = append(c.cleanupStatements, fmt.Sprintf("delete from %q where id in (select id from %q where created_at + interval '%d seconds' < now() - interval '24 hours' limit 100 for update skip locked);", tableSessions, tableSessions, timeboxSeconds)) + } + + if c.SessionInactivityTimeout != nil { + inactivitySeconds := int((*c.SessionTimebox).Seconds()) + + // delete sessions with a refreshed_at column + c.cleanupStatements = append(c.cleanupStatements, fmt.Sprintf("delete from %q where id in (select id from %q where refreshed_at is not null and refreshed_at + interval '%d seconds' < now() - interval '24 hours' limit 100 for update skip locked);", tableSessions, tableSessions, inactivitySeconds)) + + // delete sessions without a refreshed_at column by looking for + // unrevoked refresh_tokens + c.cleanupStatements = append(c.cleanupStatements, fmt.Sprintf("delete from %q where id in (select %q.id as id from %q, %q where %q.session_id = %q.id and %q.refreshed_at is null and %q.revoked is false and %q.updated_at + interval '%d seconds' < now() - interval '24 hours' limit 100 for update skip locked)", tableSessions, tableSessions, tableSessions, tableRefreshTokens, tableRefreshTokens, tableSessions, tableSessions, tableRefreshTokens, tableRefreshTokens, inactivitySeconds)) + } + + cleanupAffectedRows, err := metricglobal.Meter("gotrue").AsyncInt64().Counter( "gotrue_cleanup_affected_rows", metricinstrument.WithDescription("Number of affected rows from cleaning up stale entities"), ) if err != nil { logrus.WithError(err).Error("unable to get gotrue.gotrue_cleanup_rows counter metric") } + + c.cleanupAffectedRows = cleanupAffectedRows } // Cleanup removes stale entities in the database. You can call it on each @@ -65,7 +87,7 @@ func init() { // not affect performance of other database jobs. Note that calling this does // not clean up the whole database, but does a small piecemeal clean up each // time when called. -func Cleanup(db *storage.Connection) (int, error) { +func (c *Cleanup) Clean(db *storage.Connection) (int, error) { ctx, span := observability.Tracer("gotrue").Start(db.Context(), "database-cleanup") defer span.End() @@ -73,8 +95,8 @@ func Cleanup(db *storage.Connection) (int, error) { defer span.SetAttributes(attribute.Int64("gotrue.cleanup.affected_rows", int64(affectedRows))) if err := db.WithContext(ctx).Transaction(func(tx *storage.Connection) error { - nextIndex := atomic.AddUint32(&cleanupNext, 1) % uint32(len(CleanupStatements)) - statement := CleanupStatements[nextIndex] + nextIndex := atomic.AddUint32(&c.cleanupNext, 1) % uint32(len(c.cleanupStatements)) + statement := c.cleanupStatements[nextIndex] count, terr := tx.RawQuery(statement).ExecWithCount() if terr != nil { @@ -88,8 +110,8 @@ func Cleanup(db *storage.Connection) (int, error) { return affectedRows, err } - if cleanupAffectedRows != nil { - cleanupAffectedRows.Observe(ctx, int64(affectedRows)) + if c.cleanupAffectedRows != nil { + c.cleanupAffectedRows.Observe(ctx, int64(affectedRows)) } return affectedRows, nil diff --git a/internal/models/cleanup_test.go b/internal/models/cleanup_test.go index c38cf72f36..79052c5167 100644 --- a/internal/models/cleanup_test.go +++ b/internal/models/cleanup_test.go @@ -1,8 +1,8 @@ package models import ( - "fmt" "testing" + "time" "github.com/stretchr/testify/require" @@ -10,29 +10,24 @@ import ( "github.com/supabase/gotrue/internal/storage/test" ) -func TestCleanupSQL(t *testing.T) { +func TestCleanup(t *testing.T) { globalConfig, err := conf.LoadGlobal(modelsTestConfig) require.NoError(t, err) conn, err := test.SetupDBConnection(globalConfig) require.NoError(t, err) - for _, statement := range CleanupStatements { - _, err := conn.RawQuery(statement).ExecWithCount() - require.NoError(t, err, statement) + sessionTimebox := 10 * time.Second + sessionInactivityTimeout := 5 * time.Second + + cleanup := &Cleanup{ + SessionTimebox: &sessionTimebox, + SessionInactivityTimeout: &sessionInactivityTimeout, } -} -func TestCleanup(t *testing.T) { - globalConfig, err := conf.LoadGlobal(modelsTestConfig) - require.NoError(t, err) - conn, err := test.SetupDBConnection(globalConfig) - require.NoError(t, err) + cleanup.Setup() - for _, statement := range CleanupStatements { - _, err := Cleanup(conn) - if err != nil { - fmt.Printf("%v %t\n", err, err) - } - require.NoError(t, err, statement) + for i := 0; i < 100; i += 1 { + _, err := cleanup.Clean(conn) + require.NoError(t, err) } }