Skip to content

Commit

Permalink
Merge pull request #2096 from josephschorr/rel-struct-sql
Browse files Browse the repository at this point in the history
Relationships selected in SQL-based datastores now elide columns that have static values
  • Loading branch information
vroldanbet authored Jan 10, 2025
2 parents aae00c4 + f71404b commit e658772
Show file tree
Hide file tree
Showing 76 changed files with 3,195 additions and 1,184 deletions.
7 changes: 4 additions & 3 deletions internal/caveats/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/authzed/spicedb/internal/caveats"
"github.com/authzed/spicedb/internal/datastore/dsfortesting"
"github.com/authzed/spicedb/internal/datastore/memdb"
"github.com/authzed/spicedb/internal/testfixtures"
"github.com/authzed/spicedb/pkg/datastore"
Expand Down Expand Up @@ -448,7 +449,7 @@ func TestRunCaveatExpressions(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := require.New(t)

rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC)
rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC)
req.NoError(err)

ds, _ := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, `
Expand Down Expand Up @@ -507,7 +508,7 @@ func TestRunCaveatExpressions(t *testing.T) {
func TestRunCaveatWithMissingMap(t *testing.T) {
req := require.New(t)

rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC)
rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC)
req.NoError(err)

ds, _ := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, `
Expand Down Expand Up @@ -536,7 +537,7 @@ func TestRunCaveatWithMissingMap(t *testing.T) {
func TestRunCaveatWithEmptyMap(t *testing.T) {
req := require.New(t)

rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC)
rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC)
req.NoError(err)

ds, _ := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, `
Expand Down
159 changes: 159 additions & 0 deletions internal/datastore/common/relationships.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package common

import (
"context"
"database/sql"
"fmt"
"time"

"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/authzed/spicedb/pkg/datastore"
corev1 "github.com/authzed/spicedb/pkg/proto/core/v1"
"github.com/authzed/spicedb/pkg/tuple"
)

const errUnableToQueryRels = "unable to query relationships: %w"

// Querier is an interface for querying the database.
type Querier[R Rows] interface {
QueryFunc(ctx context.Context, f func(context.Context, R) error, sql string, args ...any) error
}

// Rows is a common interface for database rows reading.
type Rows interface {
Scan(dest ...any) error
Next() bool
Err() error
}

type closeRowsWithError interface {
Rows
Close() error
}

type closeRows interface {
Rows
Close()
}

// QueryRelationships queries relationships for the given query and transaction.
func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, builder RelationshipsQueryBuilder, tx Querier[R]) (datastore.RelationshipIterator, error) {
span := trace.SpanFromContext(ctx)
sqlString, args, err := builder.SelectSQL()
if err != nil {
return nil, fmt.Errorf(errUnableToQueryRels, err)
}

var resourceObjectType string
var resourceObjectID string
var resourceRelation string
var subjectObjectType string
var subjectObjectID string
var subjectRelation string
var caveatName sql.NullString
var caveatCtx C
var expiration *time.Time

var integrityKeyID string
var integrityHash []byte
var timestamp time.Time

span.AddEvent("Selecting columns")
colsToSelect, err := ColumnsToSelect(builder, &resourceObjectType, &resourceObjectID, &resourceRelation, &subjectObjectType, &subjectObjectID, &subjectRelation, &caveatName, &caveatCtx, &expiration, &integrityKeyID, &integrityHash, &timestamp)
if err != nil {
return nil, fmt.Errorf(errUnableToQueryRels, err)
}

span.AddEvent("Returning iterator", trace.WithAttributes(attribute.Int("column-count", len(colsToSelect))))
return func(yield func(tuple.Relationship, error) bool) {
span.AddEvent("Issuing query to database")
err := tx.QueryFunc(ctx, func(ctx context.Context, rows R) error {
span.AddEvent("Query issued to database")

var r Rows = rows
if crwe, ok := r.(closeRowsWithError); ok {
defer LogOnError(ctx, crwe.Close)
} else if cr, ok := r.(closeRows); ok {
defer cr.Close()
}

relCount := 0
for rows.Next() {
if relCount == 0 {
span.AddEvent("First row returned")
}

if err := rows.Scan(colsToSelect...); err != nil {
return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("scan err: %w", err))
}

if relCount == 0 {
span.AddEvent("First row scanned")
}

var caveat *corev1.ContextualizedCaveat
if !builder.SkipCaveats || builder.Schema.ColumnOptimization == ColumnOptimizationOptionNone {
if caveatName.Valid {
var err error
caveat, err = ContextualizedCaveatFrom(caveatName.String, caveatCtx)
if err != nil {
return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("unable to fetch caveat context: %w", err))
}
}
}

var integrity *corev1.RelationshipIntegrity
if integrityKeyID != "" {
integrity = &corev1.RelationshipIntegrity{
KeyId: integrityKeyID,
Hash: integrityHash,
HashedAt: timestamppb.New(timestamp),
}
}

if expiration != nil {
// Ensure the expiration is always read in UTC, since some datastores (like CRDB)
// will normalize to local time.
t := expiration.UTC()
expiration = &t
}

relCount++
if !yield(tuple.Relationship{
RelationshipReference: tuple.RelationshipReference{
Resource: tuple.ObjectAndRelation{
ObjectType: resourceObjectType,
ObjectID: resourceObjectID,
Relation: resourceRelation,
},
Subject: tuple.ObjectAndRelation{
ObjectType: subjectObjectType,
ObjectID: subjectObjectID,
Relation: subjectRelation,
},
},
OptionalCaveat: caveat,
OptionalExpiration: expiration,
OptionalIntegrity: integrity,
}, nil) {
return nil
}
}

span.AddEvent("Relationships loaded", trace.WithAttributes(attribute.Int("relCount", relCount)))
if err := rows.Err(); err != nil {
return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("rows err: %w", err))
}

return nil
}, sqlString, args...)
if err != nil {
if !yield(tuple.Relationship{}, err) {
return
}
}
}, nil
}
134 changes: 134 additions & 0 deletions internal/datastore/common/schema.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package common

import (
sq "github.com/Masterminds/squirrel"

"github.com/authzed/spicedb/pkg/spiceerrors"
)

const (
relationshipStandardColumnCount = 6 // ColNamespace, ColObjectID, ColRelation, ColUsersetNamespace, ColUsersetObjectID, ColUsersetRelation
relationshipCaveatColumnCount = 2 // ColCaveatName, ColCaveatContext
relationshipExpirationColumnCount = 1 // ColExpiration
relationshipIntegrityColumnCount = 3 // ColIntegrityKeyID, ColIntegrityHash, ColIntegrityTimestamp
)

// SchemaInformation holds the schema information from the SQL datastore implementation.
//
//go:generate go run github.com/ecordell/optgen -output zz_generated.schema_options.go . SchemaInformation
type SchemaInformation struct {
RelationshipTableName string `debugmap:"visible"`

ColNamespace string `debugmap:"visible"`
ColObjectID string `debugmap:"visible"`
ColRelation string `debugmap:"visible"`
ColUsersetNamespace string `debugmap:"visible"`
ColUsersetObjectID string `debugmap:"visible"`
ColUsersetRelation string `debugmap:"visible"`

ColCaveatName string `debugmap:"visible"`
ColCaveatContext string `debugmap:"visible"`

ColExpiration string `debugmap:"visible"`

ColIntegrityKeyID string `debugmap:"visible"`
ColIntegrityHash string `debugmap:"visible"`
ColIntegrityTimestamp string `debugmap:"visible"`

// PaginationFilterType is the type of pagination filter to use for this schema.
PaginationFilterType PaginationFilterType `debugmap:"visible"`

// PlaceholderFormat is the format of placeholders to use for this schema.
PlaceholderFormat sq.PlaceholderFormat `debugmap:"visible"`

// NowFunction is the function to use to get the current time in the datastore.
NowFunction string `debugmap:"visible"`

// ColumnOptimization is the optimization to use for columns in the schema, if any.
ColumnOptimization ColumnOptimizationOption `debugmap:"visible"`

// IntegrityEnabled is a flag to indicate if the schema has integrity columns.
IntegrityEnabled bool `debugmap:"visible"`

// ExpirationDisabled is a flag to indicate whether expiration support is disabled.
ExpirationDisabled bool `debugmap:"visible"`
}

func (si SchemaInformation) debugValidate() {
spiceerrors.DebugAssert(func() bool {
si.mustValidate()
return true
}, "SchemaInformation failed to validate")
}

func (si SchemaInformation) mustValidate() {
if si.RelationshipTableName == "" {
panic("RelationshipTableName is required")
}

if si.ColNamespace == "" {
panic("ColNamespace is required")
}

if si.ColObjectID == "" {
panic("ColObjectID is required")
}

if si.ColRelation == "" {
panic("ColRelation is required")
}

if si.ColUsersetNamespace == "" {
panic("ColUsersetNamespace is required")
}

if si.ColUsersetObjectID == "" {
panic("ColUsersetObjectID is required")
}

if si.ColUsersetRelation == "" {
panic("ColUsersetRelation is required")
}

if si.ColCaveatName == "" {
panic("ColCaveatName is required")
}

if si.ColCaveatContext == "" {
panic("ColCaveatContext is required")
}

if si.ColExpiration == "" {
panic("ColExpiration is required")
}

if si.IntegrityEnabled {
if si.ColIntegrityKeyID == "" {
panic("ColIntegrityKeyID is required")
}

if si.ColIntegrityHash == "" {
panic("ColIntegrityHash is required")
}

if si.ColIntegrityTimestamp == "" {
panic("ColIntegrityTimestamp is required")
}
}

if si.NowFunction == "" {
panic("NowFunction is required")
}

if si.ColumnOptimization == ColumnOptimizationOptionUnknown {
panic("ColumnOptimization is required")
}

if si.PaginationFilterType == PaginationFilterTypeUnknown {
panic("PaginationFilterType is required")
}

if si.PlaceholderFormat == nil {
panic("PlaceholderFormat is required")
}
}
File renamed without changes.
Loading

0 comments on commit e658772

Please sign in to comment.