Skip to content

Commit

Permalink
Ensure cursored LRv2 calls are dispatched to LRv2
Browse files Browse the repository at this point in the history
  • Loading branch information
josephschorr committed Aug 26, 2024
1 parent d77601b commit 78e8cb6
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 61 deletions.
2 changes: 1 addition & 1 deletion internal/middleware/consistency/consistency_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func TestAddRevisionToContextWithCursor(t *testing.T) {
ds.On("RevisionFromString", optimized.String()).Return(optimized, nil).Once()

// cursor is at `optimized`
cursor, err := cursor.EncodeFromDispatchCursor(&dispatch.Cursor{}, "somehash", optimized)
cursor, err := cursor.EncodeFromDispatchCursor(&dispatch.Cursor{}, "somehash", optimized, nil)
require.NoError(err)

// revision in context is at `exact`
Expand Down
25 changes: 21 additions & 4 deletions internal/services/v1/permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,22 @@ func TranslateExpansionTree(node *core.RelationTupleTreeNode) *v1.PermissionRela
}
}

const lrv2CursorFlag = "lrv2"

func (ps *permissionServer) LookupResources(req *v1.LookupResourcesRequest, resp v1.PermissionsService_LookupResourcesServer) error {
// If the cursor specifies that this is a LookupResources2 request, then that implementation must
// be used.
if req.OptionalCursor != nil {
_, ok, err := cursor.GetCursorFlag(req.OptionalCursor, lrv2CursorFlag)
if err != nil {
return ps.rewriteError(resp.Context(), err)
}

if ok {
return ps.lookupResources2(req, resp)
}
}

if ps.config.UseExperimentalLookupResources2 {
return ps.lookupResources2(req, resp)
}
Expand Down Expand Up @@ -445,7 +460,7 @@ func (ps *permissionServer) lookupResources1(req *v1.LookupResourcesRequest, res
}

if req.OptionalCursor != nil {
decodedCursor, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, lrRequestHash)
decodedCursor, _, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, lrRequestHash)
if err != nil {
return ps.rewriteError(ctx, err)
}
Expand Down Expand Up @@ -476,7 +491,7 @@ func (ps *permissionServer) lookupResources1(req *v1.LookupResourcesRequest, res
alreadyPublishedPermissionedResourceIds[found.ResourceId] = struct{}{}
}

encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash, atRevision)
encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash, atRevision, nil)
if err != nil {
return ps.rewriteError(ctx, err)
}
Expand Down Expand Up @@ -573,7 +588,7 @@ func (ps *permissionServer) lookupResources2(req *v1.LookupResourcesRequest, res
}

if req.OptionalCursor != nil {
decodedCursor, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, lrRequestHash)
decodedCursor, _, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, lrRequestHash)
if err != nil {
return ps.rewriteError(ctx, err)
}
Expand Down Expand Up @@ -605,7 +620,9 @@ func (ps *permissionServer) lookupResources2(req *v1.LookupResourcesRequest, res
alreadyPublishedPermissionedResourceIds[found.ResourceId] = struct{}{}
}

encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash, atRevision)
encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash, atRevision, map[string]string{
lrv2CursorFlag: "1",
})
if err != nil {
return ps.rewriteError(ctx, err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/services/v1/relationships.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func (ps *permissionServer) ReadRelationships(req *v1.ReadRelationshipsRequest,
}

if req.OptionalCursor != nil {
decodedCursor, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, rrRequestHash)
decodedCursor, _, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, rrRequestHash)
if err != nil {
return ps.rewriteError(ctx, err)
}
Expand Down Expand Up @@ -249,7 +249,7 @@ func (ps *permissionServer) ReadRelationships(req *v1.ReadRelationshipsRequest,
}

dispatchCursor.Sections[0] = tuple.StringWithoutCaveat(tpl)
encodedCursor, err := cursor.EncodeFromDispatchCursor(dispatchCursor, rrRequestHash, atRevision)
encodedCursor, err := cursor.EncodeFromDispatchCursor(dispatchCursor, rrRequestHash, atRevision, nil)
if err != nil {
return ps.rewriteError(ctx, err)
}
Expand Down
29 changes: 23 additions & 6 deletions pkg/cursor/cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func Decode(encoded *v1.Cursor) (*impl.DecodedCursor, error) {
// consumption, including the provided call context to ensure the API cursor reflects the calling
// API method. The call hash should contain all the parameters of the calling API function,
// as well as its revision and name.
func EncodeFromDispatchCursor(dispatchCursor *dispatch.Cursor, callAndParameterHash string, revision datastore.Revision) (*v1.Cursor, error) {
func EncodeFromDispatchCursor(dispatchCursor *dispatch.Cursor, callAndParameterHash string, revision datastore.Revision, flags map[string]string) (*v1.Cursor, error) {
if dispatchCursor == nil {
return nil, spiceerrors.MustBugf("got nil dispatch cursor")
}
Expand All @@ -60,34 +60,51 @@ func EncodeFromDispatchCursor(dispatchCursor *dispatch.Cursor, callAndParameterH
DispatchVersion: dispatchCursor.DispatchVersion,
Sections: dispatchCursor.Sections,
CallAndParametersHash: callAndParameterHash,
Flags: flags,
},
},
})
}

// GetCursorFlag retrieves a flag from an encoded API cursor, if any.
func GetCursorFlag(encoded *v1.Cursor, flagName string) (string, bool, error) {
decoded, err := Decode(encoded)
if err != nil {
return "", false, err
}

v1decoded := decoded.GetV1()
if v1decoded == nil {
return "", false, NewInvalidCursorErr(ErrNilCursor)
}

value, ok := v1decoded.Flags[flagName]
return value, ok, nil
}

// DecodeToDispatchCursor decodes an encoded API cursor into an internal dispatching cursor,
// ensuring that the provided call context matches that encoded into the API cursor. The call
// hash should contain all the parameters of the calling API function, as well as its revision
// and name.
func DecodeToDispatchCursor(encoded *v1.Cursor, callAndParameterHash string) (*dispatch.Cursor, error) {
func DecodeToDispatchCursor(encoded *v1.Cursor, callAndParameterHash string) (*dispatch.Cursor, map[string]string, error) {
decoded, err := Decode(encoded)
if err != nil {
return nil, err
return nil, nil, err
}

v1decoded := decoded.GetV1()
if v1decoded == nil {
return nil, NewInvalidCursorErr(ErrNilCursor)
return nil, nil, NewInvalidCursorErr(ErrNilCursor)
}

if v1decoded.CallAndParametersHash != callAndParameterHash {
return nil, NewInvalidCursorErr(ErrHashMismatch)
return nil, nil, NewInvalidCursorErr(ErrHashMismatch)
}

return &dispatch.Cursor{
DispatchVersion: v1decoded.DispatchVersion,
Sections: v1decoded.Sections,
}, nil
}, v1decoded.Flags, nil
}

// DecodeToDispatchRevision decodes an encoded API cursor into an internal dispatch revision.
Expand Down
7 changes: 4 additions & 3 deletions pkg/cursor/cursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ func TestEncodeDecode(t *testing.T) {
require := require.New(t)
encoded, err := EncodeFromDispatchCursor(&dispatch.Cursor{
Sections: tc.sections,
}, tc.hash, tc.revision)
}, tc.hash, tc.revision, map[string]string{"some": "flag"})
require.NoError(err)
require.NotNil(encoded)

decoded, err := DecodeToDispatchCursor(encoded, tc.hash)
decoded, flags, err := DecodeToDispatchCursor(encoded, tc.hash)
require.NoError(err)
require.NotNil(decoded)
require.Equal(map[string]string{"some": "flag"}, flags)

require.Equal(tc.sections, decoded.Sections)

Expand Down Expand Up @@ -123,7 +124,7 @@ func TestDecode(t *testing.T) {
t.Run(testName, func(t *testing.T) {
require := require.New(t)

decoded, err := DecodeToDispatchCursor(&v1.Cursor{
decoded, _, err := DecodeToDispatchCursor(&v1.Cursor{
Token: testCase.token,
}, testCase.expectedHash)

Expand Down
108 changes: 63 additions & 45 deletions pkg/proto/impl/v1/impl.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pkg/proto/impl/v1/impl.pb.validate.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 78e8cb6

Please sign in to comment.