diff --git a/internal/datastore/memdb/watch.go b/internal/datastore/memdb/watch.go index 528eb37fba..262fbadf6f 100644 --- a/internal/datastore/memdb/watch.go +++ b/internal/datastore/memdb/watch.go @@ -32,7 +32,7 @@ func (mdb *memdbDatastore) Watch(ctx context.Context, ar datastore.Revision, opt var stagedUpdates []*datastore.RevisionChanges var watchChan <-chan struct{} var err error - stagedUpdates, currentTxn, watchChan, err = mdb.loadChanges(ctx, currentTxn) + stagedUpdates, currentTxn, watchChan, err = mdb.loadChanges(ctx, currentTxn, options) if err != nil { errs <- err return @@ -40,10 +40,6 @@ func (mdb *memdbDatastore) Watch(ctx context.Context, ar datastore.Revision, opt // Write the staged updates to the channel for _, changeToWrite := range stagedUpdates { - if len(changeToWrite.RelationshipChanges) == 0 { - continue - } - select { case updates <- changeToWrite: default: @@ -72,7 +68,7 @@ func (mdb *memdbDatastore) Watch(ctx context.Context, ar datastore.Revision, opt return updates, errs } -func (mdb *memdbDatastore) loadChanges(_ context.Context, currentTxn int64) ([]*datastore.RevisionChanges, int64, <-chan struct{}, error) { +func (mdb *memdbDatastore) loadChanges(_ context.Context, currentTxn int64, options datastore.WatchOptions) ([]*datastore.RevisionChanges, int64, <-chan struct{}, error) { mdb.RLock() defer mdb.RUnlock() @@ -88,7 +84,23 @@ func (mdb *memdbDatastore) loadChanges(_ context.Context, currentTxn int64) ([]* lastRevision := currentTxn for changeRaw := it.Next(); changeRaw != nil; changeRaw = it.Next() { change := changeRaw.(*changelog) - changes = append(changes, &change.changes) + + if options.Content&datastore.WatchRelationships == datastore.WatchRelationships && len(change.changes.RelationshipChanges) > 0 { + changes = append(changes, &change.changes) + } + + if options.Content&datastore.WatchCheckpoints == datastore.WatchCheckpoints && change.revisionNanos > lastRevision { + changes = append(changes, &datastore.RevisionChanges{ + Revision: revisions.NewForTimestamp(change.revisionNanos), + IsCheckpoint: true, + }) + } + + if options.Content&datastore.WatchSchema == datastore.WatchSchema && + len(change.changes.ChangedDefinitions) > 0 || len(change.changes.DeletedCaveats) > 0 || len(change.changes.DeletedNamespaces) > 0 { + changes = append(changes, &change.changes) + } + lastRevision = change.revisionNanos } diff --git a/internal/datastore/mysql/datastore_test.go b/internal/datastore/mysql/datastore_test.go index aa7bbd0f6d..808fd4e671 100644 --- a/internal/datastore/mysql/datastore_test.go +++ b/internal/datastore/mysql/datastore_test.go @@ -100,7 +100,7 @@ func TestMySQLDatastoreDSNWithoutParseTime(t *testing.T) { func TestMySQL8Datastore(t *testing.T) { b := testdatastore.RunMySQLForTestingWithOptions(t, testdatastore.MySQLTesterOptions{MigrateForNewDatastore: true}, "") dst := datastoreTester{b: b, t: t} - test.AllWithExceptions(t, test.DatastoreTesterFunc(dst.createDatastore), test.WithCategories(test.WatchSchemaCategory)) + test.AllWithExceptions(t, test.DatastoreTesterFunc(dst.createDatastore), test.WithCategories(test.WatchSchemaCategory, test.WatchCheckpointsCategory)) additionalMySQLTests(t, b) } diff --git a/pkg/datastore/test/datastore.go b/pkg/datastore/test/datastore.go index b9b20318e6..0addafb726 100644 --- a/pkg/datastore/test/datastore.go +++ b/pkg/datastore/test/datastore.go @@ -56,12 +56,18 @@ func (c Categories) WatchSchema() bool { return ok } +func (c Categories) WatchCheckpoints() bool { + _, ok := c[WatchCheckpointsCategory] + return ok +} + var noException = Categories{} const ( - GCCategory = "GC" - WatchCategory = "Watch" - WatchSchemaCategory = "WatchSchema" + GCCategory = "GC" + WatchCategory = "Watch" + WatchSchemaCategory = "WatchSchema" + WatchCheckpointsCategory = "WatchCheckpoints" ) func WithCategories(cats ...string) Categories { @@ -138,6 +144,10 @@ func AllWithExceptions(t *testing.T, tester DatastoreTester, except Categories) t.Run("TestWatchSchema", func(t *testing.T) { WatchSchemaTest(t, tester) }) t.Run("TestWatchAll", func(t *testing.T) { WatchAllTest(t, tester) }) } + + if !except.Watch() && !except.WatchCheckpoints() { + t.Run("TestWatchCheckpoints", func(t *testing.T) { WatchCheckpointsTest(t, tester) }) + } } // All runs all generic datastore tests on a DatastoreTester. diff --git a/pkg/datastore/test/watch.go b/pkg/datastore/test/watch.go index f4f6babbe9..dde08ecc58 100644 --- a/pkg/datastore/test/watch.go +++ b/pkg/datastore/test/watch.go @@ -607,3 +607,50 @@ func verifyMixedUpdates( require.False(expectDisconnect, "all changes verified without expected disconnect") } + +func WatchCheckpointsTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + + ds, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 16) + require.NoError(err) + + setupDatastore(ds, require) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + lowestRevision, err := ds.HeadRevision(ctx) + require.NoError(err) + + changes, errchan := ds.Watch(ctx, lowestRevision, datastore.WatchOptions{ + Content: datastore.WatchCheckpoints | datastore.WatchRelationships, + CheckpointInterval: 100 * time.Millisecond, + }) + require.Zero(len(errchan)) + + afterTouchRevision, err := common.WriteTuples(ctx, ds, core.RelationTupleUpdate_TOUCH, + tuple.Parse("document:firstdoc#viewer@user:tom"), + ) + require.NoError(err) + verifyCheckpointUpdate(require, afterTouchRevision, changes) +} + +func verifyCheckpointUpdate( + require *require.Assertions, + expectedRevision datastore.Revision, + changes <-chan *datastore.RevisionChanges, +) { + changeWait := time.NewTimer(waitForChangesTimeout) + for { + select { + case change, ok := <-changes: + require.True(ok) + if change.IsCheckpoint { + require.True(change.Revision.Equal(change.Revision) || change.Revision.GreaterThan(expectedRevision)) + return + } + case <-changeWait.C: + require.Fail("Timed out", "waited for checkpoint") + } + } +}