diff --git a/internal/datastore/postgres/common/bulk.go b/internal/datastore/postgres/common/bulk.go index 9ba4ebea90..51de5a8073 100644 --- a/internal/datastore/postgres/common/bulk.go +++ b/internal/datastore/postgres/common/bulk.go @@ -2,7 +2,6 @@ package common import ( "context" - "database/sql" "github.com/jackc/pgx/v5" @@ -29,11 +28,10 @@ func (tg *tupleSourceAdapter) Next() bool { // Values returns the values for the current row. func (tg *tupleSourceAdapter) Values() ([]any, error) { - var caveatName sql.NullString + var caveatName string var caveatContext map[string]any if tg.current.Caveat != nil { - caveatName.String = tg.current.Caveat.CaveatName - caveatName.Valid = true + caveatName = tg.current.Caveat.CaveatName caveatContext = tg.current.Caveat.Context.AsMap() } diff --git a/internal/datastore/postgres/readwrite.go b/internal/datastore/postgres/readwrite.go index a8f689a921..dd34218c26 100644 --- a/internal/datastore/postgres/readwrite.go +++ b/internal/datastore/postgres/readwrite.go @@ -725,9 +725,7 @@ func exactRelationshipDifferentCaveatClause(r *core.RelationTuple) sq.And { colUsersetRelation: r.Subject.Relation, }, sq.Or{ - sq.NotEq{ - colCaveatContextName: caveatName, - }, + sq.Expr(fmt.Sprintf(`%s IS DISTINCT FROM ?`, colCaveatContextName), caveatName), sq.NotEq{ colCaveatContext: caveatContext, }, diff --git a/pkg/datastore/test/bulk.go b/pkg/datastore/test/bulk.go index a5fa31c576..1e20e3eb19 100644 --- a/pkg/datastore/test/bulk.go +++ b/pkg/datastore/test/bulk.go @@ -119,6 +119,79 @@ func BulkUploadAlreadyExistsSameCallErrorTest(t *testing.T, tester DatastoreTest grpcutil.RequireStatus(t, codes.AlreadyExists, err) } +func BulkUploadEditCaveat(t *testing.T, tester DatastoreTester) { + tc := 10 + require := require.New(t) + ctx := context.Background() + + rawDS, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + + ds, _ := testfixtures.StandardDatastoreWithSchema(rawDS, require) + bulkSource := testfixtures.NewBulkTupleGenerator( + testfixtures.DocumentNS.Name, + "caveated_viewer", + testfixtures.UserNS.Name, + tc, + t, + ) + + lastRevision, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + loaded, err := rwt.BulkLoad(ctx, bulkSource) + require.NoError(err) + require.Equal(uint64(tc), loaded) + return err + }) + require.NoError(err) + + iter, err := ds.SnapshotReader(lastRevision).QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: testfixtures.DocumentNS.Name, + }) + require.NoError(err) + defer iter.Close() + + updates := make([]*core.RelationTupleUpdate, 0, tc) + + for found := iter.Next(); found != nil; found = iter.Next() { + updates = append(updates, &core.RelationTupleUpdate{ + Operation: core.RelationTupleUpdate_TOUCH, + Tuple: &core.RelationTuple{ + ResourceAndRelation: found.ResourceAndRelation, + Subject: found.Subject, + Caveat: &core.ContextualizedCaveat{ + CaveatName: testfixtures.CaveatDef.Name, + Context: nil, + }, + }, + }) + } + + require.Equal(tc, len(updates)) + + lastRevision, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + err := rwt.WriteRelationships(ctx, updates) + require.NoError(err) + return err + }) + require.NoError(err) + + iter, err = ds.SnapshotReader(lastRevision).QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: testfixtures.DocumentNS.Name, + }) + require.NoError(err) + defer iter.Close() + + foundChanged := 0 + + for found := iter.Next(); found != nil; found = iter.Next() { + require.NotNil(found.Caveat) + require.NotEmpty(found.Caveat.CaveatName) + foundChanged++ + } + + require.Equal(tc, foundChanged) +} + func BulkUploadAlreadyExistsErrorTest(t *testing.T, tester DatastoreTester) { require := require.New(t) ctx := context.Background() diff --git a/pkg/datastore/test/datastore.go b/pkg/datastore/test/datastore.go index a38751cc33..4cc10ef019 100644 --- a/pkg/datastore/test/datastore.go +++ b/pkg/datastore/test/datastore.go @@ -140,6 +140,7 @@ func AllWithExceptions(t *testing.T, tester DatastoreTester, except Categories) t.Run("TestBulkUploadErrors", func(t *testing.T) { BulkUploadErrorsTest(t, tester) }) t.Run("TestBulkUploadAlreadyExistsError", func(t *testing.T) { BulkUploadAlreadyExistsErrorTest(t, tester) }) t.Run("TestBulkUploadAlreadyExistsSameCallError", func(t *testing.T) { BulkUploadAlreadyExistsSameCallErrorTest(t, tester) }) + t.Run("BulkUploadEditCaveat", func(t *testing.T) { BulkUploadEditCaveat(t, tester) }) if !except.Stats() { t.Run("TestStats", func(t *testing.T) { StatsTest(t, tester) })