Skip to content

Commit

Permalink
Implement circular import detection + import dedup
Browse files Browse the repository at this point in the history
  • Loading branch information
tstirrat15 committed Nov 7, 2024
1 parent ef6e183 commit 56a5c24
Show file tree
Hide file tree
Showing 13 changed files with 161 additions and 30 deletions.
38 changes: 29 additions & 9 deletions pkg/composableschemadsl/compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,31 @@ type Option func(*config)

type ObjectPrefixOption func(*config)

type compilationContext struct {
// The set of definition names that we've seen as we compile.
// If these collide we throw an error.
existingNames *mapz.Set[string]
// The global set of files we've visited in the import process.
// If these collide we short circuit, preventing duplicate imports.
globallyVisitedFiles *mapz.Set[string]
// The set of files that we've visited on a particular leg of the recursion.
// This allows for detection of circular imports.
// NOTE: This depends on an assumption that a depth-first search will always
// find a cycle, even if we're otherwise marking globally visited nodes.
locallyVisitedFiles *mapz.Set[string]
}

// Compile compilers the input schema into a set of namespace definition protos.
func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) {
names := mapz.NewSet[string]()
return compileImpl(schema, names, prefix, opts...)
cctx := compilationContext{
existingNames: mapz.NewSet[string](),
globallyVisitedFiles: mapz.NewSet[string](),
locallyVisitedFiles: mapz.NewSet[string](),
}
return compileImpl(schema, cctx, prefix, opts...)
}

func compileImpl(schema InputSchema, existingNames *mapz.Set[string], prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) {
func compileImpl(schema InputSchema, cctx compilationContext, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) {
cfg := &config{}
prefix(cfg) // required option

Expand All @@ -109,12 +127,14 @@ func compileImpl(schema InputSchema, existingNames *mapz.Set[string], prefix Obj
}

compiled, err := translate(translationContext{
objectTypePrefix: cfg.objectTypePrefix,
mapper: mapper,
schemaString: schema.SchemaString,
skipValidate: cfg.skipValidation,
sourceFolder: cfg.sourceFolder,
existingNames: existingNames,
objectTypePrefix: cfg.objectTypePrefix,
mapper: mapper,
schemaString: schema.SchemaString,
skipValidate: cfg.skipValidation,
sourceFolder: cfg.sourceFolder,
existingNames: cctx.existingNames,
locallyVisitedFiles: cctx.locallyVisitedFiles,
globallyVisitedFiles: cctx.globallyVisitedFiles,
}, root)
if err != nil {
var errorWithNode errorWithNode
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
definition user {}

definition persona {}

definition resource {
relation viewer: user
permission view = viewer
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .subjects import user

definition resource {
relation viewer: user
permission view = viewer
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .user import user

definition persona {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .subjects import persona

definition user {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
definition user {}

definition persona {}

definition resource {
relation viewer: user
permission view = viewer
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .subjects import user
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .subjects import persona
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .left import user
from .right import persona

definition resource {
relation viewer: user
permission view = viewer
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
definition user {}
definition persona {}
57 changes: 47 additions & 10 deletions pkg/composableschemadsl/compiler/importer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,75 @@ import (
)

type importContext struct {
pathSegments []string
sourceFolder string
names *mapz.Set[string]
pathSegments []string
sourceFolder string
names *mapz.Set[string]
locallyVisitedFiles *mapz.Set[string]
globallyVisitedFiles *mapz.Set[string]
}

const SchemaFileSuffix = ".zed"

type ErrCircularImport struct {
error
filePath string
}

func importFile(importContext importContext) (*CompiledSchema, error) {
relativeFilepath := constructFilePath(importContext.pathSegments)
filePath := path.Join(importContext.sourceFolder, relativeFilepath)

newSourceFolder := filepath.Dir(filePath)

var schemaBytes []byte
currentLocallyVisitedFiles := importContext.locallyVisitedFiles.Copy()

if ok := currentLocallyVisitedFiles.Add(filePath); !ok {
// If we've already visited the file on this particular branch walk, it's
// a circular import issue.
return nil, &ErrCircularImport{
error: fmt.Errorf("circular import detected: %s has been visited on this branch", filePath),
filePath: filePath,
}
}

if ok := importContext.globallyVisitedFiles.Add(filePath); !ok {
// If the file has already been visited, we short-circuit the import process
// by not reading the schema file in and compiling a schema with an empty string.
// This prevents duplicate definitions from ending up in the output, as well
// as preventing circular imports.
log.Debug().Str("filepath", filePath).Msg("file %s has already been visited in another part of the walk")
return compileImpl(InputSchema{
Source: input.Source(filePath),
SchemaString: "",
},
compilationContext{
existingNames: importContext.names,
locallyVisitedFiles: currentLocallyVisitedFiles,
globallyVisitedFiles: importContext.globallyVisitedFiles,
},
AllowUnprefixedObjectType(),
SourceFolder(newSourceFolder),
)
}

schemaBytes, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read schema file: %w", err)
}
log.Trace().Str("schema", string(schemaBytes)).Str("file", filePath).Msg("read schema from file")

compiled, err := compileImpl(InputSchema{
return compileImpl(InputSchema{
Source: input.Source(filePath),
SchemaString: string(schemaBytes),
},
importContext.names,
compilationContext{
existingNames: importContext.names,
locallyVisitedFiles: currentLocallyVisitedFiles,
globallyVisitedFiles: importContext.globallyVisitedFiles,
},
AllowUnprefixedObjectType(),
SourceFolder(newSourceFolder),
)
if err != nil {
return nil, err
}
return compiled, nil
}

func constructFilePath(segments []string) string {
Expand Down
21 changes: 21 additions & 0 deletions pkg/composableschemadsl/compiler/importer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func TestImporter(t *testing.T) {
{"nested local import", "nested-local"},
{"nested local import with transitive hop", "nested-local-with-hop"},
{"nested local two layers deep import", "nested-two-layer-local"},
{"diamond-shaped imports are fine", "diamond-shaped"},
}

for _, test := range importerTests {
Expand Down Expand Up @@ -89,3 +90,23 @@ func TestImporter(t *testing.T) {
})
}
}

func TestImportCycleCausesError(t *testing.T) {
t.Parallel()

workingDir, err := os.Getwd()
require.NoError(t, err)
test := importerTest{"", "circular-import"}

sourceFolder := path.Join(workingDir, test.relativePath())

inputSchema := test.input()

_, err = compiler.Compile(compiler.InputSchema{
Source: input.Source("schema"),
SchemaString: inputSchema,
}, compiler.AllowUnprefixedObjectType(),
compiler.SourceFolder(sourceFolder))

require.ErrorContains(t, err, "circular import")
}
36 changes: 25 additions & 11 deletions pkg/composableschemadsl/compiler/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package compiler

import (
"bufio"
"errors"
"fmt"
"strings"

Expand All @@ -19,12 +20,14 @@ import (
)

type translationContext struct {
objectTypePrefix *string
mapper input.PositionMapper
schemaString string
skipValidate bool
existingNames *mapz.Set[string]
sourceFolder string
objectTypePrefix *string
mapper input.PositionMapper
schemaString string
skipValidate bool
existingNames *mapz.Set[string]
locallyVisitedFiles *mapz.Set[string]
globallyVisitedFiles *mapz.Set[string]
sourceFolder string
}

func (tctx translationContext) prefixedPath(definitionName string) (string, error) {
Expand Down Expand Up @@ -696,7 +699,6 @@ func addWithCaveats(tctx translationContext, typeRefNode *dslNode, ref *core.All
func translateImport(tctx translationContext, importNode *dslNode, names *mapz.Set[string]) (*CompiledSchema, error) {
// NOTE: this function currently just grabs everything that's in the target file.
// TODO: only grab the requested definitions
// TODO: import cycle tracking
pathNodes := importNode.List(dslshape.NodeImportPredicatePathSegment)
pathSegments := make([]string, 0, len(pathNodes))

Expand All @@ -709,9 +711,21 @@ func translateImport(tctx translationContext, importNode *dslNode, names *mapz.S
pathSegments = append(pathSegments, segment)
}

return importFile(importContext{
names: names,
pathSegments: pathSegments,
sourceFolder: tctx.sourceFolder,
compiledSchema, err := importFile(importContext{
names: names,
pathSegments: pathSegments,
sourceFolder: tctx.sourceFolder,
globallyVisitedFiles: tctx.globallyVisitedFiles,
locallyVisitedFiles: tctx.locallyVisitedFiles,
})
if err != nil {
var circularImportError *ErrCircularImport
if errors.As(err, &circularImportError) {
// NOTE: The "%s" is an empty format string to keep with the form of ErrorWithSourcef
return nil, importNode.ErrorWithSourcef(circularImportError.filePath, "%s", circularImportError.error.Error())
}
return nil, err
}

return compiledSchema, nil
}

0 comments on commit 56a5c24

Please sign in to comment.