diff --git a/bundle/regal/ast/rule_heads.rego b/bundle/regal/ast/rule_heads.rego new file mode 100644 index 00000000..29c08097 --- /dev/null +++ b/bundle/regal/ast/rule_heads.rego @@ -0,0 +1,22 @@ +package regal.ast + +import rego.v1 + +# METADATA +# description: | +# For a given rule head name, this rule contains a list of locations where +# there is a rule head with that name. +rule_head_locations[name] contains info if { + some rule in input.rules + + name := concat(".", [ + "data", + package_name, + ref_static_to_string(rule.head.ref), + ]) + + info := { + "row": rule.head.location.row, + "col": rule.head.location.col, + } +} diff --git a/bundle/regal/ast/rule_heads_test.rego b/bundle/regal/ast/rule_heads_test.rego new file mode 100644 index 00000000..baf1223c --- /dev/null +++ b/bundle/regal/ast/rule_heads_test.rego @@ -0,0 +1,35 @@ +package regal.ast_test + +import rego.v1 + +import data.regal.ast + +test_rule_head_locations if { + policy := `package policy + +import rego.v1 + +default allow := false + +allow if true + +reasons contains "foo" +reasons contains "bar" + +default my_func(_) := false +my_func(1) := true + +ref_rule[foo] := true if { + some foo in [1,2,3] +} +` + + result := ast.rule_head_locations with input as regal.parse_module("p.rego", policy) + + result == { + "data.policy.allow": {{"col": 9, "row": 5}, {"col": 1, "row": 7}}, + "data.policy.reasons": {{"col": 1, "row": 9}, {"col": 1, "row": 10}}, + "data.policy.my_func": {{"col": 9, "row": 12}, {"col": 1, "row": 13}}, + "data.policy.ref_rule": {{"col": 1, "row": 15}}, + } +} diff --git a/internal/lsp/rego/rego.go b/internal/lsp/rego/rego.go index e87caf4b..92df92be 100644 --- a/internal/lsp/rego/rego.go +++ b/internal/lsp/rego/rego.go @@ -30,6 +30,8 @@ type KeywordUse struct { Location KeywordUseLocation `json:"location"` } +type RuleHeads map[string][]*ast.Location + type KeywordUseLocation struct { Row uint `json:"row"` Col uint `json:"col"` @@ -94,39 +96,54 @@ func AllBuiltinCalls(module *ast.Module) []BuiltInCall { } //nolint:gochecknoglobals -var keywordPreparedQuery *rego.PreparedEvalQuery +var keywordsPreparedQuery *rego.PreparedEvalQuery + +//nolint:gochecknoglobals +var ruleHeadLocationsPreparedQuery *rego.PreparedEvalQuery //nolint:gochecknoglobals -var keywordPreparedQueryInitOnce sync.Once +var preparedQueriesInitOnce sync.Once func initialize() { regalRules := rio.MustLoadRegalBundleFS(rbundle.Bundle) - regoArgs := []func(*rego.Rego){ - rego.ParsedBundle("regal", ®alRules), - rego.Query("data.regal.ast.keywords"), - rego.Function2(builtins.RegalParseModuleMeta, builtins.RegalParseModule), - rego.Function1(builtins.RegalLastMeta, builtins.RegalLast), + createArgs := func(args ...func(*rego.Rego)) []func(*rego.Rego) { + return append([]func(*rego.Rego){ + rego.ParsedBundle("regal", ®alRules), + rego.Function2(builtins.RegalParseModuleMeta, builtins.RegalParseModule), + rego.Function1(builtins.RegalLastMeta, builtins.RegalLast), + }, args...) + } + + keywordRegoArgs := createArgs(rego.Query("data.regal.ast.keywords")) + + kwpq, err := rego.New(keywordRegoArgs...).PrepareForEval(context.Background()) + if err != nil { + panic(err) } - preparedQuery, err := rego.New(regoArgs...).PrepareForEval(context.Background()) + keywordsPreparedQuery = &kwpq + + ruleHeadLocationsRegoArgs := createArgs(rego.Query("data.regal.ast.rule_head_locations")) + + rhlpq, err := rego.New(ruleHeadLocationsRegoArgs...).PrepareForEval(context.Background()) if err != nil { panic(err) } - keywordPreparedQuery = &preparedQuery + ruleHeadLocationsPreparedQuery = &rhlpq } // AllKeywords returns all keywords in the module. func AllKeywords(ctx context.Context, fileName, contents string, module *ast.Module) (map[string][]KeywordUse, error) { - keywordPreparedQueryInitOnce.Do(initialize) + preparedQueriesInitOnce.Do(initialize) enhancedInput, err := parse.PrepareAST(fileName, contents, module) if err != nil { return nil, fmt.Errorf("failed enhancing input: %w", err) } - rs, err := keywordPreparedQuery.Eval(ctx, rego.EvalInput(enhancedInput)) + rs, err := keywordsPreparedQuery.Eval(ctx, rego.EvalInput(enhancedInput)) if err != nil { return nil, fmt.Errorf("failed evaluating keywords: %w", err) } @@ -149,6 +166,42 @@ func AllKeywords(ctx context.Context, fileName, contents string, module *ast.Mod return result, nil } +// AllRuleHeadLocations returns mapping of rules names to the head locations. +func AllRuleHeadLocations(ctx context.Context, fileName, contents string, module *ast.Module) (RuleHeads, error) { + preparedQueriesInitOnce.Do(initialize) + + enhancedInput, err := parse.PrepareAST(fileName, contents, module) + if err != nil { + return nil, fmt.Errorf("failed enhancing input: %w", err) + } + + rs, err := ruleHeadLocationsPreparedQuery.Eval(ctx, rego.EvalInput(enhancedInput)) + if err != nil { + return nil, fmt.Errorf("failed evaluating keywords: %w", err) + } + + if len(rs) == 0 { + return nil, errors.New("no results returned from evaluation") + } + + if len(rs) != 1 { + return nil, errors.New("expected exactly one result from evaluation") + } + + if len(rs[0].Expressions) != 1 { + return nil, errors.New("expected exactly one expression in result") + } + + var result RuleHeads + + err = rio.JSONRoundTrip(rs[0].Expressions[0].Value, &result) + if err != nil { + return nil, fmt.Errorf("failed unmarshaling keywords: %w", err) + } + + return result, nil +} + // ToInput prepares a module with Regal additions to be used as input for evaluation. func ToInput( fileURI string, diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 9e4e056f..9edd939e 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -32,6 +32,7 @@ import ( "github.com/styrainc/regal/internal/lsp/examples" "github.com/styrainc/regal/internal/lsp/hover" "github.com/styrainc/regal/internal/lsp/opa/oracle" + "github.com/styrainc/regal/internal/lsp/rego" "github.com/styrainc/regal/internal/lsp/types" "github.com/styrainc/regal/internal/lsp/uri" rparse "github.com/styrainc/regal/internal/parse" @@ -478,6 +479,30 @@ func (l *LanguageServer) StartCommandWorker(ctx context.Context) { break } + currentModule, ok := l.cache.GetModule(file) + if !ok { + l.logError(fmt.Errorf("failed to get module for file %q", file)) + + break + } + + currentContents, ok := l.cache.GetFileContents(file) + if !ok { + l.logError(fmt.Errorf("failed to get contents for file %q", file)) + + break + } + + allRuleHeadLocations, err := rego.AllRuleHeadLocations(ctx, filepath.Base(file), currentContents, currentModule) + if err != nil { + l.logError(fmt.Errorf("failed to get rule head locations: %w", err)) + + break + } + + // if there are none, then it's a package evaluation + ruleHeadLocations := allRuleHeadLocations[path] + workspacePath := uri.ToPath(l.clientIdentifier, l.workspaceRootURI) input := FindInput(uri.ToPath(l.clientIdentifier, file), workspacePath) @@ -498,10 +523,20 @@ func (l *LanguageServer) StartCommandWorker(ctx context.Context) { break } + target := "package" + if len(ruleHeadLocations) > 0 { + target = strings.TrimPrefix(path, currentModule.Package.Path.String()+".") + } + if l.clientIdentifier == clients.IdentifierVSCode { responseParams := map[string]any{ "result": result, "line": line, + "target": target, + // only used when the target is 'package' + "package": strings.TrimPrefix(currentModule.Package.Path.String(), "data."), + // only used when the target is a rule + "rule_head_locations": ruleHeadLocations, } responseResult := map[string]any{}