diff --git a/internal/caveats/run_test.go b/internal/caveats/run_test.go index 390edb1884..2a799a1da9 100644 --- a/internal/caveats/run_test.go +++ b/internal/caveats/run_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/authzed/spicedb/internal/caveats" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/testfixtures" "github.com/authzed/spicedb/pkg/datastore" @@ -448,7 +449,7 @@ func TestRunCaveatExpressions(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) req.NoError(err) ds, _ := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, ` @@ -507,7 +508,7 @@ func TestRunCaveatExpressions(t *testing.T) { func TestRunCaveatWithMissingMap(t *testing.T) { req := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) req.NoError(err) ds, _ := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, ` @@ -536,7 +537,7 @@ func TestRunCaveatWithMissingMap(t *testing.T) { func TestRunCaveatWithEmptyMap(t *testing.T) { req := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) req.NoError(err) ds, _ := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, ` diff --git a/internal/datastore/common/relationships.go b/internal/datastore/common/relationships.go new file mode 100644 index 0000000000..4543a46013 --- /dev/null +++ b/internal/datastore/common/relationships.go @@ -0,0 +1,159 @@ +package common + +import ( + "context" + "database/sql" + "fmt" + "time" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/authzed/spicedb/pkg/datastore" + corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +const errUnableToQueryRels = "unable to query relationships: %w" + +// Querier is an interface for querying the database. +type Querier[R Rows] interface { + QueryFunc(ctx context.Context, f func(context.Context, R) error, sql string, args ...any) error +} + +// Rows is a common interface for database rows reading. +type Rows interface { + Scan(dest ...any) error + Next() bool + Err() error +} + +type closeRowsWithError interface { + Rows + Close() error +} + +type closeRows interface { + Rows + Close() +} + +// QueryRelationships queries relationships for the given query and transaction. +func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, builder RelationshipsQueryBuilder, tx Querier[R]) (datastore.RelationshipIterator, error) { + span := trace.SpanFromContext(ctx) + sqlString, args, err := builder.SelectSQL() + if err != nil { + return nil, fmt.Errorf(errUnableToQueryRels, err) + } + + var resourceObjectType string + var resourceObjectID string + var resourceRelation string + var subjectObjectType string + var subjectObjectID string + var subjectRelation string + var caveatName sql.NullString + var caveatCtx C + var expiration *time.Time + + var integrityKeyID string + var integrityHash []byte + var timestamp time.Time + + span.AddEvent("Selecting columns") + colsToSelect, err := ColumnsToSelect(builder, &resourceObjectType, &resourceObjectID, &resourceRelation, &subjectObjectType, &subjectObjectID, &subjectRelation, &caveatName, &caveatCtx, &expiration, &integrityKeyID, &integrityHash, ×tamp) + if err != nil { + return nil, fmt.Errorf(errUnableToQueryRels, err) + } + + span.AddEvent("Returning iterator", trace.WithAttributes(attribute.Int("column-count", len(colsToSelect)))) + return func(yield func(tuple.Relationship, error) bool) { + span.AddEvent("Issuing query to database") + err := tx.QueryFunc(ctx, func(ctx context.Context, rows R) error { + span.AddEvent("Query issued to database") + + var r Rows = rows + if crwe, ok := r.(closeRowsWithError); ok { + defer LogOnError(ctx, crwe.Close) + } else if cr, ok := r.(closeRows); ok { + defer cr.Close() + } + + relCount := 0 + for rows.Next() { + if relCount == 0 { + span.AddEvent("First row returned") + } + + if err := rows.Scan(colsToSelect...); err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("scan err: %w", err)) + } + + if relCount == 0 { + span.AddEvent("First row scanned") + } + + var caveat *corev1.ContextualizedCaveat + if !builder.SkipCaveats || builder.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + if caveatName.Valid { + var err error + caveat, err = ContextualizedCaveatFrom(caveatName.String, caveatCtx) + if err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("unable to fetch caveat context: %w", err)) + } + } + } + + var integrity *corev1.RelationshipIntegrity + if integrityKeyID != "" { + integrity = &corev1.RelationshipIntegrity{ + KeyId: integrityKeyID, + Hash: integrityHash, + HashedAt: timestamppb.New(timestamp), + } + } + + if expiration != nil { + // Ensure the expiration is always read in UTC, since some datastores (like CRDB) + // will normalize to local time. + t := expiration.UTC() + expiration = &t + } + + relCount++ + if !yield(tuple.Relationship{ + RelationshipReference: tuple.RelationshipReference{ + Resource: tuple.ObjectAndRelation{ + ObjectType: resourceObjectType, + ObjectID: resourceObjectID, + Relation: resourceRelation, + }, + Subject: tuple.ObjectAndRelation{ + ObjectType: subjectObjectType, + ObjectID: subjectObjectID, + Relation: subjectRelation, + }, + }, + OptionalCaveat: caveat, + OptionalExpiration: expiration, + OptionalIntegrity: integrity, + }, nil) { + return nil + } + } + + span.AddEvent("Relationships loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) + if err := rows.Err(); err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("rows err: %w", err)) + } + + return nil + }, sqlString, args...) + if err != nil { + if !yield(tuple.Relationship{}, err) { + return + } + } + }, nil +} diff --git a/internal/datastore/common/schema.go b/internal/datastore/common/schema.go new file mode 100644 index 0000000000..542dc5dbf3 --- /dev/null +++ b/internal/datastore/common/schema.go @@ -0,0 +1,134 @@ +package common + +import ( + sq "github.com/Masterminds/squirrel" + + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +const ( + relationshipStandardColumnCount = 6 // ColNamespace, ColObjectID, ColRelation, ColUsersetNamespace, ColUsersetObjectID, ColUsersetRelation + relationshipCaveatColumnCount = 2 // ColCaveatName, ColCaveatContext + relationshipExpirationColumnCount = 1 // ColExpiration + relationshipIntegrityColumnCount = 3 // ColIntegrityKeyID, ColIntegrityHash, ColIntegrityTimestamp +) + +// SchemaInformation holds the schema information from the SQL datastore implementation. +// +//go:generate go run github.com/ecordell/optgen -output zz_generated.schema_options.go . SchemaInformation +type SchemaInformation struct { + RelationshipTableName string `debugmap:"visible"` + + ColNamespace string `debugmap:"visible"` + ColObjectID string `debugmap:"visible"` + ColRelation string `debugmap:"visible"` + ColUsersetNamespace string `debugmap:"visible"` + ColUsersetObjectID string `debugmap:"visible"` + ColUsersetRelation string `debugmap:"visible"` + + ColCaveatName string `debugmap:"visible"` + ColCaveatContext string `debugmap:"visible"` + + ColExpiration string `debugmap:"visible"` + + ColIntegrityKeyID string `debugmap:"visible"` + ColIntegrityHash string `debugmap:"visible"` + ColIntegrityTimestamp string `debugmap:"visible"` + + // PaginationFilterType is the type of pagination filter to use for this schema. + PaginationFilterType PaginationFilterType `debugmap:"visible"` + + // PlaceholderFormat is the format of placeholders to use for this schema. + PlaceholderFormat sq.PlaceholderFormat `debugmap:"visible"` + + // NowFunction is the function to use to get the current time in the datastore. + NowFunction string `debugmap:"visible"` + + // ColumnOptimization is the optimization to use for columns in the schema, if any. + ColumnOptimization ColumnOptimizationOption `debugmap:"visible"` + + // IntegrityEnabled is a flag to indicate if the schema has integrity columns. + IntegrityEnabled bool `debugmap:"visible"` + + // ExpirationDisabled is a flag to indicate whether expiration support is disabled. + ExpirationDisabled bool `debugmap:"visible"` +} + +func (si SchemaInformation) debugValidate() { + spiceerrors.DebugAssert(func() bool { + si.mustValidate() + return true + }, "SchemaInformation failed to validate") +} + +func (si SchemaInformation) mustValidate() { + if si.RelationshipTableName == "" { + panic("RelationshipTableName is required") + } + + if si.ColNamespace == "" { + panic("ColNamespace is required") + } + + if si.ColObjectID == "" { + panic("ColObjectID is required") + } + + if si.ColRelation == "" { + panic("ColRelation is required") + } + + if si.ColUsersetNamespace == "" { + panic("ColUsersetNamespace is required") + } + + if si.ColUsersetObjectID == "" { + panic("ColUsersetObjectID is required") + } + + if si.ColUsersetRelation == "" { + panic("ColUsersetRelation is required") + } + + if si.ColCaveatName == "" { + panic("ColCaveatName is required") + } + + if si.ColCaveatContext == "" { + panic("ColCaveatContext is required") + } + + if si.ColExpiration == "" { + panic("ColExpiration is required") + } + + if si.IntegrityEnabled { + if si.ColIntegrityKeyID == "" { + panic("ColIntegrityKeyID is required") + } + + if si.ColIntegrityHash == "" { + panic("ColIntegrityHash is required") + } + + if si.ColIntegrityTimestamp == "" { + panic("ColIntegrityTimestamp is required") + } + } + + if si.NowFunction == "" { + panic("NowFunction is required") + } + + if si.ColumnOptimization == ColumnOptimizationOptionUnknown { + panic("ColumnOptimization is required") + } + + if si.PaginationFilterType == PaginationFilterTypeUnknown { + panic("PaginationFilterType is required") + } + + if si.PlaceholderFormat == nil { + panic("PlaceholderFormat is required") + } +} diff --git a/internal/datastore/common/tuple.go b/internal/datastore/common/sliceiter.go similarity index 100% rename from internal/datastore/common/tuple.go rename to internal/datastore/common/sliceiter.go diff --git a/internal/datastore/common/sql.go b/internal/datastore/common/sql.go index d886927931..733ca30bba 100644 --- a/internal/datastore/common/sql.go +++ b/internal/datastore/common/sql.go @@ -2,8 +2,10 @@ package common import ( "context" + "maps" "math" "strings" + "time" sq "github.com/Masterminds/squirrel" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" @@ -47,115 +49,157 @@ var ( tracer = otel.Tracer("spicedb/internal/datastore/common") ) -// PaginationFilterType is an enumerator +// PaginationFilterType is an enumerator for pagination filter types. type PaginationFilterType uint8 const ( + PaginationFilterTypeUnknown PaginationFilterType = iota + // TupleComparison uses a comparison with a compound key, // e.g. (namespace, object_id, relation) > ('ns', '123', 'viewer') // which is not compatible with all datastores. - TupleComparison PaginationFilterType = iota + TupleComparison = 1 // ExpandedLogicComparison comparison uses a nested tree of ANDs and ORs to properly // filter out already received relationships. Useful for databases that do not support // tuple comparison, or do not execute it efficiently - ExpandedLogicComparison + ExpandedLogicComparison = 2 +) + +// ColumnOptimizationOption is an enumerator for column optimization options. +type ColumnOptimizationOption int + +const ( + ColumnOptimizationOptionUnknown ColumnOptimizationOption = iota + + // ColumnOptimizationOptionNone is the default option, which does not optimize the static columns. + ColumnOptimizationOptionNone + + // ColumnOptimizationOptionStaticValues is an option that optimizes columns for static values. + ColumnOptimizationOptionStaticValues ) -// SchemaInformation holds the schema information from the SQL datastore implementation. -type SchemaInformation struct { - colNamespace string - colObjectID string - colRelation string - colUsersetNamespace string - colUsersetObjectID string - colUsersetRelation string - colCaveatName string - colExpiration string - paginationFilterType PaginationFilterType - nowFunction string -} - -func NewSchemaInformation( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName string, - colExpiration string, - paginationFilterType PaginationFilterType, - nowFunction string, -) SchemaInformation { - return SchemaInformation{ - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName, - colExpiration, - paginationFilterType, - nowFunction, +type columnTracker struct { + SingleValue *string +} + +type columnTrackerMap map[string]columnTracker + +func (ctm columnTrackerMap) hasStaticValue(columnName string) bool { + if r, ok := ctm[columnName]; ok && r.SingleValue != nil { + return true } + return false } // SchemaQueryFilterer wraps a SchemaInformation and SelectBuilder to give an opinionated // way to build query objects. type SchemaQueryFilterer struct { - schema SchemaInformation - queryBuilder sq.SelectBuilder - filteringColumnCounts map[string]int - filterMaximumIDCount uint16 + schema SchemaInformation + queryBuilder sq.SelectBuilder + filteringColumnTracker columnTrackerMap + filterMaximumIDCount uint16 + isCustomQuery bool + extraFields []string + fromSuffix string } -// NewSchemaQueryFilterer creates a new SchemaQueryFilterer object. -func NewSchemaQueryFilterer(schema SchemaInformation, initialQuery sq.SelectBuilder, filterMaximumIDCount uint16) SchemaQueryFilterer { +// NewSchemaQueryFiltererForRelationshipsSelect creates a new SchemaQueryFilterer object for selecting +// relationships. This method will automatically filter the columns retrieved from the database, only +// selecting the columns that are not already specified with a single static value in the query. +func NewSchemaQueryFiltererForRelationshipsSelect(schema SchemaInformation, filterMaximumIDCount uint16, extraFields ...string) SchemaQueryFilterer { + schema.debugValidate() + if filterMaximumIDCount == 0 { filterMaximumIDCount = 100 log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100") } - // Filter out any expired relationships. - initialQuery = initialQuery.Where(sq.Or{ - sq.Eq{schema.colExpiration: nil}, - sq.Expr(schema.colExpiration + " > " + schema.nowFunction + "()"), - }) + queryBuilder := sq.StatementBuilder.PlaceholderFormat(schema.PlaceholderFormat).Select() + return SchemaQueryFilterer{ + schema: schema, + queryBuilder: queryBuilder, + filteringColumnTracker: map[string]columnTracker{}, + filterMaximumIDCount: filterMaximumIDCount, + isCustomQuery: false, + extraFields: extraFields, + } +} + +// NewSchemaQueryFiltererWithStartingQuery creates a new SchemaQueryFilterer object for selecting +// relationships, with a custom starting query. Unlike NewSchemaQueryFiltererForRelationshipsSelect, +// this method will not auto-filter the columns retrieved from the database. +func NewSchemaQueryFiltererWithStartingQuery(schema SchemaInformation, startingQuery sq.SelectBuilder, filterMaximumIDCount uint16) SchemaQueryFilterer { + schema.debugValidate() + + if filterMaximumIDCount == 0 { + filterMaximumIDCount = 100 + log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100") + } return SchemaQueryFilterer{ - schema: schema, - queryBuilder: initialQuery, - filteringColumnCounts: map[string]int{}, - filterMaximumIDCount: filterMaximumIDCount, + schema: schema, + queryBuilder: startingQuery, + filteringColumnTracker: map[string]columnTracker{}, + filterMaximumIDCount: filterMaximumIDCount, + isCustomQuery: true, + extraFields: nil, } } +// WithAdditionalFilter returns the SchemaQueryFilterer with an additional filter applied to the query. +func (sqf SchemaQueryFilterer) WithAdditionalFilter(filter func(original sq.SelectBuilder) sq.SelectBuilder) SchemaQueryFilterer { + sqf.queryBuilder = filter(sqf.queryBuilder) + return sqf +} + +// WithFromSuffix returns the SchemaQueryFilterer with a suffix added to the FROM clause. +func (sqf SchemaQueryFilterer) WithFromSuffix(fromSuffix string) SchemaQueryFilterer { + sqf.fromSuffix = fromSuffix + return sqf +} + func (sqf SchemaQueryFilterer) UnderlyingQueryBuilder() sq.SelectBuilder { - return sqf.queryBuilder + spiceerrors.DebugAssert(func() bool { + return sqf.isCustomQuery + }, "UnderlyingQueryBuilder should only be called on custom queries") + return sqf.queryBuilderWithMaybeExpirationFilter(false) +} + +// queryBuilderWithMaybeExpirationFilter returns the query builder with the expiration filter applied, when necessary. +// Note that this adds the clause to the existing builder. +func (sqf SchemaQueryFilterer) queryBuilderWithMaybeExpirationFilter(skipExpiration bool) sq.SelectBuilder { + if sqf.schema.ExpirationDisabled || skipExpiration { + return sqf.queryBuilder + } + + // Filter out any expired relationships. + return sqf.queryBuilder.Where(sq.Or{ + sq.Eq{sqf.schema.ColExpiration: nil}, + sq.Expr(sqf.schema.ColExpiration + " > " + sqf.schema.NowFunction + "()"), + }) } func (sqf SchemaQueryFilterer) TupleOrder(order options.SortOrder) SchemaQueryFilterer { switch order { case options.ByResource: sqf.queryBuilder = sqf.queryBuilder.OrderBy( - sqf.schema.colNamespace, - sqf.schema.colObjectID, - sqf.schema.colRelation, - sqf.schema.colUsersetNamespace, - sqf.schema.colUsersetObjectID, - sqf.schema.colUsersetRelation, + sqf.schema.ColNamespace, + sqf.schema.ColObjectID, + sqf.schema.ColRelation, + sqf.schema.ColUsersetNamespace, + sqf.schema.ColUsersetObjectID, + sqf.schema.ColUsersetRelation, ) case options.BySubject: sqf.queryBuilder = sqf.queryBuilder.OrderBy( - sqf.schema.colUsersetNamespace, - sqf.schema.colUsersetObjectID, - sqf.schema.colUsersetRelation, - sqf.schema.colNamespace, - sqf.schema.colObjectID, - sqf.schema.colRelation, + sqf.schema.ColUsersetNamespace, + sqf.schema.ColUsersetObjectID, + sqf.schema.ColUsersetRelation, + sqf.schema.ColNamespace, + sqf.schema.ColObjectID, + sqf.schema.ColRelation, ) } @@ -174,47 +218,47 @@ func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOr columnsAndValues := map[options.SortOrder][]nameAndValue{ options.ByResource: { { - sqf.schema.colNamespace, cursor.Resource.ObjectType, + sqf.schema.ColNamespace, cursor.Resource.ObjectType, }, { - sqf.schema.colObjectID, cursor.Resource.ObjectID, + sqf.schema.ColObjectID, cursor.Resource.ObjectID, }, { - sqf.schema.colRelation, cursor.Resource.Relation, + sqf.schema.ColRelation, cursor.Resource.Relation, }, { - sqf.schema.colUsersetNamespace, cursor.Subject.ObjectType, + sqf.schema.ColUsersetNamespace, cursor.Subject.ObjectType, }, { - sqf.schema.colUsersetObjectID, cursor.Subject.ObjectID, + sqf.schema.ColUsersetObjectID, cursor.Subject.ObjectID, }, { - sqf.schema.colUsersetRelation, cursor.Subject.Relation, + sqf.schema.ColUsersetRelation, cursor.Subject.Relation, }, }, options.BySubject: { { - sqf.schema.colUsersetNamespace, cursor.Subject.ObjectType, + sqf.schema.ColUsersetNamespace, cursor.Subject.ObjectType, }, { - sqf.schema.colUsersetObjectID, cursor.Subject.ObjectID, + sqf.schema.ColUsersetObjectID, cursor.Subject.ObjectID, }, { - sqf.schema.colNamespace, cursor.Resource.ObjectType, + sqf.schema.ColNamespace, cursor.Resource.ObjectType, }, { - sqf.schema.colObjectID, cursor.Resource.ObjectID, + sqf.schema.ColObjectID, cursor.Resource.ObjectID, }, { - sqf.schema.colRelation, cursor.Resource.Relation, + sqf.schema.ColRelation, cursor.Resource.Relation, }, { - sqf.schema.colUsersetRelation, cursor.Subject.Relation, + sqf.schema.ColUsersetRelation, cursor.Subject.Relation, }, }, }[order] - switch sqf.schema.paginationFilterType { + switch sqf.schema.PaginationFilterType { case TupleComparison: // For performance reasons, remove any column names that have static values in the query. columnNames := make([]string, 0, len(columnsAndValues)) @@ -222,7 +266,7 @@ func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOr comparisonSlotCount := 0 for _, cav := range columnsAndValues { - if sqf.filteringColumnCounts[cav.name] != 1 { + if !sqf.filteringColumnTracker.hasStaticValue(cav.name) { columnNames = append(columnNames, cav.name) valueSlots = append(valueSlots, cav.value) comparisonSlotCount++ @@ -242,10 +286,10 @@ func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOr orClause := sq.Or{} for index, cav := range columnsAndValues { - if sqf.filteringColumnCounts[cav.name] != 1 { + if !sqf.filteringColumnTracker.hasStaticValue(cav.name) { andClause := sq.And{} for _, previous := range columnsAndValues[0:index] { - if sqf.filteringColumnCounts[previous.name] != 1 { + if !sqf.filteringColumnTracker.hasStaticValue(previous.name) { andClause = append(andClause, sq.Eq{previous.name: previous.value}) } } @@ -266,25 +310,31 @@ func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOr // FilterToResourceType returns a new SchemaQueryFilterer that is limited to resources of the // specified type. func (sqf SchemaQueryFilterer) FilterToResourceType(resourceType string) SchemaQueryFilterer { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colNamespace: resourceType}) - sqf.recordColumnValue(sqf.schema.colNamespace) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColNamespace: resourceType}) + sqf.recordColumnValue(sqf.schema.ColNamespace, resourceType) return sqf } -func (sqf SchemaQueryFilterer) recordColumnValue(colName string) { - if value, ok := sqf.filteringColumnCounts[colName]; ok { - sqf.filteringColumnCounts[colName] = value + 1 - return +func (sqf SchemaQueryFilterer) recordColumnValue(colName string, colValue string) { + existing, ok := sqf.filteringColumnTracker[colName] + if ok { + if existing.SingleValue != nil && *existing.SingleValue != colValue { + sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: nil} + } + } else { + sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: &colValue} } +} - sqf.filteringColumnCounts[colName] = 1 +func (sqf SchemaQueryFilterer) recordVaryingColumnValue(colName string) { + sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: nil} } // FilterToResourceID returns a new SchemaQueryFilterer that is limited to resources with the // specified ID. func (sqf SchemaQueryFilterer) FilterToResourceID(objectID string) SchemaQueryFilterer { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colObjectID: objectID}) - sqf.recordColumnValue(sqf.schema.colObjectID) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColObjectID: objectID}) + sqf.recordColumnValue(sqf.schema.ColObjectID, objectID) return sqf } @@ -309,7 +359,7 @@ func (sqf SchemaQueryFilterer) FilterWithResourceIDPrefix(prefix string) (Schema prefix = strings.ReplaceAll(prefix, `\`, `\\`) prefix = strings.ReplaceAll(prefix, "_", `\_`) - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Like{sqf.schema.colObjectID: prefix + "%"}) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Like{sqf.schema.ColObjectID: prefix + "%"}) // NOTE: we do *not* record the use of the resource ID column here, because it is not used // statically and thus is necessary for sorting operations. @@ -332,7 +382,7 @@ func (sqf SchemaQueryFilterer) FilterToResourceIDs(resourceIds []string) (Schema }, "cannot have more than %d resource IDs in a single filter", sqf.filterMaximumIDCount) var builder strings.Builder - builder.WriteString(sqf.schema.colObjectID) + builder.WriteString(sqf.schema.ColObjectID) builder.WriteString(" IN (") args := make([]any, 0, len(resourceIds)) @@ -342,7 +392,7 @@ func (sqf SchemaQueryFilterer) FilterToResourceIDs(resourceIds []string) (Schema } args = append(args, resourceID) - sqf.recordColumnValue(sqf.schema.colObjectID) + sqf.recordColumnValue(sqf.schema.ColObjectID, resourceID) } builder.WriteString("?") @@ -358,8 +408,8 @@ func (sqf SchemaQueryFilterer) FilterToResourceIDs(resourceIds []string) (Schema // FilterToRelation returns a new SchemaQueryFilterer that is limited to resources with the // specified relation. func (sqf SchemaQueryFilterer) FilterToRelation(relation string) SchemaQueryFilterer { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colRelation: relation}) - sqf.recordColumnValue(sqf.schema.colRelation) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColRelation: relation}) + sqf.recordColumnValue(sqf.schema.ColRelation, relation) return sqf } @@ -417,9 +467,10 @@ func (sqf SchemaQueryFilterer) FilterWithRelationshipsFilter(filter datastore.Re } if filter.OptionalExpirationOption == datastore.ExpirationFilterOptionHasExpiration { - csqf.queryBuilder = csqf.queryBuilder.Where(sq.NotEq{csqf.schema.colExpiration: nil}) + csqf.queryBuilder = csqf.queryBuilder.Where(sq.NotEq{csqf.schema.ColExpiration: nil}) + spiceerrors.DebugAssert(func() bool { return !sqf.schema.ExpirationDisabled }, "expiration filter requested but schema does not support expiration") } else if filter.OptionalExpirationOption == datastore.ExpirationFilterOptionNoExpiration { - csqf.queryBuilder = csqf.queryBuilder.Where(sq.Eq{csqf.schema.colExpiration: nil}) + csqf.queryBuilder = csqf.queryBuilder.Where(sq.Eq{csqf.schema.ColExpiration: nil}) } return csqf, nil @@ -440,12 +491,21 @@ func (sqf SchemaQueryFilterer) MustFilterWithSubjectsSelectors(selectors ...data func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastore.SubjectsSelector) (SchemaQueryFilterer, error) { selectorsOrClause := sq.Or{} + // If there is more than a single filter, record all the subjects as varying, as the subjects returned + // can differ for each branch. + // TODO(jschorr): Optimize this further where applicable. + if len(selectors) > 1 { + sqf.recordVaryingColumnValue(sqf.schema.ColUsersetNamespace) + sqf.recordVaryingColumnValue(sqf.schema.ColUsersetObjectID) + sqf.recordVaryingColumnValue(sqf.schema.ColUsersetRelation) + } + for _, selector := range selectors { selectorClause := sq.And{} if len(selector.OptionalSubjectType) > 0 { - selectorClause = append(selectorClause, sq.Eq{sqf.schema.colUsersetNamespace: selector.OptionalSubjectType}) - sqf.recordColumnValue(sqf.schema.colUsersetNamespace) + selectorClause = append(selectorClause, sq.Eq{sqf.schema.ColUsersetNamespace: selector.OptionalSubjectType}) + sqf.recordColumnValue(sqf.schema.ColUsersetNamespace, selector.OptionalSubjectType) } if len(selector.OptionalSubjectIds) > 0 { @@ -454,7 +514,7 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor }, "cannot have more than %d subject IDs in a single filter", sqf.filterMaximumIDCount) var builder strings.Builder - builder.WriteString(sqf.schema.colUsersetObjectID) + builder.WriteString(sqf.schema.ColUsersetObjectID) builder.WriteString(" IN (") args := make([]any, 0, len(selector.OptionalSubjectIds)) @@ -464,7 +524,7 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor } args = append(args, subjectID) - sqf.recordColumnValue(sqf.schema.colUsersetObjectID) + sqf.recordColumnValue(sqf.schema.ColUsersetObjectID, subjectID) } builder.WriteString("?") @@ -478,8 +538,8 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor if !selector.RelationFilter.IsEmpty() { if selector.RelationFilter.OnlyNonEllipsisRelations { - selectorClause = append(selectorClause, sq.NotEq{sqf.schema.colUsersetRelation: datastore.Ellipsis}) - sqf.recordColumnValue(sqf.schema.colUsersetRelation) + selectorClause = append(selectorClause, sq.NotEq{sqf.schema.ColUsersetRelation: datastore.Ellipsis}) + sqf.recordVaryingColumnValue(sqf.schema.ColUsersetRelation) } else { relations := make([]string, 0, 2) if selector.RelationFilter.IncludeEllipsisRelation { @@ -492,14 +552,14 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor if len(relations) == 1 { relName := relations[0] - selectorClause = append(selectorClause, sq.Eq{sqf.schema.colUsersetRelation: relName}) - sqf.recordColumnValue(sqf.schema.colUsersetRelation) + selectorClause = append(selectorClause, sq.Eq{sqf.schema.ColUsersetRelation: relName}) + sqf.recordColumnValue(sqf.schema.ColUsersetRelation, relName) } else { orClause := sq.Or{} for _, relationName := range relations { dsRelationName := stringz.DefaultEmpty(relationName, datastore.Ellipsis) - orClause = append(orClause, sq.Eq{sqf.schema.colUsersetRelation: dsRelationName}) - sqf.recordColumnValue(sqf.schema.colUsersetRelation) + orClause = append(orClause, sq.Eq{sqf.schema.ColUsersetRelation: dsRelationName}) + sqf.recordColumnValue(sqf.schema.ColUsersetRelation, dsRelationName) } selectorClause = append(selectorClause, orClause) @@ -517,27 +577,27 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor // FilterToSubjectFilter returns a new SchemaQueryFilterer that is limited to resources with // subjects that match the specified filter. func (sqf SchemaQueryFilterer) FilterToSubjectFilter(filter *v1.SubjectFilter) SchemaQueryFilterer { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colUsersetNamespace: filter.SubjectType}) - sqf.recordColumnValue(sqf.schema.colUsersetNamespace) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetNamespace: filter.SubjectType}) + sqf.recordColumnValue(sqf.schema.ColUsersetNamespace, filter.SubjectType) if filter.OptionalSubjectId != "" { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colUsersetObjectID: filter.OptionalSubjectId}) - sqf.recordColumnValue(sqf.schema.colUsersetObjectID) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetObjectID: filter.OptionalSubjectId}) + sqf.recordColumnValue(sqf.schema.ColUsersetObjectID, filter.OptionalSubjectId) } if filter.OptionalRelation != nil { dsRelationName := stringz.DefaultEmpty(filter.OptionalRelation.Relation, datastore.Ellipsis) - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colUsersetRelation: dsRelationName}) - sqf.recordColumnValue(sqf.schema.colUsersetRelation) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetRelation: dsRelationName}) + sqf.recordColumnValue(sqf.schema.ColUsersetRelation, datastore.Ellipsis) } return sqf } func (sqf SchemaQueryFilterer) FilterWithCaveatName(caveatName string) SchemaQueryFilterer { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colCaveatName: caveatName}) - sqf.recordColumnValue(sqf.schema.colCaveatName) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColCaveatName: caveatName}) + sqf.recordColumnValue(sqf.schema.ColCaveatName, caveatName) return sqf } @@ -547,21 +607,30 @@ func (sqf SchemaQueryFilterer) limit(limit uint64) SchemaQueryFilterer { return sqf } -// QueryExecutor is a tuple query runner shared by SQL implementations of the datastore. -type QueryExecutor struct { - Executor ExecuteQueryFunc +// QueryRelationshipsExecutor is a relationships query runner shared by SQL implementations of the datastore. +type QueryRelationshipsExecutor struct { + Executor ExecuteReadRelsQueryFunc } +// ExecuteReadRelsQueryFunc is a function that can be used to execute a single rendered SQL query. +type ExecuteReadRelsQueryFunc func(ctx context.Context, builder RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) + // ExecuteQuery executes the query. -func (tqs QueryExecutor) ExecuteQuery( +func (exc QueryRelationshipsExecutor) ExecuteQuery( ctx context.Context, query SchemaQueryFilterer, opts ...options.QueryOptionsOption, ) (datastore.RelationshipIterator, error) { + if query.isCustomQuery { + return nil, spiceerrors.MustBugf("ExecuteQuery should not be called on custom queries") + } + queryOpts := options.NewQueryOptionsWithOptions(opts...) + // Add sort order. query = query.TupleOrder(queryOpts.Sort) + // Add cursor. if queryOpts.After != nil { if queryOpts.Sort == options.Unsorted { return nil, datastore.ErrCursorsWithoutSorting @@ -570,6 +639,7 @@ func (tqs QueryExecutor) ExecuteQuery( query = query.After(queryOpts.After, queryOpts.Sort) } + // Add limit. var limit uint64 // NOTE: we use a uint here because it lines up with the // assignments in this function, but we set it to MaxInt64 @@ -580,20 +650,193 @@ func (tqs QueryExecutor) ExecuteQuery( limit = *queryOpts.Limit } - toExecute := query.limit(limit) + if limit < math.MaxInt64 { + query = query.limit(limit) + } + + // Add FROM clause. + from := query.schema.RelationshipTableName + if query.fromSuffix != "" { + from += " " + query.fromSuffix + } + + query.queryBuilder = query.queryBuilder.From(from) + + builder := RelationshipsQueryBuilder{ + Schema: query.schema, + SkipCaveats: queryOpts.SkipCaveats, + SkipExpiration: queryOpts.SkipExpiration, + sqlAssertion: queryOpts.SQLAssertion, + filteringValues: query.filteringColumnTracker, + baseQueryBuilder: query, + } + + return exc.Executor(ctx, builder) +} - // Run the query. - sql, args, err := toExecute.queryBuilder.ToSql() +// RelationshipsQueryBuilder is a builder for producing the SQL and arguments necessary for reading +// relationships. +type RelationshipsQueryBuilder struct { + Schema SchemaInformation + SkipCaveats bool + SkipExpiration bool + + filteringValues columnTrackerMap + baseQueryBuilder SchemaQueryFilterer + sqlAssertion options.Assertion +} + +// withCaveats returns true if caveats should be included in the query. +func (b RelationshipsQueryBuilder) withCaveats() bool { + return !b.SkipCaveats || b.Schema.ColumnOptimization == ColumnOptimizationOptionNone +} + +// withExpiration returns true if expiration should be included in the query. +func (b RelationshipsQueryBuilder) withExpiration() bool { + return !b.SkipExpiration && !b.Schema.ExpirationDisabled +} + +// integrityEnabled returns true if integrity columns should be included in the query. +func (b RelationshipsQueryBuilder) integrityEnabled() bool { + return b.Schema.IntegrityEnabled +} + +// columnCount returns the number of columns that will be selected in the query. +func (b RelationshipsQueryBuilder) columnCount() int { + columnCount := relationshipStandardColumnCount + if b.withCaveats() { + columnCount += relationshipCaveatColumnCount + } + if b.withExpiration() { + columnCount += relationshipExpirationColumnCount + } + if b.integrityEnabled() { + columnCount += relationshipIntegrityColumnCount + } + return columnCount +} + +// SelectSQL returns the SQL and arguments necessary for reading relationships. +func (b RelationshipsQueryBuilder) SelectSQL() (string, []any, error) { + // Set the column names to select. + columnNamesToSelect := make([]string, 0, b.columnCount()) + + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColNamespace) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColObjectID) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColRelation) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetNamespace) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetObjectID) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetRelation) + + if b.withCaveats() { + columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColCaveatName, b.Schema.ColCaveatContext) + } + + if b.withExpiration() { + columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColExpiration) + } + + if b.integrityEnabled() { + columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColIntegrityKeyID, b.Schema.ColIntegrityHash, b.Schema.ColIntegrityTimestamp) + } + + if len(columnNamesToSelect) == 0 { + columnNamesToSelect = append(columnNamesToSelect, "1") + } + + sqlBuilder := b.baseQueryBuilder.queryBuilderWithMaybeExpirationFilter(b.SkipExpiration) + sqlBuilder = sqlBuilder.Columns(columnNamesToSelect...) + + sql, args, err := sqlBuilder.ToSql() if err != nil { - return nil, err + return "", nil, err } - return tqs.Executor(ctx, sql, args) + if b.sqlAssertion != nil { + b.sqlAssertion(sql) + } + + return sql, args, nil +} + +// FilteringValuesForTesting returns the filtering values. For test use only. +func (b RelationshipsQueryBuilder) FilteringValuesForTesting() map[string]columnTracker { + return maps.Clone(b.filteringValues) +} + +func (b RelationshipsQueryBuilder) checkColumn(columns []string, colName string) []string { + if b.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + return append(columns, colName) + } + + if !b.filteringValues.hasStaticValue(colName) { + return append(columns, colName) + } + + return columns +} + +func (b RelationshipsQueryBuilder) staticValueOrAddColumnForSelect(colsToSelect []any, colName string, field *string) []any { + if b.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + // If column optimization is disabled, always add the column to the list of columns to select. + colsToSelect = append(colsToSelect, field) + return colsToSelect + } + + // If the value is static, set the field to it and return. + if found, ok := b.filteringValues[colName]; ok && found.SingleValue != nil { + *field = *found.SingleValue + return colsToSelect + } + + // Otherwise, add the column to the list of columns to select, as the value is not static. + colsToSelect = append(colsToSelect, field) + return colsToSelect } -// ExecuteQueryFunc is a function that can be used to execute a single rendered SQL query. -type ExecuteQueryFunc func(ctx context.Context, sql string, args []any) (datastore.RelationshipIterator, error) +// ColumnsToSelect returns the columns to select for a given query. The columns provided are +// the references to the slots in which the values for each relationship will be placed. +func ColumnsToSelect[CN any, CC any, EC any]( + b RelationshipsQueryBuilder, + resourceObjectType *string, + resourceObjectID *string, + resourceRelation *string, + subjectObjectType *string, + subjectObjectID *string, + subjectRelation *string, + caveatName *CN, + caveatCtx *CC, + expiration EC, + + integrityKeyID *string, + integrityHash *[]byte, + timestamp *time.Time, +) ([]any, error) { + colsToSelect := make([]any, 0, b.columnCount()) + + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColNamespace, resourceObjectType) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColObjectID, resourceObjectID) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColRelation, resourceRelation) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetNamespace, subjectObjectType) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetObjectID, subjectObjectID) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetRelation, subjectRelation) + + if b.withCaveats() { + colsToSelect = append(colsToSelect, caveatName, caveatCtx) + } + + if b.withExpiration() { + colsToSelect = append(colsToSelect, expiration) + } + + if b.Schema.IntegrityEnabled { + colsToSelect = append(colsToSelect, integrityKeyID, integrityHash, timestamp) + } -// TxCleanupFunc is a function that should be executed when the caller of -// TransactionFactory is done with the transaction. -type TxCleanupFunc func(context.Context) + if len(colsToSelect) == 0 { + var unused int + colsToSelect = append(colsToSelect, &unused) + } + + return colsToSelect, nil +} diff --git a/internal/datastore/common/sql_test.go b/internal/datastore/common/sql_test.go index 11043512b9..62d6fbeea1 100644 --- a/internal/datastore/common/sql_test.go +++ b/internal/datastore/common/sql_test.go @@ -1,9 +1,10 @@ package common import ( + "context" + "fmt" "testing" - - "github.com/google/uuid" + "time" "github.com/authzed/spicedb/pkg/datastore/options" @@ -18,297 +19,307 @@ import ( var toCursor = options.ToCursor +type expected struct { + sql string + args []any + staticCols []string +} + func TestSchemaQueryFilterer(t *testing.T) { tests := []struct { - name string - run func(filterer SchemaQueryFilterer) SchemaQueryFilterer - expectedSQL string - expectedArgs []any - expectedColumnCounts map[string]int + name string + run func(filterer SchemaQueryFilterer) SchemaQueryFilterer + withExpirationDisabled bool + expectedForTuple expected + expectedForExpanded expected }{ { - "relation filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "relation filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToRelation("somerelation") + }, + expectedForTuple: expected{ + sql: "SELECT * WHERE relation = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somerelation"}, + staticCols: []string{"relation"}, + }, + }, + { + name: "relation filter without expiration", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToRelation("somerelation") }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND relation = ?", - []any{"somerelation"}, - map[string]int{ - "relation": 1, + withExpirationDisabled: true, + expectedForTuple: expected{ + sql: "SELECT * WHERE relation = ?", + args: []any{"somerelation"}, + staticCols: []string{"relation"}, }, }, { - "resource ID filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "resource ID filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceID("someresourceid") }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND object_id = ?", - []any{"someresourceid"}, - map[string]int{ - "object_id": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE object_id = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourceid"}, + staticCols: []string{"object_id"}, }, }, { - "resource IDs filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "resource IDs filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithResourceIDPrefix("someprefix") }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND object_id LIKE ?", - []any{"someprefix%"}, - map[string]int{}, // object_id is not statically used, so not present in the map + expectedForTuple: expected{ + sql: "SELECT * WHERE object_id LIKE ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someprefix%"}, + staticCols: []string{}, + }, }, { - "resource IDs prefix filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "resource IDs prefix filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterToResourceIDs([]string{"someresourceid", "anotherresourceid"}) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND object_id IN (?,?)", - []any{"someresourceid", "anotherresourceid"}, - map[string]int{ - "object_id": 2, + expectedForTuple: expected{ + sql: "SELECT * WHERE object_id IN (?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourceid", "anotherresourceid"}, + staticCols: []string{}, }, }, { - "resource type filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "resource type filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype") }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ?", - []any{"sometype"}, - map[string]int{ - "ns": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"sometype"}, + staticCols: []string{"ns"}, }, }, { - "resource filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "resource filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel") }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ?", - []any{"sometype", "someobj", "somerel"}, - map[string]int{ - "ns": 1, - "object_id": 1, - "relation": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND object_id = ? AND relation = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"sometype", "someobj", "somerel"}, + staticCols: []string{"ns", "object_id", "relation"}, }, }, { - "relationships filter with no IDs or relations", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "relationships filter with no IDs or relations", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter(datastore.RelationshipsFilter{ OptionalResourceType: "sometype", }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ?", - []any{"sometype"}, - map[string]int{ - "ns": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"sometype"}, + staticCols: []string{"ns"}, }, }, { - "relationships filter with single ID", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "relationships filter with single ID", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter(datastore.RelationshipsFilter{ OptionalResourceType: "sometype", OptionalResourceIds: []string{"someid"}, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id IN (?)", - []any{"sometype", "someid"}, - map[string]int{ - "ns": 1, - "object_id": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND object_id IN (?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"sometype", "someid"}, + staticCols: []string{"ns", "object_id"}, }, }, { - "relationships filter with no IDs", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "relationships filter with no IDs", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter(datastore.RelationshipsFilter{ OptionalResourceType: "sometype", OptionalResourceIds: []string{}, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ?", - []any{"sometype"}, - map[string]int{ - "ns": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"sometype"}, + staticCols: []string{"ns"}, }, }, { - "relationships filter with multiple IDs", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "relationships filter with multiple IDs", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter(datastore.RelationshipsFilter{ OptionalResourceType: "sometype", OptionalResourceIds: []string{"someid", "anotherid"}, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id IN (?,?)", - []any{"sometype", "someid", "anotherid"}, - map[string]int{ - "ns": 1, - "object_id": 2, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND object_id IN (?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"sometype", "someid", "anotherid"}, + staticCols: []string{"ns"}, }, }, { - "subjects filter with no IDs or relations", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter with no IDs or relations", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?))", - []any{"somesubjectype"}, - map[string]int{ - "subject_ns": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype"}, + staticCols: []string{"subject_ns"}, }, }, { - "multiple subjects filters with just types", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "multiple subjects filters with just types", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", }, datastore.SubjectsSelector{ OptionalSubjectType: "anothersubjectype", }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?) OR (subject_ns = ?))", - []any{"somesubjectype", "anothersubjectype"}, - map[string]int{ - "subject_ns": 2, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "anothersubjectype"}, + staticCols: []string{}, }, }, { - "subjects filter with single ID", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter with single ID", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", OptionalSubjectIds: []string{"somesubjectid"}, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?)))", - []any{"somesubjectype", "somesubjectid"}, - map[string]int{ - "subject_ns": 1, - "subject_object_id": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?))) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "somesubjectid"}, + staticCols: []string{"subject_ns", "subject_object_id"}, }, }, { - "subjects filter with single ID and no type", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter with single ID and no type", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectIds: []string{"somesubjectid"}, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_object_id IN (?)))", - []any{"somesubjectid"}, - map[string]int{ - "subject_object_id": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_object_id IN (?))) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectid"}, + staticCols: []string{"subject_object_id"}, }, }, { - "empty subjects filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "empty subjects filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{}) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((1=1))", - nil, - map[string]int{}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((1=1)) AND (expiration IS NULL OR expiration > NOW())", + args: nil, + staticCols: []string{}, + }, }, { - "subjects filter with multiple IDs", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter with multiple IDs", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", OptionalSubjectIds: []string{"somesubjectid", "anothersubjectid"}, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?,?)))", - []any{"somesubjectype", "somesubjectid", "anothersubjectid"}, - map[string]int{ - "subject_ns": 1, - "subject_object_id": 2, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?))) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "somesubjectid", "anothersubjectid"}, + staticCols: []string{"subject_ns"}, }, }, { - "subjects filter with single ellipsis relation", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter with single ellipsis relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", RelationFilter: datastore.SubjectRelationFilter{}.WithEllipsisRelation(), }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_relation = ?))", - []any{"somesubjectype", "..."}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_relation = ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "..."}, + staticCols: []string{"subject_ns", "subject_relation"}, }, }, { - "subjects filter with single defined relation", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter with single defined relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", RelationFilter: datastore.SubjectRelationFilter{}.WithNonEllipsisRelation("somesubrel"), }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_relation = ?))", - []any{"somesubjectype", "somesubrel"}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_relation = ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "somesubrel"}, + staticCols: []string{"subject_ns", "subject_relation"}, }, }, { - "subjects filter with only non-ellipsis", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter with only non-ellipsis", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", RelationFilter: datastore.SubjectRelationFilter{}.WithOnlyNonEllipsisRelations(), }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_relation <> ?))", - []any{"somesubjectype", "..."}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_relation <> ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "..."}, + staticCols: []string{"subject_ns"}, }, }, { - "subjects filter with defined relation and ellipsis", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter with defined relation and ellipsis", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", RelationFilter: datastore.SubjectRelationFilter{}.WithNonEllipsisRelation("somesubrel").WithEllipsisRelation(), }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND (subject_relation = ? OR subject_relation = ?)))", - []any{"somesubjectype", "...", "somesubrel"}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 2, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND (subject_relation = ? OR subject_relation = ?))) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "...", "somesubrel"}, + staticCols: []string{"subject_ns"}, }, }, { - "subjects filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", OptionalSubjectIds: []string{"somesubjectid", "anothersubjectid"}, RelationFilter: datastore.SubjectRelationFilter{}.WithNonEllipsisRelation("somesubrel").WithEllipsisRelation(), }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)))", - []any{"somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, - map[string]int{ - "subject_ns": 1, - "subject_object_id": 2, - "subject_relation": 2, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?))) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, + staticCols: []string{"subject_ns"}, }, }, { - "multiple subjects filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "multiple subjects filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors( datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", @@ -326,45 +337,42 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)) OR (subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)) OR (subject_ns = ? AND subject_relation <> ?))", - []any{"somesubjectype", "a", "b", "...", "somesubrel", "anothersubjecttype", "b", "c", "...", "anotherrel", "thirdsubjectype", "..."}, - map[string]int{ - "subject_ns": 3, - "subject_object_id": 4, - "subject_relation": 5, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)) OR (subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)) OR (subject_ns = ? AND subject_relation <> ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "a", "b", "...", "somesubrel", "anothersubjecttype", "b", "c", "...", "anotherrel", "thirdsubjectype", "..."}, + staticCols: []string{}, }, }, { - "v1 subject filter with namespace", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "v1 subject filter with namespace", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ SubjectType: "subns", }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ?", - []any{"subns"}, - map[string]int{ - "subject_ns": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE subject_ns = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"subns"}, + staticCols: []string{"subject_ns"}, }, }, { - "v1 subject filter with subject id", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "v1 subject filter with subject id", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ SubjectType: "subns", OptionalSubjectId: "subid", }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_object_id = ?", - []any{"subns", "subid"}, - map[string]int{ - "subject_ns": 1, - "subject_object_id": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"subns", "subid"}, + staticCols: []string{"subject_ns", "subject_object_id"}, }, }, { - "v1 subject filter with relation", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "v1 subject filter with relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ SubjectType: "subns", OptionalRelation: &v1.SubjectFilter_RelationFilter{ @@ -372,16 +380,15 @@ func TestSchemaQueryFilterer(t *testing.T) { }, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_relation = ?", - []any{"subns", "subrel"}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE subject_ns = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"subns", "subrel"}, + staticCols: []string{"subject_ns", "subject_relation"}, }, }, { - "v1 subject filter with empty relation", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "v1 subject filter with empty relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ SubjectType: "subns", OptionalRelation: &v1.SubjectFilter_RelationFilter{ @@ -389,16 +396,15 @@ func TestSchemaQueryFilterer(t *testing.T) { }, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_relation = ?", - []any{"subns", "..."}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE subject_ns = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"subns", "..."}, + staticCols: []string{"subject_ns", "subject_relation"}, }, }, { - "v1 subject filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "v1 subject filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ SubjectType: "subns", OptionalSubjectId: "subid", @@ -407,26 +413,26 @@ func TestSchemaQueryFilterer(t *testing.T) { }, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", - []any{"subns", "subid", "somerel"}, - map[string]int{ - "subject_ns": 1, - "subject_object_id": 1, - "subject_relation": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"subns", "subid", "somerel"}, + staticCols: []string{"subject_ns", "subject_object_id", "subject_relation"}, }, }, { - "limit", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "limit", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.limit(100) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) LIMIT 100", - nil, - map[string]int{}, + expectedForTuple: expected{ + sql: "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) LIMIT 100", + args: nil, + staticCols: []string{}, + }, }, { - "full resources filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "full resources filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceType: "someresourcetype", @@ -442,65 +448,95 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND relation = ? AND object_id IN (?,?) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)))", - []any{"someresourcetype", "somerelation", "someid", "anotherid", "somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, - map[string]int{ - "ns": 1, - "object_id": 2, - "relation": 1, - "subject_ns": 1, - "subject_object_id": 2, - "subject_relation": 2, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND relation = ? AND object_id IN (?,?) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?))) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "somerelation", "someid", "anotherid", "somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, + staticCols: []string{"ns", "relation", "subject_ns"}, }, }, { - "order by", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "full resources filter without expiration", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.MustFilterWithRelationshipsFilter( + datastore.RelationshipsFilter{ + OptionalResourceType: "someresourcetype", + OptionalResourceIds: []string{"someid", "anotherid"}, + OptionalResourceRelation: "somerelation", + OptionalSubjectsSelectors: []datastore.SubjectsSelector{ + { + OptionalSubjectType: "somesubjectype", + OptionalSubjectIds: []string{"somesubjectid", "anothersubjectid"}, + RelationFilter: datastore.SubjectRelationFilter{}.WithNonEllipsisRelation("somesubrel").WithEllipsisRelation(), + }, + }, + }, + ) + }, + withExpirationDisabled: true, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND relation = ? AND object_id IN (?,?) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)))", + args: []any{"someresourcetype", "somerelation", "someid", "anotherid", "somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, + staticCols: []string{"ns", "relation", "subject_ns"}, + }, + }, + { + name: "order by", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceType: "someresourcetype", }, ).TupleOrder(options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? ORDER BY ns, object_id, relation, subject_ns, subject_object_id, subject_relation", - []any{"someresourcetype"}, - map[string]int{ - "ns": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW()) ORDER BY ns, object_id, relation, subject_ns, subject_object_id, subject_relation", + args: []any{"someresourcetype"}, + staticCols: []string{"ns"}, }, }, { - "after with just namespace", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with just namespace", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceType: "someresourcetype", }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND (object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?)", - []any{"someresourcetype", "foo", "viewer", "user", "bar", "..."}, - map[string]int{ - "ns": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND (object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "foo", "viewer", "user", "bar", "..."}, + staticCols: []string{"ns"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ns = ? AND ((object_id > ?) OR (object_id = ? AND relation > ?) OR (object_id = ? AND relation = ? AND subject_ns > ?) OR (object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id > ?) OR (object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "foo", "foo", "viewer", "foo", "viewer", "user", "foo", "viewer", "user", "bar", "foo", "viewer", "user", "bar", "..."}, + staticCols: []string{"ns"}, }, }, { - "after with just relation", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with just relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceRelation: "somerelation", }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND relation = ? AND (ns,object_id,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?)", - []any{"somerelation", "someresourcetype", "foo", "user", "bar", "..."}, - map[string]int{ - "relation": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE relation = ? AND (ns,object_id,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somerelation", "someresourcetype", "foo", "user", "bar", "..."}, + staticCols: []string{"relation"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE relation = ? AND ((ns > ?) OR (ns = ? AND object_id > ?) OR (ns = ? AND object_id = ? AND subject_ns > ?) OR (ns = ? AND object_id = ? AND subject_ns = ? AND subject_object_id > ?) OR (ns = ? AND object_id = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somerelation", "someresourcetype", "someresourcetype", "foo", "someresourcetype", "foo", "user", "someresourcetype", "foo", "user", "bar", "someresourcetype", "foo", "user", "bar", "..."}, + staticCols: []string{"relation"}, }, }, { - "after with namespace and single resource id", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with namespace and single resource id", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceType: "someresourcetype", @@ -508,31 +544,40 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id IN (?) AND (relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?)", - []any{"someresourcetype", "one", "viewer", "user", "bar", "..."}, - map[string]int{ - "ns": 1, - "object_id": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND object_id IN (?) AND (relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "one", "viewer", "user", "bar", "..."}, + staticCols: []string{"ns", "object_id"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ns = ? AND object_id IN (?) AND ((relation > ?) OR (relation = ? AND subject_ns > ?) OR (relation = ? AND subject_ns = ? AND subject_object_id > ?) OR (relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "one", "viewer", "viewer", "user", "viewer", "user", "bar", "viewer", "user", "bar", "..."}, + staticCols: []string{"ns", "object_id"}, }, }, { - "after with single resource id", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with single resource id", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceIds: []string{"one"}, }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND object_id IN (?) AND (ns,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?)", - []any{"one", "someresourcetype", "viewer", "user", "bar", "..."}, - map[string]int{ - "object_id": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE object_id IN (?) AND (ns,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"one", "someresourcetype", "viewer", "user", "bar", "..."}, + staticCols: []string{"object_id"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE object_id IN (?) AND ((ns > ?) OR (ns = ? AND relation > ?) OR (ns = ? AND relation = ? AND subject_ns > ?) OR (ns = ? AND relation = ? AND subject_ns = ? AND subject_object_id > ?) OR (ns = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"one", "someresourcetype", "someresourcetype", "viewer", "someresourcetype", "viewer", "user", "someresourcetype", "viewer", "user", "bar", "someresourcetype", "viewer", "user", "bar", "..."}, + staticCols: []string{"object_id"}, }, }, { - "after with namespace and resource ids", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with namespace and resource ids", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceType: "someresourcetype", @@ -540,16 +585,20 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id IN (?,?) AND (object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?)", - []any{"someresourcetype", "one", "two", "foo", "viewer", "user", "bar", "..."}, - map[string]int{ - "ns": 1, - "object_id": 2, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND object_id IN (?,?) AND (object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "one", "two", "foo", "viewer", "user", "bar", "..."}, + staticCols: []string{"ns"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ns = ? AND object_id IN (?,?) AND ((object_id > ?) OR (object_id = ? AND relation > ?) OR (object_id = ? AND relation = ? AND subject_ns > ?) OR (object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id > ?) OR (object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "one", "two", "foo", "foo", "viewer", "foo", "viewer", "user", "foo", "viewer", "user", "bar", "foo", "viewer", "user", "bar", "..."}, + staticCols: []string{"ns"}, }, }, { - "after with namespace and relation", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with namespace and relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceType: "someresourcetype", @@ -557,29 +606,38 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND relation = ? AND (object_id,subject_ns,subject_object_id,subject_relation) > (?,?,?,?)", - []any{"someresourcetype", "somerelation", "foo", "user", "bar", "..."}, - map[string]int{ - "ns": 1, - "relation": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND relation = ? AND (object_id,subject_ns,subject_object_id,subject_relation) > (?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "somerelation", "foo", "user", "bar", "..."}, + staticCols: []string{"ns", "relation"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ns = ? AND relation = ? AND ((object_id > ?) OR (object_id = ? AND subject_ns > ?) OR (object_id = ? AND subject_ns = ? AND subject_object_id > ?) OR (object_id = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "somerelation", "foo", "foo", "user", "foo", "user", "bar", "foo", "user", "bar", "..."}, + staticCols: []string{"ns", "relation"}, }, }, { - "after with subject namespace", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with subject namespace", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", }).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?)) AND (ns,object_id,relation,subject_object_id,subject_relation) > (?,?,?,?,?)", - []any{"somesubjectype", "someresourcetype", "foo", "viewer", "bar", "..."}, - map[string]int{ - "subject_ns": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ?)) AND (ns,object_id,relation,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "someresourcetype", "foo", "viewer", "bar", "..."}, + staticCols: []string{"subject_ns"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ((subject_ns = ?)) AND ((ns > ?) OR (ns = ? AND object_id > ?) OR (ns = ? AND object_id = ? AND relation > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_object_id > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "someresourcetype", "someresourcetype", "foo", "someresourcetype", "foo", "viewer", "someresourcetype", "foo", "viewer", "bar", "someresourcetype", "foo", "viewer", "bar", "..."}, + staticCols: []string{"subject_ns"}, }, }, { - "after with subject namespaces", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with subject namespaces", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { // NOTE: this isn't really valid (it'll return no results), but is a good test to ensure // the duplicate subject type results in the subject type being in the ORDER BY. return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ @@ -588,126 +646,761 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalSubjectType: "anothersubjectype", }).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?)) AND ((subject_ns = ?)) AND (ns,object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?,?)", - []any{"somesubjectype", "anothersubjectype", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, - map[string]int{ - "subject_ns": 2, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ?)) AND ((subject_ns = ?)) AND (ns,object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "anothersubjectype", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, + staticCols: []string{}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ((subject_ns = ?)) AND ((subject_ns = ?)) AND ((ns > ?) OR (ns = ? AND object_id > ?) OR (ns = ? AND object_id = ? AND relation > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_ns > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "anothersubjectype", "someresourcetype", "someresourcetype", "foo", "someresourcetype", "foo", "viewer", "someresourcetype", "foo", "viewer", "user", "someresourcetype", "foo", "viewer", "user", "bar", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, + staticCols: []string{}, }, }, { - "after with resource ID prefix", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with resource ID prefix", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithResourceIDPrefix("someprefix").After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND object_id LIKE ? AND (ns,object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?,?)", - []any{"someprefix%", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, - map[string]int{}, + expectedForTuple: expected{ + sql: "SELECT * WHERE object_id LIKE ? AND (ns,object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someprefix%", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, + staticCols: []string{}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE object_id LIKE ? AND ((ns > ?) OR (ns = ? AND object_id > ?) OR (ns = ? AND object_id = ? AND relation > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_ns > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someprefix%", "someresourcetype", "someresourcetype", "foo", "someresourcetype", "foo", "viewer", "someresourcetype", "foo", "viewer", "user", "someresourcetype", "foo", "viewer", "user", "bar", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, + staticCols: []string{}, + }, }, { - "order by subject", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "order by subject", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceType: "someresourcetype", }, ).TupleOrder(options.BySubject) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? ORDER BY subject_ns, subject_object_id, subject_relation, ns, object_id, relation", - []any{"someresourcetype"}, - map[string]int{ - "ns": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW()) ORDER BY subject_ns, subject_object_id, subject_relation, ns, object_id, relation", + args: []any{"someresourcetype"}, + staticCols: []string{"ns"}, }, }, { - "order by subject, after with subject namespace", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "order by subject, after with subject namespace", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", }).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.BySubject) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?)) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?)", - []any{"somesubjectype", "bar", "someresourcetype", "foo", "viewer", "..."}, - map[string]int{ - "subject_ns": 1, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ?)) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "bar", "someresourcetype", "foo", "viewer", "..."}, + staticCols: []string{"subject_ns"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ((subject_ns = ?)) AND ((subject_object_id > ?) OR (subject_object_id = ? AND ns > ?) OR (subject_object_id = ? AND ns = ? AND object_id > ?) OR (subject_object_id = ? AND ns = ? AND object_id = ? AND relation > ?) OR (subject_object_id = ? AND ns = ? AND object_id = ? AND relation = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "bar", "bar", "someresourcetype", "bar", "someresourcetype", "foo", "bar", "someresourcetype", "foo", "viewer", "bar", "someresourcetype", "foo", "viewer", "..."}, + staticCols: []string{"subject_ns"}, }, }, { - "order by subject, after with subject namespace and subject object id", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "order by subject, after with subject namespace and subject object id", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", OptionalSubjectIds: []string{"foo"}, }).After(toCursor(tuple.MustParse("someresourcetype:someresource#viewer@user:bar")), options.BySubject) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?))) AND (ns,object_id,relation,subject_relation) > (?,?,?,?)", - []any{"somesubjectype", "foo", "someresourcetype", "someresource", "viewer", "..."}, - map[string]int{"subject_ns": 1, "subject_object_id": 1}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?))) AND (ns,object_id,relation,subject_relation) > (?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "foo", "someresourcetype", "someresource", "viewer", "..."}, + staticCols: []string{"subject_ns", "subject_object_id"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?))) AND ((ns > ?) OR (ns = ? AND object_id > ?) OR (ns = ? AND object_id = ? AND relation > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "foo", "someresourcetype", "someresourcetype", "someresource", "someresourcetype", "someresource", "viewer", "someresourcetype", "someresource", "viewer", "..."}, + staticCols: []string{"subject_ns", "subject_object_id"}, + }, }, { - "order by subject, after with subject namespace and multiple subject object IDs", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "order by subject, after with subject namespace and multiple subject object IDs", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", OptionalSubjectIds: []string{"foo", "bar"}, }).After(toCursor(tuple.MustParse("someresourcetype:someresource#viewer@user:next")), options.BySubject) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?,?))) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?)", - []any{"somesubjectype", "foo", "bar", "next", "someresourcetype", "someresource", "viewer", "..."}, - map[string]int{"subject_ns": 1, "subject_object_id": 2}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?))) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "foo", "bar", "next", "someresourcetype", "someresource", "viewer", "..."}, + staticCols: []string{"subject_ns"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?))) AND ((subject_object_id > ?) OR (subject_object_id = ? AND ns > ?) OR (subject_object_id = ? AND ns = ? AND object_id > ?) OR (subject_object_id = ? AND ns = ? AND object_id = ? AND relation > ?) OR (subject_object_id = ? AND ns = ? AND object_id = ? AND relation = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "foo", "bar", "next", "next", "someresourcetype", "next", "someresourcetype", "someresource", "next", "someresourcetype", "someresource", "viewer", "next", "someresourcetype", "someresource", "viewer", "..."}, + staticCols: []string{"subject_ns"}, + }, + }, + { + name: "order by subject, after with subject namespace and multiple subject object IDs and no expiration", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ + OptionalSubjectType: "somesubjectype", + OptionalSubjectIds: []string{"foo", "bar"}, + }).After(toCursor(tuple.MustParse("someresourcetype:someresource#viewer@user:next")), options.BySubject) + }, + withExpirationDisabled: true, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?))) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?)", + args: []any{"somesubjectype", "foo", "bar", "next", "someresourcetype", "someresource", "viewer", "..."}, + staticCols: []string{"subject_ns"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?))) AND ((subject_object_id > ?) OR (subject_object_id = ? AND ns > ?) OR (subject_object_id = ? AND ns = ? AND object_id > ?) OR (subject_object_id = ? AND ns = ? AND object_id = ? AND relation > ?) OR (subject_object_id = ? AND ns = ? AND object_id = ? AND relation = ? AND subject_relation > ?))", + args: []any{"somesubjectype", "foo", "bar", "next", "next", "someresourcetype", "next", "someresourcetype", "someresource", "next", "someresourcetype", "someresource", "viewer", "next", "someresourcetype", "someresource", "viewer", "..."}, + staticCols: []string{"subject_ns"}, + }, }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - base := sq.Select("*") - schema := NewSchemaInformation( - "ns", - "object_id", - "relation", - "subject_ns", - "subject_object_id", - "subject_relation", - "caveat", - "expiration", - TupleComparison, - "NOW", - ) - filterer := NewSchemaQueryFilterer(schema, base, 100) + for _, filterType := range []PaginationFilterType{TupleComparison, ExpandedLogicComparison} { + t.Run(fmt.Sprintf("filter type: %v", filterType), func(t *testing.T) { + schema := NewSchemaInformationWithOptions( + WithRelationshipTableName("relationtuples"), + WithColNamespace("ns"), + WithColObjectID("object_id"), + WithColRelation("relation"), + WithColUsersetNamespace("subject_ns"), + WithColUsersetObjectID("subject_object_id"), + WithColUsersetRelation("subject_relation"), + WithColCaveatName("caveat"), + WithColCaveatContext("caveat_context"), + WithColExpiration("expiration"), + WithPlaceholderFormat(sq.Question), + WithPaginationFilterType(filterType), + WithColumnOptimization(ColumnOptimizationOptionStaticValues), + WithNowFunction("NOW"), + ) + filterer := NewSchemaQueryFiltererForRelationshipsSelect(*schema, 100) + + ran := test.run(filterer) + foundStaticColumns := []string{} + for col, tracker := range ran.filteringColumnTracker { + if tracker.SingleValue != nil { + foundStaticColumns = append(foundStaticColumns, col) + } + } + + expected := test.expectedForTuple + if filterType == ExpandedLogicComparison && test.expectedForExpanded.sql != "" { + expected = test.expectedForExpanded + } - ran := test.run(filterer) - require.Equal(t, test.expectedColumnCounts, ran.filteringColumnCounts) + require.ElementsMatch(t, expected.staticCols, foundStaticColumns) - sql, args, err := ran.queryBuilder.ToSql() - require.NoError(t, err) - require.Equal(t, test.expectedSQL, sql) - require.Equal(t, test.expectedArgs, args) + ran.queryBuilder = ran.queryBuilderWithMaybeExpirationFilter(test.withExpirationDisabled).Columns("*") + + sql, args, err := ran.queryBuilder.ToSql() + require.NoError(t, err) + require.Equal(t, expected.sql, sql) + require.Equal(t, expected.args, args) + }) + } }) } } -func BenchmarkSchemaFilterer(b *testing.B) { - si := NewSchemaInformation( - "namespace", - "object_id", - "object_relation", - "resource_type", - "resource_id", - "resource_relation", - "caveat_name", - "expiration", - TupleComparison, - "NOW", - ) - sqf := NewSchemaQueryFilterer(si, sq.Select("*"), 100) - var names []string - for i := 0; i < 500; i++ { - names = append(names, uuid.NewString()) +func TestExecuteQuery(t *testing.T) { + tcs := []struct { + name string + run func(filterer SchemaQueryFilterer) SchemaQueryFilterer + options []options.QueryOptionsOption + expectedSQL string + expectedArgs []any + expectedStaticColCount int + expectedSkipCaveats bool + expectedSkipExpiration bool + withExpirationDisabled bool + withIntegrityEnabled bool + fromSuffix string + limit uint64 + }{ + { + name: "filter by static resource type", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype"}, + expectedStaticColCount: 1, + }, + { + name: "filter by static resource type and resource ID", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj") + }, + expectedSQL: "SELECT relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj"}, + expectedStaticColCount: 2, + }, + { + name: "filter by static resource type and resource ID prefix", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").MustFilterWithResourceIDPrefix("someprefix") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id LIKE ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someprefix%"}, + expectedStaticColCount: 1, + }, + { + name: "filter by static resource type and resource IDs", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").MustFilterToResourceIDs([]string{"someobj", "anotherobj"}) + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id IN (?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "anotherobj"}, + expectedStaticColCount: 1, + }, + { + name: "filter by static resource type, resource ID and relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel") + }, + expectedSQL: "SELECT subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel"}, + expectedStaticColCount: 3, + }, + { + name: "filter by static resource type, resource ID, relation and subject type", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel").FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + }) + }, + expectedSQL: "SELECT subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns"}, + expectedStaticColCount: 4, + }, + { + name: "filter by static resource type, resource ID, relation, subject type and subject ID", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel").FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + }) + }, + expectedSQL: "SELECT subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid"}, + expectedStaticColCount: 5, + }, + { + name: "filter by static resource type, resource ID, relation, subject type, subject ID and subject relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel").FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedStaticColCount: 6, + }, + { + name: "filter by static everything without caveats", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel").FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + }, + expectedSkipCaveats: true, + expectedSQL: "SELECT expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedStaticColCount: 6, + }, + { + name: "filter by static everything (except one field) without caveats", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").MustFilterToResourceIDs([]string{"someobj", "anotherobj"}).FilterToRelation("somerel").FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + }, + expectedSkipCaveats: true, + expectedSQL: "SELECT object_id, expiration FROM relationtuples WHERE ns = ? AND object_id IN (?,?) AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "anotherobj", "somerel", "subns", "subid", "subrel"}, + expectedStaticColCount: 5, + }, + { + name: "filter by static resource type with no caveats", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + }, + expectedSkipCaveats: true, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, expiration FROM relationtuples WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype"}, + expectedStaticColCount: 1, + }, + { + name: "filter by just subject type", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + }) + }, + expectedSQL: "SELECT ns, object_id, relation, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns"}, + expectedStaticColCount: 1, + }, + { + name: "filter by just subject type and subject ID", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + }) + }, + expectedSQL: "SELECT ns, object_id, relation, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns", "subid"}, + expectedStaticColCount: 2, + }, + { + name: "filter by just subject type and subject relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT ns, object_id, relation, subject_object_id, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns", "subrel"}, + expectedStaticColCount: 2, + }, + { + name: "filter by just subject type and subject ID and relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT ns, object_id, relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns", "subid", "subrel"}, + expectedStaticColCount: 3, + }, + { + name: "filter by multiple subject types, but static subject ID", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + }).FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "anothersubns", + OptionalSubjectId: "subid", + }) + }, + expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns", "subid", "anothersubns", "subid"}, + expectedStaticColCount: 1, + }, + { + name: "multiple subjects filters with just types", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ + OptionalSubjectType: "somesubjectype", + }, datastore.SubjectsSelector{ + OptionalSubjectType: "anothersubjectype", + }) + }, + expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "anothersubjectype"}, + expectedStaticColCount: 0, + }, + { + name: "multiple subjects filters with just types and static resource type", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ + OptionalSubjectType: "somesubjectype", + }, datastore.SubjectsSelector{ + OptionalSubjectType: "anothersubjectype", + }).FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "anothersubjectype", "sometype"}, + expectedStaticColCount: 1, + }, + { + name: "filter by static resource type with expiration disabled", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ?", + expectedArgs: []any{"sometype"}, + withExpirationDisabled: true, + expectedStaticColCount: 1, + }, + { + name: "filter by static resource type with expiration skipped", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ?", + expectedArgs: []any{"sometype"}, + withExpirationDisabled: false, + expectedSkipExpiration: true, + options: []options.QueryOptionsOption{ + options.WithSkipExpiration(true), + }, + expectedStaticColCount: 1, + }, + { + name: "filter by static resource type with expiration skipped and disabled", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ?", + expectedArgs: []any{"sometype"}, + withExpirationDisabled: true, + expectedSkipExpiration: true, + options: []options.QueryOptionsOption{ + options.WithSkipExpiration(true), + }, + expectedStaticColCount: 1, + }, + { + name: "with from suffix", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples as of tomorrow WHERE ns = ?", + expectedArgs: []any{"sometype"}, + withExpirationDisabled: true, + fromSuffix: "as of tomorrow", + expectedStaticColCount: 1, + }, + { + name: "with limit", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ? LIMIT 65", + expectedArgs: []any{"sometype"}, + withExpirationDisabled: true, + limit: 65, + expectedStaticColCount: 1, + }, + { + name: "with integrity", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, integrity_key_id, integrity_hash, integrity_timestamp FROM relationtuples WHERE ns = ?", + expectedArgs: []any{"sometype"}, + withExpirationDisabled: true, + withIntegrityEnabled: true, + expectedStaticColCount: 1, + }, + { + name: "all columns static with caveats", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype"). + FilterToResourceID("someobj"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + withExpirationDisabled: true, + expectedSkipExpiration: true, + options: []options.QueryOptionsOption{ + options.WithSkipExpiration(true), + }, + expectedStaticColCount: 6, + }, + { + name: "all columns static with expiration", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype"). + FilterToResourceID("someobj"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedSkipCaveats: true, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + }, + expectedStaticColCount: 6, + }, + { + name: "all columns static with caveats and expiration", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype"). + FilterToResourceID("someobj"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedStaticColCount: 6, + }, + { + name: "all columns static without caveats", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype"). + FilterToResourceID("someobj"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT 1 FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + withExpirationDisabled: true, + expectedSkipExpiration: true, + expectedSkipCaveats: true, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + options.WithSkipExpiration(true), + }, + expectedStaticColCount: -1, + }, + { + name: "one column not static", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + f := filterer.FilterToResourceType("sometype"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + + f2, _ := f.FilterToResourceIDs([]string{"foo", "bar"}) + return f2 + }, + expectedSQL: "SELECT object_id FROM relationtuples WHERE ns = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND object_id IN (?,?)", + expectedArgs: []any{"sometype", "somerel", "subns", "subid", "subrel", "foo", "bar"}, + withExpirationDisabled: true, + expectedSkipExpiration: true, + expectedSkipCaveats: true, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + options.WithSkipExpiration(true), + }, + expectedStaticColCount: 5, + }, + { + name: "resource ID prefix", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + f := filterer.FilterToResourceType("sometype"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + + f2, _ := f.FilterWithResourceIDPrefix("foo") + return f2 + }, + expectedSQL: "SELECT object_id FROM relationtuples WHERE ns = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND object_id LIKE ?", + expectedArgs: []any{"sometype", "somerel", "subns", "subid", "subrel", "foo%"}, + withExpirationDisabled: true, + expectedSkipExpiration: true, + expectedSkipCaveats: true, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + options.WithSkipExpiration(true), + }, + expectedStaticColCount: 5, + }, } - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = sqf.FilterToResourceIDs(names) + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + for _, filterType := range []PaginationFilterType{TupleComparison, ExpandedLogicComparison} { + t.Run(fmt.Sprintf("filter type: %v", filterType), func(t *testing.T) { + schema := NewSchemaInformationWithOptions( + WithRelationshipTableName("relationtuples"), + WithColNamespace("ns"), + WithColObjectID("object_id"), + WithColRelation("relation"), + WithColUsersetNamespace("subject_ns"), + WithColUsersetObjectID("subject_object_id"), + WithColUsersetRelation("subject_relation"), + WithColCaveatName("caveat"), + WithColCaveatContext("caveat_context"), + WithColExpiration("expiration"), + WithColIntegrityHash("integrity_hash"), + WithColIntegrityKeyID("integrity_key_id"), + WithColIntegrityTimestamp("integrity_timestamp"), + WithPlaceholderFormat(sq.Question), + WithPaginationFilterType(filterType), + WithColumnOptimization(ColumnOptimizationOptionStaticValues), + WithNowFunction("NOW"), + WithIntegrityEnabled(tc.withIntegrityEnabled), + WithExpirationDisabled(tc.withExpirationDisabled), + ) + filterer := NewSchemaQueryFiltererForRelationshipsSelect(*schema, 100) + filterer = filterer.WithFromSuffix(tc.fromSuffix) + if tc.limit > 0 { + filterer = filterer.limit(tc.limit) + } + + ran := tc.run(filterer) + + var wasRun bool + fake := QueryRelationshipsExecutor{ + Executor: func(ctx context.Context, builder RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { + sql, args, err := builder.SelectSQL() + require.NoError(t, err) + + wasRun = true + require.Equal(t, tc.expectedSQL, sql) + require.Equal(t, tc.expectedArgs, args) + require.Equal(t, tc.expectedSkipCaveats, builder.SkipCaveats) + require.Equal(t, tc.expectedSkipExpiration, builder.SkipExpiration) + + // 6 standard columns for relationships: + // ns, object_id, relation, subject_ns, subject_object_id, subject_relation + expectedColCount := 6 - tc.expectedStaticColCount + if !tc.expectedSkipCaveats { + // caveat, caveat_context + expectedColCount += 2 + } + if !tc.expectedSkipExpiration && !tc.withExpirationDisabled { + // expiration + expectedColCount++ + } + if tc.withIntegrityEnabled { + // integrity_key_id, integrity_hash, integrity_timestamp + expectedColCount += 3 + } + + if tc.expectedStaticColCount == -1 { + // SELECT 1 + expectedColCount = 1 + } + + var resourceObjectType string + var resourceObjectID string + var resourceRelation string + var subjectObjectType string + var subjectObjectID string + var subjectRelation string + var caveatName *string + var caveatCtx map[string]any + var expiration *time.Time + + var integrityKeyID string + var integrityHash []byte + var timestamp time.Time + + colsToSelect, err := ColumnsToSelect(builder, &resourceObjectType, &resourceObjectID, &resourceRelation, &subjectObjectType, &subjectObjectID, &subjectRelation, &caveatName, &caveatCtx, &expiration, &integrityKeyID, &integrityHash, ×tamp) + require.NoError(t, err) + require.Equal(t, expectedColCount, len(colsToSelect)) + + return nil, nil + }, + } + _, err := fake.ExecuteQuery(context.Background(), ran, tc.options...) + require.NoError(t, err) + require.True(t, wasRun) + }) + } + }) } } + +func TestNewSchemaQueryFiltererWithStartingQuery(t *testing.T) { + schema := NewSchemaInformationWithOptions( + WithRelationshipTableName("relationtuples"), + WithColNamespace("ns"), + WithColObjectID("object_id"), + WithColRelation("relation"), + WithColUsersetNamespace("subject_ns"), + WithColUsersetObjectID("subject_object_id"), + WithColUsersetRelation("subject_relation"), + WithColCaveatName("caveat"), + WithColCaveatContext("caveat_context"), + WithColExpiration("expiration"), + WithPlaceholderFormat(sq.Question), + WithPaginationFilterType(TupleComparison), + WithColumnOptimization(ColumnOptimizationOptionStaticValues), + WithNowFunction("NOW"), + WithExpirationDisabled(true), + ) + + sql := sq.StatementBuilder.PlaceholderFormat(sq.AtP) + query := sql.Select("COUNT(*)").From("sometable") + filterer := NewSchemaQueryFiltererWithStartingQuery(*schema, query, 50) + filterer = filterer.MustFilterToResourceIDs([]string{"someid"}) + filterer = filterer.WithAdditionalFilter(func(original sq.SelectBuilder) sq.SelectBuilder { + return original.Where("somecoolclause") + }) + + sqlQuery, args, err := filterer.UnderlyingQueryBuilder().ToSql() + require.NoError(t, err) + + expectedSQL := "SELECT COUNT(*) FROM sometable WHERE object_id IN (@p1) AND somecoolclause" + expectedArgs := []any{"someid"} + require.Equal(t, expectedSQL, sqlQuery) + require.Equal(t, expectedArgs, args) +} diff --git a/internal/datastore/common/zz_generated.schema_options.go b/internal/datastore/common/zz_generated.schema_options.go new file mode 100644 index 0000000000..04b6088a36 --- /dev/null +++ b/internal/datastore/common/zz_generated.schema_options.go @@ -0,0 +1,228 @@ +// Code generated by github.com/ecordell/optgen. DO NOT EDIT. +package common + +import ( + squirrel "github.com/Masterminds/squirrel" + defaults "github.com/creasty/defaults" + helpers "github.com/ecordell/optgen/helpers" +) + +type SchemaInformationOption func(s *SchemaInformation) + +// NewSchemaInformationWithOptions creates a new SchemaInformation with the passed in options set +func NewSchemaInformationWithOptions(opts ...SchemaInformationOption) *SchemaInformation { + s := &SchemaInformation{} + for _, o := range opts { + o(s) + } + return s +} + +// NewSchemaInformationWithOptionsAndDefaults creates a new SchemaInformation with the passed in options set starting from the defaults +func NewSchemaInformationWithOptionsAndDefaults(opts ...SchemaInformationOption) *SchemaInformation { + s := &SchemaInformation{} + defaults.MustSet(s) + for _, o := range opts { + o(s) + } + return s +} + +// ToOption returns a new SchemaInformationOption that sets the values from the passed in SchemaInformation +func (s *SchemaInformation) ToOption() SchemaInformationOption { + return func(to *SchemaInformation) { + to.RelationshipTableName = s.RelationshipTableName + to.ColNamespace = s.ColNamespace + to.ColObjectID = s.ColObjectID + to.ColRelation = s.ColRelation + to.ColUsersetNamespace = s.ColUsersetNamespace + to.ColUsersetObjectID = s.ColUsersetObjectID + to.ColUsersetRelation = s.ColUsersetRelation + to.ColCaveatName = s.ColCaveatName + to.ColCaveatContext = s.ColCaveatContext + to.ColExpiration = s.ColExpiration + to.ColIntegrityKeyID = s.ColIntegrityKeyID + to.ColIntegrityHash = s.ColIntegrityHash + to.ColIntegrityTimestamp = s.ColIntegrityTimestamp + to.PaginationFilterType = s.PaginationFilterType + to.PlaceholderFormat = s.PlaceholderFormat + to.NowFunction = s.NowFunction + to.ColumnOptimization = s.ColumnOptimization + to.IntegrityEnabled = s.IntegrityEnabled + to.ExpirationDisabled = s.ExpirationDisabled + } +} + +// DebugMap returns a map form of SchemaInformation for debugging +func (s SchemaInformation) DebugMap() map[string]any { + debugMap := map[string]any{} + debugMap["RelationshipTableName"] = helpers.DebugValue(s.RelationshipTableName, false) + debugMap["ColNamespace"] = helpers.DebugValue(s.ColNamespace, false) + debugMap["ColObjectID"] = helpers.DebugValue(s.ColObjectID, false) + debugMap["ColRelation"] = helpers.DebugValue(s.ColRelation, false) + debugMap["ColUsersetNamespace"] = helpers.DebugValue(s.ColUsersetNamespace, false) + debugMap["ColUsersetObjectID"] = helpers.DebugValue(s.ColUsersetObjectID, false) + debugMap["ColUsersetRelation"] = helpers.DebugValue(s.ColUsersetRelation, false) + debugMap["ColCaveatName"] = helpers.DebugValue(s.ColCaveatName, false) + debugMap["ColCaveatContext"] = helpers.DebugValue(s.ColCaveatContext, false) + debugMap["ColExpiration"] = helpers.DebugValue(s.ColExpiration, false) + debugMap["ColIntegrityKeyID"] = helpers.DebugValue(s.ColIntegrityKeyID, false) + debugMap["ColIntegrityHash"] = helpers.DebugValue(s.ColIntegrityHash, false) + debugMap["ColIntegrityTimestamp"] = helpers.DebugValue(s.ColIntegrityTimestamp, false) + debugMap["PaginationFilterType"] = helpers.DebugValue(s.PaginationFilterType, false) + debugMap["PlaceholderFormat"] = helpers.DebugValue(s.PlaceholderFormat, false) + debugMap["NowFunction"] = helpers.DebugValue(s.NowFunction, false) + debugMap["ColumnOptimization"] = helpers.DebugValue(s.ColumnOptimization, false) + debugMap["IntegrityEnabled"] = helpers.DebugValue(s.IntegrityEnabled, false) + debugMap["ExpirationDisabled"] = helpers.DebugValue(s.ExpirationDisabled, false) + return debugMap +} + +// SchemaInformationWithOptions configures an existing SchemaInformation with the passed in options set +func SchemaInformationWithOptions(s *SchemaInformation, opts ...SchemaInformationOption) *SchemaInformation { + for _, o := range opts { + o(s) + } + return s +} + +// WithOptions configures the receiver SchemaInformation with the passed in options set +func (s *SchemaInformation) WithOptions(opts ...SchemaInformationOption) *SchemaInformation { + for _, o := range opts { + o(s) + } + return s +} + +// WithRelationshipTableName returns an option that can set RelationshipTableName on a SchemaInformation +func WithRelationshipTableName(relationshipTableName string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.RelationshipTableName = relationshipTableName + } +} + +// WithColNamespace returns an option that can set ColNamespace on a SchemaInformation +func WithColNamespace(colNamespace string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColNamespace = colNamespace + } +} + +// WithColObjectID returns an option that can set ColObjectID on a SchemaInformation +func WithColObjectID(colObjectID string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColObjectID = colObjectID + } +} + +// WithColRelation returns an option that can set ColRelation on a SchemaInformation +func WithColRelation(colRelation string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColRelation = colRelation + } +} + +// WithColUsersetNamespace returns an option that can set ColUsersetNamespace on a SchemaInformation +func WithColUsersetNamespace(colUsersetNamespace string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColUsersetNamespace = colUsersetNamespace + } +} + +// WithColUsersetObjectID returns an option that can set ColUsersetObjectID on a SchemaInformation +func WithColUsersetObjectID(colUsersetObjectID string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColUsersetObjectID = colUsersetObjectID + } +} + +// WithColUsersetRelation returns an option that can set ColUsersetRelation on a SchemaInformation +func WithColUsersetRelation(colUsersetRelation string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColUsersetRelation = colUsersetRelation + } +} + +// WithColCaveatName returns an option that can set ColCaveatName on a SchemaInformation +func WithColCaveatName(colCaveatName string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColCaveatName = colCaveatName + } +} + +// WithColCaveatContext returns an option that can set ColCaveatContext on a SchemaInformation +func WithColCaveatContext(colCaveatContext string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColCaveatContext = colCaveatContext + } +} + +// WithColExpiration returns an option that can set ColExpiration on a SchemaInformation +func WithColExpiration(colExpiration string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColExpiration = colExpiration + } +} + +// WithColIntegrityKeyID returns an option that can set ColIntegrityKeyID on a SchemaInformation +func WithColIntegrityKeyID(colIntegrityKeyID string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColIntegrityKeyID = colIntegrityKeyID + } +} + +// WithColIntegrityHash returns an option that can set ColIntegrityHash on a SchemaInformation +func WithColIntegrityHash(colIntegrityHash string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColIntegrityHash = colIntegrityHash + } +} + +// WithColIntegrityTimestamp returns an option that can set ColIntegrityTimestamp on a SchemaInformation +func WithColIntegrityTimestamp(colIntegrityTimestamp string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColIntegrityTimestamp = colIntegrityTimestamp + } +} + +// WithPaginationFilterType returns an option that can set PaginationFilterType on a SchemaInformation +func WithPaginationFilterType(paginationFilterType PaginationFilterType) SchemaInformationOption { + return func(s *SchemaInformation) { + s.PaginationFilterType = paginationFilterType + } +} + +// WithPlaceholderFormat returns an option that can set PlaceholderFormat on a SchemaInformation +func WithPlaceholderFormat(placeholderFormat squirrel.PlaceholderFormat) SchemaInformationOption { + return func(s *SchemaInformation) { + s.PlaceholderFormat = placeholderFormat + } +} + +// WithNowFunction returns an option that can set NowFunction on a SchemaInformation +func WithNowFunction(nowFunction string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.NowFunction = nowFunction + } +} + +// WithColumnOptimization returns an option that can set ColumnOptimization on a SchemaInformation +func WithColumnOptimization(columnOptimization ColumnOptimizationOption) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColumnOptimization = columnOptimization + } +} + +// WithIntegrityEnabled returns an option that can set IntegrityEnabled on a SchemaInformation +func WithIntegrityEnabled(integrityEnabled bool) SchemaInformationOption { + return func(s *SchemaInformation) { + s.IntegrityEnabled = integrityEnabled + } +} + +// WithExpirationDisabled returns an option that can set ExpirationDisabled on a SchemaInformation +func WithExpirationDisabled(expirationDisabled bool) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ExpirationDisabled = expirationDisabled + } +} diff --git a/internal/datastore/crdb/caveat.go b/internal/datastore/crdb/caveat.go index 3f66b95810..94c74f6553 100644 --- a/internal/datastore/crdb/caveat.go +++ b/internal/datastore/crdb/caveat.go @@ -23,7 +23,7 @@ var ( ) writeCaveat = psql.Insert(tableCaveat).Columns(colCaveatName, colCaveatDefinition).Suffix(upsertCaveatSuffix) readCaveat = psql.Select(colCaveatDefinition, colTimestamp) - listCaveat = psql.Select(colCaveatName, colCaveatDefinition, colTimestamp).From(tableCaveat).OrderBy(colCaveatName) + listCaveat = psql.Select(colCaveatName, colCaveatDefinition, colTimestamp).OrderBy(colCaveatName) deleteCaveat = psql.Delete(tableCaveat) ) @@ -35,11 +35,12 @@ const ( ) func (cr *crdbReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { - query := cr.fromBuilder(readCaveat, tableCaveat).Where(sq.Eq{colCaveatName: name}) + query := cr.addFromToQuery(readCaveat.Where(sq.Eq{colCaveatName: name}), tableCaveat) sql, args, err := query.ToSql() if err != nil { return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err) } + cr.assertHasExpectedAsOfSystemTime(sql) var definitionBytes []byte var timestamp time.Time @@ -79,7 +80,7 @@ type bytesAndTimestamp struct { } func (cr *crdbReader) lookupCaveats(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) { - caveatsWithNames := cr.fromBuilder(listCaveat, tableCaveat) + caveatsWithNames := cr.addFromToQuery(listCaveat, tableCaveat) if len(caveatNames) > 0 { caveatsWithNames = caveatsWithNames.Where(sq.Eq{colCaveatName: caveatNames}) } @@ -88,6 +89,7 @@ func (cr *crdbReader) lookupCaveats(ctx context.Context, caveatNames []string) ( if err != nil { return nil, fmt.Errorf(errListCaveats, err) } + cr.assertHasExpectedAsOfSystemTime(sql) var allDefinitionBytes []bytesAndTimestamp diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index 88d6e3e9ae..517afe0c2c 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -200,6 +200,33 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas return nil, fmt.Errorf("invalid head migration found for cockroach: %w", err) } + relTableName := tableTuple + if config.withIntegrity { + relTableName = tableTupleWithIntegrity + } + + schema := common.NewSchemaInformationWithOptions( + common.WithRelationshipTableName(relTableName), + common.WithColNamespace(colNamespace), + common.WithColObjectID(colObjectID), + common.WithColRelation(colRelation), + common.WithColUsersetNamespace(colUsersetNamespace), + common.WithColUsersetObjectID(colUsersetObjectID), + common.WithColUsersetRelation(colUsersetRelation), + common.WithColCaveatName(colCaveatContextName), + common.WithColCaveatContext(colCaveatContext), + common.WithColExpiration(colExpiration), + common.WithColIntegrityKeyID(colIntegrityKeyID), + common.WithColIntegrityHash(colIntegrityHash), + common.WithColIntegrityTimestamp(colTimestamp), + common.WithPaginationFilterType(common.ExpandedLogicComparison), + common.WithPlaceholderFormat(sq.Dollar), + common.WithNowFunction("NOW"), + common.WithColumnOptimization(config.columnOptimizationOption), + common.WithIntegrityEnabled(config.withIntegrity), + common.WithExpirationDisabled(config.expirationDisabled), + ) + ds := &crdbDatastore{ RemoteClockRevisions: revisions.NewRemoteClockRevisions( config.gcWindow, @@ -221,6 +248,7 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas filterMaximumIDCount: config.filterMaximumIDCount, supportsIntegrity: config.withIntegrity, gcWindow: config.gcWindow, + schema: *schema, } ds.RemoteClockRevisions.SetNowFunc(ds.headRevisionInternal) @@ -305,6 +333,7 @@ type crdbDatastore struct { overlapKeyInit func(ctx context.Context) keySet analyzeBeforeStatistics bool gcWindow time.Duration + schema common.SchemaInformation beginChangefeedQuery string transactionNowQuery string @@ -319,15 +348,19 @@ type crdbDatastore struct { } func (cds *crdbDatastore) SnapshotReader(rev datastore.Revision) datastore.Reader { - executor := common.QueryExecutor{ - Executor: pgxcommon.NewPGXExecutorWithIntegrityOption(cds.readPool, cds.supportsIntegrity), - } - - fromBuilder := func(query sq.SelectBuilder, fromStr string) sq.SelectBuilder { - return query.From(fromStr + " AS OF SYSTEM TIME " + rev.String()) + executor := common.QueryRelationshipsExecutor{ + Executor: pgxcommon.NewPGXQueryRelationshipsExecutor(cds.readPool), + } + return &crdbReader{ + schema: cds.schema, + query: cds.readPool, + executor: executor, + keyer: noOverlapKeyer, + overlapKeySet: nil, + filterMaximumIDCount: cds.filterMaximumIDCount, + withIntegrity: cds.supportsIntegrity, + atSpecificRevision: rev.String(), } - - return &crdbReader{cds.readPool, executor, noOverlapKeyer, nil, fromBuilder, cds.filterMaximumIDCount, cds.tableTupleName(), cds.supportsIntegrity} } func (cds *crdbDatastore) ReadWriteTx( @@ -344,8 +377,8 @@ func (cds *crdbDatastore) ReadWriteTx( err := cds.writePool.BeginFunc(ctx, func(tx pgx.Tx) error { querier := pgxcommon.QuerierFuncsFor(tx) - executor := common.QueryExecutor{ - Executor: pgxcommon.NewPGXExecutorWithIntegrityOption(querier, cds.supportsIntegrity), + executor := common.QueryRelationshipsExecutor{ + Executor: pgxcommon.NewPGXQueryRelationshipsExecutor(querier), } // Write metadata onto the transaction. @@ -369,19 +402,19 @@ func (cds *crdbDatastore) ReadWriteTx( return fmt.Errorf("error writing metadata: %w", err) } + reader := &crdbReader{ + schema: cds.schema, + query: querier, + executor: executor, + keyer: cds.writeOverlapKeyer, + overlapKeySet: cds.overlapKeyInit(ctx), + filterMaximumIDCount: cds.filterMaximumIDCount, + withIntegrity: cds.supportsIntegrity, + atSpecificRevision: "", // No AS OF SYSTEM TIME for writes + } + rwt := &crdbReadWriteTXN{ - &crdbReader{ - querier, - executor, - cds.writeOverlapKeyer, - cds.overlapKeyInit(ctx), - func(query sq.SelectBuilder, fromStr string) sq.SelectBuilder { - return query.From(fromStr) - }, - cds.filterMaximumIDCount, - cds.tableTupleName(), - cds.supportsIntegrity, - }, + reader, tx, 0, } @@ -526,14 +559,6 @@ func (cds *crdbDatastore) Features(ctx context.Context) (*datastore.Features, er return features, err } -func (cds *crdbDatastore) tableTupleName() string { - if cds.supportsIntegrity { - return tableTupleWithIntegrity - } - - return tableTuple -} - func (cds *crdbDatastore) features(ctx context.Context) (*datastore.Features, error) { features := datastore.Features{ ContinuousCheckpointing: datastore.Feature{ @@ -567,7 +592,7 @@ func (cds *crdbDatastore) features(ctx context.Context) (*datastore.Features, er features.Watch.Reason = fmt.Sprintf("Range feeds must be enabled in CockroachDB and the user must have permission to create them in order to enable the Watch API: %s", err.Error()) } return nil - }, fmt.Sprintf(cds.beginChangefeedQuery, cds.tableTupleName(), head, "1s")) + }, fmt.Sprintf(cds.beginChangefeedQuery, cds.schema.RelationshipTableName, head, "1s")) <-streamCtx.Done() diff --git a/internal/datastore/crdb/options.go b/internal/datastore/crdb/options.go index 67d8933638..7c62227085 100644 --- a/internal/datastore/crdb/options.go +++ b/internal/datastore/crdb/options.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/authzed/spicedb/internal/datastore/common" pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" log "github.com/authzed/spicedb/internal/logging" ) @@ -27,8 +28,10 @@ type crdbOptions struct { filterMaximumIDCount uint16 enablePrometheusStats bool withIntegrity bool - includeQueryParametersInTraces bool allowedMigrations []string + columnOptimizationOption common.ColumnOptimizationOption + includeQueryParametersInTraces bool + expirationDisabled bool } const ( @@ -56,7 +59,9 @@ const ( defaultConnectRate = 100 * time.Millisecond defaultFilterMaximumIDCount = 100 defaultWithIntegrity = false + defaultColumnOptimizationOption = common.ColumnOptimizationOptionNone defaultIncludeQueryParametersInTraces = false + defaultExpirationDisabled = false ) // Option provides the facility to configure how clients within the CRDB @@ -80,7 +85,9 @@ func generateConfig(options []Option) (crdbOptions, error) { connectRate: defaultConnectRate, filterMaximumIDCount: defaultFilterMaximumIDCount, withIntegrity: defaultWithIntegrity, + columnOptimizationOption: defaultColumnOptimizationOption, includeQueryParametersInTraces: defaultIncludeQueryParametersInTraces, + expirationDisabled: defaultExpirationDisabled, } for _, option := range options { @@ -353,3 +360,19 @@ func AllowedMigrations(allowedMigrations []string) Option { func IncludeQueryParametersInTraces(includeQueryParametersInTraces bool) Option { return func(po *crdbOptions) { po.includeQueryParametersInTraces = includeQueryParametersInTraces } } + +// WithColumnOptimization configures the column optimization option for the datastore. +func WithColumnOptimization(isEnabled bool) Option { + return func(po *crdbOptions) { + if isEnabled { + po.columnOptimizationOption = common.ColumnOptimizationOptionStaticValues + } else { + po.columnOptimizationOption = common.ColumnOptimizationOptionNone + } + } +} + +// WithExpirationDisabled configures the datastore to disable relationship expiration. +func WithExpirationDisabled(isDisabled bool) Option { + return func(po *crdbOptions) { po.expirationDisabled = isDisabled } +} diff --git a/internal/datastore/crdb/reader.go b/internal/datastore/crdb/reader.go index 70002a8965..db44aaadd6 100644 --- a/internal/datastore/crdb/reader.go +++ b/internal/datastore/crdb/reader.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strings" "time" sq "github.com/Masterminds/squirrel" @@ -16,6 +17,7 @@ import ( "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/options" core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" ) const ( @@ -29,19 +31,6 @@ var ( countRels = psql.Select("count(*)") - schema = common.NewSchemaInformation( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colExpiration, - common.ExpandedLogicComparison, - "NOW", - ) - queryCounters = psql.Select( colCounterName, colCounterSerializedFilter, @@ -51,14 +40,42 @@ var ( ) type crdbReader struct { + schema common.SchemaInformation query pgxcommon.DBFuncQuerier - executor common.QueryExecutor + executor common.QueryRelationshipsExecutor keyer overlapKeyer overlapKeySet keySet - fromBuilder func(query sq.SelectBuilder, fromStr string) sq.SelectBuilder filterMaximumIDCount uint16 - tupleTableName string withIntegrity bool + atSpecificRevision string +} + +const asOfSystemTime = "AS OF SYSTEM TIME" + +func (cr *crdbReader) addFromToQuery(query sq.SelectBuilder, tableName string) sq.SelectBuilder { + if cr.atSpecificRevision == "" { + return query.From(tableName) + } + + return query.From(tableName + " " + asOfSystemTime + " " + cr.atSpecificRevision) +} + +func (cr *crdbReader) fromSuffix() string { + if cr.atSpecificRevision == "" { + return "" + } + + return " " + asOfSystemTime + " " + cr.atSpecificRevision +} + +func (cr *crdbReader) assertHasExpectedAsOfSystemTime(sql string) { + spiceerrors.DebugAssert(func() bool { + if cr.atSpecificRevision == "" { + return !strings.Contains(sql, "AS OF SYSTEM TIME") + } else { + return strings.Contains(sql, "AS OF SYSTEM TIME") + } + }, "mismatch in AS OF SYSTEM TIME in query: %s", sql) } func (cr *crdbReader) CountRelationships(ctx context.Context, name string) (int, error) { @@ -76,8 +93,8 @@ func (cr *crdbReader) CountRelationships(ctx context.Context, name string) (int, return 0, err } - query := cr.fromBuilder(countRels, cr.tupleTableName) - builder, err := common.NewSchemaQueryFilterer(schema, query, cr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) + query := cr.addFromToQuery(countRels, cr.schema.RelationshipTableName) + builder, err := common.NewSchemaQueryFiltererWithStartingQuery(cr.schema, query, cr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) if err != nil { return 0, err } @@ -86,6 +103,7 @@ func (cr *crdbReader) CountRelationships(ctx context.Context, name string) (int, if err != nil { return 0, err } + cr.assertHasExpectedAsOfSystemTime(sql) var count int err = cr.query.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error { @@ -105,8 +123,7 @@ func (cr *crdbReader) LookupCounters(ctx context.Context) ([]datastore.Relations } func (cr *crdbReader) lookupCounters(ctx context.Context, optionalFilterName string) ([]datastore.RelationshipCounter, error) { - query := cr.fromBuilder(queryCounters, tableRelationshipCounter) - + query := cr.addFromToQuery(queryCounters, tableRelationshipCounter) if optionalFilterName != noFilterOnCounterName { query = query.Where(sq.Eq{colCounterName: optionalFilterName}) } @@ -115,6 +132,7 @@ func (cr *crdbReader) lookupCounters(ctx context.Context, optionalFilterName str if err != nil { return nil, err } + cr.assertHasExpectedAsOfSystemTime(sql) var counters []datastore.RelationshipCounter err = cr.query.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error { @@ -178,44 +196,14 @@ func (cr *crdbReader) ReadNamespaceByName( } func (cr *crdbReader) ListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { - nsDefs, err := loadAllNamespaces(ctx, cr.query, cr.fromBuilder) + nsDefs, sql, err := loadAllNamespaces(ctx, cr.query, cr.addFromToQuery) if err != nil { return nil, fmt.Errorf(errUnableToListNamespaces, err) } + cr.assertHasExpectedAsOfSystemTime(sql) return nsDefs, nil } -func (cr *crdbReader) queryTuples() sq.SelectBuilder { - if cr.withIntegrity { - return psql.Select( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colCaveatContext, - colExpiration, - colIntegrityKeyID, - colIntegrityHash, - colTimestamp, - ) - } - - return psql.Select( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colCaveatContext, - colExpiration, - ) -} - func (cr *crdbReader) LookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedNamespace, error) { if len(nsNames) == 0 { return nil, nil @@ -232,12 +220,15 @@ func (cr *crdbReader) QueryRelationships( filter datastore.RelationshipsFilter, opts ...options.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - query := cr.fromBuilder(cr.queryTuples(), cr.tupleTableName) - qBuilder, err := common.NewSchemaQueryFilterer(schema, query, cr.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(cr.schema, cr.filterMaximumIDCount).WithFromSuffix(cr.fromSuffix()).FilterWithRelationshipsFilter(filter) if err != nil { return nil, err } + if spiceerrors.DebugAssertionsEnabled { + opts = append(opts, options.WithSQLAssertion(cr.assertHasExpectedAsOfSystemTime)) + } + return cr.executor.ExecuteQuery(ctx, qBuilder, opts...) } @@ -246,36 +237,44 @@ func (cr *crdbReader) ReverseQueryRelationships( subjectsFilter datastore.SubjectsFilter, opts ...options.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - query := cr.fromBuilder(cr.queryTuples(), cr.tupleTableName) - qBuilder, err := common.NewSchemaQueryFilterer(schema, query, cr.filterMaximumIDCount). + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(cr.schema, cr.filterMaximumIDCount). + WithFromSuffix(cr.fromSuffix()). FilterWithSubjectsSelectors(subjectsFilter.AsSelector()) if err != nil { return nil, err } queryOpts := options.NewReverseQueryOptionsWithOptions(opts...) - if queryOpts.ResRelation != nil { qBuilder = qBuilder. FilterToResourceType(queryOpts.ResRelation.Namespace). FilterToRelation(queryOpts.ResRelation.Relation) } + eopts := []options.QueryOptionsOption{ + options.WithLimit(queryOpts.LimitForReverse), + options.WithAfter(queryOpts.AfterForReverse), + options.WithSort(queryOpts.SortForReverse), + } + + if spiceerrors.DebugAssertionsEnabled { + eopts = append(eopts, options.WithSQLAssertion(cr.assertHasExpectedAsOfSystemTime)) + } + return cr.executor.ExecuteQuery( ctx, qBuilder, - options.WithLimit(queryOpts.LimitForReverse), - options.WithAfter(queryOpts.AfterForReverse), - options.WithSort(queryOpts.SortForReverse)) + eopts..., + ) } func (cr crdbReader) loadNamespace(ctx context.Context, tx pgxcommon.DBFuncQuerier, nsName string) (*core.NamespaceDefinition, time.Time, error) { - query := cr.fromBuilder(queryReadNamespace, tableNamespace).Where(sq.Eq{colNamespace: nsName}) - + query := cr.addFromToQuery(queryReadNamespace, tableNamespace).Where(sq.Eq{colNamespace: nsName}) sql, args, err := query.ToSql() if err != nil { return nil, time.Time{}, err } + cr.assertHasExpectedAsOfSystemTime(sql) var config []byte var timestamp time.Time @@ -304,12 +303,12 @@ func (cr crdbReader) lookupNamespaces(ctx context.Context, tx pgxcommon.DBFuncQu clause = append(clause, sq.Eq{colNamespace: nsName}) } - query := cr.fromBuilder(queryReadNamespace, tableNamespace).Where(clause) - + query := cr.addFromToQuery(queryReadNamespace, tableNamespace).Where(clause) sql, args, err := query.ToSql() if err != nil { return nil, err } + cr.assertHasExpectedAsOfSystemTime(sql) var nsDefs []datastore.RevisionedNamespace @@ -344,12 +343,11 @@ func (cr crdbReader) lookupNamespaces(ctx context.Context, tx pgxcommon.DBFuncQu return nsDefs, nil } -func loadAllNamespaces(ctx context.Context, tx pgxcommon.DBFuncQuerier, fromBuilder func(sq.SelectBuilder, string) sq.SelectBuilder) ([]datastore.RevisionedNamespace, error) { +func loadAllNamespaces(ctx context.Context, tx pgxcommon.DBFuncQuerier, fromBuilder func(sq.SelectBuilder, string) sq.SelectBuilder) ([]datastore.RevisionedNamespace, string, error) { query := fromBuilder(queryReadNamespace, tableNamespace) - sql, args, err := query.ToSql() if err != nil { - return nil, err + return nil, sql, err } var nsDefs []datastore.RevisionedNamespace @@ -379,10 +377,10 @@ func loadAllNamespaces(ctx context.Context, tx pgxcommon.DBFuncQuerier, fromBuil return nil }, sql, args...) if err != nil { - return nil, err + return nil, sql, err } - return nsDefs, nil + return nsDefs, sql, nil } func (cr *crdbReader) addOverlapKey(namespace string) { diff --git a/internal/datastore/crdb/readwrite.go b/internal/datastore/crdb/readwrite.go index 53a25cc586..acfaf97709 100644 --- a/internal/datastore/crdb/readwrite.go +++ b/internal/datastore/crdb/readwrite.go @@ -123,11 +123,11 @@ var ( ) func (rwt *crdbReadWriteTXN) insertQuery() sq.InsertBuilder { - return psql.Insert(rwt.tupleTableName) + return psql.Insert(rwt.schema.RelationshipTableName) } func (rwt *crdbReadWriteTXN) queryDeleteTuples() sq.DeleteBuilder { - return psql.Delete(rwt.tupleTableName) + return psql.Delete(rwt.schema.RelationshipTableName) } func (rwt *crdbReadWriteTXN) queryWriteTuple() sq.InsertBuilder { @@ -555,10 +555,10 @@ var copyColsWithIntegrity = []string{ func (rwt *crdbReadWriteTXN) BulkLoad(ctx context.Context, iter datastore.BulkWriteRelationshipSource) (uint64, error) { if rwt.withIntegrity { - return pgxcommon.BulkLoad(ctx, rwt.tx, rwt.tupleTableName, copyColsWithIntegrity, iter) + return pgxcommon.BulkLoad(ctx, rwt.tx, rwt.schema.RelationshipTableName, copyColsWithIntegrity, iter) } - return pgxcommon.BulkLoad(ctx, rwt.tx, rwt.tupleTableName, copyCols, iter) + return pgxcommon.BulkLoad(ctx, rwt.tx, rwt.schema.RelationshipTableName, copyCols, iter) } var _ datastore.ReadWriteTransaction = &crdbReadWriteTXN{} diff --git a/internal/datastore/crdb/stats.go b/internal/datastore/crdb/stats.go index 2b66297d91..c468747b1a 100644 --- a/internal/datastore/crdb/stats.go +++ b/internal/datastore/crdb/stats.go @@ -44,8 +44,8 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro if err != nil { return fmt.Errorf("unable to read namespaces: %w", err) } - nsDefs, err = loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), func(sb squirrel.SelectBuilder, fromStr string) squirrel.SelectBuilder { - return sb.From(fromStr) + nsDefs, _, err = loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), func(sb squirrel.SelectBuilder, tableName string) squirrel.SelectBuilder { + return sb.From(tableName) }) if err != nil { return fmt.Errorf("unable to read namespaces: %w", err) @@ -57,7 +57,7 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro if cds.analyzeBeforeStatistics { if err := cds.readPool.BeginTxFunc(ctx, pgx.TxOptions{AccessMode: pgx.ReadOnly}, func(tx pgx.Tx) error { - if _, err := tx.Exec(ctx, "ANALYZE "+cds.tableTupleName()); err != nil { + if _, err := tx.Exec(ctx, "ANALYZE "+cds.schema.RelationshipTableName); err != nil { return fmt.Errorf("unable to analyze tuple table: %w", err) } @@ -131,7 +131,7 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro log.Warn().Bool("has-rows", hasRows).Msg("unable to find row count in statistics query result") return nil - }, "SHOW STATISTICS FOR TABLE "+cds.tableTupleName()); err != nil { + }, "SHOW STATISTICS FOR TABLE "+cds.schema.RelationshipTableName); err != nil { return datastore.Stats{}, fmt.Errorf("unable to query unique estimated row count: %w", err) } diff --git a/internal/datastore/crdb/watch.go b/internal/datastore/crdb/watch.go index de341112f7..2008a743ad 100644 --- a/internal/datastore/crdb/watch.go +++ b/internal/datastore/crdb/watch.go @@ -128,7 +128,7 @@ func (cds *crdbDatastore) watch( tableNames := make([]string, 0, 4) tableNames = append(tableNames, tableTransactionMetadata) if opts.Content&datastore.WatchRelationships == datastore.WatchRelationships { - tableNames = append(tableNames, cds.tableTupleName()) + tableNames = append(tableNames, cds.schema.RelationshipTableName) } if opts.Content&datastore.WatchSchema == datastore.WatchSchema { tableNames = append(tableNames, tableNamespace) @@ -433,7 +433,7 @@ func (cds *crdbDatastore) processChanges(ctx context.Context, changes pgx.Rows, } switch tableName { - case cds.tableTupleName(): + case cds.schema.RelationshipTableName: var caveatName string var caveatContext map[string]any if details.After != nil && details.After.RelationshipCaveatName != "" { diff --git a/internal/datastore/dsfortesting/dsfortesting.go b/internal/datastore/dsfortesting/dsfortesting.go new file mode 100644 index 0000000000..b6fc7f5d9c --- /dev/null +++ b/internal/datastore/dsfortesting/dsfortesting.go @@ -0,0 +1,156 @@ +package dsfortesting + +import ( + "context" + "fmt" + "time" + + sq "github.com/Masterminds/squirrel" + + "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/tuple" +) + +// NewMemDBDatastoreForTesting creates a new in-memory datastore for testing. +// This is a convenience function that wraps the creation of a new MemDB datastore, +// and injects additional proxies for validation at test time. +// NOTE: These additional proxies are not performant for use in production (but then, +// neither is memdb) +func NewMemDBDatastoreForTesting( + watchBufferLength uint16, + revisionQuantization, + gcWindow time.Duration, +) (datastore.Datastore, error) { + ds, err := memdb.NewMemdbDatastore(watchBufferLength, revisionQuantization, gcWindow) + if err != nil { + return nil, err + } + + return validatingDatastore{ds}, nil +} + +type validatingDatastore struct { + datastore.Datastore +} + +func (vds validatingDatastore) SnapshotReader(rev datastore.Revision) datastore.Reader { + return validatingReader{vds.Datastore.SnapshotReader(rev)} +} + +type validatingReader struct { + datastore.Reader +} + +func (vr validatingReader) QueryRelationships( + ctx context.Context, + filter datastore.RelationshipsFilter, + options ...options.QueryOptionsOption, +) (datastore.RelationshipIterator, error) { + schema := common.NewSchemaInformationWithOptions( + common.WithRelationshipTableName("relationtuples"), + common.WithColNamespace("ns"), + common.WithColObjectID("object_id"), + common.WithColRelation("relation"), + common.WithColUsersetNamespace("subject_ns"), + common.WithColUsersetObjectID("subject_object_id"), + common.WithColUsersetRelation("subject_relation"), + common.WithColCaveatName("caveat"), + common.WithColCaveatContext("caveat_context"), + common.WithColExpiration("expiration"), + common.WithPlaceholderFormat(sq.Question), + common.WithPaginationFilterType(common.TupleComparison), + common.WithColumnOptimization(common.ColumnOptimizationOptionStaticValues), + common.WithNowFunction("NOW"), + ) + + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(*schema, 100). + FilterWithRelationshipsFilter(filter) + if err != nil { + return nil, err + } + + // Run the filter through the common SQL ellison system and ensure that any + // relationships return have values matching the static fields, if applicable. + var builder *common.RelationshipsQueryBuilder + executor := common.QueryRelationshipsExecutor{ + Executor: func(ctx context.Context, b common.RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { + builder = &b + return nil, nil + }, + } + + _, _ = executor.ExecuteQuery(ctx, qBuilder, options...) + if builder == nil { + return nil, fmt.Errorf("no builder returned") + } + + checkStaticField := func(returnedValue string, fieldName string) error { + if found, ok := builder.FilteringValuesForTesting()[fieldName]; ok && found.SingleValue != nil { + if returnedValue != *found.SingleValue { + return fmt.Errorf("static field `%s` does not match expected value `%s`: `%s", fieldName, returnedValue, *found.SingleValue) + } + } + + return nil + } + + // Run the actual query on the memdb instance. + iter, err := vr.Reader.QueryRelationships(ctx, filter, options...) + if err != nil { + return nil, err + } + + return func(yield func(tuple.Relationship, error) bool) { + for rel, err := range iter { + if err != nil { + if !yield(rel, err) { + return + } + continue + } + + if err := checkStaticField(rel.Resource.ObjectType, "ns"); err != nil { + if !yield(rel, err) { + return + } + } + + if err := checkStaticField(rel.Resource.ObjectID, "object_id"); err != nil { + if !yield(rel, err) { + return + } + } + + if err := checkStaticField(rel.Resource.Relation, "relation"); err != nil { + if !yield(rel, err) { + return + } + } + + if err := checkStaticField(rel.Subject.ObjectType, "subject_ns"); err != nil { + if !yield(rel, err) { + return + } + } + + if err := checkStaticField(rel.Subject.ObjectID, "subject_object_id"); err != nil { + if !yield(rel, err) { + return + } + } + + if err := checkStaticField(rel.Subject.Relation, "subject_relation"); err != nil { + if !yield(rel, err) { + return + } + } + + if !yield(rel, err) { + return + } + } + }, nil +} diff --git a/internal/datastore/memdb/readonly.go b/internal/datastore/memdb/readonly.go index 87ef405c1f..e348e2d773 100644 --- a/internal/datastore/memdb/readonly.go +++ b/internal/datastore/memdb/readonly.go @@ -151,11 +151,11 @@ func (r *memdbReader) QueryRelationships( fallthrough case options.ByResource: - iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.Limit) + iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.Limit, queryOpts.SkipCaveats, queryOpts.SkipExpiration) return iter, nil case options.BySubject: - return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.Limit) + return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.Limit, queryOpts.SkipCaveats, queryOpts.SkipExpiration) default: return nil, spiceerrors.MustBugf("unsupported sort order: %v", queryOpts.Sort) @@ -214,11 +214,11 @@ func (r *memdbReader) ReverseQueryRelationships( fallthrough case options.ByResource: - iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.LimitForReverse) + iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.LimitForReverse, false, false) return iter, nil case options.BySubject: - return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.LimitForReverse) + return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.LimitForReverse, false, false) default: return nil, spiceerrors.MustBugf("unsupported sort order: %v", queryOpts.SortForReverse) @@ -476,7 +476,7 @@ func makeCursorFilterFn(after options.Cursor, order options.SortOrder) func(tpl return noopCursorFilter } -func newSubjectSortedIterator(now time.Time, it memdb.ResultIterator, limit *uint64) (datastore.RelationshipIterator, error) { +func newSubjectSortedIterator(now time.Time, it memdb.ResultIterator, limit *uint64, skipCaveats bool, skipExpiration bool) (datastore.RelationshipIterator, error) { results := make([]tuple.Relationship, 0) // Coalesce all of the results into memory @@ -490,6 +490,14 @@ func newSubjectSortedIterator(now time.Time, it memdb.ResultIterator, limit *uin continue } + if skipCaveats && rt.OptionalCaveat != nil { + return nil, spiceerrors.MustBugf("unexpected caveat in result for relationship: %v", rt) + } + + if skipExpiration && rt.OptionalExpiration != nil { + return nil, spiceerrors.MustBugf("unexpected expiration in result for relationship: %v", rt) + } + results = append(results, rt) } @@ -526,7 +534,7 @@ func eq(lhsNamespace, lhsObjectID, lhsRelation string, rhs tuple.ObjectAndRelati return lhsNamespace == rhs.ObjectType && lhsObjectID == rhs.ObjectID && lhsRelation == rhs.Relation } -func newMemdbTupleIterator(now time.Time, it memdb.ResultIterator, limit *uint64) datastore.RelationshipIterator { +func newMemdbTupleIterator(now time.Time, it memdb.ResultIterator, limit *uint64, skipCaveats bool, skipExpiration bool) datastore.RelationshipIterator { var count uint64 return func(yield func(tuple.Relationship, error) bool) { for { @@ -547,6 +555,16 @@ func newMemdbTupleIterator(now time.Time, it memdb.ResultIterator, limit *uint64 continue } + if skipCaveats && rt.OptionalCaveat != nil { + yield(rt, fmt.Errorf("unexpected caveat in result for relationship: %v", rt)) + return + } + + if skipExpiration && rt.OptionalExpiration != nil { + yield(rt, fmt.Errorf("unexpected expiration in result for relationship: %v", rt)) + return + } + if rt.OptionalExpiration != nil && rt.OptionalExpiration.Before(now) { continue } diff --git a/internal/datastore/mysql/caveat.go b/internal/datastore/mysql/caveat.go index 84283a3bb6..6cb7edafab 100644 --- a/internal/datastore/mysql/caveat.go +++ b/internal/datastore/mysql/caveat.go @@ -22,7 +22,7 @@ const ( ) func (mr *mysqlReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { - filteredReadCaveat := mr.filterer(mr.ReadCaveatQuery) + filteredReadCaveat := mr.aliveFilter(mr.ReadCaveatQuery) sqlStatement, args, err := filteredReadCaveat.Where(sq.Eq{colName: name}).ToSql() if err != nil { return nil, datastore.NoRevision, err @@ -68,7 +68,7 @@ func (mr *mysqlReader) lookupCaveats(ctx context.Context, caveatNames []string) caveatsWithNames = caveatsWithNames.Where(sq.Eq{colName: caveatNames}) } - filteredListCaveat := mr.filterer(caveatsWithNames) + filteredListCaveat := mr.aliveFilter(caveatsWithNames) listSQL, listArgs, err := filteredListCaveat.ToSql() if err != nil { return nil, err diff --git a/internal/datastore/mysql/datastore.go b/internal/datastore/mysql/datastore.go index e70ae293f7..3ccd1b8327 100644 --- a/internal/datastore/mysql/datastore.go +++ b/internal/datastore/mysql/datastore.go @@ -18,8 +18,6 @@ import ( "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" datastoreinternal "github.com/authzed/spicedb/internal/datastore" @@ -29,7 +27,6 @@ import ( log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/options" - "github.com/authzed/spicedb/pkg/tuple" ) const ( @@ -246,6 +243,24 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option -1*config.gcWindow.Seconds(), ) + schema := common.NewSchemaInformationWithOptions( + common.WithRelationshipTableName(driver.RelationTuple()), + common.WithColNamespace(colNamespace), + common.WithColObjectID(colObjectID), + common.WithColRelation(colRelation), + common.WithColUsersetNamespace(colUsersetNamespace), + common.WithColUsersetObjectID(colUsersetObjectID), + common.WithColUsersetRelation(colUsersetRelation), + common.WithColCaveatName(colCaveatName), + common.WithColCaveatContext(colCaveatContext), + common.WithColExpiration(colExpiration), + common.WithPaginationFilterType(common.ExpandedLogicComparison), + common.WithPlaceholderFormat(sq.Question), + common.WithNowFunction("NOW"), + common.WithColumnOptimization(config.columnOptimizationOption), + common.WithExpirationDisabled(config.expirationDisabled), + ) + store := &Datastore{ MigrationValidator: common.NewMigrationValidator(headMigration, config.allowedMigrations), db: db, @@ -267,6 +282,7 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option readTxOptions: &sql.TxOptions{Isolation: sql.LevelSerializable, ReadOnly: true}, maxRetries: config.maxRetries, analyzeBeforeStats: config.analyzeBeforeStats, + schema: *schema, CachedOptimizedRevisions: revisions.NewCachedOptimizedRevisions( maxRevisionStaleness, ), @@ -316,7 +332,7 @@ func (mds *Datastore) SnapshotReader(rev datastore.Revision) datastore.Reader { return tx, tx.Rollback, nil } - executor := common.QueryExecutor{ + executor := common.QueryRelationshipsExecutor{ Executor: newMySQLExecutor(mds.db), } @@ -326,6 +342,7 @@ func (mds *Datastore) SnapshotReader(rev datastore.Revision) datastore.Reader { executor, buildLivingObjectFilterForRevision(rev), mds.filterMaximumIDCount, + mds.schema, } } @@ -358,7 +375,7 @@ func (mds *Datastore) ReadWriteTx( return tx, noCleanup, nil } - executor := common.QueryExecutor{ + executor := common.QueryRelationshipsExecutor{ Executor: newMySQLExecutor(tx), } @@ -369,6 +386,7 @@ func (mds *Datastore) ReadWriteTx( executor, currentlyLivingObjects, mds.filterMaximumIDCount, + mds.schema, }, mds.driver.RelationTuple(), tx, @@ -417,7 +435,24 @@ type querier interface { QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) } -func newMySQLExecutor(tx querier) common.ExecuteQueryFunc { +type asQueryableTx struct { + tx querier +} + +func (aqt asQueryableTx) QueryFunc(ctx context.Context, f func(context.Context, common.Rows) error, sql string, args ...any) error { + rows, err := aqt.tx.QueryContext(ctx, sql, args...) + if err != nil { + return err + } + + if err := rows.Err(); err != nil { + return err + } + + return f(ctx, rows) +} + +func newMySQLExecutor(tx querier) common.ExecuteReadRelsQueryFunc { // This implementation does not create a transaction because it's redundant for single statements, and it avoids // the network overhead and reduce contention on the connection pool. From MySQL docs: // @@ -433,82 +468,8 @@ func newMySQLExecutor(tx querier) common.ExecuteQueryFunc { // // Prepared statements are also not used given they perform poorly on environments where connections have // short lifetime (e.g. to gracefully handle load-balancer connection drain) - return func(ctx context.Context, sqlQuery string, args []interface{}) (datastore.RelationshipIterator, error) { - return func(yield func(tuple.Relationship, error) bool) { - span := trace.SpanFromContext(ctx) - - rows, err := tx.QueryContext(ctx, sqlQuery, args...) - if err != nil { - yield(tuple.Relationship{}, fmt.Errorf(errUnableToQueryTuples, err)) - return - } - defer common.LogOnError(ctx, rows.Close) - - span.AddEvent("Query issued to database") - - relCount := 0 - - defer func() { - span.AddEvent("Relationships loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) - }() - - for rows.Next() { - var resourceObjectType string - var resourceObjectID string - var relation string - var subjectObjectType string - var subjectObjectID string - var subjectRelation string - var caveatName string - var caveatContext structpbWrapper - var expiration *time.Time - err := rows.Scan( - &resourceObjectType, - &resourceObjectID, - &relation, - &subjectObjectType, - &subjectObjectID, - &subjectRelation, - &caveatName, - &caveatContext, - &expiration, - ) - if err != nil { - yield(tuple.Relationship{}, fmt.Errorf(errUnableToQueryTuples, err)) - return - } - - caveat, err := common.ContextualizedCaveatFrom(caveatName, caveatContext) - if err != nil { - yield(tuple.Relationship{}, fmt.Errorf(errUnableToQueryTuples, err)) - return - } - - relCount++ - if !yield(tuple.Relationship{ - RelationshipReference: tuple.RelationshipReference{ - Resource: tuple.ObjectAndRelation{ - ObjectType: resourceObjectType, - ObjectID: resourceObjectID, - Relation: relation, - }, - Subject: tuple.ObjectAndRelation{ - ObjectType: subjectObjectType, - ObjectID: subjectObjectID, - Relation: subjectRelation, - }, - }, - OptionalCaveat: caveat, - OptionalExpiration: expiration, - }, nil) { - return - } - } - if err := rows.Err(); err != nil { - yield(tuple.Relationship{}, fmt.Errorf(errUnableToQueryTuples, err)) - return - } - }, nil + return func(ctx context.Context, builder common.RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { + return common.QueryRelationships[common.Rows, structpbWrapper](ctx, builder, asQueryableTx{tx}) } } @@ -529,6 +490,7 @@ type Datastore struct { watchBufferWriteTimeout time.Duration maxRetries uint8 filterMaximumIDCount uint16 + schema common.SchemaInformation optimizedRevisionQuery string validTransactionQuery string diff --git a/internal/datastore/mysql/gc.go b/internal/datastore/mysql/gc.go index 2137fe427e..44ab61d7c5 100644 --- a/internal/datastore/mysql/gc.go +++ b/internal/datastore/mysql/gc.go @@ -108,6 +108,10 @@ func (mds *Datastore) DeleteBeforeTx( } func (mds *Datastore) DeleteExpiredRels(ctx context.Context) (int64, error) { + if mds.schema.ExpirationDisabled { + return 0, nil + } + now, err := mds.Now(ctx) if err != nil { return 0, err diff --git a/internal/datastore/mysql/options.go b/internal/datastore/mysql/options.go index 4a48e44fad..4405dbcaba 100644 --- a/internal/datastore/mysql/options.go +++ b/internal/datastore/mysql/options.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/authzed/spicedb/internal/datastore/common" log "github.com/authzed/spicedb/internal/logging" ) @@ -25,6 +26,8 @@ const ( defaultGCEnabled = true defaultCredentialsProviderName = "" defaultFilterMaximumIDCount = 100 + defaultColumnOptimizationOption = common.ColumnOptimizationOptionNone + defaultExpirationDisabled = false ) type mysqlOptions struct { @@ -47,6 +50,8 @@ type mysqlOptions struct { credentialsProviderName string filterMaximumIDCount uint16 allowedMigrations []string + columnOptimizationOption common.ColumnOptimizationOption + expirationDisabled bool } // Option provides the facility to configure how clients within the @@ -70,6 +75,8 @@ func generateConfig(options []Option) (mysqlOptions, error) { gcEnabled: defaultGCEnabled, credentialsProviderName: defaultCredentialsProviderName, filterMaximumIDCount: defaultFilterMaximumIDCount, + columnOptimizationOption: defaultColumnOptimizationOption, + expirationDisabled: defaultExpirationDisabled, } for _, option := range options { @@ -269,3 +276,21 @@ func FilterMaximumIDCount(filterMaximumIDCount uint16) Option { func AllowedMigrations(allowedMigrations []string) Option { return func(mo *mysqlOptions) { mo.allowedMigrations = allowedMigrations } } + +// WithColumnOptimization configures the column optimization strategy for the MySQL datastore. +func WithColumnOptimization(isEnabled bool) Option { + return func(mo *mysqlOptions) { + if isEnabled { + mo.columnOptimizationOption = common.ColumnOptimizationOptionStaticValues + } else { + mo.columnOptimizationOption = common.ColumnOptimizationOptionNone + } + } +} + +// WithExpirationDisabled disables the expiration of relationships in the MySQL datastore. +func WithExpirationDisabled(isDisabled bool) Option { + return func(mo *mysqlOptions) { + mo.expirationDisabled = isDisabled + } +} diff --git a/internal/datastore/mysql/reader.go b/internal/datastore/mysql/reader.go index 55eaaf4ea7..c963fb808f 100644 --- a/internal/datastore/mysql/reader.go +++ b/internal/datastore/mysql/reader.go @@ -23,9 +23,10 @@ type mysqlReader struct { *QueryBuilder txSource txFactory - executor common.QueryExecutor - filterer queryFilterer + executor common.QueryRelationshipsExecutor + aliveFilter queryFilterer filterMaximumIDCount uint16 + schema common.SchemaInformation } type queryFilterer func(original sq.SelectBuilder) sq.SelectBuilder @@ -39,19 +40,6 @@ const ( errUnableToReadCount = "unable to read count: %w" ) -var schema = common.NewSchemaInformation( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName, - colExpiration, - common.ExpandedLogicComparison, - "NOW", -) - func (mr *mysqlReader) CountRelationships(ctx context.Context, name string) (int, error) { // Ensure the counter is registered. counters, err := mr.lookupCounters(ctx, name) @@ -68,7 +56,7 @@ func (mr *mysqlReader) CountRelationships(ctx context.Context, name string) (int return 0, err } - qBuilder, err := common.NewSchemaQueryFilterer(schema, mr.filterer(mr.CountRelsQuery), mr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) + qBuilder, err := common.NewSchemaQueryFiltererWithStartingQuery(mr.schema, mr.aliveFilter(mr.CountRelsQuery), mr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) if err != nil { return 0, err } @@ -116,7 +104,7 @@ func (mr *mysqlReader) LookupCounters(ctx context.Context) ([]datastore.Relation } func (mr *mysqlReader) lookupCounters(ctx context.Context, optionalName string) ([]datastore.RelationshipCounter, error) { - query := mr.filterer(mr.ReadCounterQuery) + query := mr.aliveFilter(mr.ReadCounterQuery) if optionalName != noFilterOnCounterName { query = query.Where(sq.Eq{colCounterName: optionalName}) } @@ -177,7 +165,9 @@ func (mr *mysqlReader) QueryRelationships( filter datastore.RelationshipsFilter, opts ...options.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, mr.filterer(mr.QueryRelsQuery), mr.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(mr.schema, mr.filterMaximumIDCount). + WithAdditionalFilter(mr.aliveFilter). + FilterWithRelationshipsFilter(filter) if err != nil { return nil, err } @@ -190,7 +180,8 @@ func (mr *mysqlReader) ReverseQueryRelationships( subjectsFilter datastore.SubjectsFilter, opts ...options.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, mr.filterer(mr.QueryRelsQuery), mr.filterMaximumIDCount). + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(mr.schema, mr.filterMaximumIDCount). + WithAdditionalFilter(mr.aliveFilter). FilterWithSubjectsSelectors(subjectsFilter.AsSelector()) if err != nil { return nil, err @@ -220,7 +211,7 @@ func (mr *mysqlReader) ReadNamespaceByName(ctx context.Context, nsName string) ( } defer common.LogOnError(ctx, txCleanup) - loaded, version, err := loadNamespace(ctx, nsName, tx, mr.filterer(mr.ReadNamespaceQuery)) + loaded, version, err := loadNamespace(ctx, nsName, tx, mr.aliveFilter(mr.ReadNamespaceQuery)) switch { case errors.As(err, &datastore.NamespaceNotFoundError{}): return nil, datastore.NoRevision, err @@ -265,7 +256,7 @@ func (mr *mysqlReader) ListAllNamespaces(ctx context.Context) ([]datastore.Revis } defer common.LogOnError(ctx, txCleanup) - query := mr.filterer(mr.ReadNamespaceQuery) + query := mr.aliveFilter(mr.ReadNamespaceQuery) nsDefs, err := loadAllNamespaces(ctx, tx, query) if err != nil { @@ -291,7 +282,7 @@ func (mr *mysqlReader) LookupNamespacesWithNames(ctx context.Context, nsNames [] clause = append(clause, sq.Eq{colNamespace: nsName}) } - query := mr.filterer(mr.ReadNamespaceQuery.Where(clause)) + query := mr.aliveFilter(mr.ReadNamespaceQuery.Where(clause)) nsDefs, err := loadAllNamespaces(ctx, tx, query) if err != nil { diff --git a/internal/datastore/mysql/readwrite.go b/internal/datastore/mysql/readwrite.go index cdc93c454e..d8b5b40647 100644 --- a/internal/datastore/mysql/readwrite.go +++ b/internal/datastore/mysql/readwrite.go @@ -17,6 +17,7 @@ import ( "github.com/ccoveille/go-safecast" "github.com/go-sql-driver/mysql" "github.com/jzelinskie/stringz" + "golang.org/x/exp/maps" "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/internal/datastore/revisions" @@ -59,6 +60,8 @@ func (cc *structpbWrapper) Scan(val any) error { if !ok { return fmt.Errorf("unsupported type: %T", v) } + + maps.Clear(*cc) return json.Unmarshal(v, &cc) } diff --git a/internal/datastore/postgres/caveat.go b/internal/datastore/postgres/caveat.go index d6e77f708b..9c483c3b93 100644 --- a/internal/datastore/postgres/caveat.go +++ b/internal/datastore/postgres/caveat.go @@ -34,7 +34,7 @@ const ( ) func (r *pgReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { - filteredReadCaveat := r.filterer(readCaveat) + filteredReadCaveat := r.aliveFilter(readCaveat) sql, args, err := filteredReadCaveat.Where(sq.Eq{colCaveatName: name}).ToSql() if err != nil { return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err) @@ -79,7 +79,7 @@ func (r *pgReader) lookupCaveats(ctx context.Context, caveatNames []string) ([]d caveatsWithNames = caveatsWithNames.Where(sq.Eq{colCaveatName: caveatNames}) } - filteredListCaveat := r.filterer(caveatsWithNames) + filteredListCaveat := r.aliveFilter(caveatsWithNames) sql, args, err := filteredListCaveat.ToSql() if err != nil { return nil, fmt.Errorf(errListCaveats, err) diff --git a/internal/datastore/postgres/common/pgx.go b/internal/datastore/postgres/common/pgx.go index 4187d96334..f0530cb754 100644 --- a/internal/datastore/postgres/common/pgx.go +++ b/internal/datastore/postgres/common/pgx.go @@ -2,9 +2,7 @@ package common import ( "context" - "database/sql" "errors" - "fmt" "time" "github.com/ccoveille/go-safecast" @@ -16,150 +14,19 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/tracelog" "github.com/rs/zerolog" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" - "google.golang.org/protobuf/types/known/timestamppb" "github.com/authzed/spicedb/internal/datastore/common" log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/datastore" - corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" - "github.com/authzed/spicedb/pkg/tuple" ) -const errUnableToQueryTuples = "unable to query tuples: %w" - -// NewPGXExecutor creates an executor that uses the pgx library to make the specified queries. -func NewPGXExecutor(querier DBFuncQuerier) common.ExecuteQueryFunc { - return func(ctx context.Context, sql string, args []any) (datastore.RelationshipIterator, error) { - span := trace.SpanFromContext(ctx) - return queryRels(ctx, sql, args, span, querier, false) +// NewPGXQueryRelationshipsExecutor creates an executor that uses the pgx library to make the specified queries. +func NewPGXQueryRelationshipsExecutor(querier DBFuncQuerier) common.ExecuteReadRelsQueryFunc { + return func(ctx context.Context, builder common.RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { + return common.QueryRelationships[pgx.Rows, map[string]any](ctx, builder, querier) } } -func NewPGXExecutorWithIntegrityOption(querier DBFuncQuerier, withIntegrity bool) common.ExecuteQueryFunc { - return func(ctx context.Context, sql string, args []any) (datastore.RelationshipIterator, error) { - span := trace.SpanFromContext(ctx) - return queryRels(ctx, sql, args, span, querier, withIntegrity) - } -} - -// queryRels queries relationships for the given query and transaction. -func queryRels(ctx context.Context, sqlStatement string, args []any, span trace.Span, tx DBFuncQuerier, withIntegrity bool) (datastore.RelationshipIterator, error) { - return func(yield func(tuple.Relationship, error) bool) { - err := tx.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error { - span.AddEvent("Query issued to database") - - var resourceObjectType string - var resourceObjectID string - var resourceRelation string - var subjectObjectType string - var subjectObjectID string - var subjectRelation string - var caveatName sql.NullString - var caveatCtx map[string]any - var expiration *time.Time - - relCount := 0 - for rows.Next() { - var integrity *corev1.RelationshipIntegrity - - if withIntegrity { - var integrityKeyID string - var integrityHash []byte - var timestamp time.Time - - if err := rows.Scan( - &resourceObjectType, - &resourceObjectID, - &resourceRelation, - &subjectObjectType, - &subjectObjectID, - &subjectRelation, - &caveatName, - &caveatCtx, - &expiration, - &integrityKeyID, - &integrityHash, - ×tamp, - ); err != nil { - return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("scan err: %w", err)) - } - - integrity = &corev1.RelationshipIntegrity{ - KeyId: integrityKeyID, - Hash: integrityHash, - HashedAt: timestamppb.New(timestamp), - } - } else { - if err := rows.Scan( - &resourceObjectType, - &resourceObjectID, - &resourceRelation, - &subjectObjectType, - &subjectObjectID, - &subjectRelation, - &caveatName, - &caveatCtx, - &expiration, - ); err != nil { - return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("scan err: %w", err)) - } - } - - var caveat *corev1.ContextualizedCaveat - if caveatName.Valid { - var err error - caveat, err = common.ContextualizedCaveatFrom(caveatName.String, caveatCtx) - if err != nil { - return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("unable to fetch caveat context: %w", err)) - } - } - - if expiration != nil { - // Ensure the returned expiration is always in UTC, as some datastores (like CRDB) - // convert to the local timezone when reading. - utc := expiration.UTC() - expiration = &utc - } - - relCount++ - if !yield(tuple.Relationship{ - RelationshipReference: tuple.RelationshipReference{ - Resource: tuple.ObjectAndRelation{ - ObjectType: resourceObjectType, - ObjectID: resourceObjectID, - Relation: resourceRelation, - }, - Subject: tuple.ObjectAndRelation{ - ObjectType: subjectObjectType, - ObjectID: subjectObjectID, - Relation: subjectRelation, - }, - }, - OptionalCaveat: caveat, - OptionalIntegrity: integrity, - OptionalExpiration: expiration, - }, nil) { - return nil - } - } - - if err := rows.Err(); err != nil { - return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("rows err: %w", err)) - } - - span.AddEvent("Rels loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) - return nil - }, sqlStatement, args...) - if err != nil { - if !yield(tuple.Relationship{}, err) { - return - } - } - }, nil -} - // ParseConfigWithInstrumentation returns a pgx.ConnConfig that has been instrumented for observability func ParseConfigWithInstrumentation(url string) (*pgx.ConnConfig, error) { connConfig, err := pgx.ParseConfig(url) diff --git a/internal/datastore/postgres/gc.go b/internal/datastore/postgres/gc.go index 7252876d94..5b478cb0e1 100644 --- a/internal/datastore/postgres/gc.go +++ b/internal/datastore/postgres/gc.go @@ -77,6 +77,10 @@ func (pgd *pgDatastore) TxIDBefore(ctx context.Context, before time.Time) (datas } func (pgd *pgDatastore) DeleteExpiredRels(ctx context.Context) (int64, error) { + if pgd.schema.ExpirationDisabled { + return 0, nil + } + now, err := pgd.Now(ctx) if err != nil { return 0, err diff --git a/internal/datastore/postgres/options.go b/internal/datastore/postgres/options.go index be79c03d9a..82a42de6af 100644 --- a/internal/datastore/postgres/options.go +++ b/internal/datastore/postgres/options.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/authzed/spicedb/internal/datastore/common" pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" log "github.com/authzed/spicedb/internal/logging" ) @@ -28,6 +29,8 @@ type postgresOptions struct { analyzeBeforeStatistics bool gcEnabled bool readStrictMode bool + expirationDisabled bool + columnOptimizationOption common.ColumnOptimizationOption includeQueryParametersInTraces bool migrationPhase string @@ -68,7 +71,9 @@ const ( defaultCredentialsProviderName = "" defaultReadStrictMode = false defaultFilterMaximumIDCount = 100 + defaultColumnOptimizationOption = common.ColumnOptimizationOptionNone defaultIncludeQueryParametersInTraces = false + defaultExpirationDisabled = false ) // Option provides the facility to configure how clients within the @@ -91,7 +96,9 @@ func generateConfig(options []Option) (postgresOptions, error) { readStrictMode: defaultReadStrictMode, queryInterceptor: nil, filterMaximumIDCount: defaultFilterMaximumIDCount, + columnOptimizationOption: defaultColumnOptimizationOption, includeQueryParametersInTraces: defaultIncludeQueryParametersInTraces, + expirationDisabled: defaultExpirationDisabled, } for _, option := range options { @@ -385,3 +392,19 @@ func FilterMaximumIDCount(filterMaximumIDCount uint16) Option { func IncludeQueryParametersInTraces(includeQueryParametersInTraces bool) Option { return func(po *postgresOptions) { po.includeQueryParametersInTraces = includeQueryParametersInTraces } } + +// WithColumnOptimization sets the column optimization option for the datastore. +func WithColumnOptimization(isEnabled bool) Option { + return func(po *postgresOptions) { + if isEnabled { + po.columnOptimizationOption = common.ColumnOptimizationOptionStaticValues + } else { + po.columnOptimizationOption = common.ColumnOptimizationOptionNone + } + } +} + +// WithExpirationDisabled disables support for relationship expiration. +func WithExpirationDisabled(isDisabled bool) Option { + return func(po *postgresOptions) { po.expirationDisabled = isDisabled } +} diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index bc1aae6303..8f5d66b068 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -316,6 +316,24 @@ func newPostgresDatastore( maxRevisionStaleness := time.Duration(float64(config.revisionQuantization.Nanoseconds())* config.maxRevisionStalenessPercent) * time.Nanosecond + schema := common.NewSchemaInformationWithOptions( + common.WithRelationshipTableName(tableTuple), + common.WithColNamespace(colNamespace), + common.WithColObjectID(colObjectID), + common.WithColRelation(colRelation), + common.WithColUsersetNamespace(colUsersetNamespace), + common.WithColUsersetObjectID(colUsersetObjectID), + common.WithColUsersetRelation(colUsersetRelation), + common.WithColCaveatName(colCaveatContextName), + common.WithColCaveatContext(colCaveatContext), + common.WithColExpiration(colExpiration), + common.WithPaginationFilterType(common.TupleComparison), + common.WithPlaceholderFormat(sq.Dollar), + common.WithNowFunction("NOW"), + common.WithColumnOptimization(config.columnOptimizationOption), + common.WithExpirationDisabled(config.expirationDisabled), + ) + datastore := &pgDatastore{ CachedOptimizedRevisions: revisions.NewCachedOptimizedRevisions( maxRevisionStaleness, @@ -341,6 +359,7 @@ func newPostgresDatastore( isPrimary: isPrimary, inStrictReadMode: config.readStrictMode, filterMaximumIDCount: config.filterMaximumIDCount, + schema: *schema, } if isPrimary && config.readStrictMode { @@ -393,6 +412,7 @@ type pgDatastore struct { watchEnabled bool isPrimary bool inStrictReadMode bool + schema common.SchemaInformation includeQueryParametersInTraces bool credentialsProvider datastore.CredentialsProvider @@ -416,8 +436,8 @@ func (pgd *pgDatastore) SnapshotReader(revRaw datastore.Revision) datastore.Read queryFuncs = strictReaderQueryFuncs{wrapped: queryFuncs, revision: rev} } - executor := common.QueryExecutor{ - Executor: pgxcommon.NewPGXExecutor(queryFuncs), + executor := common.QueryRelationshipsExecutor{ + Executor: pgxcommon.NewPGXQueryRelationshipsExecutor(queryFuncs), } return &pgReader{ @@ -425,6 +445,7 @@ func (pgd *pgDatastore) SnapshotReader(revRaw datastore.Revision) datastore.Read executor, buildLivingObjectFilterForRevision(rev), pgd.filterMaximumIDCount, + pgd.schema, } } @@ -458,8 +479,8 @@ func (pgd *pgDatastore) ReadWriteTx( } queryFuncs := pgxcommon.QuerierFuncsFor(pgd.readPool) - executor := common.QueryExecutor{ - Executor: pgxcommon.NewPGXExecutor(queryFuncs), + executor := common.QueryRelationshipsExecutor{ + Executor: pgxcommon.NewPGXQueryRelationshipsExecutor(queryFuncs), } rwt := &pgReadWriteTXN{ @@ -468,6 +489,7 @@ func (pgd *pgDatastore) ReadWriteTx( executor, currentlyLivingObjects, pgd.filterMaximumIDCount, + pgd.schema, }, tx, newXID, diff --git a/internal/datastore/postgres/reader.go b/internal/datastore/postgres/reader.go index 14a1b99a0e..ed0ad792d1 100644 --- a/internal/datastore/postgres/reader.go +++ b/internal/datastore/postgres/reader.go @@ -17,41 +17,17 @@ import ( type pgReader struct { query pgxcommon.DBFuncQuerier - executor common.QueryExecutor - filterer queryFilterer + executor common.QueryRelationshipsExecutor + aliveFilter queryFilterer filterMaximumIDCount uint16 + schema common.SchemaInformation } type queryFilterer func(original sq.SelectBuilder) sq.SelectBuilder var ( - queryTuples = psql.Select( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colCaveatContext, - colExpiration, - ).From(tableTuple) - countRels = psql.Select("COUNT(*)").From(tableTuple) - schema = common.NewSchemaInformation( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colExpiration, - common.TupleComparison, - "NOW", - ) - readNamespace = psql. Select(colConfig, colCreatedXid). From(tableNamespace) @@ -85,7 +61,7 @@ func (r *pgReader) CountRelationships(ctx context.Context, name string) (int, er return 0, err } - qBuilder, err := common.NewSchemaQueryFilterer(schema, r.filterer(countRels), r.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) + qBuilder, err := common.NewSchemaQueryFiltererWithStartingQuery(r.schema, r.aliveFilter(countRels), r.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) if err != nil { return 0, err } @@ -125,7 +101,7 @@ func (r *pgReader) lookupCounters(ctx context.Context, optionalName string) ([]d query = query.Where(sq.Eq{colCounterName: optionalName}) } - sql, args, err := r.filterer(query).ToSql() + sql, args, err := r.aliveFilter(query).ToSql() if err != nil { return nil, fmt.Errorf("unable to lookup counters: %w", err) } @@ -173,7 +149,9 @@ func (r *pgReader) QueryRelationships( filter datastore.RelationshipsFilter, opts ...options.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, r.filterer(queryTuples), r.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(r.schema, r.filterMaximumIDCount). + WithAdditionalFilter(r.aliveFilter). + FilterWithRelationshipsFilter(filter) if err != nil { return nil, err } @@ -186,7 +164,8 @@ func (r *pgReader) ReverseQueryRelationships( subjectsFilter datastore.SubjectsFilter, opts ...options.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, r.filterer(queryTuples), r.filterMaximumIDCount). + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(r.schema, r.filterMaximumIDCount). + WithAdditionalFilter(r.aliveFilter). FilterWithSubjectsSelectors(subjectsFilter.AsSelector()) if err != nil { return nil, err @@ -209,7 +188,7 @@ func (r *pgReader) ReverseQueryRelationships( } func (r *pgReader) ReadNamespaceByName(ctx context.Context, nsName string) (*core.NamespaceDefinition, datastore.Revision, error) { - loaded, version, err := r.loadNamespace(ctx, nsName, r.query, r.filterer) + loaded, version, err := r.loadNamespace(ctx, nsName, r.query, r.aliveFilter) switch { case errors.As(err, &datastore.NamespaceNotFoundError{}): return nil, datastore.NoRevision, err @@ -239,7 +218,7 @@ func (r *pgReader) loadNamespace(ctx context.Context, namespace string, tx pgxco } func (r *pgReader) ListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { - nsDefsWithRevisions, err := loadAllNamespaces(ctx, r.query, r.filterer) + nsDefsWithRevisions, err := loadAllNamespaces(ctx, r.query, r.aliveFilter) if err != nil { return nil, fmt.Errorf(errUnableToListNamespaces, err) } @@ -258,7 +237,7 @@ func (r *pgReader) LookupNamespacesWithNames(ctx context.Context, nsNames []stri } nsDefsWithRevisions, err := loadAllNamespaces(ctx, r.query, func(original sq.SelectBuilder) sq.SelectBuilder { - return r.filterer(original).Where(clause) + return r.aliveFilter(original).Where(clause) }) if err != nil { return nil, fmt.Errorf(errUnableToListNamespaces, err) diff --git a/internal/datastore/postgres/readwrite.go b/internal/datastore/postgres/readwrite.go index 6bc3ead8c4..e135b158ad 100644 --- a/internal/datastore/postgres/readwrite.go +++ b/internal/datastore/postgres/readwrite.go @@ -585,7 +585,7 @@ func (rwt *pgReadWriteTXN) WriteNamespaces(ctx context.Context, newConfigs ...*c } func (rwt *pgReadWriteTXN) DeleteNamespaces(ctx context.Context, nsNames ...string) error { - filterer := func(original sq.SelectBuilder) sq.SelectBuilder { + aliveFilter := func(original sq.SelectBuilder) sq.SelectBuilder { return original.Where(sq.Eq{colDeletedXid: liveDeletedTxnID}) } @@ -593,7 +593,7 @@ func (rwt *pgReadWriteTXN) DeleteNamespaces(ctx context.Context, nsNames ...stri tplClauses := make([]sq.Sqlizer, 0, len(nsNames)) querier := pgxcommon.QuerierFuncsFor(rwt.tx) for _, nsName := range nsNames { - _, _, err := rwt.loadNamespace(ctx, nsName, querier, filterer) + _, _, err := rwt.loadNamespace(ctx, nsName, querier, aliveFilter) switch { case errors.As(err, &datastore.NamespaceNotFoundError{}): return err diff --git a/internal/datastore/postgres/stats.go b/internal/datastore/postgres/stats.go index b428cd4657..0e0bea63f6 100644 --- a/internal/datastore/postgres/stats.go +++ b/internal/datastore/postgres/stats.go @@ -51,7 +51,7 @@ func (pgd *pgDatastore) Statistics(ctx context.Context) (datastore.Stats, error) return datastore.Stats{}, fmt.Errorf("unable to prepare row count sql: %w", err) } - filterer := func(original sq.SelectBuilder) sq.SelectBuilder { + aliveFilter := func(original sq.SelectBuilder) sq.SelectBuilder { return original.Where(sq.Eq{colDeletedXid: liveDeletedTxnID}) } @@ -69,7 +69,7 @@ func (pgd *pgDatastore) Statistics(ctx context.Context) (datastore.Stats, error) return fmt.Errorf("unable to query unique ID: %w", err) } - nsDefsWithRevisions, err := loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), filterer) + nsDefsWithRevisions, err := loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), aliveFilter) if err != nil { return fmt.Errorf("unable to load namespaces: %w", err) } diff --git a/internal/datastore/proxy/observable_test.go b/internal/datastore/proxy/observable_test.go index 63a388957c..874be163aa 100644 --- a/internal/datastore/proxy/observable_test.go +++ b/internal/datastore/proxy/observable_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/test" @@ -12,7 +13,7 @@ import ( type observableTest struct{} func (obs observableTest) New(revisionQuantization, _, gcWindow time.Duration, watchBufferLength uint16) (datastore.Datastore, error) { - db, err := memdb.NewMemdbDatastore(watchBufferLength, revisionQuantization, gcWindow) + db, err := dsfortesting.NewMemDBDatastoreForTesting(watchBufferLength, revisionQuantization, gcWindow) if err != nil { return nil, err } diff --git a/internal/datastore/proxy/relationshipintegrity_test.go b/internal/datastore/proxy/relationshipintegrity_test.go index f8f03ecc54..a591720dde 100644 --- a/internal/datastore/proxy/relationshipintegrity_test.go +++ b/internal/datastore/proxy/relationshipintegrity_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/timestamppb" - "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/pkg/datastore" core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/tuple" @@ -58,7 +58,7 @@ var expiredKeyForTesting = KeyConfig{ } func TestWriteWithPredefinedIntegrity(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 1*time.Hour) require.NoError(t, err) pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil) @@ -76,7 +76,7 @@ func TestWriteWithPredefinedIntegrity(t *testing.T) { } func TestReadWithMissingIntegrity(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 1*time.Hour) require.NoError(t, err) // Write a relationship to the underlying datastore without integrity information. @@ -108,7 +108,7 @@ func TestReadWithMissingIntegrity(t *testing.T) { } func TestBasicIntegrityFailureDueToInvalidHashVersion(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 1*time.Hour) require.NoError(t, err) pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil) @@ -157,7 +157,7 @@ func TestBasicIntegrityFailureDueToInvalidHashVersion(t *testing.T) { } func TestBasicIntegrityFailureDueToInvalidHashSignature(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 1*time.Hour) require.NoError(t, err) pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil) @@ -206,7 +206,7 @@ func TestBasicIntegrityFailureDueToInvalidHashSignature(t *testing.T) { } func TestBasicIntegrityFailureDueToWriteWithExpiredKey(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 1*time.Hour) require.NoError(t, err) // Create a proxy with the to-be-expired key and write some relationships. @@ -245,7 +245,7 @@ func TestBasicIntegrityFailureDueToWriteWithExpiredKey(t *testing.T) { } func TestWatchIntegrityFailureDueToInvalidHashSignature(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 1*time.Hour) require.NoError(t, err) headRev, err := ds.HeadRevision(context.Background()) @@ -289,7 +289,7 @@ func TestWatchIntegrityFailureDueToInvalidHashSignature(t *testing.T) { func BenchmarkQueryRelsWithIntegrity(b *testing.B) { for _, withIntegrity := range []bool{true, false} { b.Run(fmt.Sprintf("withIntegrity=%t", withIntegrity), func(b *testing.B) { - ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 1*time.Hour) require.NoError(b, err) pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil) diff --git a/internal/datastore/proxy/schemacaching/estimatedsize_test.go b/internal/datastore/proxy/schemacaching/estimatedsize_test.go index 275f544bf7..f0a6db168f 100644 --- a/internal/datastore/proxy/schemacaching/estimatedsize_test.go +++ b/internal/datastore/proxy/schemacaching/estimatedsize_test.go @@ -14,6 +14,7 @@ import ( "github.com/ccoveille/go-safecast" "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/validationfile" @@ -46,7 +47,7 @@ func TestEstimatedDefinitionSizes(t *testing.T) { filePath := filePath t.Run(path.Base(filePath), func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 1*time.Second, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 1*time.Second, memdb.DisableGC) require.NoError(err) fullyResolved, _, err := validationfile.PopulateFromFiles(context.Background(), ds, []string{filePath}) diff --git a/internal/datastore/proxy/schemacaching/standardcaching_test.go b/internal/datastore/proxy/schemacaching/standardcaching_test.go index 2bc6e20ec1..5cdf164ceb 100644 --- a/internal/datastore/proxy/schemacaching/standardcaching_test.go +++ b/internal/datastore/proxy/schemacaching/standardcaching_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/datastore/proxy/proxy_test" "github.com/authzed/spicedb/internal/datastore/revisions" @@ -366,7 +367,7 @@ func TestSnapshotCachingRealDatastore(t *testing.T) { for _, tc := range tcs { tc := tc t.Run(tc.name, func(t *testing.T) { - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) ctx := context.Background() diff --git a/internal/datastore/spanner/options.go b/internal/datastore/spanner/options.go index 29d9617428..75f84cf4b9 100644 --- a/internal/datastore/spanner/options.go +++ b/internal/datastore/spanner/options.go @@ -6,6 +6,7 @@ import ( "runtime" "time" + "github.com/authzed/spicedb/internal/datastore/common" log "github.com/authzed/spicedb/internal/logging" ) @@ -26,6 +27,8 @@ type spannerOptions struct { migrationPhase string allowedMigrations []string filterMaximumIDCount uint16 + columnOptimizationOption common.ColumnOptimizationOption + expirationDisabled bool } type migrationPhase uint8 @@ -49,6 +52,8 @@ const ( defaultDisableStats = false maxRevisionQuantization = 24 * time.Hour defaultFilterMaximumIDCount = 100 + defaultColumnOptimizationOption = common.ColumnOptimizationOptionNone + defaultExpirationDisabled = false ) // Option provides the facility to configure how clients within the Spanner @@ -72,6 +77,8 @@ func generateConfig(options []Option) (spannerOptions, error) { maxSessions: 400, migrationPhase: "", // no migration filterMaximumIDCount: defaultFilterMaximumIDCount, + columnOptimizationOption: defaultColumnOptimizationOption, + expirationDisabled: defaultExpirationDisabled, } for _, option := range options { @@ -224,3 +231,22 @@ func AllowedMigrations(allowedMigrations []string) Option { func FilterMaximumIDCount(filterMaximumIDCount uint16) Option { return func(po *spannerOptions) { po.filterMaximumIDCount = filterMaximumIDCount } } + +// WithColumnOptimization configures the Spanner driver to optimize the columns +// in the underlying tables. +func WithColumnOptimization(isEnabled bool) Option { + return func(po *spannerOptions) { + if isEnabled { + po.columnOptimizationOption = common.ColumnOptimizationOptionStaticValues + } else { + po.columnOptimizationOption = common.ColumnOptimizationOptionNone + } + } +} + +// WithExpirationDisabled disables relationship expiration support in the Spanner. +func WithExpirationDisabled(isDisabled bool) Option { + return func(po *spannerOptions) { + po.expirationDisabled = isDisabled + } +} diff --git a/internal/datastore/spanner/reader.go b/internal/datastore/spanner/reader.go index c165dc7699..2dc8ee5022 100644 --- a/internal/datastore/spanner/reader.go +++ b/internal/datastore/spanner/reader.go @@ -31,9 +31,10 @@ type readTX interface { type txFactory func() readTX type spannerReader struct { - executor common.QueryExecutor + executor common.QueryRelationshipsExecutor txSource txFactory filterMaximumIDCount uint16 + schema common.SchemaInformation } func (sr spannerReader) CountRelationships(ctx context.Context, name string) (int, error) { @@ -54,7 +55,7 @@ func (sr spannerReader) CountRelationships(ctx context.Context, name string) (in return 0, err } - builder, err := common.NewSchemaQueryFilterer(schema, countRels, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) + builder, err := common.NewSchemaQueryFiltererWithStartingQuery(sr.schema, countRels, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) if err != nil { return 0, err } @@ -134,7 +135,7 @@ func (sr spannerReader) QueryRelationships( filter datastore.RelationshipsFilter, opts ...options.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, queryTuples, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(sr.schema, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) if err != nil { return nil, err } @@ -147,7 +148,7 @@ func (sr spannerReader) ReverseQueryRelationships( subjectsFilter datastore.SubjectsFilter, opts ...options.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, queryTuples, sr.filterMaximumIDCount). + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(sr.schema, sr.filterMaximumIDCount). FilterWithSubjectsSelectors(subjectsFilter.AsSelector()) if err != nil { return nil, err @@ -171,11 +172,18 @@ func (sr spannerReader) ReverseQueryRelationships( var errStopIterator = fmt.Errorf("stop iteration") -func queryExecutor(txSource txFactory) common.ExecuteQueryFunc { - return func(ctx context.Context, sql string, args []any) (datastore.RelationshipIterator, error) { +func queryExecutor(txSource txFactory) common.ExecuteReadRelsQueryFunc { + return func(ctx context.Context, builder common.RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { return func(yield func(tuple.Relationship, error) bool) { span := trace.SpanFromContext(ctx) span.AddEvent("Query issued to database") + + sql, args, err := builder.SelectSQL() + if err != nil { + yield(tuple.Relationship{}, err) + return + } + iter := txSource().Query(ctx, statementFromSQL(sql, args)) defer iter.Stop() @@ -185,27 +193,42 @@ func queryExecutor(txSource txFactory) common.ExecuteQueryFunc { relCount := 0 defer span.SetAttributes(attribute.Int("count", relCount)) + var resourceObjectType string + var resourceObjectID string + var relation string + var subjectObjectType string + var subjectObjectID string + var subjectRelation string + var caveatName spanner.NullString + var caveatCtx spanner.NullJSON + var expirationOrNull spanner.NullTime + + // NOTE: these are unused in Spanner, but necessary for the ColumnsToSelect call. + var integrityKeyID string + var integrityHash []byte + var timestamp time.Time + + colsToSelect, err := common.ColumnsToSelect(builder, + &resourceObjectType, + &resourceObjectID, + &relation, + &subjectObjectType, + &subjectObjectID, + &subjectRelation, + &caveatName, + &caveatCtx, + &expirationOrNull, + &integrityKeyID, + &integrityHash, + ×tamp, + ) + if err != nil { + yield(tuple.Relationship{}, err) + return + } + if err := iter.Do(func(row *spanner.Row) error { - var resourceObjectType string - var resourceObjectID string - var relation string - var subjectObjectType string - var subjectObjectID string - var subjectRelation string - var caveatName spanner.NullString - var caveatCtx spanner.NullJSON - var expirationOrNull spanner.NullTime - err := row.Columns( - &resourceObjectType, - &resourceObjectID, - &relation, - &subjectObjectType, - &subjectObjectID, - &subjectRelation, - &caveatName, - &caveatCtx, - &expirationOrNull, - ) + err := row.Columns(colsToSelect...) if err != nil { return err } @@ -355,18 +378,6 @@ func readAllNamespaces(iter *spanner.RowIterator, span trace.Span) ([]datastore. return allNamespaces, nil } -var queryTuples = sql.Select( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName, - colCaveatContext, - colExpiration, -).From(tableRelationship) - var countRels = sql.Select("COUNT(*)").From(tableRelationship) var queryTuplesForDelete = sql.Select( @@ -378,17 +389,4 @@ var queryTuplesForDelete = sql.Select( colUsersetRelation, ).From(tableRelationship) -var schema = common.NewSchemaInformation( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName, - colExpiration, - common.ExpandedLogicComparison, - "CURRENT_TIMESTAMP", -) - var _ datastore.Reader = spannerReader{} diff --git a/internal/datastore/spanner/spanner.go b/internal/datastore/spanner/spanner.go index a0304a07e2..e894de1cee 100644 --- a/internal/datastore/spanner/spanner.go +++ b/internal/datastore/spanner/spanner.go @@ -90,6 +90,7 @@ type spannerDatastore struct { client *spanner.Client config spannerOptions database string + schema common.SchemaInformation cachedEstimatedBytesPerRelationshipLock sync.RWMutex cachedEstimatedBytesPerRelationship uint64 @@ -180,6 +181,23 @@ func NewSpannerDatastore(ctx context.Context, database string, opts ...Option) ( return nil, fmt.Errorf("invalid head migration found for spanner: %w", err) } + schema := common.NewSchemaInformationWithOptions( + common.WithRelationshipTableName(tableRelationship), + common.WithColNamespace(colNamespace), + common.WithColObjectID(colObjectID), + common.WithColRelation(colRelation), + common.WithColUsersetNamespace(colUsersetNamespace), + common.WithColUsersetObjectID(colUsersetObjectID), + common.WithColUsersetRelation(colUsersetRelation), + common.WithColCaveatName(colCaveatName), + common.WithColCaveatContext(colCaveatContext), + common.WithColExpiration(colExpiration), + common.WithPaginationFilterType(common.ExpandedLogicComparison), + common.WithPlaceholderFormat(sq.AtP), + common.WithNowFunction("CURRENT_TIMESTAMP"), + common.WithColumnOptimization(config.columnOptimizationOption), + ) + ds := &spannerDatastore{ RemoteClockRevisions: revisions.NewRemoteClockRevisions( defaultChangeStreamRetention, @@ -200,6 +218,7 @@ func NewSpannerDatastore(ctx context.Context, database string, opts ...Option) ( cachedEstimatedBytesPerRelationshipLock: sync.RWMutex{}, tableSizesStatsTable: tableSizesStatsTable, filterMaximumIDCount: config.filterMaximumIDCount, + schema: *schema, } // Optimized revision and revision checking use a stale read for the // current timestamp. @@ -247,8 +266,8 @@ func (sd *spannerDatastore) SnapshotReader(revisionRaw datastore.Revision) datas txSource := func() readTX { return &traceableRTX{delegate: sd.client.Single().WithTimestampBound(spanner.ReadTimestamp(r.Time()))} } - executor := common.QueryExecutor{Executor: queryExecutor(txSource)} - return spannerReader{executor, txSource, sd.filterMaximumIDCount} + executor := common.QueryRelationshipsExecutor{Executor: queryExecutor(txSource)} + return spannerReader{executor, txSource, sd.filterMaximumIDCount, sd.schema} } func (sd *spannerDatastore) readTransactionMetadata(ctx context.Context, transactionTag string) (map[string]any, error) { @@ -295,9 +314,9 @@ func (sd *spannerDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUser } } - executor := common.QueryExecutor{Executor: queryExecutor(txSource)} + executor := common.QueryRelationshipsExecutor{Executor: queryExecutor(txSource)} rwt := spannerReadWriteTXN{ - spannerReader{executor, txSource, sd.filterMaximumIDCount}, + spannerReader{executor, txSource, sd.filterMaximumIDCount, sd.schema}, spannerRWT, } err := func() error { diff --git a/internal/dispatch/combined/combined_test.go b/internal/dispatch/combined/combined_test.go index a15e3196ff..79663e23f0 100644 --- a/internal/dispatch/combined/combined_test.go +++ b/internal/dispatch/combined/combined_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" "github.com/authzed/spicedb/internal/testfixtures" @@ -22,7 +23,7 @@ func TestCombinedRecursiveCall(t *testing.T) { ctx := datastoremw.ContextWithHandle(context.Background()) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, ` diff --git a/internal/dispatch/graph/check_test.go b/internal/dispatch/graph/check_test.go index 83453c7eb5..5c75adc40a 100644 --- a/internal/dispatch/graph/check_test.go +++ b/internal/dispatch/graph/check_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/dispatch" "github.com/authzed/spicedb/internal/dispatch/caching" @@ -153,7 +154,7 @@ func TestMaxDepth(t *testing.T) { t.Parallel() require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, _ := testfixtures.StandardDatastoreWithSchema(rawDS, require) @@ -1251,7 +1252,7 @@ func TestCheckPermissionOverSchema(t *testing.T) { definition user {} definition role { - relation member: user + relation member: user with somecaveat } definition resource { @@ -1287,7 +1288,7 @@ func TestCheckPermissionOverSchema(t *testing.T) { definition user {} definition role { - relation member: user + relation member: user with somecaveat } definition resource { @@ -1322,7 +1323,7 @@ func TestCheckPermissionOverSchema(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) @@ -1823,7 +1824,7 @@ func TestCheckWithHints(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) @@ -1863,7 +1864,7 @@ func TestCheckHintsPartialApplication(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, ` @@ -1909,7 +1910,7 @@ func TestCheckHintsPartialApplicationOverArrow(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, ` @@ -1955,7 +1956,7 @@ func TestCheckHintsPartialApplicationOverArrow(t *testing.T) { } func newLocalDispatcherWithConcurrencyLimit(t testing.TB, concurrencyLimit uint16) (context.Context, dispatch.Dispatcher, datastore.Revision) { - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require.New(t)) @@ -1977,7 +1978,7 @@ func newLocalDispatcher(t testing.TB) (context.Context, dispatch.Dispatcher, dat } func newLocalDispatcherWithSchemaAndRels(t testing.TB, schema string, rels []tuple.Relationship) (context.Context, dispatch.Dispatcher, datastore.Revision) { - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, schema, rels, require.New(t)) diff --git a/internal/dispatch/graph/expand_test.go b/internal/dispatch/graph/expand_test.go index 051199bacb..d48adaf554 100644 --- a/internal/dispatch/graph/expand_test.go +++ b/internal/dispatch/graph/expand_test.go @@ -15,6 +15,7 @@ import ( "google.golang.org/protobuf/testing/protocmp" "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" expand "github.com/authzed/spicedb/internal/graph" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" @@ -280,7 +281,7 @@ func TestMaxDepthExpand(t *testing.T) { require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, _ := testfixtures.StandardDatastoreWithSchema(rawDS, require) diff --git a/internal/dispatch/graph/lookupresources2_test.go b/internal/dispatch/graph/lookupresources2_test.go index 7d7a1851d9..1a3b702209 100644 --- a/internal/dispatch/graph/lookupresources2_test.go +++ b/internal/dispatch/graph/lookupresources2_test.go @@ -13,6 +13,7 @@ import ( "go.uber.org/goleak" "google.golang.org/protobuf/types/known/structpb" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/dispatch" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" @@ -310,7 +311,7 @@ func TestMaxDepthLookup2(t *testing.T) { t.Parallel() require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) @@ -754,7 +755,7 @@ func TestLookupResources2OverSchemaWithCursors(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) @@ -830,7 +831,7 @@ func TestLookupResources2ImmediateTimeout(t *testing.T) { require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) @@ -865,7 +866,7 @@ func TestLookupResources2WithError(t *testing.T) { require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) @@ -1341,7 +1342,7 @@ func TestLookupResources2EnsureCheckHints(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, tc.schema, tc.relationships, require) diff --git a/internal/dispatch/graph/lookupresources_test.go b/internal/dispatch/graph/lookupresources_test.go index ec0f731c23..9766cbedde 100644 --- a/internal/dispatch/graph/lookupresources_test.go +++ b/internal/dispatch/graph/lookupresources_test.go @@ -10,6 +10,7 @@ import ( "github.com/ccoveille/go-safecast" "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/dispatch" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" @@ -302,7 +303,7 @@ func TestMaxDepthLookup(t *testing.T) { t.Parallel() require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) @@ -606,7 +607,7 @@ func TestLookupResourcesOverSchemaWithCursors(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) @@ -664,7 +665,7 @@ func TestLookupResourcesImmediateTimeout(t *testing.T) { require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) @@ -697,7 +698,7 @@ func TestLookupResourcesWithError(t *testing.T) { require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) diff --git a/internal/dispatch/graph/lookupsubjects_test.go b/internal/dispatch/graph/lookupsubjects_test.go index d441749a12..23c1d84038 100644 --- a/internal/dispatch/graph/lookupsubjects_test.go +++ b/internal/dispatch/graph/lookupsubjects_test.go @@ -10,6 +10,7 @@ import ( "github.com/authzed/spicedb/internal/caveats" "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/dispatch" log "github.com/authzed/spicedb/internal/logging" @@ -194,7 +195,7 @@ func TestLookupSubjectsMaxDepth(t *testing.T) { t.Parallel() require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, _ := testfixtures.StandardDatastoreWithSchema(rawDS, require) @@ -997,7 +998,7 @@ func TestLookupSubjectsOverSchema(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) diff --git a/internal/dispatch/graph/reachableresources_test.go b/internal/dispatch/graph/reachableresources_test.go index e4442b2a29..759022cb6c 100644 --- a/internal/dispatch/graph/reachableresources_test.go +++ b/internal/dispatch/graph/reachableresources_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/dispatch" "github.com/authzed/spicedb/internal/dispatch/caching" @@ -257,7 +258,7 @@ func BenchmarkReachableResources(b *testing.B) { ) require := require.New(b) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) @@ -567,7 +568,7 @@ func TestCaveatedReachableResources(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) @@ -636,7 +637,7 @@ func TestReachableResourcesWithConsistencyLimitOf1(t *testing.T) { func TestReachableResourcesMultipleEntrypointEarlyCancel(t *testing.T) { t.Parallel() - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) testRels := make([]tuple.Relationship, 0) @@ -712,7 +713,7 @@ func TestReachableResourcesMultipleEntrypointEarlyCancel(t *testing.T) { func TestReachableResourcesCursors(t *testing.T) { t.Parallel() - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) testRels := make([]tuple.Relationship, 0) @@ -828,7 +829,7 @@ func TestReachableResourcesCursors(t *testing.T) { func TestReachableResourcesPaginationWithLimit(t *testing.T) { t.Parallel() - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) testRels := make([]tuple.Relationship, 0) @@ -909,7 +910,7 @@ func TestReachableResourcesPaginationWithLimit(t *testing.T) { func TestReachableResourcesWithQueryError(t *testing.T) { t.Parallel() - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) testRels := make([]tuple.Relationship, 0) @@ -1209,7 +1210,7 @@ func TestReachableResourcesOverSchema(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) @@ -1269,7 +1270,7 @@ func TestReachableResourcesOverSchema(t *testing.T) { func TestReachableResourcesWithPreCancelation(t *testing.T) { t.Parallel() - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) testRels := make([]tuple.Relationship, 0) @@ -1323,7 +1324,7 @@ func TestReachableResourcesWithPreCancelation(t *testing.T) { func TestReachableResourcesWithUnexpectedContextCancelation(t *testing.T) { t.Parallel() - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) testRels := make([]tuple.Relationship, 0) @@ -1407,7 +1408,7 @@ func (cr *cancelingReader) ReverseQueryRelationships( func TestReachableResourcesWithCachingInParallelTest(t *testing.T) { t.Parallel() - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) testRels := make([]tuple.Relationship, 0) diff --git a/internal/graph/check.go b/internal/graph/check.go index 47bd8b7d5d..0acb216a30 100644 --- a/internal/graph/check.go +++ b/internal/graph/check.go @@ -3,6 +3,7 @@ package graph import ( "context" "errors" + "fmt" "time" "github.com/google/uuid" @@ -19,6 +20,7 @@ import ( "github.com/authzed/spicedb/internal/namespace" "github.com/authzed/spicedb/internal/taskrunner" "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/genutil/mapz" "github.com/authzed/spicedb/pkg/middleware/nodeid" nspkg "github.com/authzed/spicedb/pkg/namespace" @@ -321,14 +323,14 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest // 2) the wildcard form of the target subject, if a wildcard is allowed on this relation // 3) Otherwise, any non-terminal (non-`...`) subjects, if allowed on this relation, to be // redispatched outward - hasNonTerminals := false - hasDirectSubject := false - hasWildcardSubject := false + totalNonTerminals := 0 + totalDirectSubjects := 0 + totalWildcardSubjects := 0 defer func() { - if hasNonTerminals { + if totalNonTerminals > 0 { span.SetName("non terminal") - } else if hasDirectSubject { + } else if totalDirectSubjects > 0 { span.SetName("terminal") } else { span.SetName("wildcard subject") @@ -337,6 +339,11 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest log.Ctx(ctx).Trace().Object("direct", crc.parentReq).Send() ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + directSubjectsAndWildcardsWithoutCaveats := 0 + directSubjectsAndWildcardsWithoutExpiration := 0 + nonTerminalsWithoutCaveats := 0 + nonTerminalsWithoutExpiration := 0 + for _, allowedDirectRelation := range relation.GetTypeInformation().GetAllowedDirectRelations() { // If the namespace of the allowed direct relation matches the subject type, there are two // cases to optimize: @@ -344,9 +351,17 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest // 2) Finding a wildcard for the subject type+relation if allowedDirectRelation.GetNamespace() == crc.parentReq.Subject.Namespace { if allowedDirectRelation.GetPublicWildcard() != nil { - hasWildcardSubject = true + totalWildcardSubjects++ } else if allowedDirectRelation.GetRelation() == crc.parentReq.Subject.Relation { - hasDirectSubject = true + totalDirectSubjects++ + } + + if allowedDirectRelation.RequiredCaveat == nil { + directSubjectsAndWildcardsWithoutCaveats++ + } + + if allowedDirectRelation.RequiredExpiration == nil { + directSubjectsAndWildcardsWithoutExpiration++ } } @@ -356,10 +371,20 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest // TODO(jschorr): Use type information to *further* optimize this query around which nested // relations can reach the target subject type. if allowedDirectRelation.GetRelation() != tuple.Ellipsis { - hasNonTerminals = true + totalNonTerminals++ + if allowedDirectRelation.RequiredCaveat == nil { + nonTerminalsWithoutCaveats++ + } + if allowedDirectRelation.RequiredExpiration == nil { + nonTerminalsWithoutExpiration++ + } } } + nonTerminalsCanHaveCaveats := totalNonTerminals != nonTerminalsWithoutCaveats + nonTerminalsCanHaveExpiration := totalNonTerminals != nonTerminalsWithoutExpiration + hasNonTerminals := totalNonTerminals > 0 + foundResources := NewMembershipSet() // If the direct subject or a wildcard form can be found, issue a query for just that @@ -369,7 +394,12 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest directDispatchQueryHistogram.Observe(queryCount) }() + hasDirectSubject := totalDirectSubjects > 0 + hasWildcardSubject := totalWildcardSubjects > 0 if hasDirectSubject || hasWildcardSubject { + directSubjectOrWildcardCanHaveCaveats := directSubjectsAndWildcardsWithoutCaveats != (totalDirectSubjects + totalWildcardSubjects) + directSubjectOrWildcardCanHaveExpiration := directSubjectsAndWildcardsWithoutExpiration != (totalDirectSubjects + totalWildcardSubjects) + subjectSelectors := []datastore.SubjectsSelector{} if hasDirectSubject { @@ -395,7 +425,10 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest OptionalSubjectsSelectors: subjectSelectors, } - it, err := ds.QueryRelationships(ctx, filter) + it, err := ds.QueryRelationships(ctx, filter, + options.WithSkipCaveats(!directSubjectOrWildcardCanHaveCaveats), + options.WithSkipExpiration(!directSubjectOrWildcardCanHaveExpiration), + ) if err != nil { return checkResultError(NewCheckFailureErr(err), emptyMetadata) } @@ -444,7 +477,10 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest }, } - it, err := ds.QueryRelationships(ctx, filter) + it, err := ds.QueryRelationships(ctx, filter, + options.WithSkipCaveats(!nonTerminalsCanHaveCaveats), + options.WithSkipExpiration(!nonTerminalsCanHaveExpiration), + ) if err != nil { return checkResultError(NewCheckFailureErr(err), emptyMetadata) } @@ -638,6 +674,73 @@ func (cc *ConcurrentChecker) checkComputedUserset(ctx context.Context, crc curre return combineResultWithFoundResources(result, membershipSet) } +type Traits struct { + HasCaveats bool + HasExpiration bool +} + +// TraitsForArrowRelation returns traits such as HasCaveats and HasExpiration if *any* of the subject +// types of the given relation support caveats or expiration. +func TraitsForArrowRelation(ctx context.Context, reader datastore.Reader, namespaceName string, relationName string) (Traits, error) { + // TODO(jschorr): Change to use the type system once we wire it through Check dispatch. + nsDefs, err := reader.LookupNamespacesWithNames(ctx, []string{namespaceName}) + if err != nil { + return Traits{}, err + } + + if len(nsDefs) != 1 { + return Traits{}, fmt.Errorf("namespace %q not found", namespaceName) + } + + var relation *core.Relation + for _, rel := range nsDefs[0].Definition.Relation { + if rel.Name == relationName { + relation = rel + break + } + } + + if relation == nil || relation.TypeInformation == nil { + return Traits{}, fmt.Errorf("relation %q not found", relationName) + } + + hasCaveats := false + hasExpiration := false + + for _, allowedDirectRelation := range relation.TypeInformation.GetAllowedDirectRelations() { + if allowedDirectRelation.RequiredCaveat != nil { + hasCaveats = true + } + + if allowedDirectRelation.RequiredExpiration != nil { + hasExpiration = true + } + } + + return Traits{ + HasCaveats: hasCaveats, + HasExpiration: hasExpiration, + }, nil +} + +func queryOptionsForArrowRelation(ctx context.Context, ds datastore.Reader, namespaceName string, relationName string) ([]options.QueryOptionsOption, error) { + traits, err := TraitsForArrowRelation(ctx, ds, namespaceName, relationName) + if err != nil { + return nil, err + } + + opts := []options.QueryOptionsOption{} + if !traits.HasCaveats { + opts = append(opts, options.WithSkipCaveats(true)) + } + + if !traits.HasExpiration { + opts = append(opts, options.WithSkipExpiration(true)) + } + + return opts, nil +} + func filterForFoundMemberResource(resourceRelation *core.RelationReference, resourceIds []string, subject *core.ObjectAndRelation) (*MembershipSet, []string) { if resourceRelation.Namespace != subject.Namespace || resourceRelation.Relation != subject.Relation { return nil, resourceIds @@ -688,11 +791,16 @@ func checkIntersectionTupleToUserset( // Query for the subjects over which to walk the TTU. log.Ctx(ctx).Trace().Object("intersectionttu", crc.parentReq).Send() ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + queryOpts, err := queryOptionsForArrowRelation(ctx, ds, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: crc.parentReq.ResourceRelation.Namespace, OptionalResourceIds: crc.filteredResourceIDs, OptionalResourceRelation: ttu.GetTupleset().GetRelation(), - }) + }, queryOpts...) if err != nil { return checkResultError(NewCheckFailureErr(err), emptyMetadata) } @@ -849,11 +957,17 @@ func checkTupleToUserset[T relation]( log.Ctx(ctx).Trace().Object("ttu", crc.parentReq).Send() ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + + queryOpts, err := queryOptionsForArrowRelation(ctx, ds, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: crc.parentReq.ResourceRelation.Namespace, OptionalResourceIds: filteredResourceIDs, OptionalResourceRelation: ttu.GetTupleset().GetRelation(), - }) + }, queryOpts...) if err != nil { return checkResultError(NewCheckFailureErr(err), emptyMetadata) } diff --git a/internal/graph/check_isolated_test.go b/internal/graph/check_isolated_test.go new file mode 100644 index 0000000000..0957026a20 --- /dev/null +++ b/internal/graph/check_isolated_test.go @@ -0,0 +1,148 @@ +package graph_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/internal/datastore/dsfortesting" + "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/internal/graph" + "github.com/authzed/spicedb/internal/testfixtures" +) + +func TestTraitsForArrowRelation(t *testing.T) { + tcs := []struct { + name string + schema string + namespaceName string + relationName string + expectedTraits graph.Traits + expectedError string + }{ + { + name: "unknown namespace", + schema: `definition user {}`, + namespaceName: "unknown", + relationName: "unknown", + expectedTraits: graph.Traits{}, + expectedError: "not found", + }, + { + name: "unknown relation", + schema: `definition resource {}`, + namespaceName: "resource", + relationName: "unknown", + expectedTraits: graph.Traits{}, + expectedError: "not found", + }, + { + name: "known relation with all optimizations", + schema: ` + definition folder {} + + definition resource { + relation folder: folder + }`, + namespaceName: "resource", + relationName: "folder", + expectedTraits: graph.Traits{}, + }, + { + name: "known relation with caveats", + schema: ` + definition folder {} + + caveat somecaveat(somecondition int) { + somecondition == 42 + } + + definition resource { + relation folder: folder with somecaveat + }`, + namespaceName: "resource", + relationName: "folder", + expectedTraits: graph.Traits{ + HasCaveats: true, + }, + }, + { + name: "known relation with expiration", + schema: ` + use expiration + + definition folder {} + + definition resource { + relation folder: folder with expiration + }`, + namespaceName: "resource", + relationName: "folder", + expectedTraits: graph.Traits{ + HasExpiration: true, + }, + }, + { + name: "known relation with caveats and expiration", + schema: ` + use expiration + + caveat somecaveat(somecondition int) { + somecondition == 42 + } + + definition folder {} + + definition resource { + relation folder: folder with somecaveat and expiration + }`, + namespaceName: "resource", + relationName: "folder", + expectedTraits: graph.Traits{ + HasCaveats: true, + HasExpiration: true, + }, + }, + { + name: "different relation with caveats and expiration", + schema: ` + use expiration + + caveat somecaveat(somecondition int) { + somecondition == 42 + } + + definition folder {} + + definition resource { + relation folder: folder + relation folder2: folder with somecaveat and expiration + }`, + namespaceName: "resource", + relationName: "folder", + expectedTraits: graph.Traits{}, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) + require.NoError(err) + + ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, tc.schema, nil, require) + reader := ds.SnapshotReader(revision) + + traits, err := graph.TraitsForArrowRelation(context.Background(), reader, tc.namespaceName, tc.relationName) + if tc.expectedError != "" { + require.ErrorContains(err, tc.expectedError) + return + } + + require.NoError(err) + require.Equal(tc.expectedTraits, traits) + }) + } +} diff --git a/internal/graph/computed/computecheck_test.go b/internal/graph/computed/computecheck_test.go index c00102fc34..0eea7e9bee 100644 --- a/internal/graph/computed/computecheck_test.go +++ b/internal/graph/computed/computecheck_test.go @@ -8,6 +8,7 @@ import ( "google.golang.org/protobuf/types/known/structpb" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/dispatch/graph" "github.com/authzed/spicedb/internal/graph/computed" @@ -805,7 +806,7 @@ func TestComputeCheckWithCaveats(t *testing.T) { for _, tt := range testCases { tt := tt t.Run(tt.name, func(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) dispatch := graph.NewLocalOnlyDispatcher(10, 100) @@ -855,7 +856,7 @@ func TestComputeCheckWithCaveats(t *testing.T) { } func TestComputeCheckError(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) dispatch := graph.NewLocalOnlyDispatcher(10, 100) @@ -878,7 +879,7 @@ func TestComputeCheckError(t *testing.T) { } func TestComputeBulkCheck(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) dispatch := graph.NewLocalOnlyDispatcher(10, 100) diff --git a/internal/graph/hints/checkhints_test.go b/internal/graph/hints/checkhints_test.go index 1513fc0f01..bcce7a3b4e 100644 --- a/internal/graph/hints/checkhints_test.go +++ b/internal/graph/hints/checkhints_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" "github.com/authzed/spicedb/pkg/datastore" @@ -97,7 +98,7 @@ func TestHintForEntrypoint(t *testing.T) { func buildReachabilityGraph(t *testing.T, schema string) *typesystem.ReachabilityGraph { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ctx := datastoremw.ContextWithDatastore(context.Background(), ds) diff --git a/internal/namespace/aliasing_test.go b/internal/namespace/aliasing_test.go index ebe2649d2a..8483c9b5ef 100644 --- a/internal/namespace/aliasing_test.go +++ b/internal/namespace/aliasing_test.go @@ -9,6 +9,7 @@ import ( core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/typesystem" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" ns "github.com/authzed/spicedb/pkg/namespace" ) @@ -195,7 +196,7 @@ func TestAliasing(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) lastRevision, err := ds.HeadRevision(context.Background()) diff --git a/internal/namespace/annotate_test.go b/internal/namespace/annotate_test.go index e1deb67b28..3ada546a37 100644 --- a/internal/namespace/annotate_test.go +++ b/internal/namespace/annotate_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/pkg/schemadsl/compiler" "github.com/authzed/spicedb/pkg/schemadsl/input" @@ -15,7 +16,7 @@ import ( func TestAnnotateNamespace(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) compiled, err := compiler.Compile(compiler.InputSchema{ diff --git a/internal/namespace/canonicalization_test.go b/internal/namespace/canonicalization_test.go index 3250faf2e1..c8c313966a 100644 --- a/internal/namespace/canonicalization_test.go +++ b/internal/namespace/canonicalization_test.go @@ -10,6 +10,7 @@ import ( core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/typesystem" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" ns "github.com/authzed/spicedb/pkg/namespace" "github.com/authzed/spicedb/pkg/schemadsl/compiler" @@ -425,7 +426,7 @@ func TestCanonicalization(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ctx := context.Background() @@ -552,7 +553,7 @@ func TestCanonicalizationComparison(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ctx := context.Background() diff --git a/internal/namespace/util_test.go b/internal/namespace/util_test.go index 7865860244..82a69d3a4b 100644 --- a/internal/namespace/util_test.go +++ b/internal/namespace/util_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/namespace" "github.com/authzed/spicedb/internal/testfixtures" @@ -162,7 +163,7 @@ func TestCheckNamespaceAndRelations(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { req := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) req.NoError(err) ds, _ := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, tc.schema, nil, req) diff --git a/internal/relationships/validation_test.go b/internal/relationships/validation_test.go index ef645c50fe..6fc5c3a9ab 100644 --- a/internal/relationships/validation_test.go +++ b/internal/relationships/validation_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/testfixtures" core "github.com/authzed/spicedb/pkg/proto/core/v1" @@ -311,7 +312,7 @@ func TestValidateRelationshipOperations(t *testing.T) { t.Parallel() req := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) req.NoError(err) uds, rev := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, nil, req) diff --git a/internal/services/integrationtesting/cert_test.go b/internal/services/integrationtesting/cert_test.go index 329d6e114a..465122a34e 100644 --- a/internal/services/integrationtesting/cert_test.go +++ b/internal/services/integrationtesting/cert_test.go @@ -23,7 +23,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/backoff" - "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/dispatch/graph" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" "github.com/authzed/spicedb/internal/middleware/servicespecific" @@ -115,7 +115,7 @@ func TestCertRotation(t *testing.T) { require.NoError(t, certFile.Close()) // start a server with an initial set of certs - emptyDS, err := memdb.NewMemdbDatastore(0, 10, time.Duration(90_000_000_000_000)) + emptyDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 10, time.Duration(90_000_000_000_000)) require.NoError(t, err) ds, revision := tf.StandardDatastoreWithData(emptyDS, require.New(t)) ctx, cancel := context.WithCancel(context.Background()) diff --git a/internal/services/integrationtesting/consistencytestutil/clusteranddata.go b/internal/services/integrationtesting/consistencytestutil/clusteranddata.go index cb38dbf365..d273e7bd86 100644 --- a/internal/services/integrationtesting/consistencytestutil/clusteranddata.go +++ b/internal/services/integrationtesting/consistencytestutil/clusteranddata.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/grpc" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/dispatch/caching" "github.com/authzed/spicedb/internal/dispatch/graph" @@ -38,7 +39,7 @@ type ConsistencyClusterAndData struct { func LoadDataAndCreateClusterForTesting(t *testing.T, consistencyTestFilePath string, revisionDelta time.Duration, additionalServerOptions ...server.ConfigOption) ConsistencyClusterAndData { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, revisionDelta, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, revisionDelta, memdb.DisableGC) require.NoError(err) return BuildDataAndCreateClusterForTesting(t, consistencyTestFilePath, ds, additionalServerOptions...) diff --git a/internal/services/shared/schema_test.go b/internal/services/shared/schema_test.go index d2efcad058..983fd35895 100644 --- a/internal/services/shared/schema_test.go +++ b/internal/services/shared/schema_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/testfixtures" "github.com/authzed/spicedb/pkg/datastore" @@ -292,7 +293,7 @@ func TestApplySchemaChanges(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) // Write the initial schema. diff --git a/internal/services/steelthreadtesting/steelthread_test.go b/internal/services/steelthreadtesting/steelthread_test.go index 79fa8211ef..39733522a1 100644 --- a/internal/services/steelthreadtesting/steelthread_test.go +++ b/internal/services/steelthreadtesting/steelthread_test.go @@ -17,7 +17,7 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" - "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/testserver" testdatastore "github.com/authzed/spicedb/internal/testserver/datastore" "github.com/authzed/spicedb/internal/testserver/datastore/config" @@ -31,7 +31,7 @@ const defaultConnBufferSize = humanize.MiByte func TestMemdbSteelThreads(t *testing.T) { for _, tc := range steelThreadTestCases { t.Run(tc.name, func(t *testing.T) { - emptyDS, err := memdb.NewMemdbDatastore(0, 5*time.Second, 2*time.Hour) + emptyDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 2*time.Hour) require.NoError(t, err) runSteelThreadTest(t, tc, emptyDS) diff --git a/internal/services/v1/experimental_test.go b/internal/services/v1/experimental_test.go index 595878c3ea..dd84ebda37 100644 --- a/internal/services/v1/experimental_test.go +++ b/internal/services/v1/experimental_test.go @@ -10,7 +10,6 @@ import ( "strconv" "testing" - "github.com/authzed/authzed-go/pkg/responsemeta" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/authzed/grpcutil" "github.com/ccoveille/go-safecast" @@ -435,10 +434,9 @@ func TestBulkCheckPermission(t *testing.T) { defer cleanup() testCases := []struct { - name string - requests []string - response []bulkCheckTest - expectedDispatchCount int + name string + requests []string + response []bulkCheckTest }{ { name: "same resource and permission, different subjects", @@ -461,7 +459,6 @@ func TestBulkCheckPermission(t *testing.T) { resp: v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, }, }, - expectedDispatchCount: 49, }, { name: "different resources, same permission and subject", @@ -484,7 +481,6 @@ func TestBulkCheckPermission(t *testing.T) { resp: v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, }, }, - expectedDispatchCount: 18, }, { name: "some items fail", @@ -507,36 +503,6 @@ func TestBulkCheckPermission(t *testing.T) { err: namespace.NewNamespaceNotFoundErr("superfake"), }, }, - expectedDispatchCount: 17, - }, - { - name: "different caveat context is not clustered", - requests: []string{ - `document:masterplan#view@user:eng_lead[test:{"secret": "1234"}]`, - `document:companyplan#view@user:eng_lead[test:{"secret": "1234"}]`, - `document:masterplan#view@user:eng_lead[test:{"secret": "4321"}]`, - `document:masterplan#view@user:eng_lead`, - }, - response: []bulkCheckTest{ - { - req: `document:masterplan#view@user:eng_lead[test:{"secret": "1234"}]`, - resp: v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION, - }, - { - req: `document:companyplan#view@user:eng_lead[test:{"secret": "1234"}]`, - resp: v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, - }, - { - req: `document:masterplan#view@user:eng_lead[test:{"secret": "4321"}]`, - resp: v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, - }, - { - req: `document:masterplan#view@user:eng_lead`, - resp: v1.CheckPermissionResponse_PERMISSIONSHIP_CONDITIONAL_PERMISSION, - partial: []string{"secret"}, - }, - }, - expectedDispatchCount: 50, }, { name: "namespace validation", @@ -554,7 +520,6 @@ func TestBulkCheckPermission(t *testing.T) { err: namespace.NewNamespaceNotFoundErr("fake"), }, }, - expectedDispatchCount: 1, }, { name: "chunking test", @@ -577,7 +542,6 @@ func TestBulkCheckPermission(t *testing.T) { return toReturn })(), - expectedDispatchCount: 11, }, { name: "chunking test with errors", @@ -607,7 +571,6 @@ func TestBulkCheckPermission(t *testing.T) { return toReturn })(), - expectedDispatchCount: 11, }, { name: "same resource and permission with same subject, repeated", @@ -625,7 +588,6 @@ func TestBulkCheckPermission(t *testing.T) { resp: v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION, }, }, - expectedDispatchCount: 17, }, } @@ -694,10 +656,6 @@ func TestBulkCheckPermission(t *testing.T) { actual, err := client.BulkCheckPermission(context.Background(), &req, grpc.Trailer(&trailer)) require.NoError(t, err) - dispatchCount, err := responsemeta.GetIntResponseTrailerMetadata(trailer, responsemeta.DispatchedOperationsCount) - require.NoError(t, err) - require.Equal(t, tt.expectedDispatchCount, dispatchCount) - testutil.RequireProtoSlicesEqual(t, expected, actual.Pairs, nil, "response bulk check pairs did not match") }) } diff --git a/internal/services/v1/permissions_test.go b/internal/services/v1/permissions_test.go index 1e03fe09c8..ff0bd11f59 100644 --- a/internal/services/v1/permissions_test.go +++ b/internal/services/v1/permissions_test.go @@ -1028,9 +1028,9 @@ func TestCheckWithCaveats(t *testing.T) { AtLeastAsFresh: zedtoken.MustNewFromRevision(revision), }, }, - Resource: obj("document", "companyplan"), - Permission: "view", - Subject: sub("user", "owner", ""), + Resource: obj("document", "caveatedplan"), + Permission: "caveated_viewer", + Subject: sub("user", "caveatedguy", ""), } // caveat evaluated and returned false @@ -1775,10 +1775,9 @@ func TestCheckBulkPermissions(t *testing.T) { defer cleanup() testCases := []struct { - name string - requests []string - response []bulkCheckTest - expectedDispatchCount int + name string + requests []string + response []bulkCheckTest }{ { name: "same resource and permission, different subjects", @@ -1801,7 +1800,6 @@ func TestCheckBulkPermissions(t *testing.T) { resp: v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, }, }, - expectedDispatchCount: 49, }, { name: "different resources, same permission and subject", @@ -1824,7 +1822,6 @@ func TestCheckBulkPermissions(t *testing.T) { resp: v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, }, }, - expectedDispatchCount: 18, }, { name: "some items fail", @@ -1847,36 +1844,29 @@ func TestCheckBulkPermissions(t *testing.T) { err: namespace.NewNamespaceNotFoundErr("superfake"), }, }, - expectedDispatchCount: 17, }, { name: "different caveat context is not clustered", requests: []string{ - `document:masterplan#view@user:eng_lead[test:{"secret": "1234"}]`, - `document:companyplan#view@user:eng_lead[test:{"secret": "1234"}]`, - `document:masterplan#view@user:eng_lead[test:{"secret": "4321"}]`, - `document:masterplan#view@user:eng_lead`, + `document:caveatedplan#caveated_viewer@user:caveatedguy[test:{"secret": "1234"}]`, + `document:caveatedplan#caveated_viewer@user:caveatedguy[test:{"secret": "4321"}]`, + `document:caveatedplan#caveated_viewer@user:caveatedguy`, }, response: []bulkCheckTest{ { - req: `document:masterplan#view@user:eng_lead[test:{"secret": "1234"}]`, + req: `document:caveatedplan#caveated_viewer@user:caveatedguy[test:{"secret": "1234"}]`, resp: v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION, }, { - req: `document:companyplan#view@user:eng_lead[test:{"secret": "1234"}]`, - resp: v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, - }, - { - req: `document:masterplan#view@user:eng_lead[test:{"secret": "4321"}]`, + req: `document:caveatedplan#caveated_viewer@user:caveatedguy[test:{"secret": "4321"}]`, resp: v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, }, { - req: `document:masterplan#view@user:eng_lead`, + req: `document:caveatedplan#caveated_viewer@user:caveatedguy`, resp: v1.CheckPermissionResponse_PERMISSIONSHIP_CONDITIONAL_PERMISSION, partial: []string{"secret"}, }, }, - expectedDispatchCount: 50, }, { name: "namespace validation", @@ -1894,7 +1884,6 @@ func TestCheckBulkPermissions(t *testing.T) { err: namespace.NewNamespaceNotFoundErr("fake"), }, }, - expectedDispatchCount: 1, }, { name: "chunking test", @@ -1917,7 +1906,6 @@ func TestCheckBulkPermissions(t *testing.T) { return toReturn })(), - expectedDispatchCount: 11, }, { name: "chunking test with errors", @@ -1947,7 +1935,6 @@ func TestCheckBulkPermissions(t *testing.T) { return toReturn })(), - expectedDispatchCount: 11, }, { name: "same resource and permission with same subject, repeated", @@ -1965,7 +1952,6 @@ func TestCheckBulkPermissions(t *testing.T) { resp: v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION, }, }, - expectedDispatchCount: 17, }, } @@ -2028,10 +2014,6 @@ func TestCheckBulkPermissions(t *testing.T) { actual, err := client.CheckBulkPermissions(context.Background(), &req, grpc.Trailer(&trailer)) require.NoError(t, err) - dispatchCount, err := responsemeta.GetIntResponseTrailerMetadata(trailer, responsemeta.DispatchedOperationsCount) - require.NoError(t, err) - require.Equal(t, tt.expectedDispatchCount, dispatchCount) - if withTracing { for index, pair := range actual.Pairs { if pair.GetItem() != nil { diff --git a/internal/services/v1/preconditions_test.go b/internal/services/v1/preconditions_test.go index 88ab71b7a0..3ac10366ce 100644 --- a/internal/services/v1/preconditions_test.go +++ b/internal/services/v1/preconditions_test.go @@ -7,6 +7,7 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/testfixtures" "github.com/authzed/spicedb/pkg/datastore" @@ -32,7 +33,7 @@ var prefixNoMatch = &v1.RelationshipFilter{ func TestPreconditions(t *testing.T) { require := require.New(t) - uninitialized, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + uninitialized, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, _ := testfixtures.StandardDatastoreWithData(uninitialized, require) diff --git a/internal/testfixtures/datastore.go b/internal/testfixtures/datastore.go index 53d0b2e8d4..cc448ed5e3 100644 --- a/internal/testfixtures/datastore.go +++ b/internal/testfixtures/datastore.go @@ -4,7 +4,6 @@ import ( "context" "github.com/stretchr/testify/require" - "google.golang.org/protobuf/types/known/structpb" "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/internal/namespace" @@ -37,6 +36,7 @@ var DocumentNS = ns.Namespace( ns.MustRelation("owner", nil, ns.AllowedRelation("user", "..."), + ns.AllowedRelationWithCaveat("user", "...", ns.AllowedCaveat("test")), ), ns.MustRelation("editor", nil, @@ -45,6 +45,7 @@ var DocumentNS = ns.Namespace( ns.MustRelation("viewer", nil, ns.AllowedRelation("user", "..."), + ns.AllowedRelationWithCaveat("user", "...", ns.AllowedCaveat("test")), ), ns.MustRelation("viewer_and_editor", nil, @@ -85,6 +86,7 @@ var FolderNS = ns.Namespace( ns.MustRelation("owner", nil, ns.AllowedRelation("user", "..."), + ns.AllowedRelationWithCaveat("user", "...", ns.AllowedCaveat("test")), ), ns.MustRelation("editor", nil, @@ -94,6 +96,7 @@ var FolderNS = ns.Namespace( nil, ns.AllowedRelation("user", "..."), ns.AllowedRelation("folder", "viewer"), + ns.AllowedRelationWithCaveat("folder", "viewer", ns.AllowedCaveat("test")), ), ns.MustRelation("parent", nil, ns.AllowedRelation("folder", "...")), ns.MustRelation("edit", @@ -137,6 +140,10 @@ var StandardRelationships = []string{ "document:ownerplan#viewer@user:owner#...", } +var StandardCaveatedRelationships = []string{ + "document:caveatedplan#caveated_viewer@user:caveatedguy#...[test:{\"expectedSecret\":\"1234\"}]", +} + // EmptyDatastore returns an empty datastore for testing. func EmptyDatastore(ds datastore.Datastore, require *require.Assertions) (datastore.Datastore, datastore.Revision) { rev, err := ds.HeadRevision(context.Background()) @@ -181,16 +188,17 @@ func StandardDatastoreWithCaveatedData(ds datastore.Datastore, require *require. }) require.NoError(err) - rels := make([]tuple.Relationship, 0, len(StandardRelationships)) + rels := make([]tuple.Relationship, 0, len(StandardRelationships)+len(StandardCaveatedRelationships)) for _, tupleStr := range StandardRelationships { rel, err := tuple.Parse(tupleStr) require.NoError(err) require.NotNil(rel) - - rel.OptionalCaveat = &core.ContextualizedCaveat{ - CaveatName: "test", - Context: mustProtoStruct(map[string]any{"expectedSecret": "1234"}), - } + rels = append(rels, rel) + } + for _, tupleStr := range StandardCaveatedRelationships { + rel, err := tuple.Parse(tupleStr) + require.NoError(err) + require.NotNil(rel) rels = append(rels, rel) } @@ -356,11 +364,3 @@ func (tc RelationshipChecker) NoRelationshipExists(ctx context.Context, rel tupl iter := tc.ExactRelationshipIterator(ctx, rel, rev) tc.VerifyIteratorResults(iter) } - -func mustProtoStruct(in map[string]any) *structpb.Struct { - out, err := structpb.NewStruct(in) - if err != nil { - panic(err) - } - return out -} diff --git a/internal/testserver/datastore/config/config.go b/internal/testserver/datastore/config/config.go index 960a22cadf..f745a73cdd 100644 --- a/internal/testserver/datastore/config/config.go +++ b/internal/testserver/datastore/config/config.go @@ -24,6 +24,7 @@ func DatastoreConfigInitFunc(t testing.TB, options ...dsconfig.ConfigOption) tes append(options, dsconfig.WithEngine(engine), dsconfig.WithEnableDatastoreMetrics(false), + dsconfig.WithEnableExperimentalRelationshipExpiration(true), dsconfig.WithURI(uri), )...) require.NoError(t, err) diff --git a/pkg/cmd/datastore/datastore.go b/pkg/cmd/datastore/datastore.go index 8b7da6726b..08ed140fc9 100644 --- a/pkg/cmd/datastore/datastore.go +++ b/pkg/cmd/datastore/datastore.go @@ -166,6 +166,10 @@ type Config struct { // Migrations MigrationPhase string `debugmap:"visible"` AllowedMigrations []string `debugmap:"visible"` + + // Expermimental + ExperimentalColumnOptimization bool `debugmap:"visible"` + EnableExperimentalRelationshipExpiration bool `debugmap:"visible"` } //go:generate go run github.com/ecordell/optgen -sensitive-field-name-matches uri,secure -output zz_generated.relintegritykey.options.go . RelIntegrityKey @@ -271,53 +275,57 @@ func RegisterDatastoreFlagsWithPrefix(flagSet *pflag.FlagSet, prefix string, opt return fmt.Errorf("failed to mark flag as hidden: %w", err) } + flagSet.BoolVar(&opts.ExperimentalColumnOptimization, flagName("datastore-experimental-column-optimization"), false, "enable experimental column optimization") + return nil } func DefaultDatastoreConfig() *Config { return &Config{ - Engine: MemoryEngine, - GCWindow: 24 * time.Hour, - LegacyFuzzing: -1, - RevisionQuantization: 5 * time.Second, - MaxRevisionStalenessPercent: .1, // 10% - ReadConnPool: *DefaultReadConnPool(), - WriteConnPool: *DefaultWriteConnPool(), - ReadReplicaConnPool: *DefaultReadConnPool(), - ReadReplicaURIs: []string{}, - ReadOnly: false, - MaxRetries: 10, - OverlapKey: "key", - OverlapStrategy: "static", - ConnectRate: 100 * time.Millisecond, - EnableConnectionBalancing: true, - GCInterval: 3 * time.Minute, - GCMaxOperationTime: 1 * time.Minute, - WatchBufferLength: 1024, - WatchBufferWriteTimeout: 1 * time.Second, - WatchConnectTimeout: 1 * time.Second, - EnableDatastoreMetrics: true, - DisableStats: false, - BootstrapFiles: []string{}, - BootstrapTimeout: 10 * time.Second, - BootstrapOverwrite: false, - RequestHedgingEnabled: false, - RequestHedgingInitialSlowValue: 10000000, - RequestHedgingMaxRequests: 1_000_000, - RequestHedgingQuantile: 0.95, - SpannerCredentialsFile: "", - SpannerEmulatorHost: "", - TablePrefix: "", - MigrationPhase: "", - FollowerReadDelay: 4_800 * time.Millisecond, - SpannerMinSessions: 100, - SpannerMaxSessions: 400, - FilterMaximumIDCount: 100, - RelationshipIntegrityEnabled: false, - RelationshipIntegrityCurrentKey: RelIntegrityKey{}, - RelationshipIntegrityExpiredKeys: []string{}, - AllowedMigrations: []string{}, - IncludeQueryParametersInTraces: false, + Engine: MemoryEngine, + GCWindow: 24 * time.Hour, + LegacyFuzzing: -1, + RevisionQuantization: 5 * time.Second, + MaxRevisionStalenessPercent: .1, // 10% + ReadConnPool: *DefaultReadConnPool(), + WriteConnPool: *DefaultWriteConnPool(), + ReadReplicaConnPool: *DefaultReadConnPool(), + ReadReplicaURIs: []string{}, + ReadOnly: false, + MaxRetries: 10, + OverlapKey: "key", + OverlapStrategy: "static", + ConnectRate: 100 * time.Millisecond, + EnableConnectionBalancing: true, + GCInterval: 3 * time.Minute, + GCMaxOperationTime: 1 * time.Minute, + WatchBufferLength: 1024, + WatchBufferWriteTimeout: 1 * time.Second, + WatchConnectTimeout: 1 * time.Second, + EnableDatastoreMetrics: true, + DisableStats: false, + BootstrapFiles: []string{}, + BootstrapTimeout: 10 * time.Second, + BootstrapOverwrite: false, + RequestHedgingEnabled: false, + RequestHedgingInitialSlowValue: 10000000, + RequestHedgingMaxRequests: 1_000_000, + RequestHedgingQuantile: 0.95, + SpannerCredentialsFile: "", + SpannerEmulatorHost: "", + TablePrefix: "", + MigrationPhase: "", + FollowerReadDelay: 4_800 * time.Millisecond, + SpannerMinSessions: 100, + SpannerMaxSessions: 400, + FilterMaximumIDCount: 100, + RelationshipIntegrityEnabled: false, + RelationshipIntegrityCurrentKey: RelIntegrityKey{}, + RelationshipIntegrityExpiredKeys: []string{}, + AllowedMigrations: []string{}, + ExperimentalColumnOptimization: false, + IncludeQueryParametersInTraces: false, + EnableExperimentalRelationshipExpiration: false, } } @@ -512,7 +520,9 @@ func newCRDBDatastore(ctx context.Context, opts Config) (datastore.Datastore, er crdb.FilterMaximumIDCount(opts.FilterMaximumIDCount), crdb.WithIntegrity(opts.RelationshipIntegrityEnabled), crdb.AllowedMigrations(opts.AllowedMigrations), + crdb.WithColumnOptimization(opts.ExperimentalColumnOptimization), crdb.IncludeQueryParametersInTraces(opts.IncludeQueryParametersInTraces), + crdb.WithExpirationDisabled(!opts.EnableExperimentalRelationshipExpiration), ) } @@ -553,7 +563,9 @@ func commonPostgresDatastoreOptions(opts Config) ([]postgres.Option, error) { postgres.WithEnablePrometheusStats(opts.EnableDatastoreMetrics), postgres.MaxRetries(maxRetries), postgres.FilterMaximumIDCount(opts.FilterMaximumIDCount), + postgres.WithColumnOptimization(opts.ExperimentalColumnOptimization), postgres.IncludeQueryParametersInTraces(opts.IncludeQueryParametersInTraces), + postgres.WithExpirationDisabled(!opts.EnableExperimentalRelationshipExpiration), }, nil } @@ -636,6 +648,8 @@ func newSpannerDatastore(ctx context.Context, opts Config) (datastore.Datastore, spanner.MigrationPhase(opts.MigrationPhase), spanner.AllowedMigrations(opts.AllowedMigrations), spanner.FilterMaximumIDCount(opts.FilterMaximumIDCount), + spanner.WithColumnOptimization(opts.ExperimentalColumnOptimization), + spanner.WithExpirationDisabled(!opts.EnableExperimentalRelationshipExpiration), ) } @@ -680,6 +694,8 @@ func commonMySQLDatastoreOptions(opts Config) ([]mysql.Option, error) { mysql.RevisionQuantization(opts.RevisionQuantization), mysql.FilterMaximumIDCount(opts.FilterMaximumIDCount), mysql.AllowedMigrations(opts.AllowedMigrations), + mysql.WithColumnOptimization(opts.ExperimentalColumnOptimization), + mysql.WithExpirationDisabled(!opts.EnableExperimentalRelationshipExpiration), }, nil } diff --git a/pkg/cmd/datastore/zz_generated.options.go b/pkg/cmd/datastore/zz_generated.options.go index 4bf6a58d30..4e0c537307 100644 --- a/pkg/cmd/datastore/zz_generated.options.go +++ b/pkg/cmd/datastore/zz_generated.options.go @@ -78,6 +78,8 @@ func (c *Config) ToOption() ConfigOption { to.WatchConnectTimeout = c.WatchConnectTimeout to.MigrationPhase = c.MigrationPhase to.AllowedMigrations = c.AllowedMigrations + to.ExperimentalColumnOptimization = c.ExperimentalColumnOptimization + to.EnableExperimentalRelationshipExpiration = c.EnableExperimentalRelationshipExpiration } } @@ -130,6 +132,8 @@ func (c Config) DebugMap() map[string]any { debugMap["WatchConnectTimeout"] = helpers.DebugValue(c.WatchConnectTimeout, false) debugMap["MigrationPhase"] = helpers.DebugValue(c.MigrationPhase, false) debugMap["AllowedMigrations"] = helpers.DebugValue(c.AllowedMigrations, false) + debugMap["ExperimentalColumnOptimization"] = helpers.DebugValue(c.ExperimentalColumnOptimization, false) + debugMap["EnableExperimentalRelationshipExpiration"] = helpers.DebugValue(c.EnableExperimentalRelationshipExpiration, false) return debugMap } @@ -519,3 +523,17 @@ func SetAllowedMigrations(allowedMigrations []string) ConfigOption { c.AllowedMigrations = allowedMigrations } } + +// WithExperimentalColumnOptimization returns an option that can set ExperimentalColumnOptimization on a Config +func WithExperimentalColumnOptimization(experimentalColumnOptimization bool) ConfigOption { + return func(c *Config) { + c.ExperimentalColumnOptimization = experimentalColumnOptimization + } +} + +// WithEnableExperimentalRelationshipExpiration returns an option that can set EnableExperimentalRelationshipExpiration on a Config +func WithEnableExperimentalRelationshipExpiration(enableExperimentalRelationshipExpiration bool) ConfigOption { + return func(c *Config) { + c.EnableExperimentalRelationshipExpiration = enableExperimentalRelationshipExpiration + } +} diff --git a/pkg/cmd/server/server.go b/pkg/cmd/server/server.go index 5f2ca4a37b..3ded6efb7a 100644 --- a/pkg/cmd/server/server.go +++ b/pkg/cmd/server/server.go @@ -226,7 +226,9 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { ds, err = datastorecfg.NewDatastore(context.Background(), c.DatastoreConfig.ToOption(), // Datastore's filter maximum ID count is set to the max size, since the number of elements to be dispatched // are at most the number of elements returned from a datastore query - datastorecfg.WithFilterMaximumIDCount(c.DispatchChunkSize)) + datastorecfg.WithFilterMaximumIDCount(c.DispatchChunkSize), + datastorecfg.WithEnableExperimentalRelationshipExpiration(c.EnableExperimentalRelationshipExpiration), + ) if err != nil { return nil, spiceerrors.NewTerminationErrorBuilder(fmt.Errorf("failed to create datastore: %w", err)). Component("datastore"). diff --git a/pkg/cmd/server/server_test.go b/pkg/cmd/server/server_test.go index dadcee86ad..b36e405da5 100644 --- a/pkg/cmd/server/server_test.go +++ b/pkg/cmd/server/server_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/cmd/datastore" "github.com/authzed/spicedb/pkg/cmd/util" @@ -26,7 +26,7 @@ func TestServerGracefulTermination(t *testing.T) { defer goleak.VerifyNone(t, append(testutil.GoLeakIgnores(), goleak.IgnoreCurrent())...) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - ds, err := memdb.NewMemdbDatastore(0, 1*time.Second, 10*time.Second) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 1*time.Second, 10*time.Second) require.NoError(t, err) c := ConfigWithOptions( @@ -164,7 +164,7 @@ func TestServerGracefulTerminationOnError(t *testing.T) { defer goleak.VerifyNone(t, append(testutil.GoLeakIgnores(), goleak.IgnoreCurrent())...) ctx, cancel := context.WithCancel(context.Background()) - ds, err := memdb.NewMemdbDatastore(0, 1*time.Second, 10*time.Second) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 1*time.Second, 10*time.Second) require.NoError(t, err) c := ConfigWithOptions(&Config{ diff --git a/pkg/datastore/options/options.go b/pkg/datastore/options/options.go index f8fe66f4f4..5bf6bbd40e 100644 --- a/pkg/datastore/options/options.go +++ b/pkg/datastore/options/options.go @@ -41,11 +41,16 @@ func ToRelationship(c Cursor) *tuple.Relationship { return (*tuple.Relationship)(c) } +type Assertion func(sql string) + // QueryOptions are the options that can affect the results of a normal forward query. type QueryOptions struct { - Limit *uint64 `debugmap:"visible"` - Sort SortOrder `debugmap:"visible"` - After Cursor `debugmap:"visible"` + Limit *uint64 `debugmap:"visible"` + Sort SortOrder `debugmap:"visible"` + After Cursor `debugmap:"visible"` + SkipCaveats bool `debugmap:"visible"` + SkipExpiration bool `debugmap:"visible"` + SQLAssertion Assertion `debugmap:"visible"` } // ReverseQueryOptions are the options that can affect the results of a reverse query. diff --git a/pkg/datastore/options/zz_generated.query_options.go b/pkg/datastore/options/zz_generated.query_options.go index f761b06b66..348c53639f 100644 --- a/pkg/datastore/options/zz_generated.query_options.go +++ b/pkg/datastore/options/zz_generated.query_options.go @@ -34,6 +34,9 @@ func (q *QueryOptions) ToOption() QueryOptionsOption { to.Limit = q.Limit to.Sort = q.Sort to.After = q.After + to.SkipCaveats = q.SkipCaveats + to.SkipExpiration = q.SkipExpiration + to.SQLAssertion = q.SQLAssertion } } @@ -43,6 +46,9 @@ func (q QueryOptions) DebugMap() map[string]any { debugMap["Limit"] = helpers.DebugValue(q.Limit, false) debugMap["Sort"] = helpers.DebugValue(q.Sort, false) debugMap["After"] = helpers.DebugValue(q.After, false) + debugMap["SkipCaveats"] = helpers.DebugValue(q.SkipCaveats, false) + debugMap["SkipExpiration"] = helpers.DebugValue(q.SkipExpiration, false) + debugMap["SQLAssertion"] = helpers.DebugValue(q.SQLAssertion, false) return debugMap } @@ -83,6 +89,27 @@ func WithAfter(after Cursor) QueryOptionsOption { } } +// WithSkipCaveats returns an option that can set SkipCaveats on a QueryOptions +func WithSkipCaveats(skipCaveats bool) QueryOptionsOption { + return func(q *QueryOptions) { + q.SkipCaveats = skipCaveats + } +} + +// WithSkipExpiration returns an option that can set SkipExpiration on a QueryOptions +func WithSkipExpiration(skipExpiration bool) QueryOptionsOption { + return func(q *QueryOptions) { + q.SkipExpiration = skipExpiration + } +} + +// WithSQLAssertion returns an option that can set SQLAssertion on a QueryOptions +func WithSQLAssertion(sQLAssertion Assertion) QueryOptionsOption { + return func(q *QueryOptions) { + q.SQLAssertion = sQLAssertion + } +} + type ReverseQueryOptionsOption func(r *ReverseQueryOptions) // NewReverseQueryOptionsWithOptions creates a new ReverseQueryOptions with the passed in options set diff --git a/pkg/datastore/test/datastore.go b/pkg/datastore/test/datastore.go index 8617c12248..034e265a43 100644 --- a/pkg/datastore/test/datastore.go +++ b/pkg/datastore/test/datastore.go @@ -147,6 +147,7 @@ func AllWithExceptions(t *testing.T, tester DatastoreTester, except Categories, t.Run("TestOrderedLimit", runner(tester, OrderedLimitTest)) t.Run("TestResume", runner(tester, ResumeTest)) t.Run("TestReverseQueryCursor", runner(tester, ReverseQueryCursorTest)) + t.Run("TestReverseQueryFilteredCursor", runner(tester, ReverseQueryFilteredOverMultipleValuesCursorTest)) t.Run("TestRevisionQuantization", runner(tester, RevisionQuantizationTest)) t.Run("TestRevisionSerialization", runner(tester, RevisionSerializationTest)) diff --git a/pkg/datastore/test/pagination.go b/pkg/datastore/test/pagination.go index fd7b09edbd..bfe8b9399a 100644 --- a/pkg/datastore/test/pagination.go +++ b/pkg/datastore/test/pagination.go @@ -246,6 +246,65 @@ func ResumeTest(t *testing.T, tester DatastoreTester) { } } +func ReverseQueryFilteredOverMultipleValuesCursorTest(t *testing.T, tester DatastoreTester) { + rawDS, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(t, err) + + // Create a datastore with the standard schema but no data. + ds, _ := testfixtures.StandardDatastoreWithSchema(rawDS, require.New(t)) + + // Add test relationships. + rev, err := ds.ReadWriteTx(context.Background(), func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + return rwt.WriteRelationships(ctx, []tuple.RelationshipUpdate{ + tuple.Create(tuple.MustParse("document:firstdoc#viewer@user:alice")), + tuple.Create(tuple.MustParse("document:firstdoc#viewer@user:tom")), + tuple.Create(tuple.MustParse("document:firstdoc#viewer@user:fred")), + tuple.Create(tuple.MustParse("document:seconddoc#viewer@user:alice")), + tuple.Create(tuple.MustParse("document:seconddoc#viewer@user:*")), + tuple.Create(tuple.MustParse("document:thirddoc#viewer@user:*")), + }) + }) + require.NoError(t, err) + + // Issue a reverse query call with a limit. + for _, sortBy := range []options.SortOrder{options.ByResource, options.BySubject} { + t.Run(fmt.Sprintf("SortBy-%d", sortBy), func(t *testing.T) { + reader := ds.SnapshotReader(rev) + + var limit uint64 = 2 + var cursor options.Cursor + + foundTuples := mapz.NewSet[string]() + + for i := 0; i < 5; i++ { + iter, err := reader.ReverseQueryRelationships(context.Background(), datastore.SubjectsFilter{ + SubjectType: testfixtures.UserNS.Name, + OptionalSubjectIds: []string{"alice", "tom", "fred", "*"}, + }, options.WithResRelation(&options.ResourceRelation{ + Namespace: "document", + Relation: "viewer", + }), options.WithSortForReverse(sortBy), options.WithLimitForReverse(&limit), options.WithAfterForReverse(cursor)) + require.NoError(t, err) + + encounteredTuples := mapz.NewSet[string]() + for rel, err := range iter { + require.NoError(t, err) + require.True(t, encounteredTuples.Add(tuple.MustString(rel))) + cursor = options.ToCursor(rel) + } + + require.LessOrEqual(t, encounteredTuples.Len(), 2) + foundTuples = foundTuples.Union(encounteredTuples) + if encounteredTuples.IsEmpty() { + break + } + } + + require.Equal(t, 6, foundTuples.Len()) + }) + } +} + func ReverseQueryCursorTest(t *testing.T, tester DatastoreTester) { rawDS, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) require.NoError(t, err) diff --git a/pkg/datastore/test/relationships.go b/pkg/datastore/test/relationships.go index cac514944a..d80cca3797 100644 --- a/pkg/datastore/test/relationships.go +++ b/pkg/datastore/test/relationships.go @@ -1113,10 +1113,12 @@ func RecreateRelationshipsAfterDeleteWithFilter(t *testing.T, tester DatastoreTe // QueryRelationshipsWithVariousFiltersTest tests various relationship filters for query relationships. func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTester) { tcs := []struct { - name string - filter datastore.RelationshipsFilter - relationships []string - expected []string + name string + filter datastore.RelationshipsFilter + withoutCaveats bool + withoutExpiration bool + relationships []string + expected []string }{ { name: "resource type", @@ -1475,6 +1477,37 @@ func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTest "folder:someotherfolder#viewer@user:tom", }, }, + { + name: "resource type with caveats", + filter: datastore.RelationshipsFilter{ + OptionalResourceType: "document", + }, + relationships: []string{ + "document:first#viewer@user:tom[firstcaveat]", + "document:second#viewer@user:tom[secondcaveat]", + "folder:secondfolder#viewer@user:tom", + "folder:someotherfolder#viewer@user:tom", + }, + expected: []string{"document:first#viewer@user:tom[firstcaveat]", "document:second#viewer@user:tom[secondcaveat]"}, + }, + { + name: "resource type with caveats and context", + filter: datastore.RelationshipsFilter{ + OptionalResourceType: "document", + }, + relationships: []string{ + "document:first#viewer@user:tom[firstcaveat:{\"foo\":\"bar\"}]", + "document:second#viewer@user:tom[secondcaveat]", + "document:third#viewer@user:tom[secondcaveat:{\"bar\":\"baz\"}]", + "folder:secondfolder#viewer@user:tom", + "folder:someotherfolder#viewer@user:tom", + }, + expected: []string{ + "document:first#viewer@user:tom[firstcaveat:{\"foo\":\"bar\"}]", + "document:second#viewer@user:tom[secondcaveat]", + "document:third#viewer@user:tom[secondcaveat:{\"bar\":\"baz\"}]", + }, + }, { name: "relationship expiration", filter: datastore.RelationshipsFilter{ @@ -1520,6 +1553,24 @@ func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTest "document:first#viewer@user:tom", }, }, + { + name: "no caveats and no expiration", + filter: datastore.RelationshipsFilter{ + OptionalResourceType: "document", + }, + relationships: []string{ + "document:first#viewer@user:tom", + "document:first#viewer@user:fred", + "document:first#viewer@user:sarah", + }, + expected: []string{ + "document:first#viewer@user:tom", + "document:first#viewer@user:fred", + "document:first#viewer@user:sarah", + }, + withoutCaveats: true, + withoutExpiration: true, + }, { name: "multiple subject IDs with subject type", filter: datastore.RelationshipsFilter{ @@ -1544,7 +1595,20 @@ func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTest }, }, { - name: "multiple subject filters", + name: "relationships with expiration", + filter: datastore.RelationshipsFilter{ + OptionalResourceType: "document", + }, + relationships: []string{ + "document:first#expiring_viewer@user:tom[expiration:2020-01-01T00:00:00Z]", + "document:first#expiring_viewer@user:fred[expiration:2321-01-01T00:00:00Z]", + }, + expected: []string{ + "document:first#expiring_viewer@user:fred[expiration:2321-01-01T00:00:00Z]", + }, + }, + { + name: "multiple subject filters with multiple ids", filter: datastore.RelationshipsFilter{ OptionalSubjectsSelectors: []datastore.SubjectsSelector{ { @@ -1574,19 +1638,6 @@ func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTest "folder:secondfolder#viewer@anotheruser:jerry", }, }, - { - name: "relationships with expiration", - filter: datastore.RelationshipsFilter{ - OptionalResourceType: "document", - }, - relationships: []string{ - "document:first#expiring_viewer@user:tom[expiration:2020-01-01T00:00:00Z]", - "document:first#expiring_viewer@user:fred[expiration:2321-01-01T00:00:00Z]", - }, - expected: []string{ - "document:first#expiring_viewer@user:fred[expiration:2321-01-01T00:00:00Z]", - }, - }, } for _, tc := range tcs { @@ -1610,7 +1661,7 @@ func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTest require.NoError(err) reader := ds.SnapshotReader(headRev) - iter, err := reader.QueryRelationships(ctx, tc.filter) + iter, err := reader.QueryRelationships(ctx, tc.filter, options.WithSkipCaveats(tc.withoutCaveats), options.WithSkipExpiration(tc.withoutExpiration)) require.NoError(err) var results []string @@ -1624,28 +1675,6 @@ func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTest } } -// TypedTouchAlreadyExistingTest tests touching a relationship twice, when valid type information is provided. -func TypedTouchAlreadyExistingTest(t *testing.T, tester DatastoreTester) { - require := require.New(t) - - rawDS, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) - require.NoError(err) - - ds, _ := testfixtures.StandardDatastoreWithData(rawDS, require) - ctx := context.Background() - - tpl1, err := tuple.Parse("document:foo#viewer@user:tom") - require.NoError(err) - - _, err = common.WriteRelationships(ctx, ds, tuple.UpdateOperationTouch, tpl1) - require.NoError(err) - ensureRelationships(ctx, require, ds, tpl1) - - _, err = common.WriteRelationships(ctx, ds, tuple.UpdateOperationTouch, tpl1) - require.NoError(err) - ensureRelationships(ctx, require, ds, tpl1) -} - // RelationshipExpirationTest tests expiration on relationships. func RelationshipExpirationTest(t *testing.T, tester DatastoreTester) { require := require.New(t) @@ -1701,6 +1730,28 @@ func RelationshipExpirationTest(t *testing.T, tester DatastoreTester) { ensureReverseRelationships(ctx, require, ds, rel4) } +// TypedTouchAlreadyExistingTest tests touching a relationship twice, when valid type information is provided. +func TypedTouchAlreadyExistingTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + + rawDS, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + + ds, _ := testfixtures.StandardDatastoreWithData(rawDS, require) + ctx := context.Background() + + tpl1, err := tuple.Parse("document:foo#viewer@user:tom") + require.NoError(err) + + _, err = common.WriteRelationships(ctx, ds, tuple.UpdateOperationTouch, tpl1) + require.NoError(err) + ensureRelationships(ctx, require, ds, tpl1) + + _, err = common.WriteRelationships(ctx, ds, tuple.UpdateOperationTouch, tpl1) + require.NoError(err) + ensureRelationships(ctx, require, ds, tpl1) +} + // TypedTouchAlreadyExistingWithCaveatTest tests touching a relationship twice, when valid type information is provided. func TypedTouchAlreadyExistingWithCaveatTest(t *testing.T, tester DatastoreTester) { require := require.New(t) diff --git a/pkg/spiceerrors/assert_off.go b/pkg/spiceerrors/assert_off.go index fa0ac4731e..20ac7ef717 100644 --- a/pkg/spiceerrors/assert_off.go +++ b/pkg/spiceerrors/assert_off.go @@ -3,6 +3,8 @@ package spiceerrors +const DebugAssertionsEnabled = false + // DebugAssert is a no-op in non-CI builds func DebugAssert(condition func() bool, format string, args ...any) { // Do nothing on purpose diff --git a/pkg/spiceerrors/assert_on.go b/pkg/spiceerrors/assert_on.go index b71f8de614..a21414791e 100644 --- a/pkg/spiceerrors/assert_on.go +++ b/pkg/spiceerrors/assert_on.go @@ -8,6 +8,8 @@ import ( "runtime" ) +const DebugAssertionsEnabled = true + // DebugAssert panics if the condition is false in CI builds. func DebugAssert(condition func() bool, format string, args ...any) { if !condition() { diff --git a/pkg/typesystem/reachabilitygraph_test.go b/pkg/typesystem/reachabilitygraph_test.go index 8ca9386d09..02dfdf80ca 100644 --- a/pkg/typesystem/reachabilitygraph_test.go +++ b/pkg/typesystem/reachabilitygraph_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" "github.com/authzed/spicedb/pkg/datastore" @@ -206,7 +207,7 @@ func TestRelationsEncounteredForSubject(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ctx := datastoremw.ContextWithDatastore(context.Background(), ds) @@ -575,7 +576,7 @@ func TestRelationsEncounteredForResource(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ctx := datastoremw.ContextWithDatastore(context.Background(), ds) @@ -1186,7 +1187,7 @@ func TestReachabilityGraph(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ctx := datastoremw.ContextWithDatastore(context.Background(), ds) diff --git a/pkg/typesystem/typesystem_test.go b/pkg/typesystem/typesystem_test.go index 3f576b38da..b865d24386 100644 --- a/pkg/typesystem/typesystem_test.go +++ b/pkg/typesystem/typesystem_test.go @@ -10,6 +10,7 @@ import ( "github.com/authzed/spicedb/pkg/genutil/mapz" core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" "github.com/authzed/spicedb/pkg/caveats" @@ -416,7 +417,7 @@ func TestTypeSystem(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ctx := context.Background() @@ -937,7 +938,7 @@ func TestTypeSystemAccessors(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ctx := datastoremw.ContextWithDatastore(context.Background(), ds) diff --git a/pkg/validationfile/loader_test.go b/pkg/validationfile/loader_test.go index 09741be873..fc1a9e67cc 100644 --- a/pkg/validationfile/loader_test.go +++ b/pkg/validationfile/loader_test.go @@ -5,7 +5,7 @@ import ( "sort" "testing" - "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/proxy/proxy_test" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/options" @@ -127,7 +127,7 @@ func TestPopulateFromFiles(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, 0) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, 0) require.NoError(err) parsed, _, err := PopulateFromFiles(context.Background(), ds, tt.filePaths) @@ -153,7 +153,7 @@ func TestPopulateFromFiles(t *testing.T) { func TestPopulationChunking(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, 0) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, 0) require.NoError(err) cs := txCountingDatastore{delegate: ds}