Skip to content

Commit

Permalink
Merge pull request #2147 from josephschorr/bulk-import-caveat-fix
Browse files Browse the repository at this point in the history
Ensure caveats are read in bulk import
  • Loading branch information
josephschorr authored Nov 27, 2024
2 parents 245c85a + 50f8562 commit bc1136f
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 95 deletions.
8 changes: 8 additions & 0 deletions internal/services/v1/experimental.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,14 @@ func (a *bulkLoadAdapter) Next(_ context.Context) (*tuple.Relationship, error) {
return nil, nil
}

if a.currentBatch[a.numSent].OptionalCaveat != nil {
a.caveat.CaveatName = a.currentBatch[a.numSent].OptionalCaveat.CaveatName
a.caveat.Context = a.currentBatch[a.numSent].OptionalCaveat.Context
a.current.OptionalCaveat = &a.caveat
} else {
a.current.OptionalCaveat = nil
}

if a.caveat.CaveatName != "" {
a.current.OptionalCaveat = &a.caveat
} else {
Expand Down
112 changes: 64 additions & 48 deletions internal/services/v1/experimental_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,63 +53,79 @@ func TestBulkImportRelationships(t *testing.T) {
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
require := require.New(t)
for _, withCaveats := range []bool{true, false} {
withCaveats := withCaveats
t.Run(fmt.Sprintf("withCaveats=%t", withCaveats), func(t *testing.T) {
require := require.New(t)

conn, cleanup, _, _ := testserver.NewTestServer(require, 0, memdb.DisableGC, true, tf.StandardDatastoreWithSchema)
client := v1.NewExperimentalServiceClient(conn)
t.Cleanup(cleanup)
conn, cleanup, _, _ := testserver.NewTestServer(require, 0, memdb.DisableGC, true, tf.StandardDatastoreWithSchema)
client := v1.NewExperimentalServiceClient(conn)
t.Cleanup(cleanup)

ctx := context.Background()
ctx := context.Background()

writer, err := client.BulkImportRelationships(ctx)
require.NoError(err)

var expectedTotal uint64
for batchNum := 0; batchNum < tc.numBatches; batchNum++ {
batchSize := tc.batchSize()
batch := make([]*v1.Relationship, 0, batchSize)

for i := uint64(0); i < batchSize; i++ {
batch = append(batch, rel(
tf.DocumentNS.Name,
strconv.Itoa(batchNum)+"_"+strconv.FormatUint(i, 10),
"viewer",
tf.UserNS.Name,
strconv.FormatUint(i, 10),
"",
))
}
writer, err := client.BulkImportRelationships(ctx)
require.NoError(err)

err := writer.Send(&v1.BulkImportRelationshipsRequest{
Relationships: batch,
})
require.NoError(err)
var expectedTotal uint64
for batchNum := 0; batchNum < tc.numBatches; batchNum++ {
batchSize := tc.batchSize()
batch := make([]*v1.Relationship, 0, batchSize)

for i := uint64(0); i < batchSize; i++ {
if withCaveats {
batch = append(batch, relWithCaveat(
tf.DocumentNS.Name,
strconv.Itoa(batchNum)+"_"+strconv.FormatUint(i, 10),
"caveated_viewer",
tf.UserNS.Name,
strconv.FormatUint(i, 10),
"",
"test",
))
} else {
batch = append(batch, rel(
tf.DocumentNS.Name,
strconv.Itoa(batchNum)+"_"+strconv.FormatUint(i, 10),
"viewer",
tf.UserNS.Name,
strconv.FormatUint(i, 10),
"",
))
}
}

err := writer.Send(&v1.BulkImportRelationshipsRequest{
Relationships: batch,
})
require.NoError(err)

expectedTotal += batchSize
}
expectedTotal += batchSize
}

resp, err := writer.CloseAndRecv()
require.NoError(err)
require.Equal(expectedTotal, resp.NumLoaded)
resp, err := writer.CloseAndRecv()
require.NoError(err)
require.Equal(expectedTotal, resp.NumLoaded)

readerClient := v1.NewPermissionsServiceClient(conn)
stream, err := readerClient.ReadRelationships(ctx, &v1.ReadRelationshipsRequest{
RelationshipFilter: &v1.RelationshipFilter{
ResourceType: tf.DocumentNS.Name,
},
Consistency: &v1.Consistency{
Requirement: &v1.Consistency_FullyConsistent{FullyConsistent: true},
},
})
require.NoError(err)
readerClient := v1.NewPermissionsServiceClient(conn)
stream, err := readerClient.ReadRelationships(ctx, &v1.ReadRelationshipsRequest{
RelationshipFilter: &v1.RelationshipFilter{
ResourceType: tf.DocumentNS.Name,
},
Consistency: &v1.Consistency{
Requirement: &v1.Consistency_FullyConsistent{FullyConsistent: true},
},
})
require.NoError(err)

var readBack uint64
for _, err = stream.Recv(); err == nil; _, err = stream.Recv() {
readBack++
var readBack uint64
for _, err = stream.Recv(); err == nil; _, err = stream.Recv() {
readBack++
}
require.ErrorIs(err, io.EOF)
require.Equal(expectedTotal, readBack)
})
}
require.ErrorIs(err, io.EOF)
require.Equal(expectedTotal, readBack)
})
}
}
Expand Down
4 changes: 3 additions & 1 deletion internal/services/v1/permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,9 @@ func (a *loadBulkAdapter) Next(_ context.Context) (*tuple.Relationship, error) {
return nil, nil
}

if a.caveat.CaveatName != "" {
if a.currentBatch[a.numSent].OptionalCaveat != nil {
a.caveat.CaveatName = a.currentBatch[a.numSent].OptionalCaveat.CaveatName
a.caveat.Context = a.currentBatch[a.numSent].OptionalCaveat.Context
a.current.OptionalCaveat = &a.caveat
} else {
a.current.OptionalCaveat = nil
Expand Down
120 changes: 74 additions & 46 deletions internal/services/v1/permissions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2066,62 +2066,90 @@ func TestImportBulkRelationships(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
require := require.New(t)
for _, withCaveats := range []bool{true, false} {
withCaveats := withCaveats
t.Run(fmt.Sprintf("withCaveats=%t", withCaveats), func(t *testing.T) {
require := require.New(t)

conn, cleanup, _, _ := testserver.NewTestServer(require, 0, memdb.DisableGC, true, tf.StandardDatastoreWithSchema)
client := v1.NewPermissionsServiceClient(conn)
t.Cleanup(cleanup)
conn, cleanup, _, _ := testserver.NewTestServer(require, 0, memdb.DisableGC, true, tf.StandardDatastoreWithSchema)
client := v1.NewPermissionsServiceClient(conn)
t.Cleanup(cleanup)

ctx := context.Background()
ctx := context.Background()

writer, err := client.ImportBulkRelationships(ctx)
require.NoError(err)
writer, err := client.ImportBulkRelationships(ctx)
require.NoError(err)

var expectedTotal uint64
for batchNum := 0; batchNum < tc.numBatches; batchNum++ {
batchSize := tc.batchSize()
batch := make([]*v1.Relationship, 0, batchSize)

for i := uint64(0); i < batchSize; i++ {
batch = append(batch, rel(
tf.DocumentNS.Name,
strconv.Itoa(batchNum)+"_"+strconv.FormatUint(i, 10),
"viewer",
tf.UserNS.Name,
strconv.FormatUint(i, 10),
"",
))
}
var expectedTotal uint64
for batchNum := 0; batchNum < tc.numBatches; batchNum++ {
batchSize := tc.batchSize()
batch := make([]*v1.Relationship, 0, batchSize)

for i := uint64(0); i < batchSize; i++ {
if withCaveats {
batch = append(batch, relWithCaveat(
tf.DocumentNS.Name,
strconv.Itoa(batchNum)+"_"+strconv.FormatUint(i, 10),
"caveated_viewer",
tf.UserNS.Name,
strconv.FormatUint(i, 10),
"",
"test",
))
} else {
batch = append(batch, rel(
tf.DocumentNS.Name,
strconv.Itoa(batchNum)+"_"+strconv.FormatUint(i, 10),
"viewer",
tf.UserNS.Name,
strconv.FormatUint(i, 10),
"",
))
}
}

err := writer.Send(&v1.ImportBulkRelationshipsRequest{
Relationships: batch,
})
require.NoError(err)
err := writer.Send(&v1.ImportBulkRelationshipsRequest{
Relationships: batch,
})
require.NoError(err)

expectedTotal += batchSize
}
expectedTotal += batchSize
}

resp, err := writer.CloseAndRecv()
require.NoError(err)
require.Equal(expectedTotal, resp.NumLoaded)
resp, err := writer.CloseAndRecv()
require.NoError(err)
require.Equal(expectedTotal, resp.NumLoaded)

readerClient := v1.NewPermissionsServiceClient(conn)
stream, err := readerClient.ReadRelationships(ctx, &v1.ReadRelationshipsRequest{
RelationshipFilter: &v1.RelationshipFilter{
ResourceType: tf.DocumentNS.Name,
},
Consistency: &v1.Consistency{
Requirement: &v1.Consistency_FullyConsistent{FullyConsistent: true},
},
})
require.NoError(err)
readerClient := v1.NewPermissionsServiceClient(conn)
stream, err := readerClient.ReadRelationships(ctx, &v1.ReadRelationshipsRequest{
RelationshipFilter: &v1.RelationshipFilter{
ResourceType: tf.DocumentNS.Name,
},
Consistency: &v1.Consistency{
Requirement: &v1.Consistency_FullyConsistent{FullyConsistent: true},
},
})
require.NoError(err)

var readBack uint64
for _, err = stream.Recv(); err == nil; _, err = stream.Recv() {
readBack++
var readBack uint64
var res *v1.ReadRelationshipsResponse
for _, err = stream.Recv(); err == nil; res, err = stream.Recv() {
readBack++
if res == nil {
continue
}

if withCaveats {
require.NotNil(res.Relationship.OptionalCaveat)
require.Equal("test", res.Relationship.OptionalCaveat.CaveatName)
} else {
require.Nil(res.Relationship.OptionalCaveat)
}
}
require.ErrorIs(err, io.EOF)
require.Equal(expectedTotal, readBack)
})
}
require.ErrorIs(err, io.EOF)
require.Equal(expectedTotal, readBack)
})
}
}
Expand Down

0 comments on commit bc1136f

Please sign in to comment.