diff --git a/internal/datastore/common/errors.go b/internal/datastore/common/errors.go index 5d5a2bee9d..99faf7af51 100644 --- a/internal/datastore/common/errors.go +++ b/internal/datastore/common/errors.go @@ -118,3 +118,13 @@ func RedactAndLogSensitiveConnString(ctx context.Context, baseErr string, err er log.Ctx(ctx).Trace().Msg(baseErr + ": " + filtered) return fmt.Errorf("%s. To view details of this error (that may contain sensitive information), please run with --log-level=trace", baseErr) } + +// RevisionUnavailableError is returned when a revision is not available on a replica. +type RevisionUnavailableError struct { + error +} + +// NewRevisionUnavailableError creates a new RevisionUnavailableError. +func NewRevisionUnavailableError(err error) error { + return RevisionUnavailableError{err} +} diff --git a/internal/datastore/context.go b/internal/datastore/context.go index 96ed9f684d..4ef59874a0 100644 --- a/internal/datastore/context.go +++ b/internal/datastore/context.go @@ -33,7 +33,7 @@ func SeparateContextWithTracing(ctx context.Context) context.Context { // // This is useful for datastores that do not want to close connections when a // cancel or deadline occurs. -func NewSeparatingContextDatastoreProxy(d datastore.Datastore) datastore.Datastore { +func NewSeparatingContextDatastoreProxy(d datastore.Datastore) datastore.StrictReadDatastore { return &ctxProxy{d} } @@ -47,6 +47,20 @@ func (p *ctxProxy) ReadWriteTx( return p.delegate.ReadWriteTx(ctx, f, opts...) } +func (p *ctxProxy) IsStrictReadModeEnabled() bool { + ds := p.delegate + unwrapped, ok := p.delegate.(datastore.UnwrappableDatastore) + if ok { + ds = unwrapped.Unwrap() + } + + if srm, ok := ds.(datastore.StrictReadDatastore); ok { + return srm.IsStrictReadModeEnabled() + } + + return false +} + func (p *ctxProxy) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { return p.delegate.OptimizedRevision(SeparateContextWithTracing(ctx)) } diff --git a/internal/datastore/mysql/datastore.go b/internal/datastore/mysql/datastore.go index 8428233f1b..6bc70bd64a 100644 --- a/internal/datastore/mysql/datastore.go +++ b/internal/datastore/mysql/datastore.go @@ -62,6 +62,8 @@ const ( noLastInsertID = 0 seedingTimeout = 10 * time.Second + primaryInstanceID = -1 + // https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html#error_er_lock_wait_timeout errMysqlLockWaitTimeout = 1205 @@ -102,7 +104,21 @@ type sqlFilter interface { // URI: [scheme://][user[:[password]]@]host[:port][/schema][?attribute1=value1&attribute2=value2... // See https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html func NewMySQLDatastore(ctx context.Context, uri string, options ...Option) (datastore.Datastore, error) { - ds, err := newMySQLDatastore(ctx, uri, options...) + ds, err := newMySQLDatastore(ctx, uri, primaryInstanceID, options...) + if err != nil { + return nil, err + } + + return datastoreinternal.NewSeparatingContextDatastoreProxy(ds), nil +} + +func NewReadOnlyMySQLDatastore( + ctx context.Context, + url string, + index uint32, + options ...Option, +) (datastore.ReadOnlyDatastore, error) { + ds, err := newMySQLDatastore(ctx, url, int(index), options...) if err != nil { return nil, err } @@ -110,7 +126,8 @@ func NewMySQLDatastore(ctx context.Context, uri string, options ...Option) (data return datastoreinternal.NewSeparatingContextDatastoreProxy(ds), nil } -func newMySQLDatastore(ctx context.Context, uri string, options ...Option) (*Datastore, error) { +func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, options ...Option) (*Datastore, error) { + isPrimary := replicaIndex == primaryInstanceID config, err := generateConfig(options) if err != nil { return nil, fmt.Errorf(errUnableToInstantiate, err) @@ -162,14 +179,21 @@ func newMySQLDatastore(ctx context.Context, uri string, options ...Option) (*Dat return nil, common.RedactAndLogSensitiveConnString(ctx, "NewMySQLDatastore: unable to instrument connector", err, uri) } + dbName := "spicedb" + if replicaIndex != primaryInstanceID { + dbName = fmt.Sprintf("spicedb_replica_%d", replicaIndex) + } + db = sql.OpenDB(connector) - collector := sqlstats.NewStatsCollector("spicedb", db) + collector := sqlstats.NewStatsCollector(dbName, db) if err := prometheus.Register(collector); err != nil { return nil, fmt.Errorf(errUnableToInstantiate, err) } - if err := common.RegisterGCMetrics(); err != nil { - return nil, fmt.Errorf(errUnableToInstantiate, err) + if isPrimary { + if err := common.RegisterGCMetrics(); err != nil { + return nil, fmt.Errorf(errUnableToInstantiate, err) + } } } else { db = sql.OpenDB(connector) @@ -256,19 +280,21 @@ func newMySQLDatastore(ctx context.Context, uri string, options ...Option) (*Dat } // Start a goroutine for garbage collection. - if store.gcInterval > 0*time.Minute && config.gcEnabled { - store.gcGroup, store.gcCtx = errgroup.WithContext(store.gcCtx) - store.gcGroup.Go(func() error { - return common.StartGarbageCollector( - store.gcCtx, - store, - store.gcInterval, - store.gcWindow, - store.gcTimeout, - ) - }) - } else { - log.Warn().Msg("datastore background garbage collection disabled") + if isPrimary { + if store.gcInterval > 0*time.Minute && config.gcEnabled { + store.gcGroup, store.gcCtx = errgroup.WithContext(store.gcCtx) + store.gcGroup.Go(func() error { + return common.StartGarbageCollector( + store.gcCtx, + store, + store.gcInterval, + store.gcWindow, + store.gcTimeout, + ) + }) + } else { + log.Warn().Msg("datastore background garbage collection disabled") + } } return store, nil diff --git a/internal/datastore/mysql/datastore_test.go b/internal/datastore/mysql/datastore_test.go index 2d421aeee6..23390ba1eb 100644 --- a/internal/datastore/mysql/datastore_test.go +++ b/internal/datastore/mysql/datastore_test.go @@ -48,7 +48,7 @@ type datastoreTester struct { func (dst *datastoreTester) createDatastore(revisionQuantization, gcInterval, gcWindow time.Duration, _ uint16) (datastore.Datastore, error) { ctx := context.Background() ds := dst.b.NewDatastore(dst.t, func(engine, uri string) datastore.Datastore { - ds, err := newMySQLDatastore(ctx, uri, + ds, err := newMySQLDatastore(ctx, uri, primaryInstanceID, RevisionQuantization(revisionQuantization), GCWindow(gcWindow), GCInterval(gcInterval), @@ -82,7 +82,7 @@ func createDatastoreTest(b testdatastore.RunningEngineForTest, tf datastoreTestF return func(t *testing.T) { ctx := context.Background() ds := b.NewDatastore(t, func(engine, uri string) datastore.Datastore { - ds, err := newMySQLDatastore(ctx, uri, options...) + ds, err := newMySQLDatastore(ctx, uri, primaryInstanceID, options...) require.NoError(t, err) return ds }) @@ -568,6 +568,7 @@ func QuantizedRevisionTest(t *testing.T, b testdatastore.RunningEngineForTest) { ds, err := newMySQLDatastore( ctx, uri, + primaryInstanceID, RevisionQuantization(5*time.Second), GCWindow(24*time.Hour), WatchBufferLength(1), diff --git a/internal/datastore/postgres/options.go b/internal/datastore/postgres/options.go index 0eba2ab3e4..9fe42d6085 100644 --- a/internal/datastore/postgres/options.go +++ b/internal/datastore/postgres/options.go @@ -25,6 +25,7 @@ type postgresOptions struct { enablePrometheusStats bool analyzeBeforeStatistics bool gcEnabled bool + readStrictMode bool migrationPhase string @@ -61,6 +62,7 @@ const ( defaultMaxRetries = 10 defaultGCEnabled = true defaultCredentialsProviderName = "" + defaultReadStrictMode = false ) // Option provides the facility to configure how clients within the @@ -80,6 +82,7 @@ func generateConfig(options []Option) (postgresOptions, error) { maxRetries: defaultMaxRetries, gcEnabled: defaultGCEnabled, credentialsProviderName: defaultCredentialsProviderName, + readStrictMode: defaultReadStrictMode, queryInterceptor: nil, } @@ -103,6 +106,15 @@ func generateConfig(options []Option) (postgresOptions, error) { return computed, nil } +// ReadStrictMode sets whether strict mode is used for reads in the Postgres reader. If enabled, +// an assertion is added into the WHERE clause of all read queries to ensure that the revision +// being read is available on the read connection. +// +// Strict mode is disabled by default, as the default behavior is to read from the primary. +func ReadStrictMode(readStrictMode bool) Option { + return func(po *postgresOptions) { po.readStrictMode = readStrictMode } +} + // ReadConnHealthCheckInterval is the frequency at which both idle and max // lifetime connections are checked, and also the frequency at which the // minimum number of connections is checked. diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index e528fccc87..b1980c5822 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "os" + "strconv" "sync/atomic" "time" @@ -30,6 +31,7 @@ import ( log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/spiceerrors" ) func init() { @@ -80,6 +82,8 @@ const ( tracingDriverName = "postgres-tracing" gcBatchDeleteSize = 1000 + + primaryInstanceID = -1 ) var livingTupleConstraints = []string{"uq_relation_tuple_living_xid", "pk_relation_tuple"} @@ -122,7 +126,24 @@ func NewPostgresDatastore( url string, options ...Option, ) (datastore.Datastore, error) { - ds, err := newPostgresDatastore(ctx, url, options...) + ds, err := newPostgresDatastore(ctx, url, primaryInstanceID, options...) + if err != nil { + return nil, err + } + + return datastoreinternal.NewSeparatingContextDatastoreProxy(ds), nil +} + +// NewReadOnlyPostgresDatastore initializes a SpiceDB datastore that uses a PostgreSQL +// database by leveraging manual book-keeping to implement revisioning. This version is +// read only and does not allow for write transactions. +func NewReadOnlyPostgresDatastore( + ctx context.Context, + url string, + index uint32, + options ...Option, +) (datastore.StrictReadDatastore, error) { + ds, err := newPostgresDatastore(ctx, url, int(index), options...) if err != nil { return nil, err } @@ -133,8 +154,10 @@ func NewPostgresDatastore( func newPostgresDatastore( ctx context.Context, pgURL string, + replicaIndex int, options ...Option, ) (datastore.Datastore, error) { + isPrimary := replicaIndex == primaryInstanceID config, err := generateConfig(options) if err != nil { return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, pgURL) @@ -170,12 +193,15 @@ func newPostgresDatastore( return nil } - writePoolConfig := pgConfig.Copy() - config.writePoolOpts.ConfigurePgx(writePoolConfig) + var writePoolConfig *pgxpool.Config + if isPrimary { + writePoolConfig = pgConfig.Copy() + config.writePoolOpts.ConfigurePgx(writePoolConfig) - writePoolConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { - RegisterTypes(conn.TypeMap()) - return nil + writePoolConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { + RegisterTypes(conn.TypeMap()) + return nil + } } if credentialsProvider != nil { @@ -185,7 +211,10 @@ func newPostgresDatastore( return err } readPoolConfig.BeforeConnect = getToken - writePoolConfig.BeforeConnect = getToken + + if isPrimary { + writePoolConfig.BeforeConnect = getToken + } } if config.migrationPhase != "" { @@ -202,9 +231,14 @@ func newPostgresDatastore( return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, pgURL) } - writePool, err := pgxpool.NewWithConfig(initializationContext, writePoolConfig) - if err != nil { - return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, pgURL) + var writePool *pgxpool.Pool + + if isPrimary { + wp, err := pgxpool.NewWithConfig(initializationContext, writePoolConfig) + if err != nil { + return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, pgURL) + } + writePool = wp } // Verify that the server supports commit timestamps @@ -221,20 +255,29 @@ func newPostgresDatastore( } if config.enablePrometheusStats { + replicaIndexStr := strconv.Itoa(replicaIndex) + dbname := "spicedb" + if replicaIndex != primaryInstanceID { + dbname = fmt.Sprintf("spicedb_replica_%s", replicaIndexStr) + } + if err := prometheus.Register(pgxpoolprometheus.NewCollector(readPool, map[string]string{ - "db_name": "spicedb", + "db_name": dbname, "pool_usage": "read", })); err != nil { return nil, err } - if err := prometheus.Register(pgxpoolprometheus.NewCollector(writePool, map[string]string{ - "db_name": "spicedb", - "pool_usage": "write", - })); err != nil { - return nil, err - } - if err := common.RegisterGCMetrics(); err != nil { - return nil, err + + if isPrimary { + if err := prometheus.Register(pgxpoolprometheus.NewCollector(writePool, map[string]string{ + "db_name": "spicedb", + "pool_usage": "write", + })); err != nil { + return nil, err + } + if err := common.RegisterGCMetrics(); err != nil { + return nil, err + } } } @@ -271,7 +314,7 @@ func newPostgresDatastore( ), dburl: pgURL, readPool: pgxcommon.MustNewInterceptorPooler(readPool, config.queryInterceptor), - writePool: pgxcommon.MustNewInterceptorPooler(writePool, config.queryInterceptor), + writePool: nil, /* disabled by default */ watchBufferLength: config.watchBufferLength, watchBufferWriteTimeout: config.watchBufferWriteTimeout, optimizedRevisionQuery: revisionQuery, @@ -286,24 +329,36 @@ func newPostgresDatastore( readTxOptions: pgx.TxOptions{IsoLevel: pgx.RepeatableRead, AccessMode: pgx.ReadOnly}, maxRetries: config.maxRetries, credentialsProvider: credentialsProvider, + isPrimary: isPrimary, + inStrictReadMode: config.readStrictMode, + } + + if isPrimary && config.readStrictMode { + return nil, spiceerrors.MustBugf("strict read mode is not supported on primary instances") + } + + if isPrimary { + datastore.writePool = pgxcommon.MustNewInterceptorPooler(writePool, config.queryInterceptor) } datastore.SetOptimizedRevisionFunc(datastore.optimizedRevisionFunc) // Start a goroutine for garbage collection. - if datastore.gcInterval > 0*time.Minute && config.gcEnabled { - datastore.gcGroup, datastore.gcCtx = errgroup.WithContext(datastore.gcCtx) - datastore.gcGroup.Go(func() error { - return common.StartGarbageCollector( - datastore.gcCtx, - datastore, - datastore.gcInterval, - datastore.gcWindow, - datastore.gcTimeout, - ) - }) - } else { - log.Warn().Msg("datastore background garbage collection disabled") + if isPrimary { + if datastore.gcInterval > 0*time.Minute && config.gcEnabled { + datastore.gcGroup, datastore.gcCtx = errgroup.WithContext(datastore.gcCtx) + datastore.gcGroup.Go(func() error { + return common.StartGarbageCollector( + datastore.gcCtx, + datastore, + datastore.gcInterval, + datastore.gcWindow, + datastore.gcTimeout, + ) + }) + } else { + log.Warn().Msg("datastore background garbage collection disabled") + } } return datastore, nil @@ -325,6 +380,8 @@ type pgDatastore struct { readTxOptions pgx.TxOptions maxRetries uint8 watchEnabled bool + isPrimary bool + inStrictReadMode bool credentialsProvider datastore.CredentialsProvider @@ -334,10 +391,18 @@ type pgDatastore struct { gcHasRun atomic.Bool } +func (pgd *pgDatastore) IsStrictReadModeEnabled() bool { + return pgd.inStrictReadMode +} + func (pgd *pgDatastore) SnapshotReader(revRaw datastore.Revision) datastore.Reader { rev := revRaw.(postgresRevision) queryFuncs := pgxcommon.QuerierFuncsFor(pgd.readPool) + if pgd.inStrictReadMode { + queryFuncs = strictReaderQueryFuncs{wrapped: queryFuncs, revision: rev} + } + executor := common.QueryExecutor{ Executor: pgxcommon.NewPGXExecutor(queryFuncs), } @@ -356,6 +421,10 @@ func (pgd *pgDatastore) ReadWriteTx( fn datastore.TxUserFunc, opts ...options.RWTOptionsOption, ) (datastore.Revision, error) { + if !pgd.isPrimary { + return datastore.NoRevision, spiceerrors.MustBugf("read-write transaction not supported on read-only datastore") + } + config := options.NewRWTOptionsWithOptions(opts...) var err error @@ -535,7 +604,11 @@ func (pgd *pgDatastore) Close() error { } pgd.readPool.Close() - pgd.writePool.Close() + + if pgd.writePool != nil { + pgd.writePool.Close() + } + return nil } diff --git a/internal/datastore/postgres/postgres_shared_test.go b/internal/datastore/postgres/postgres_shared_test.go index 151dc281da..cdb9f0097a 100644 --- a/internal/datastore/postgres/postgres_shared_test.go +++ b/internal/datastore/postgres/postgres_shared_test.go @@ -86,7 +86,7 @@ func testPostgresDatastore(t *testing.T, pc []postgresConfig) { test.All(t, test.DatastoreTesterFunc(func(revisionQuantization, gcInterval, gcWindow time.Duration, watchBufferLength uint16) (datastore.Datastore, error) { ds := b.NewDatastore(t, func(engine, uri string) datastore.Datastore { - ds, err := newPostgresDatastore(ctx, uri, + ds, err := newPostgresDatastore(ctx, uri, primaryInstanceID, RevisionQuantization(revisionQuantization), GCWindow(gcWindow), GCInterval(gcInterval), @@ -176,6 +176,16 @@ func testPostgresDatastore(t *testing.T, pc []postgresConfig) { WatchBufferLength(50), MigrationPhase(config.migrationPhase), )) + + t.Run("TestStrictReadMode", createReplicaDatastoreTest( + b, + StrictReadModeTest, + RevisionQuantization(0), + GCWindow(1000*time.Second), + WatchBufferLength(50), + MigrationPhase(config.migrationPhase), + ReadStrictMode(true), + )) } t.Run("OTelTracing", createDatastoreTest( @@ -203,7 +213,7 @@ func testPostgresDatastoreWithoutCommitTimestamps(t *testing.T, pc []postgresCon // NOTE: watch API requires the commit timestamps, so we skip those tests here. test.AllWithExceptions(t, test.DatastoreTesterFunc(func(revisionQuantization, gcInterval, gcWindow time.Duration, watchBufferLength uint16) (datastore.Datastore, error) { ds := b.NewDatastore(t, func(engine, uri string) datastore.Datastore { - ds, err := newPostgresDatastore(ctx, uri, + ds, err := newPostgresDatastore(ctx, uri, primaryInstanceID, RevisionQuantization(revisionQuantization), GCWindow(gcWindow), GCInterval(gcInterval), @@ -225,7 +235,21 @@ func createDatastoreTest(b testdatastore.RunningEngineForTest, tf datastoreTestF return func(t *testing.T) { ctx := context.Background() ds := b.NewDatastore(t, func(engine, uri string) datastore.Datastore { - ds, err := newPostgresDatastore(ctx, uri, options...) + ds, err := newPostgresDatastore(ctx, uri, primaryInstanceID, options...) + require.NoError(t, err) + return ds + }) + defer ds.Close() + + tf(t, ds) + } +} + +func createReplicaDatastoreTest(b testdatastore.RunningEngineForTest, tf datastoreTestFunc, options ...Option) func(*testing.T) { + return func(t *testing.T) { + ctx := context.Background() + ds := b.NewDatastore(t, func(engine, uri string) datastore.Datastore { + ds, err := newPostgresDatastore(ctx, uri, 42, options...) require.NoError(t, err) return ds }) @@ -635,6 +659,7 @@ func QuantizedRevisionTest(t *testing.T, b testdatastore.RunningEngineForTest) { ds, err := newPostgresDatastore( ctx, uri, + primaryInstanceID, RevisionQuantization(5*time.Second), GCWindow(24*time.Hour), WatchBufferLength(1), @@ -1136,6 +1161,7 @@ func WatchNotEnabledTest(t *testing.T, _ testdatastore.RunningEngineForTest, pgV ds := testdatastore.RunPostgresForTestingWithCommitTimestamps(t, "", migrate.Head, false, pgVersion, false).NewDatastore(t, func(engine, uri string) datastore.Datastore { ctx := context.Background() ds, err := newPostgresDatastore(ctx, uri, + primaryInstanceID, RevisionQuantization(0), GCWindow(time.Millisecond*1), WatchBufferLength(1), @@ -1162,6 +1188,7 @@ func BenchmarkPostgresQuery(b *testing.B) { ds := testdatastore.RunPostgresForTesting(b, "", migrate.Head, pgversion.MinimumSupportedPostgresVersion, false).NewDatastore(b, func(engine, uri string) datastore.Datastore { ctx := context.Background() ds, err := newPostgresDatastore(ctx, uri, + primaryInstanceID, RevisionQuantization(0), GCWindow(time.Millisecond*1), WatchBufferLength(1), @@ -1197,6 +1224,7 @@ func datastoreWithInterceptorAndTestData(t *testing.T, interceptor pgcommon.Quer ds := testdatastore.RunPostgresForTestingWithCommitTimestamps(t, "", migrate.Head, false, pgVersion, false).NewDatastore(t, func(engine, uri string) datastore.Datastore { ctx := context.Background() ds, err := newPostgresDatastore(ctx, uri, + primaryInstanceID, RevisionQuantization(0), GCWindow(time.Millisecond*1), WatchBufferLength(1), @@ -1406,6 +1434,38 @@ func RepairTransactionsTest(t *testing.T, ds datastore.Datastore) { require.Greater(t, currentMaximumID, 12345) } +func StrictReadModeTest(t *testing.T, ds datastore.Datastore) { + require := require.New(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + lowestRevision, err := ds.HeadRevision(ctx) + require.NoError(err) + + // Perform a read at the head revision, which should succeed. + reader := ds.SnapshotReader(lowestRevision) + it, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: "resource", + }) + require.NoError(err) + it.Close() + + // Perform a read at a manually constructed revision beyond head, which should fail. + badRev := postgresRevision{ + snapshot: pgSnapshot{ + xmax: 9999999999999999999, + }, + } + + _, err = ds.SnapshotReader(badRev).QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: "resource", + }) + require.Error(err) + require.ErrorContains(err, "is not available on the replica") + require.ErrorAs(err, &common.RevisionUnavailableError{}) +} + func NullCaveatWatchTest(t *testing.T, ds datastore.Datastore) { require := require.New(t) diff --git a/internal/datastore/postgres/strictreader.go b/internal/datastore/postgres/strictreader.go new file mode 100644 index 0000000000..b656e237ca --- /dev/null +++ b/internal/datastore/postgres/strictreader.go @@ -0,0 +1,65 @@ +package postgres + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + + "github.com/authzed/spicedb/internal/datastore/common" + pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" +) + +const pgInvalidArgument = "22023" + +// strictReaderQueryFuncs wraps a DBFuncQuerier and adds a strict read assertion to all queries. +// This assertion ensures that the transaction is not reading from the future or from a +// transaction that has not been committed on the replica. +type strictReaderQueryFuncs struct { + wrapped pgxcommon.DBFuncQuerier + revision postgresRevision +} + +func (srqf strictReaderQueryFuncs) ExecFunc(ctx context.Context, tagFunc func(ctx context.Context, tag pgconn.CommandTag, err error) error, sql string, args ...any) error { + // NOTE: it is *required* for the pgx.QueryExecModeSimpleProtocol to be added as pgx will otherwise wrap + // the query as a prepared statement, which does *not* support running more than a single statement at a time. + return srqf.rewriteError(srqf.wrapped.ExecFunc(ctx, tagFunc, srqf.addAssertToSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...)) +} + +func (srqf strictReaderQueryFuncs) QueryFunc(ctx context.Context, rowsFunc func(ctx context.Context, rows pgx.Rows) error, sql string, args ...any) error { + return srqf.rewriteError(srqf.wrapped.QueryFunc(ctx, rowsFunc, srqf.addAssertToSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...)) +} + +func (srqf strictReaderQueryFuncs) QueryRowFunc(ctx context.Context, rowFunc func(ctx context.Context, row pgx.Row) error, sql string, args ...any) error { + return srqf.rewriteError(srqf.wrapped.QueryRowFunc(ctx, rowFunc, srqf.addAssertToSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...)) +} + +func (srqf strictReaderQueryFuncs) rewriteError(err error) error { + if err == nil { + return nil + } + + var pgerr *pgconn.PgError + if errors.As(err, &pgerr) { + if (pgerr.Code == pgInvalidArgument && strings.Contains(pgerr.Message, "is in the future")) || + strings.Contains(pgerr.Message, "replica missing revision") { + return common.NewRevisionUnavailableError(fmt.Errorf("revision %s is not available on the replica", srqf.revision.String())) + } + } + + return err +} + +func (srqf strictReaderQueryFuncs) addAssertToSQL(sql string) string { + // The assertion checks that the transaction is not reading from the future or from a + // transaction that is still in-progress on the replica. If the transaction is not yet + // available on the replica at all, the call to `pg_xact_status` will fail with an invalid + // argument error and a message indicating that the xid "is in the future". If the transaction + // does exist, but has not yet been committed (or aborted), the call to `pg_xact_status` will return + // "in progress". rewriteError will catch these cases and return a RevisionUnavailableError. + assertion := fmt.Sprintf(`do $$ begin assert (select pg_xact_status(%d::text::xid8) != 'in progress'), 'replica missing revision';end;$$;`, srqf.revision.snapshot.xmin-1) + return assertion + sql +} diff --git a/internal/datastore/proxy/cachedcheckrev.go b/internal/datastore/proxy/cachedcheckrev.go new file mode 100644 index 0000000000..4ca32d5c32 --- /dev/null +++ b/internal/datastore/proxy/cachedcheckrev.go @@ -0,0 +1,44 @@ +package proxy + +import ( + "context" + "sync/atomic" + + "github.com/authzed/spicedb/pkg/datastore" +) + +// newCachedCheckRevision wraps a datastore with a cache that will avoid checking the revision +// if the last checked revision is at least as fresh as the one specified. +func newCachedCheckRevision(ds datastore.ReadOnlyDatastore) datastore.ReadOnlyDatastore { + return &cachedCheckRevision{ + ReadOnlyDatastore: ds, + lastCheckRevision: atomic.Pointer[datastore.Revision]{}, + } +} + +type cachedCheckRevision struct { + datastore.ReadOnlyDatastore + lastCheckRevision atomic.Pointer[datastore.Revision] +} + +func (c *cachedCheckRevision) CheckRevision(ctx context.Context, rev datastore.Revision) error { + // Check if we've already seen a revision at least as fresh as that specified. If so, we can skip the check. + lastChecked := c.lastCheckRevision.Load() + if lastChecked != nil { + lastCheckedRev := *lastChecked + if lastCheckedRev.Equal(rev) || lastCheckedRev.GreaterThan(rev) { + return nil + } + } + + err := c.ReadOnlyDatastore.CheckRevision(ctx, rev) + if err != nil { + return err + } + + if lastChecked == nil || rev.LessThan(*lastChecked) { + c.lastCheckRevision.CompareAndSwap(lastChecked, &rev) + } + + return nil +} diff --git a/internal/datastore/proxy/cachedcheckrev_test.go b/internal/datastore/proxy/cachedcheckrev_test.go new file mode 100644 index 0000000000..fbabb50be5 --- /dev/null +++ b/internal/datastore/proxy/cachedcheckrev_test.go @@ -0,0 +1,56 @@ +package proxy + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/revisionparsing" +) + +func TestCachedCheckRevision(t *testing.T) { + ds := &fakeBrokenDatastore{checkCount: 0} + + wrapped := newCachedCheckRevision(ds) + err := wrapped.CheckRevision(context.Background(), revisionparsing.MustParseRevisionForTest("10")) + require.NoError(t, err) + + // Check again for the same revision, should not call the underlying datastore. + err = wrapped.CheckRevision(context.Background(), revisionparsing.MustParseRevisionForTest("10")) + require.NoError(t, err) + + // Check again for a lesser revision, should not call the underlying datastore. + err = wrapped.CheckRevision(context.Background(), revisionparsing.MustParseRevisionForTest("10")) + require.NoError(t, err) + + // Check again for a higher revision, which should call the datastore. + err = wrapped.CheckRevision(context.Background(), revisionparsing.MustParseRevisionForTest("11")) + require.Error(t, err) + + err = wrapped.CheckRevision(context.Background(), revisionparsing.MustParseRevisionForTest("11")) + require.Error(t, err) + + err = wrapped.CheckRevision(context.Background(), revisionparsing.MustParseRevisionForTest("12")) + require.Error(t, err) + + // Ensure the older revision still works. + err = wrapped.CheckRevision(context.Background(), revisionparsing.MustParseRevisionForTest("10")) + require.NoError(t, err) +} + +type fakeBrokenDatastore struct { + fakeDatastore + checkCount int +} + +func (f *fakeBrokenDatastore) CheckRevision(_ context.Context, _ datastore.Revision) error { + if f.checkCount == 0 { + f.checkCount++ + return nil + } + + return fmt.Errorf("broken") +} diff --git a/internal/datastore/proxy/replicated.go b/internal/datastore/proxy/replicated.go new file mode 100644 index 0000000000..589ed51277 --- /dev/null +++ b/internal/datastore/proxy/replicated.go @@ -0,0 +1,384 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + + "github.com/authzed/spicedb/internal/datastore/common" + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// NewCheckingReplicatedDatastore creates a new datastore that writes to the provided primary and reads +// from the provided replicas. The replicas are chosen in a round-robin fashion. If a replica does +// not have the requested revision, the primary is used instead. +// +// NOTE: Be *very* careful when using this function. It is not safe to use this function without +// knowledge of the layout of the underlying datastore and its replicas. +// +// Replicas will be checked for the requested revision before reading from them, which means that the +// read pool for the replicas *must* point to a *stable* instance of the datastore (not a load balancer). +// That means that *each* replica node in the database must be configured as its own replica to SpiceDB, +// with each URI given distinctly. +func NewCheckingReplicatedDatastore(primary datastore.Datastore, replicas ...datastore.ReadOnlyDatastore) (datastore.Datastore, error) { + if len(replicas) == 0 { + log.Debug().Msg("No replicas provided, using primary as read source") + return primary, nil + } + + cachingReplicas := make([]datastore.ReadOnlyDatastore, 0, len(replicas)) + for _, replica := range replicas { + cachingReplicas = append(cachingReplicas, newCachedCheckRevision(replica)) + } + + log.Debug().Int("replica-count", len(replicas)).Msg("Using replicas for reads") + return &checkingReplicatedDatastore{ + primary, + cachingReplicas, + 0, + }, nil +} + +// NewStrictReplicatedDatastore creates a new datastore that writes to the provided primary and reads +// from the provided replicas. The replicas are chosen in a round-robin fashion. If a replica does +// not have the requested revision, the primary is used instead. +// +// Unlike NewCheckingReplicatedDatastore, this function does not check the replicas for the requested +// revision before reading from them; instead, a revision check is inserted into the SQL for each read. +// This is useful when the read pool points to a load balancer that can transparently handle the request. +// In this case, the primary will be used as a fallback if the replica does not have the requested revision. +// The replica(s) supplied to this proxy *must*, therefore, have strict read mode enabled, to ensure the +// query will fail with a RevisionUnavailableError if the revision is not available. +func NewStrictReplicatedDatastore(primary datastore.Datastore, replicas ...datastore.StrictReadDatastore) (datastore.Datastore, error) { + if len(replicas) == 0 { + log.Debug().Msg("No replicas provided, using primary as read source") + return primary, nil + } + + cachingReplicas := make([]datastore.ReadOnlyDatastore, 0, len(replicas)) + for _, replica := range replicas { + if !replica.IsStrictReadModeEnabled() { + return nil, fmt.Errorf("replica %v does not have strict read mode enabled", replica) + } + + cachingReplicas = append(cachingReplicas, newCachedCheckRevision(replica)) + } + + log.Debug().Int("replica-count", len(replicas)).Msg("Using replicas for reads") + return &strictReplicatedDatastore{ + primary, + cachingReplicas, + 0, + }, nil +} + +func selectReplica[T any](replicas []T, lastReplica *uint64) T { + if len(replicas) == 1 { + return replicas[0] + } + + var swapped bool + var next uint64 + for !swapped { + last := *lastReplica + next = (*lastReplica + 1) % uint64(len(replicas)) + swapped = atomic.CompareAndSwapUint64(lastReplica, last, next) + } + + log.Trace().Uint64("replica", next).Msg("choosing replica for read") + return replicas[next] +} + +type checkingReplicatedDatastore struct { + datastore.Datastore + replicas []datastore.ReadOnlyDatastore + + lastReplica uint64 +} + +// SnapshotReader creates a read-only handle that reads the datastore at the specified revision. +// Any errors establishing the reader will be returned by subsequent calls. +func (rd *checkingReplicatedDatastore) SnapshotReader(revision datastore.Revision) datastore.Reader { + replica := selectReplica(rd.replicas, &rd.lastReplica) + return &checkingStableReader{ + rev: revision, + replica: replica, + primary: rd.Datastore, + } +} + +type strictReplicatedDatastore struct { + datastore.Datastore + replicas []datastore.ReadOnlyDatastore + + lastReplica uint64 +} + +// SnapshotReader creates a read-only handle that reads the datastore at the specified revision. +// Any errors establishing the reader will be returned by subsequent calls. +func (rd *strictReplicatedDatastore) SnapshotReader(revision datastore.Revision) datastore.Reader { + replica := selectReplica(rd.replicas, &rd.lastReplica) + return &strictReadReplicatedReader{ + rev: revision, + replica: replica, + primary: rd.Datastore, + } +} + +// checkingStableReader is a reader that will check the replica for the requested revision before +// reading from it. If the replica does not have the requested revision, the primary will be used +// instead. Only supported for a stable replica within each pool. +type checkingStableReader struct { + rev datastore.Revision + replica datastore.ReadOnlyDatastore + primary datastore.Datastore + + // chosePrimaryForTest is used for testing to determine if the primary was used for the read. + chosePrimaryForTest bool + + chosenReader datastore.Reader + choose sync.Once +} + +func (rr *checkingStableReader) ReadCaveatByName(ctx context.Context, name string) (caveat *core.CaveatDefinition, lastWritten datastore.Revision, err error) { + if err := rr.determineSource(ctx); err != nil { + return nil, datastore.NoRevision, err + } + + return rr.chosenReader.ReadCaveatByName(ctx, name) +} + +func (rr *checkingStableReader) ListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) { + if err := rr.determineSource(ctx); err != nil { + return nil, err + } + + return rr.chosenReader.ListAllCaveats(ctx) +} + +func (rr *checkingStableReader) LookupCaveatsWithNames(ctx context.Context, names []string) ([]datastore.RevisionedCaveat, error) { + if err := rr.determineSource(ctx); err != nil { + return nil, err + } + + return rr.chosenReader.LookupCaveatsWithNames(ctx, names) +} + +func (rr *checkingStableReader) QueryRelationships( + ctx context.Context, + filter datastore.RelationshipsFilter, + options ...options.QueryOptionsOption, +) (datastore.RelationshipIterator, error) { + if err := rr.determineSource(ctx); err != nil { + return nil, err + } + + return rr.chosenReader.QueryRelationships(ctx, filter, options...) +} + +func (rr *checkingStableReader) ReverseQueryRelationships( + ctx context.Context, + subjectsFilter datastore.SubjectsFilter, + options ...options.ReverseQueryOptionsOption, +) (datastore.RelationshipIterator, error) { + if err := rr.determineSource(ctx); err != nil { + return nil, err + } + + return rr.chosenReader.ReverseQueryRelationships(ctx, subjectsFilter, options...) +} + +func (rr *checkingStableReader) ReadNamespaceByName(ctx context.Context, nsName string) (ns *core.NamespaceDefinition, lastWritten datastore.Revision, err error) { + if err := rr.determineSource(ctx); err != nil { + return nil, datastore.NoRevision, err + } + + return rr.chosenReader.ReadNamespaceByName(ctx, nsName) +} + +func (rr *checkingStableReader) ListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { + if err := rr.determineSource(ctx); err != nil { + return nil, err + } + + return rr.chosenReader.ListAllNamespaces(ctx) +} + +func (rr *checkingStableReader) LookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedNamespace, error) { + if err := rr.determineSource(ctx); err != nil { + return nil, err + } + + return rr.chosenReader.LookupNamespacesWithNames(ctx, nsNames) +} + +func (rr *checkingStableReader) CountRelationships(ctx context.Context, filter string) (int, error) { + if err := rr.determineSource(ctx); err != nil { + return 0, err + } + + return rr.chosenReader.CountRelationships(ctx, filter) +} + +func (rr *checkingStableReader) LookupCounters(ctx context.Context) ([]datastore.RelationshipCounter, error) { + if err := rr.determineSource(ctx); err != nil { + return nil, err + } + + return rr.chosenReader.LookupCounters(ctx) +} + +// determineSource will choose the replica or primary to read from based on the revision, by checking +// if the replica contains the revision. If the replica does not contain the revision, the primary +// will be used instead. +func (rr *checkingStableReader) determineSource(ctx context.Context) error { + var finalError error + rr.choose.Do(func() { + // If the revision is not known to the replica, use the primary instead. + if err := rr.replica.CheckRevision(ctx, rr.rev); err != nil { + var irr datastore.ErrInvalidRevision + if errors.As(err, &irr) { + if irr.Reason() == datastore.CouldNotDetermineRevision { + log.Trace().Str("revision", rr.rev.String()).Err(err).Msg("replica does not contain the requested revision, using primary") + rr.chosenReader = rr.primary.SnapshotReader(rr.rev) + rr.chosePrimaryForTest = true + return + } + } + finalError = err + return + } + log.Trace().Str("revision", rr.rev.String()).Msg("replica contains the requested revision") + + rr.chosenReader = rr.replica.SnapshotReader(rr.rev) + rr.chosePrimaryForTest = false + }) + + return finalError +} + +// strictReadReplicatedReader is a reader that will use the replica for reads without itself checking for +// the requested revision. If the replica does not have the requested revision, the primary will be +// used instead. This is useful when the read pool points to a load balancer that can transparently +// handle the request. In this case, the primary will be used as a fallback if the replica does not +// have the requested revision. The replica(s) supplied to this proxy *must*, therefore, have strict +// read mode enabled, to ensure the query will fail with a RevisionUnavailableError if the revision is +// not available. +type strictReadReplicatedReader struct { + rev datastore.Revision + replica datastore.ReadOnlyDatastore + primary datastore.Datastore +} + +func (rr *strictReadReplicatedReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { + sr := rr.replica.SnapshotReader(rr.rev) + caveat, lastWritten, err := sr.ReadCaveatByName(ctx, name) + if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { + log.Trace().Str("caveat", name).Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") + return rr.primary.SnapshotReader(rr.rev).ReadCaveatByName(ctx, name) + } + return caveat, lastWritten, err +} + +func (rr *strictReadReplicatedReader) ListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) { + sr := rr.replica.SnapshotReader(rr.rev) + caveats, err := sr.ListAllCaveats(ctx) + if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { + log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") + return rr.primary.SnapshotReader(rr.rev).ListAllCaveats(ctx) + } + return caveats, err +} + +func (rr *strictReadReplicatedReader) LookupCaveatsWithNames(ctx context.Context, names []string) ([]datastore.RevisionedCaveat, error) { + sr := rr.replica.SnapshotReader(rr.rev) + caveats, err := sr.LookupCaveatsWithNames(ctx, names) + if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { + log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") + return rr.primary.SnapshotReader(rr.rev).LookupCaveatsWithNames(ctx, names) + } + return caveats, err +} + +func (rr *strictReadReplicatedReader) QueryRelationships( + ctx context.Context, + filter datastore.RelationshipsFilter, + options ...options.QueryOptionsOption, +) (datastore.RelationshipIterator, error) { + sr := rr.replica.SnapshotReader(rr.rev) + relationships, err := sr.QueryRelationships(ctx, filter, options...) + if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { + log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") + return rr.primary.SnapshotReader(rr.rev).QueryRelationships(ctx, filter, options...) + } + return relationships, err +} + +func (rr *strictReadReplicatedReader) ReverseQueryRelationships( + ctx context.Context, + subjectsFilter datastore.SubjectsFilter, + options ...options.ReverseQueryOptionsOption, +) (datastore.RelationshipIterator, error) { + sr := rr.replica.SnapshotReader(rr.rev) + relationships, err := sr.ReverseQueryRelationships(ctx, subjectsFilter, options...) + if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { + log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") + return rr.primary.SnapshotReader(rr.rev).ReverseQueryRelationships(ctx, subjectsFilter, options...) + } + return relationships, err +} + +func (rr *strictReadReplicatedReader) ReadNamespaceByName(ctx context.Context, nsName string) (ns *core.NamespaceDefinition, lastWritten datastore.Revision, err error) { + sr := rr.replica.SnapshotReader(rr.rev) + namespace, lastWritten, err := sr.ReadNamespaceByName(ctx, nsName) + if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { + log.Trace().Str("namespace", nsName).Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") + return rr.primary.SnapshotReader(rr.rev).ReadNamespaceByName(ctx, nsName) + } + return namespace, lastWritten, err +} + +func (rr *strictReadReplicatedReader) ListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { + sr := rr.replica.SnapshotReader(rr.rev) + namespaces, err := sr.ListAllNamespaces(ctx) + if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { + log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") + return rr.primary.SnapshotReader(rr.rev).ListAllNamespaces(ctx) + } + return namespaces, err +} + +func (rr *strictReadReplicatedReader) LookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedNamespace, error) { + sr := rr.replica.SnapshotReader(rr.rev) + namespaces, err := sr.LookupNamespacesWithNames(ctx, nsNames) + if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { + log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") + return rr.primary.SnapshotReader(rr.rev).LookupNamespacesWithNames(ctx, nsNames) + } + return namespaces, err +} + +func (rr *strictReadReplicatedReader) CountRelationships(ctx context.Context, filter string) (int, error) { + sr := rr.replica.SnapshotReader(rr.rev) + count, err := sr.CountRelationships(ctx, filter) + if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { + log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") + return rr.primary.SnapshotReader(rr.rev).CountRelationships(ctx, filter) + } + return count, err +} + +func (rr *strictReadReplicatedReader) LookupCounters(ctx context.Context) ([]datastore.RelationshipCounter, error) { + sr := rr.replica.SnapshotReader(rr.rev) + counters, err := sr.LookupCounters(ctx) + if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { + log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") + return rr.primary.SnapshotReader(rr.rev).LookupCounters(ctx) + } + return counters, err +} diff --git a/internal/datastore/proxy/replicated_test.go b/internal/datastore/proxy/replicated_test.go new file mode 100644 index 0000000000..39a23c90ab --- /dev/null +++ b/internal/datastore/proxy/replicated_test.go @@ -0,0 +1,211 @@ +package proxy + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/revisionparsing" + corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +func TestReplicatedReaderWithOnlyPrimary(t *testing.T) { + primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")} + + replicated, err := NewStrictReplicatedDatastore(primary) + require.NoError(t, err) + + require.Equal(t, primary, replicated) +} + +func TestReplicatedReaderFallsbackToPrimaryOnCheckRevisionFailure(t *testing.T) { + primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")} + replica := fakeDatastore{false, revisionparsing.MustParseRevisionForTest("1")} + + replicated, err := NewCheckingReplicatedDatastore(primary, replica) + require.NoError(t, err) + + // Try at revision 1, which should use the replica. + reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("1")) + ns, err := reader.ListAllNamespaces(context.Background()) + require.NoError(t, err) + require.Equal(t, 0, len(ns)) + + require.False(t, reader.(*checkingStableReader).chosePrimaryForTest) + + // Try at revision 2, which should use the primary. + reader = replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("2")) + ns, err = reader.ListAllNamespaces(context.Background()) + require.NoError(t, err) + require.Equal(t, 0, len(ns)) + + require.True(t, reader.(*checkingStableReader).chosePrimaryForTest) +} + +func TestReplicatedReaderFallsbackToPrimaryOnRevisionNotAvailableError(t *testing.T) { + primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")} + replica := fakeDatastore{false, revisionparsing.MustParseRevisionForTest("1")} + + replicated, err := NewCheckingReplicatedDatastore(primary, replica) + require.NoError(t, err) + + reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("3")) + ns, err := reader.LookupNamespacesWithNames(context.Background(), []string{"ns1"}) + require.NoError(t, err) + require.Equal(t, 1, len(ns)) +} + +func TestReplicatedReaderReturnsExpectedError(t *testing.T) { + for _, requireCheck := range []bool{true, false} { + t.Run(fmt.Sprintf("requireCheck=%v", requireCheck), func(t *testing.T) { + primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")} + replica := fakeDatastore{false, revisionparsing.MustParseRevisionForTest("1")} + + var ds datastore.Datastore + if requireCheck { + r, err := NewCheckingReplicatedDatastore(primary, replica) + require.NoError(t, err) + ds = r + } else { + r, err := NewStrictReplicatedDatastore(primary, replica) + ds = r + require.NoError(t, err) + } + + // Try at revision 1, which should use the replica. + reader := ds.SnapshotReader(revisionparsing.MustParseRevisionForTest("1")) + _, _, err := reader.ReadNamespaceByName(context.Background(), "expecterror") + require.Error(t, err) + require.ErrorContains(t, err, "raising an expected error") + }) + } +} + +type fakeDatastore struct { + isPrimary bool + revision datastore.Revision +} + +func (f fakeDatastore) SnapshotReader(revision datastore.Revision) datastore.Reader { + return fakeSnapshotReader{ + revision: revision, + isPrimary: f.isPrimary, + } +} + +func (f fakeDatastore) ReadWriteTx(_ context.Context, _ datastore.TxUserFunc, _ ...options.RWTOptionsOption) (datastore.Revision, error) { + return nil, nil +} + +func (f fakeDatastore) OptimizedRevision(_ context.Context) (datastore.Revision, error) { + return nil, nil +} + +func (f fakeDatastore) HeadRevision(_ context.Context) (datastore.Revision, error) { + return nil, nil +} + +func (f fakeDatastore) CheckRevision(_ context.Context, rev datastore.Revision) error { + if rev.GreaterThan(f.revision) { + return datastore.NewInvalidRevisionErr(rev, datastore.CouldNotDetermineRevision) + } + + return nil +} + +func (f fakeDatastore) RevisionFromString(_ string) (datastore.Revision, error) { + return nil, nil +} + +func (f fakeDatastore) Watch(_ context.Context, _ datastore.Revision, _ datastore.WatchOptions) (<-chan *datastore.RevisionChanges, <-chan error) { + return nil, nil +} + +func (f fakeDatastore) ReadyState(_ context.Context) (datastore.ReadyState, error) { + return datastore.ReadyState{}, nil +} + +func (f fakeDatastore) Features(_ context.Context) (*datastore.Features, error) { + return nil, nil +} + +func (f fakeDatastore) Statistics(_ context.Context) (datastore.Stats, error) { + return datastore.Stats{}, nil +} + +func (f fakeDatastore) Close() error { + return nil +} + +func (f fakeDatastore) IsStrictReadModeEnabled() bool { + return true +} + +type fakeSnapshotReader struct { + revision datastore.Revision + isPrimary bool +} + +func (fsr fakeSnapshotReader) LookupNamespacesWithNames(_ context.Context, nsNames []string) ([]datastore.RevisionedDefinition[*corev1.NamespaceDefinition], error) { + if fsr.isPrimary { + return []datastore.RevisionedDefinition[*corev1.NamespaceDefinition]{ + { + Definition: &corev1.NamespaceDefinition{ + Name: "ns1", + }, + LastWrittenRevision: revisionparsing.MustParseRevisionForTest("2"), + }, + }, nil + } + + if !fsr.isPrimary && fsr.revision.GreaterThan(revisionparsing.MustParseRevisionForTest("2")) { + return nil, common.NewRevisionUnavailableError(fmt.Errorf("revision not available")) + } + + return nil, fmt.Errorf("not implemented") +} + +func (fakeSnapshotReader) ReadNamespaceByName(_ context.Context, nsName string) (ns *corev1.NamespaceDefinition, lastWritten datastore.Revision, err error) { + if nsName == "expecterror" { + return nil, nil, fmt.Errorf("raising an expected error") + } + + return nil, nil, fmt.Errorf("not implemented") +} + +func (fakeSnapshotReader) LookupCaveatsWithNames(_ context.Context, names []string) ([]datastore.RevisionedDefinition[*corev1.CaveatDefinition], error) { + return nil, fmt.Errorf("not implemented") +} + +func (fakeSnapshotReader) ReadCaveatByName(_ context.Context, name string) (caveat *corev1.CaveatDefinition, lastWritten datastore.Revision, err error) { + return nil, nil, fmt.Errorf("not implemented") +} + +func (fakeSnapshotReader) ListAllCaveats(context.Context) ([]datastore.RevisionedDefinition[*corev1.CaveatDefinition], error) { + return nil, fmt.Errorf("not implemented") +} + +func (fakeSnapshotReader) ListAllNamespaces(context.Context) ([]datastore.RevisionedDefinition[*corev1.NamespaceDefinition], error) { + return nil, nil +} + +func (fakeSnapshotReader) QueryRelationships(context.Context, datastore.RelationshipsFilter, ...options.QueryOptionsOption) (datastore.RelationshipIterator, error) { + return nil, fmt.Errorf("not implemented") +} + +func (fakeSnapshotReader) ReverseQueryRelationships(context.Context, datastore.SubjectsFilter, ...options.ReverseQueryOptionsOption) (datastore.RelationshipIterator, error) { + return nil, fmt.Errorf("not implemented") +} + +func (fakeSnapshotReader) CountRelationships(ctx context.Context, filter string) (int, error) { + return -1, fmt.Errorf("not implemented") +} + +func (fakeSnapshotReader) LookupCounters(ctx context.Context) ([]datastore.RelationshipCounter, error) { + return nil, fmt.Errorf("not implemented") +} diff --git a/pkg/cmd/datastore/datastore.go b/pkg/cmd/datastore/datastore.go index 4a0347afb6..0cd772d856 100644 --- a/pkg/cmd/datastore/datastore.go +++ b/pkg/cmd/datastore/datastore.go @@ -23,6 +23,8 @@ import ( type engineBuilderFunc func(ctx context.Context, options Config) (datastore.Datastore, error) +const MaxReplicaCount = 16 + const ( MemoryEngine = "memory" PostgresEngine = "postgres" @@ -108,6 +110,11 @@ type Config struct { EnableDatastoreMetrics bool `debugmap:"visible"` DisableStats bool `debugmap:"visible"` + // Read Replicas + ReadReplicaConnPool ConnPoolConfig `debugmap:"visible"` + ReadReplicaURIs []string `debugmap:"sensitive"` + ReadReplicaCredentialsProviderName string `debugmap:"visible"` + // Bootstrap BootstrapFiles []string `debugmap:"visible-format"` BootstrapFileContents map[string][]byte `debugmap:"visible"` @@ -170,11 +177,15 @@ func RegisterDatastoreFlagsWithPrefix(flagSet *pflag.FlagSet, prefix string, opt flagSet.StringVar(&opts.URI, flagName("datastore-conn-uri"), defaults.URI, `connection string used by remote datastores (e.g. "postgres://postgres:password@localhost:5432/spicedb")`) flagSet.StringVar(&opts.CredentialsProviderName, flagName("datastore-credentials-provider-name"), defaults.CredentialsProviderName, fmt.Sprintf(`retrieve datastore credentials dynamically using (%s)`, datastore.CredentialsProviderOptions())) + flagSet.StringArrayVar(&opts.ReadReplicaURIs, flagName("datastore-read-replica-conn-uri"), []string{}, "connection string used by remote datastores for read replicas (e.g. \"postgres://postgres:password@localhost:5432/spicedb\"). Only supported for postgres and mysql.") + flagSet.StringVar(&opts.ReadReplicaCredentialsProviderName, flagName("datastore-read-replica-credentials-provider-name"), defaults.CredentialsProviderName, fmt.Sprintf(`retrieve datastore credentials dynamically using (%s)`, datastore.CredentialsProviderOptions())) + var legacyConnPool ConnPoolConfig RegisterConnPoolFlagsWithPrefix(flagSet, "datastore-conn", DefaultReadConnPool(), &legacyConnPool) deprecateUnifiedConnFlags(flagSet) RegisterConnPoolFlagsWithPrefix(flagSet, "datastore-conn-pool-read", &legacyConnPool, &opts.ReadConnPool) RegisterConnPoolFlagsWithPrefix(flagSet, "datastore-conn-pool-write", DefaultWriteConnPool(), &opts.WriteConnPool) + RegisterConnPoolFlagsWithPrefix(flagSet, "datastore-read-replica-conn-pool", DefaultReadConnPool(), &opts.ReadReplicaConnPool) normalizeFunc := flagSet.GetNormalizeFunc() flagSet.SetNormalizeFunc(func(f *pflag.FlagSet, name string) pflag.NormalizedName { @@ -247,6 +258,8 @@ func DefaultDatastoreConfig() *Config { MaxRevisionStalenessPercent: .1, // 10% ReadConnPool: *DefaultReadConnPool(), WriteConnPool: *DefaultWriteConnPool(), + ReadReplicaConnPool: *DefaultReadConnPool(), + ReadReplicaURIs: []string{}, ReadOnly: false, MaxRetries: 10, OverlapKey: "key", @@ -361,6 +374,10 @@ func NewDatastore(ctx context.Context, options ...ConfigOption) (datastore.Datas } func newCRDBDatastore(ctx context.Context, opts Config) (datastore.Datastore, error) { + if len(opts.ReadReplicaURIs) > 0 { + return nil, errors.New("read replicas are not supported for the CockroachDB datastore engine") + } + return crdb.NewCRDBDatastore( ctx, opts.URI, @@ -392,6 +409,52 @@ func newCRDBDatastore(ctx context.Context, opts Config) (datastore.Datastore, er } func newPostgresDatastore(ctx context.Context, opts Config) (datastore.Datastore, error) { + primary, err := newPostgresPrimaryDatastore(ctx, opts) + if err != nil { + return nil, fmt.Errorf("failed to create primary datastore: %w", err) + } + + if len(opts.ReadReplicaURIs) > MaxReplicaCount { + return nil, fmt.Errorf("too many read replicas, max is %d", MaxReplicaCount) + } + + replicas := make([]datastore.StrictReadDatastore, 0, len(opts.ReadReplicaURIs)) + for index, replicaURI := range opts.ReadReplicaURIs { + replica, err := newPostgresReplicaDatastore(ctx, uint32(index), replicaURI, opts) + if err != nil { + return nil, err + } + replicas = append(replicas, replica) + } + + return proxy.NewStrictReplicatedDatastore(primary, replicas...) +} + +func commonPostgresDatastoreOptions(opts Config) []postgres.Option { + return []postgres.Option{ + postgres.EnableTracing(), + postgres.WithEnablePrometheusStats(opts.EnableDatastoreMetrics), + postgres.MaxRetries(uint8(opts.MaxRetries)), + } +} + +func newPostgresReplicaDatastore(ctx context.Context, replicaIndex uint32, replicaURI string, opts Config) (datastore.StrictReadDatastore, error) { + pgOpts := []postgres.Option{ + postgres.CredentialsProviderName(opts.ReadReplicaCredentialsProviderName), + postgres.ReadConnsMaxOpen(opts.ReadReplicaConnPool.MaxOpenConns), + postgres.ReadConnsMinOpen(opts.ReadReplicaConnPool.MinOpenConns), + postgres.ReadConnMaxIdleTime(opts.ReadReplicaConnPool.MaxIdleTime), + postgres.ReadConnMaxLifetime(opts.ReadReplicaConnPool.MaxLifetime), + postgres.ReadConnMaxLifetimeJitter(opts.ReadReplicaConnPool.MaxLifetimeJitter), + postgres.ReadConnHealthCheckInterval(opts.ReadReplicaConnPool.HealthCheckInterval), + postgres.ReadStrictMode( /* strict read mode is required for Postgres read replicas */ true), + } + + pgOpts = append(pgOpts, commonPostgresDatastoreOptions(opts)...) + return postgres.NewReadOnlyPostgresDatastore(ctx, replicaURI, replicaIndex, pgOpts...) +} + +func newPostgresPrimaryDatastore(ctx context.Context, opts Config) (datastore.Datastore, error) { pgOpts := []postgres.Option{ postgres.CredentialsProviderName(opts.CredentialsProviderName), postgres.GCWindow(opts.GCWindow), @@ -412,17 +475,20 @@ func newPostgresDatastore(ctx context.Context, opts Config) (datastore.Datastore postgres.WriteConnHealthCheckInterval(opts.WriteConnPool.HealthCheckInterval), postgres.GCInterval(opts.GCInterval), postgres.GCMaxOperationTime(opts.GCMaxOperationTime), - postgres.EnableTracing(), postgres.WatchBufferLength(opts.WatchBufferLength), postgres.WatchBufferWriteTimeout(opts.WatchBufferWriteTimeout), - postgres.WithEnablePrometheusStats(opts.EnableDatastoreMetrics), - postgres.MaxRetries(uint8(opts.MaxRetries)), postgres.MigrationPhase(opts.MigrationPhase), } + + pgOpts = append(pgOpts, commonPostgresDatastoreOptions(opts)...) return postgres.NewPostgresDatastore(ctx, opts.URI, pgOpts...) } func newSpannerDatastore(ctx context.Context, opts Config) (datastore.Datastore, error) { + if len(opts.ReadReplicaURIs) > 0 { + return nil, errors.New("read replicas are not supported for the Spanner datastore engine") + } + return spanner.NewSpannerDatastore( ctx, opts.URI, @@ -444,6 +510,53 @@ func newSpannerDatastore(ctx context.Context, opts Config) (datastore.Datastore, } func newMySQLDatastore(ctx context.Context, opts Config) (datastore.Datastore, error) { + primary, err := newMySQLPrimaryDatastore(ctx, opts) + if err != nil { + return nil, err + } + + if len(opts.ReadReplicaURIs) > MaxReplicaCount { + return nil, fmt.Errorf("too many read replicas, max is %d", MaxReplicaCount) + } + + replicas := make([]datastore.ReadOnlyDatastore, 0, len(opts.ReadReplicaURIs)) + for index, replicaURI := range opts.ReadReplicaURIs { + replica, err := newMySQLReplicaDatastore(ctx, uint32(index), replicaURI, opts) + if err != nil { + return nil, err + } + replicas = append(replicas, replica) + } + + return proxy.NewCheckingReplicatedDatastore(primary, replicas...) +} + +func commonMySQLDatastoreOptions(opts Config) []mysql.Option { + return []mysql.Option{ + mysql.TablePrefix(opts.TablePrefix), + mysql.MaxRetries(uint8(opts.MaxRetries)), + mysql.OverrideLockWaitTimeout(1), + mysql.WithEnablePrometheusStats(opts.EnableDatastoreMetrics), + mysql.MaxRevisionStalenessPercent(opts.MaxRevisionStalenessPercent), + mysql.RevisionQuantization(opts.RevisionQuantization), + } +} + +func newMySQLReplicaDatastore(ctx context.Context, replicaIndex uint32, replicaURI string, opts Config) (datastore.ReadOnlyDatastore, error) { + mysqlOpts := []mysql.Option{ + mysql.MaxOpenConns(opts.ReadReplicaConnPool.MaxOpenConns), + mysql.ConnMaxIdleTime(opts.ReadReplicaConnPool.MaxIdleTime), + mysql.ConnMaxLifetime(opts.ReadReplicaConnPool.MaxLifetime), + mysql.WatchBufferLength(opts.WatchBufferLength), + mysql.WatchBufferWriteTimeout(opts.WatchBufferWriteTimeout), + mysql.CredentialsProviderName(opts.ReadReplicaCredentialsProviderName), + } + + mysqlOpts = append(mysqlOpts, commonMySQLDatastoreOptions(opts)...) + return mysql.NewReadOnlyMySQLDatastore(ctx, replicaURI, replicaIndex, mysqlOpts...) +} + +func newMySQLPrimaryDatastore(ctx context.Context, opts Config) (datastore.Datastore, error) { mysqlOpts := []mysql.Option{ mysql.GCInterval(opts.GCInterval), mysql.GCWindow(opts.GCWindow), @@ -453,20 +566,20 @@ func newMySQLDatastore(ctx context.Context, opts Config) (datastore.Datastore, e mysql.MaxOpenConns(opts.ReadConnPool.MaxOpenConns), mysql.ConnMaxIdleTime(opts.ReadConnPool.MaxIdleTime), mysql.ConnMaxLifetime(opts.ReadConnPool.MaxLifetime), - mysql.RevisionQuantization(opts.RevisionQuantization), - mysql.MaxRevisionStalenessPercent(opts.MaxRevisionStalenessPercent), - mysql.TablePrefix(opts.TablePrefix), mysql.WatchBufferLength(opts.WatchBufferLength), mysql.WatchBufferWriteTimeout(opts.WatchBufferWriteTimeout), - mysql.WithEnablePrometheusStats(opts.EnableDatastoreMetrics), - mysql.MaxRetries(uint8(opts.MaxRetries)), - mysql.OverrideLockWaitTimeout(1), mysql.CredentialsProviderName(opts.CredentialsProviderName), } + + mysqlOpts = append(mysqlOpts, commonMySQLDatastoreOptions(opts)...) return mysql.NewMySQLDatastore(ctx, opts.URI, mysqlOpts...) } func newMemoryDatstore(_ context.Context, opts Config) (datastore.Datastore, error) { + if len(opts.ReadReplicaURIs) > 0 { + return nil, errors.New("read replicas are not supported for the in-memory datastore engine") + } + log.Warn().Msg("in-memory datastore is not persistent and not feasible to run in a high availability fashion") return memdb.NewMemdbDatastore(opts.WatchBufferLength, opts.RevisionQuantization, opts.GCWindow) } diff --git a/pkg/cmd/datastore/zz_generated.options.go b/pkg/cmd/datastore/zz_generated.options.go index be686e26ea..779180dd58 100644 --- a/pkg/cmd/datastore/zz_generated.options.go +++ b/pkg/cmd/datastore/zz_generated.options.go @@ -43,6 +43,9 @@ func (c *Config) ToOption() ConfigOption { to.ReadOnly = c.ReadOnly to.EnableDatastoreMetrics = c.EnableDatastoreMetrics to.DisableStats = c.DisableStats + to.ReadReplicaConnPool = c.ReadReplicaConnPool + to.ReadReplicaURIs = c.ReadReplicaURIs + to.ReadReplicaCredentialsProviderName = c.ReadReplicaCredentialsProviderName to.BootstrapFiles = c.BootstrapFiles to.BootstrapFileContents = c.BootstrapFileContents to.BootstrapOverwrite = c.BootstrapOverwrite @@ -86,6 +89,9 @@ func (c Config) DebugMap() map[string]any { debugMap["ReadOnly"] = helpers.DebugValue(c.ReadOnly, false) debugMap["EnableDatastoreMetrics"] = helpers.DebugValue(c.EnableDatastoreMetrics, false) debugMap["DisableStats"] = helpers.DebugValue(c.DisableStats, false) + debugMap["ReadReplicaConnPool"] = helpers.DebugValue(c.ReadReplicaConnPool, false) + debugMap["ReadReplicaURIs"] = helpers.SensitiveDebugValue(c.ReadReplicaURIs) + debugMap["ReadReplicaCredentialsProviderName"] = helpers.DebugValue(c.ReadReplicaCredentialsProviderName, false) debugMap["BootstrapFiles"] = helpers.DebugValue(c.BootstrapFiles, true) debugMap["BootstrapFileContents"] = helpers.DebugValue(c.BootstrapFileContents, false) debugMap["BootstrapOverwrite"] = helpers.DebugValue(c.BootstrapOverwrite, false) @@ -214,6 +220,34 @@ func WithDisableStats(disableStats bool) ConfigOption { } } +// WithReadReplicaConnPool returns an option that can set ReadReplicaConnPool on a Config +func WithReadReplicaConnPool(readReplicaConnPool ConnPoolConfig) ConfigOption { + return func(c *Config) { + c.ReadReplicaConnPool = readReplicaConnPool + } +} + +// WithReadReplicaURIs returns an option that can append ReadReplicaURIss to Config.ReadReplicaURIs +func WithReadReplicaURIs(readReplicaURIs string) ConfigOption { + return func(c *Config) { + c.ReadReplicaURIs = append(c.ReadReplicaURIs, readReplicaURIs) + } +} + +// SetReadReplicaURIs returns an option that can set ReadReplicaURIs on a Config +func SetReadReplicaURIs(readReplicaURIs []string) ConfigOption { + return func(c *Config) { + c.ReadReplicaURIs = readReplicaURIs + } +} + +// WithReadReplicaCredentialsProviderName returns an option that can set ReadReplicaCredentialsProviderName on a Config +func WithReadReplicaCredentialsProviderName(readReplicaCredentialsProviderName string) ConfigOption { + return func(c *Config) { + c.ReadReplicaCredentialsProviderName = readReplicaCredentialsProviderName + } +} + // WithBootstrapFiles returns an option that can append BootstrapFiless to Config.BootstrapFiles func WithBootstrapFiles(bootstrapFiles string) ConfigOption { return func(c *Config) { diff --git a/pkg/datastore/datastore.go b/pkg/datastore/datastore.go index b9e10b041b..99ba61cf02 100644 --- a/pkg/datastore/datastore.go +++ b/pkg/datastore/datastore.go @@ -541,16 +541,12 @@ func (wo WatchOptions) WithCheckpointInterval(interval time.Duration) WatchOptio } } -// Datastore represents tuple access for a single namespace. -type Datastore interface { +// ReadOnlyDatastore is an interface for reading relationships from the datastore. +type ReadOnlyDatastore interface { // SnapshotReader creates a read-only handle that reads the datastore at the specified revision. // Any errors establishing the reader will be returned by subsequent calls. SnapshotReader(Revision) Reader - // ReadWriteTx starts a read/write transaction, which will be committed if no error is - // returned and rolled back if an error is returned. - ReadWriteTx(context.Context, TxUserFunc, ...options.RWTOptionsOption) (Revision, error) - // OptimizedRevision gets a revision that will likely already be replicated // and will likely be shared amongst many queries. OptimizedRevision(ctx context.Context) (Revision, error) @@ -588,6 +584,23 @@ type Datastore interface { Close() error } +// Datastore represents tuple access for a single namespace. +type Datastore interface { + ReadOnlyDatastore + + // ReadWriteTx starts a read/write transaction, which will be committed if no error is + // returned and rolled back if an error is returned. + ReadWriteTx(context.Context, TxUserFunc, ...options.RWTOptionsOption) (Revision, error) +} + +// StrictReadDatastore is an interface for datastores that support strict read mode. +type StrictReadDatastore interface { + Datastore + + // IsStrictReadModeEnabled returns whether the datastore is in strict read mode. + IsStrictReadModeEnabled() bool +} + type strArray []string // MarshalZerologArray implements zerolog array marshalling. diff --git a/pkg/datastore/test/datastore.go b/pkg/datastore/test/datastore.go index 20cccd1fdd..2be74ade23 100644 --- a/pkg/datastore/test/datastore.go +++ b/pkg/datastore/test/datastore.go @@ -131,6 +131,7 @@ func AllWithExceptions(t *testing.T, tester DatastoreTester, except Categories) t.Run("TestRevisionSerialization", func(t *testing.T) { RevisionSerializationTest(t, tester) }) t.Run("TestSequentialRevisions", func(t *testing.T) { SequentialRevisionsTest(t, tester) }) t.Run("TestConcurrentRevisions", func(t *testing.T) { ConcurrentRevisionsTest(t, tester) }) + t.Run("TestCheckRevisions", func(t *testing.T) { CheckRevisionsTest(t, tester) }) if !except.GC() { t.Run("TestRevisionGC", func(t *testing.T) { RevisionGCTest(t, tester) }) diff --git a/pkg/datastore/test/revisions.go b/pkg/datastore/test/revisions.go index 1c42d69932..b6eef01533 100644 --- a/pkg/datastore/test/revisions.go +++ b/pkg/datastore/test/revisions.go @@ -172,6 +172,46 @@ func RevisionGCTest(t *testing.T, tester DatastoreTester) { require.Error(ds.CheckRevision(ctx, previousRev), "expected revision head-1 to be outside GC Window") } +func CheckRevisionsTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + + ds, err := tester.New(0, 1000*time.Second, 300*time.Minute, 1) + require.NoError(err) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Write a new revision. + writtenRev, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + return rwt.WriteNamespaces(ctx, ns.Namespace("foo/somethingnew1")) + }) + require.NoError(err) + require.NoError(ds.CheckRevision(ctx, writtenRev), "expected written revision to be valid in GC Window") + + head, err := ds.HeadRevision(ctx) + require.NoError(err) + + // Check the head revision is valid + require.NoError(ds.CheckRevision(ctx, head), "expected head revision to be valid in GC Window") + + // Write a new revision. + writtenRev, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + return rwt.WriteNamespaces(ctx, ns.Namespace("foo/somethingnew2")) + }) + require.NoError(err) + require.NoError(ds.CheckRevision(ctx, writtenRev), "expected written revision to be valid in GC Window") + + // Check the previous head revision is still valid + require.NoError(ds.CheckRevision(ctx, head), "expected previous revision to be valid in GC Window") + + // Get the updated head revision. + head, err = ds.HeadRevision(ctx) + require.NoError(err) + + // Check the new head revision is valid. + require.NoError(ds.CheckRevision(ctx, head), "expected head revision to be valid in GC Window") +} + func SequentialRevisionsTest(t *testing.T, tester DatastoreTester) { require := require.New(t)