From 5836aea17e55c73bf13622b8b29f8cc0ecd215f6 Mon Sep 17 00:00:00 2001 From: Anders Eknert Date: Mon, 20 Jan 2025 19:37:19 +0100 Subject: [PATCH] Add util.Keys and util.KeysSorted And use them to reduce imperative boilerplate throughout the codebase. Additionally, replace use of sort.Slice with slices.SortFunc which is more efficient since it is generic and as such avoids allocations related to `interface{}` casts. Also a few performance-related minor fixes, but not the main theme of this PR. ``` BenchmarkRegalLintingItself-10 before / after 1832684458 ns/op 3453470360 B/op 66125422 allocs/op 1826601250 ns/op 3449619024 B/op 65999164 allocs/op ```` Signed-off-by: Anders Eknert --- cmd/bench.go | 8 +- cmd/inspect.go | 10 +- internal/pathwatcher/utils.go | 10 +- internal/providers/aws/signing_v4.go | 17 +-- v1/ast/compile.go | 169 ++++++++++++--------------- v1/ast/compile_test.go | 71 +++++------ v1/ast/varset.go | 15 +-- v1/bundle/hash.go | 11 +- v1/cover/cover.go | 17 +-- v1/loader/loader.go | 8 +- v1/server/server_test.go | 57 ++++----- v1/tester/runner.go | 10 +- v1/util/compare.go | 13 +-- v1/util/maps.go | 24 ++++ 14 files changed, 186 insertions(+), 254 deletions(-) diff --git a/cmd/bench.go b/cmd/bench.go index 2ec1e9a125..108377892d 100644 --- a/cmd/bench.go +++ b/cmd/bench.go @@ -13,7 +13,6 @@ import ( "math" "net/http" "os" - "sort" "strconv" "strings" "testing" @@ -608,12 +607,7 @@ func renderBenchmarkResult(params benchmarkCommandParams, br testing.BenchmarkRe }) } - var keys []string - for k := range br.Extra { - keys = append(keys, k) - } - sort.Strings(keys) - for _, k := range keys { + for _, k := range util.KeysSorted(br.Extra) { data = append(data, []string{k, prettyFormatFloat(br.Extra[k])}) } diff --git a/cmd/inspect.go b/cmd/inspect.go index cd5125dde3..96e5a80f89 100644 --- a/cmd/inspect.go +++ b/cmd/inspect.go @@ -81,7 +81,7 @@ Example: bundle.tar.gz $ opa inspect bundle.tar.gz -You can provide exactly one OPA bundle, path to a bundle directory, or direct path to a Rego file to the 'inspect' command +You can provide exactly one OPA bundle, path to a bundle directory, or direct path to a Rego file to the 'inspect' command on the command-line. If you provide a path referring to a directory, the 'inspect' command will load that path as a bundle and summarize its structure and contents. If you provide a path referring to a Rego file, the 'inspect' command will load that file and summarize its structure and contents. @@ -210,13 +210,7 @@ func populateNamespaces(out io.Writer, n map[string][]string) error { t.SetAutoMergeCellsByColumnIndex([]int{0}) var lines [][]string - keys := make([]string, 0, len(n)) - for k := range n { - keys = append(keys, k) - } - sort.Strings(keys) - - for _, k := range keys { + for _, k := range util.KeysSorted(n) { for _, file := range n[k] { lines = append(lines, []string{k, truncateFileName(file)}) } diff --git a/internal/pathwatcher/utils.go b/internal/pathwatcher/utils.go index ee7fb794cf..2a0e5d6ef0 100644 --- a/internal/pathwatcher/utils.go +++ b/internal/pathwatcher/utils.go @@ -9,13 +9,13 @@ import ( "context" "os" "path/filepath" - "sort" "github.com/fsnotify/fsnotify" initload "github.com/open-policy-agent/opa/internal/runtime/init" "github.com/open-policy-agent/opa/v1/ast" "github.com/open-policy-agent/opa/v1/loader" "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/util" ) // CreatePathWatcher creates watchers to monitor for path changes @@ -119,13 +119,7 @@ func getWatchPaths(rootPaths []string) ([]string, error) { } } - u := make([]string, 0, len(unique)) - for k := range unique { - u = append(u, k) - } - sort.Strings(u) - - paths = append(paths, u...) + paths = append(paths, util.KeysSorted(unique)...) } return paths, nil diff --git a/internal/providers/aws/signing_v4.go b/internal/providers/aws/signing_v4.go index 3c152831b8..1e50d01f92 100644 --- a/internal/providers/aws/signing_v4.go +++ b/internal/providers/aws/signing_v4.go @@ -13,13 +13,13 @@ import ( "io" "net/http" "net/url" - "sort" "strings" "time" v4 "github.com/open-policy-agent/opa/internal/providers/aws/v4" "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/util" ) func stringFromTerm(t *ast.Term) string { @@ -67,19 +67,6 @@ func sha256MAC(message string, key []byte) []byte { return mac.Sum(nil) } -func sortKeys(strMap map[string][]string) []string { - keys := make([]string, len(strMap)) - - i := 0 - for k := range strMap { - keys[i] = k - i++ - } - sort.Strings(keys) - - return keys -} - // SignRequest modifies an http.Request to include an AWS V4 signature based on the provided credentials. func SignRequest(req *http.Request, service string, creds Credentials, theTime time.Time, sigVersion string) error { // General ref. https://docs.aws.amazon.com/general/latest/gr/sigv4_signing.html @@ -168,7 +155,7 @@ func SignV4(headers map[string][]string, method string, theURL *url.URL, body [] canonicalReq += theURL.RawQuery + "\n" // RAW Query String // include the values for the signed headers - orderedKeys := sortKeys(headersToSign) + orderedKeys := util.KeysSorted(headersToSign) for _, k := range orderedKeys { canonicalReq += k + ":" + strings.Join(headersToSign[k], ",") + "\n" } diff --git a/v1/ast/compile.go b/v1/ast/compile.go index 76b3c51bda..9b0302474e 100644 --- a/v1/ast/compile.go +++ b/v1/ast/compile.go @@ -8,6 +8,8 @@ import ( "errors" "fmt" "io" + "maps" + "slices" "sort" "strconv" "strings" @@ -438,24 +440,21 @@ func (c *Compiler) WithDebug(sink io.Writer) *Compiler { return c } -// WithBuiltins is deprecated. Use WithCapabilities instead. +// WithBuiltins is deprecated. +// Deprecated: Use WithCapabilities instead. func (c *Compiler) WithBuiltins(builtins map[string]*Builtin) *Compiler { - c.customBuiltins = make(map[string]*Builtin) - for k, v := range builtins { - c.customBuiltins[k] = v - } + c.customBuiltins = maps.Clone(builtins) return c } -// WithUnsafeBuiltins is deprecated. Use WithCapabilities instead. +// WithUnsafeBuiltins is deprecated. +// Deprecated: Use WithCapabilities instead. func (c *Compiler) WithUnsafeBuiltins(unsafeBuiltins map[string]struct{}) *Compiler { - for name := range unsafeBuiltins { - c.unsafeBuiltinsMap[name] = struct{}{} - } + maps.Copy(c.unsafeBuiltinsMap, unsafeBuiltins) return c } -// WithStrict enables strict mode in the compiler. +// WithStrict toggles strict mode in the compiler. func (c *Compiler) WithStrict(strict bool) *Compiler { c.strict = strict return c @@ -560,7 +559,7 @@ func (c *Compiler) ComprehensionIndex(term *Term) *ComprehensionIndex { // otherwise, the ref is used to perform a ruleset lookup. func (c *Compiler) GetArity(ref Ref) int { if bi := c.builtins[ref.String()]; bi != nil { - return len(bi.Decl.FuncArgs().Args) + return bi.Decl.Arity() } rules := c.GetRulesExact(ref) if len(rules) == 0 { @@ -668,7 +667,7 @@ func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) { return rules } -func extractRules(s []util.T) []*Rule { +func extractRules(s []any) []*Rule { rules := make([]*Rule, len(s)) for i := range s { rules[i] = s[i].(*Rule) @@ -811,7 +810,7 @@ func (c *Compiler) GetRulesDynamicWithOpts(ref Ref, opts RulesOptions) []*Rule { } // Utility: add all rule values to the set. -func insertRules(set map[*Rule]struct{}, rules []util.T) { +func insertRules(set map[*Rule]struct{}, rules []any) { for _, rule := range rules { set[rule.(*Rule)] = struct{}{} } @@ -972,6 +971,13 @@ func (c *Compiler) buildComprehensionIndices() { } } +var ( + keywordsTerm = StringTerm("keywords") + pathTerm = StringTerm("path") + annotationsTerm = StringTerm("annotations") + futureKeywordsPrefix = Ref{FutureRootDocument, keywordsTerm} +) + // buildRequiredCapabilities updates the required capabilities on the compiler // to include any keyword and feature dependencies present in the modules. The // built-in function dependencies will have already been added by the type @@ -983,7 +989,7 @@ func (c *Compiler) buildRequiredCapabilities() { // extract required keywords from modules keywords := map[string]struct{}{} - futureKeywordsPrefix := Ref{FutureRootDocument, StringTerm("keywords")} + for _, name := range c.sorted { for _, imp := range c.imports[name] { mod := c.Modules[name] @@ -1026,7 +1032,7 @@ func (c *Compiler) buildRequiredCapabilities() { } } - c.Required.FutureKeywords = stringMapToSortedSlice(keywords) + c.Required.FutureKeywords = util.KeysSorted(keywords) // extract required features from modules @@ -1049,25 +1055,13 @@ func (c *Compiler) buildRequiredCapabilities() { } } - c.Required.Features = stringMapToSortedSlice(features) + c.Required.Features = util.KeysSorted(features) for i, bi := range c.Required.Builtins { c.Required.Builtins[i] = bi.Minimal() } } -func stringMapToSortedSlice(xs map[string]struct{}) []string { - if len(xs) == 0 { - return nil - } - s := make([]string, 0, len(xs)) - for k := range xs { - s = append(s, k) - } - sort.Strings(s) - return s -} - // checkRecursion ensures that there are no recursive definitions, i.e., there are // no cycles in the Graph. func (c *Compiler) checkRecursion() { @@ -1609,6 +1603,10 @@ func (c *Compiler) checkTypes() { } func (c *Compiler) checkUnsafeBuiltins() { + if len(c.unsafeBuiltinsMap) == 0 { + return + } + for _, name := range c.sorted { errs := checkUnsafeBuiltins(c.unsafeBuiltinsMap, c.Modules[name]) for _, err := range errs { @@ -1618,6 +1616,17 @@ func (c *Compiler) checkUnsafeBuiltins() { } func (c *Compiler) checkDeprecatedBuiltins() { + checkNeeded := false + for _, b := range c.Required.Builtins { + if _, found := c.deprecatedBuiltinsMap[b.Name]; found { + checkNeeded = true + break + } + } + if !checkNeeded { + return + } + for _, name := range c.sorted { mod := c.Modules[name] if c.strict || mod.regoV1Compatible() { @@ -1776,7 +1785,7 @@ func (c *Compiler) checkImports() { mod := c.Modules[name] for _, imp := range mod.Imports { - if !supportsRegoV1Import && Compare(imp.Path, RegoV1CompatibleRef) == 0 { + if !supportsRegoV1Import && RegoV1CompatibleRef.Equal(imp.Path.Value) { c.err(NewError(CompileErr, imp.Loc(), "rego.v1 import is not supported")) } } @@ -2218,8 +2227,10 @@ func containsPrintCall(x interface{}) bool { return found } +var printRef = Print.Ref() + func isPrintCall(x *Expr) bool { - return x.IsCall() && x.Operator().Equal(Print.Ref()) + return x.IsCall() && x.Operator().Equal(printRef) } // rewriteRefsInHead will rewrite rules so that the head does not contain any @@ -2454,8 +2465,8 @@ func getPrimaryRuleAnnotations(as *AnnotationSet, rule *Rule) *Annotations { } // Sort by annotation location; chain must start with annotations declared closest to rule, then going outward - sort.SliceStable(annots, func(i, j int) bool { - return annots[i].Location.Compare(annots[j].Location) > 0 + slices.SortStableFunc(annots, func(a, b *Annotations) int { + return -a.Location.Compare(b.Location) }) return annots[0] @@ -2509,12 +2520,15 @@ func rewriteRegoMetadataCalls(metadataChainVar *Var, metadataRuleVar *Var, body return errs } +var regoMetadataChainRef = RegoMetadataChain.Ref() +var regoMetadataRuleRef = RegoMetadataRule.Ref() + func isRegoMetadataChainCall(x *Expr) bool { - return x.IsCall() && x.Operator().Equal(RegoMetadataChain.Ref()) + return x.IsCall() && x.Operator().Equal(regoMetadataChainRef) } func isRegoMetadataRuleCall(x *Expr) bool { - return x.IsCall() && x.Operator().Equal(RegoMetadataRule.Ref()) + return x.IsCall() && x.Operator().Equal(regoMetadataRuleRef) } func createMetadataChain(chain []*AnnotationsRef) (*Term, *Error) { @@ -2524,14 +2538,14 @@ func createMetadataChain(chain []*AnnotationsRef) (*Term, *Error) { p := link.Path.toArray(). Slice(1, -1) // Dropping leading 'data' element of path obj := NewObject( - Item(StringTerm("path"), NewTerm(p)), + Item(pathTerm, NewTerm(p)), ) if link.Annotations != nil { annotObj, err := link.Annotations.toObject() if err != nil { return nil, err } - obj.Insert(StringTerm("annotations"), NewTerm(*annotObj)) + obj.Insert(annotationsTerm, NewTerm(*annotObj)) } metaArray = metaArray.Append(NewTerm(obj)) } @@ -2646,9 +2660,7 @@ func (c *Compiler) rewriteLocalVarsInRule(rule *Rule, unusedArgs VarSet, argsSta // For rewritten vars use the collection of all variables that // were in the stack at some point in time. - for k, v := range stack.rewritten { - c.RewrittenVars[k] = v - } + maps.Copy(c.RewrittenVars, stack.rewritten) rule.Body = body @@ -2716,9 +2728,7 @@ func (xform *rewriteNestedHeadVarLocalTransform) Visit(x interface{}) bool { stop = true } - for k, v := range stack.rewritten { - xform.RewrittenVars[k] = v - } + maps.Copy(xform.RewrittenVars, stack.rewritten) return stop } @@ -3045,13 +3055,12 @@ func (qc *queryCompiler) rewriteLocalVars(_ *QueryContext, body Body) (Body, err if len(err) != 0 { return nil, err } - qc.rewritten = make(map[Var]Var, len(stack.rewritten)) - for k, v := range stack.rewritten { - // The vars returned during the rewrite will include all seen vars, - // even if they're not declared with an assignment operation. We don't - // want to include these inside the rewritten set though. - qc.rewritten[k] = v - } + + // The vars returned during the rewrite will include all seen vars, + // even if they're not declared with an assignment operation. We don't + // want to include these inside the rewritten set though. + qc.rewritten = maps.Clone(stack.rewritten) + return body, nil } @@ -3279,9 +3288,7 @@ func getComprehensionIndex(dbg debug.Debug, arity func(Ref) int, candidates VarS result = append(result, NewTerm(v)) } - sort.Slice(result, func(i, j int) bool { - return result[i].Value.Compare(result[j].Value) < 0 - }) + slices.SortFunc(result, TermValueCompare) debugRes := make([]*Term, len(result)) for i, r := range result { @@ -3406,12 +3413,7 @@ func NewModuleTree(mods map[string]*Module) *ModuleTreeNode { root := &ModuleTreeNode{ Children: map[Value]*ModuleTreeNode{}, } - names := make([]string, 0, len(mods)) - for name := range mods { - names = append(names, name) - } - sort.Strings(names) - for _, name := range names { + for _, name := range util.KeysSorted(mods) { m := mods[name] node := root for i, x := range m.Package.Path { @@ -3488,7 +3490,7 @@ func (n *ModuleTreeNode) DepthFirst(f func(*ModuleTreeNode) bool) { // rule path. type TreeNode struct { Key Value - Values []util.T + Values []any Children map[Value]*TreeNode Sorted []Value Hide bool @@ -3608,9 +3610,7 @@ func (n *TreeNode) DepthFirst(f func(*TreeNode) bool) { } func (n *TreeNode) sort() { - sort.Slice(n.Sorted, func(i, j int) bool { - return n.Sorted[i].Compare(n.Sorted[j]) < 0 - }) + slices.SortFunc(n.Sorted, Value.Compare) } func treeNodeFromRef(ref Ref, rule *Rule) *TreeNode { @@ -3621,7 +3621,7 @@ func treeNodeFromRef(ref Ref, rule *Rule) *TreeNode { Children: nil, } if rule != nil { - node.Values = []util.T{rule} + node.Values = []any{rule} } for i := len(ref) - 2; i >= 0; i-- { @@ -3648,9 +3648,7 @@ func (n *TreeNode) flattenChildren() []Ref { }) } - sort.Slice(ret.s, func(i, j int) bool { - return ret.s[i].Compare(ret.s[j]) < 0 - }) + slices.SortFunc(ret.s, RefCompare) return ret.s } @@ -3888,8 +3886,8 @@ func (vs unsafeVars) Vars() (result []unsafeVarLoc) { }) } - sort.Slice(result, func(i, j int) bool { - return result[i].Loc.Compare(result[j].Loc) < 0 + slices.SortFunc(result, func(a, b unsafeVarLoc) int { + return a.Loc.Compare(b.Loc) }) return result @@ -5101,23 +5099,13 @@ func (s *localDeclaredVars) Copy() *localDeclaredVars { for i := range s.vars { stack.vars = append(stack.vars, newDeclaredVarSet()) - for k, v := range s.vars[i].vs { - stack.vars[0].vs[k] = v - } - for k, v := range s.vars[i].reverse { - stack.vars[0].reverse[k] = v - } - for k, v := range s.vars[i].count { - stack.vars[0].count[k] = v - } - for k, v := range s.vars[i].occurrence { - stack.vars[0].occurrence[k] = v - } + maps.Copy(stack.vars[0].vs, s.vars[i].vs) + maps.Copy(stack.vars[0].reverse, s.vars[i].reverse) + maps.Copy(stack.vars[0].occurrence, s.vars[i].occurrence) + maps.Copy(stack.vars[0].count, s.vars[i].count) } - for k, v := range s.rewritten { - stack.rewritten[k] = v - } + maps.Copy(stack.rewritten, s.rewritten) return stack } @@ -5889,8 +5877,8 @@ func safetyErrorSlice(unsafe unsafeVars, rewritten map[Var]Var) (result Errors) // the latter are not meaningful to the user.) pairs := unsafe.Slice() - sort.Slice(pairs, func(i, j int) bool { - return pairs[i].Expr.Location.Compare(pairs[j].Expr.Location) < 0 + slices.SortFunc(pairs, func(a, b unsafePair) int { + return a.Expr.Location.Compare(b.Expr.Location) }) // Report at most one error per generated variable. @@ -5957,12 +5945,7 @@ func newRefSet(x ...Ref) *refSet { // ContainsPrefix returns true if r is prefixed by any of the existing refs in the set. func (rs *refSet) ContainsPrefix(r Ref) bool { - for i := range rs.s { - if r.HasPrefix(rs.s[i]) { - return true - } - } - return false + return slices.ContainsFunc(rs.s, r.HasPrefix) } // AddPrefix inserts r into the set if r is not prefixed by any existing @@ -5987,8 +5970,6 @@ func (rs *refSet) Sorted() []*Term { for i := range rs.s { terms[i] = NewTerm(rs.s[i]) } - sort.Slice(terms, func(i, j int) bool { - return terms[i].Value.Compare(terms[j].Value) < 0 - }) + slices.SortFunc(terms, TermValueCompare) return terms } diff --git a/v1/ast/compile_test.go b/v1/ast/compile_test.go index 793e1d03d5..16fd8596b9 100644 --- a/v1/ast/compile_test.go +++ b/v1/ast/compile_test.go @@ -9,6 +9,7 @@ import ( "encoding/json" "errors" "fmt" + "maps" "reflect" "slices" "sort" @@ -299,8 +300,8 @@ func TestModuleTree(t *testing.T) { if tree.Children[Var("data")].Children[String("user")].Children[String("system")].Hide { t.Fatalf("Expected user.system node to be visible") } - } + func TestCompilerGetExports(t *testing.T) { tests := []struct { note string @@ -777,7 +778,7 @@ func TestRuleIndices(t *testing.T) { note: "regression test for #6930 (no if)", modules: modules( `package test - + p.q contains "foo" p[q] := r if { @@ -1525,7 +1526,7 @@ func TestCompilerErrorLimit(t *testing.T) { sort.Strings(exp) sort.Strings(result) - if !reflect.DeepEqual(exp, result) { + if !slices.Equal(exp, result) { t.Errorf("Expected errors %v, got %v", exp, result) } } @@ -1958,9 +1959,9 @@ p[r] := 2 if { r := "foo" }`, }) c.WithPathConflictsCheck(func(path []string) (bool, error) { - if reflect.DeepEqual(path, []string{"badrules", "dataoverlap", "p"}) { + if slices.Equal(path, []string{"badrules", "dataoverlap", "p"}) { return true, nil - } else if reflect.DeepEqual(path, []string{"badrules", "existserr", "p"}) { + } else if slices.Equal(path, []string{"badrules", "existserr", "p"}) { return false, fmt.Errorf("unexpected error") } return false, nil @@ -2007,7 +2008,7 @@ p if { true }`, c.WithPathConflictsCheck(func(path []string) (bool, error) { if slices.Contains(path, "dataoverlap") { return true, nil - } else if reflect.DeepEqual(path, []string{"badrules", "existserr", "p"}) { + } else if slices.Equal(path, []string{"badrules", "existserr", "p"}) { return false, fmt.Errorf("unexpected error") } return false, nil @@ -2750,13 +2751,13 @@ func TestCompilerRewriteExprTerms(t *testing.T) { expected: ` package test - p = true { + p = true { plus(1, 2, __local3__) mul(3, 4, __local4__) numbers.range(__local3__, __local4__, __local5__) __local2__ = __local5__ - every __local0__, __local1__ in __local2__ { - __local1__ + every __local0__, __local1__ in __local2__ { + __local1__ } }`, }, @@ -2770,13 +2771,13 @@ func TestCompilerRewriteExprTerms(t *testing.T) { expected: ` package test - p = true { + p = true { div(1, 2, __local3__) abs(-1, __local4__) __local2__ = [__local3__, "foo", __local4__] every __local0__, __local1__ in __local2__ { __local1__ - } + } }`, }, { @@ -2788,13 +2789,13 @@ func TestCompilerRewriteExprTerms(t *testing.T) { expected: ` package test - p = true { + p = true { div(1, 2, __local3__) abs(-1, __local4__) __local2__ = [__local3__, ["foo", __local4__]] - every __local0__, __local1__ in __local2__ { - __local1__ - } + every __local0__, __local1__ in __local2__ { + __local1__ + } }`, }, } @@ -5547,7 +5548,7 @@ func TestCompilerRewriteLocalAssignments(t *testing.T) { if result.Compare(exp) != 0 { t.Fatalf("\nExpected:\n\n%v\n\nGot:\n\n%v", exp, result) } - if !reflect.DeepEqual(c.RewrittenVars, tc.expRewrittenMap) { + if !maps.Equal(c.RewrittenVars, tc.expRewrittenMap) { t.Fatalf("\nExpected Rewritten Vars:\n\n\t%+v\n\nGot:\n\n\t%+v\n\n", tc.expRewrittenMap, c.RewrittenVars) } }) @@ -9721,15 +9722,15 @@ func TestCompilerBuildRequiredCapabilities(t *testing.T) { names = append(names, compiler.Required.Builtins[i].Name) } - if !reflect.DeepEqual(names, tc.builtins) { + if !slices.Equal(names, tc.builtins) { t.Fatalf("expected builtins to be %v but got %v", tc.builtins, names) } - if !reflect.DeepEqual(compiler.Required.FutureKeywords, tc.keywords) { + if !slices.Equal(compiler.Required.FutureKeywords, tc.keywords) { t.Fatalf("expected keywords to be %v but got %v", tc.keywords, compiler.Required.FutureKeywords) } - if !reflect.DeepEqual(compiler.Required.Features, tc.features) { + if !slices.Equal(compiler.Required.Features, tc.features) { t.Fatalf("expected features to be %v but got %v", tc.features, compiler.Required.Features) } }) @@ -11292,12 +11293,12 @@ test_something if { a == b }`, exp: `package test - + a := 1 if { true } b := 2 if { true } -test_something = true if { - data.test.a = data.test.b +test_something = true if { + data.test.a = data.test.b }`, }, { @@ -11314,11 +11315,11 @@ test_something if { // When the test fails on '__local0__ = __local1__', the values for 'a' and 'b' are captured in local bindings, // accessible by the tracer. exp: `package test - + a := 1 if { true } b := 2 if { true } -test_something = true if { +test_something = true if { __local0__ = data.test.a __local1__ = data.test.b __local0__ = __local1__ @@ -11341,7 +11342,7 @@ test_something if { a := 1 if { true } b := 2 if { true } -test_something = true if { +test_something = true if { not data.test.a = data.test.b }`, }, @@ -11364,14 +11365,14 @@ a := 1 if { true } b := 2 if { true } l := [1, 2, 3] if { true } -test_something = true if { +test_something = true if { __local2__ = data.test.l - every __local0__, __local1__ in __local2__ { + every __local0__, __local1__ in __local2__ { __local4__ = data.test.b plus(__local4__, __local1__, __local3__) __local5__ = data.test.a - lt(__local5__, __local3__) - } + lt(__local5__, __local3__) + } }`, }, { @@ -11391,19 +11392,19 @@ test_something if { // When tests contain an 'every' statement, we're interested in the circumstances that made the every fail, // so it's body is rewritten. exp: `package test - + a := 1 if { true } b := 2 if { true } l := [1, 2, 3] if { true } -test_something = true if { - __local2__ = data.test.l; - every __local0__, __local1__ in __local2__ { +test_something = true if { + __local2__ = data.test.l; + every __local0__, __local1__ in __local2__ { __local4__ = data.test.b plus(__local4__, __local1__, __local3__) __local5__ = data.test.a - lt(__local5__, __local3__) - } + lt(__local5__, __local3__) + } }`, }, } diff --git a/v1/ast/varset.go b/v1/ast/varset.go index 14f531494b..d51abbdae6 100644 --- a/v1/ast/varset.go +++ b/v1/ast/varset.go @@ -6,7 +6,9 @@ package ast import ( "fmt" - "sort" + "slices" + + "github.com/open-policy-agent/opa/v1/util" ) // VarSet represents a set of variables. @@ -77,9 +79,7 @@ func (s VarSet) Sorted() []Var { for v := range s { sorted = append(sorted, v) } - sort.Slice(sorted, func(i, j int) bool { - return sorted[i].Compare(sorted[j]) < 0 - }) + slices.SortFunc(sorted, VarCompare) return sorted } @@ -91,10 +91,5 @@ func (s VarSet) Update(vs VarSet) { } func (s VarSet) String() string { - tmp := make([]string, 0, len(s)) - for v := range s { - tmp = append(tmp, string(v)) - } - sort.Strings(tmp) - return fmt.Sprintf("%v", tmp) + return fmt.Sprintf("%v", util.KeysSorted(s)) } diff --git a/v1/bundle/hash.go b/v1/bundle/hash.go index 021801bb0a..ab6fcd0f38 100644 --- a/v1/bundle/hash.go +++ b/v1/bundle/hash.go @@ -14,8 +14,9 @@ import ( "fmt" "hash" "io" - "sort" "strings" + + "github.com/open-policy-agent/opa/v1/util" ) // HashingAlgorithm represents a subset of hashing algorithms implemented in Go @@ -97,13 +98,7 @@ func walk(v interface{}, h io.Writer) { case map[string]interface{}: _, _ = h.Write([]byte("{")) - var keys []string - for k := range x { - keys = append(keys, k) - } - sort.Strings(keys) - - for i, key := range keys { + for i, key := range util.KeysSorted(x) { if i > 0 { _, _ = h.Write([]byte(",")) } diff --git a/v1/cover/cover.go b/v1/cover/cover.go index 2e119337b7..ea79fc9ef4 100644 --- a/v1/cover/cover.go +++ b/v1/cover/cover.go @@ -8,10 +8,11 @@ package cover import ( "bytes" "fmt" - "sort" + "slices" "github.com/open-policy-agent/opa/v1/ast" "github.com/open-policy-agent/opa/v1/topdown" + "github.com/open-policy-agent/opa/v1/util" ) // Cover computes and reports on coverage. @@ -144,8 +145,8 @@ type PositionSlice []Position // Sort sorts the slice by line number. func (sl PositionSlice) Sort() { - sort.Slice(sl, func(i, j int) bool { - return sl[i].Row < sl[j].Row + slices.SortFunc(sl, func(a, b Position) int { + return a.Row - b.Row }) } @@ -257,13 +258,7 @@ func (e *CoverageThresholdError) Error() string { if e.Report != nil && len(e.Report.Files) > 0 { buffer.WriteString("\nLines not covered:") - sorted := make([]string, 0, len(e.Report.Files)) - for file := range e.Report.Files { - sorted = append(sorted, file) - } - sort.Strings(sorted) - - for _, file := range sorted { + for _, file := range util.KeysSorted(e.Report.Files) { report := e.Report.Files[file] for _, r := range report.NotCovered { if r.Start.Row == r.End.Row { @@ -275,7 +270,7 @@ func (e *CoverageThresholdError) Error() string { } } - return fmt.Sprint(buffer.String()) + return buffer.String() } func sortedPositionSliceToRangeSlice(sorted []Position) (result []Range) { diff --git a/v1/loader/loader.go b/v1/loader/loader.go index 8daf22458b..563f99efca 100644 --- a/v1/loader/loader.go +++ b/v1/loader/loader.go @@ -12,7 +12,6 @@ import ( "io/fs" "os" "path/filepath" - "sort" "strings" "sigs.k8s.io/yaml" @@ -564,12 +563,7 @@ func Dirs(paths []string) []string { unique[dir] = struct{}{} } - u := make([]string, 0, len(unique)) - for k := range unique { - u = append(u, k) - } - sort.Strings(u) - return u + return util.KeysSorted(unique) } // SplitPrefix returns a tuple specifying the document prefix and the file diff --git a/v1/server/server_test.go b/v1/server/server_test.go index 17482c3cc9..6c5b562400 100644 --- a/v1/server/server_test.go +++ b/v1/server/server_test.go @@ -31,7 +31,6 @@ import ( "os" "path/filepath" "reflect" - "sort" "strconv" "strings" "sync/atomic" @@ -880,60 +879,60 @@ func TestCompileV1(t *testing.T) { t.Parallel() v0mod := `package test - + p { input.x = 1 } - + q { data.a[i] = input.x } - + default r = true - + r { input.x = 1 } - + custom_func(x) { data.a[i] == x } - + s { custom_func(input.x) } ` v1mod := `package test - + p if { input.x = 1 } - + q if { data.a[i] = input.x } - + default r = true - + r if { input.x = 1 } - + custom_func(x) if { data.a[i] == x } - + s if { custom_func(input.x) } ` v0v1mod := `package test import rego.v1 - + p if { input.x = 1 } - + q if { data.a[i] = input.x } - + default r = true - + r if { input.x = 1 } - + custom_func(x) if { data.a[i] == x } - + s if { custom_func(input.x) } ` @@ -1122,7 +1121,7 @@ func TestCompileV1(t *testing.T) { }`, 200, expQueryAndSupport( `data.partial.test.s = true`, `package partial.test - + s if { data.partial.test.custom_func(1) } custom_func(__local0__2) if { data.a[i2] = __local0__2 } `, @@ -1194,7 +1193,7 @@ func TestCompileV1Observability(t *testing.T) { err = f.v1(http.MethodPut, "/policies/test", `package test import rego.v1 - + p if { input.x = 1 }`, 200, "") if err != nil { t.Fatal(err) @@ -3083,7 +3082,7 @@ func TestDataPostExplainNotes(t *testing.T) { err := f.v1(http.MethodPut, "/policies/test", ` package test import rego.v1 - + p if { data.a[i] = x; x > 1 trace(sprintf("found x = %d", [x])) @@ -3693,13 +3692,7 @@ func TestPoliciesPutV1Noop(t *testing.T) { // Sort the metric keys and compare to expected value. We're assuming the // server skips parsing if the bytes are equal. - result := []string{} - - for k := range resp.Metrics { - result = append(result, k) - } - - sort.Strings(result) + result := util.KeysSorted(resp.Metrics) if !reflect.DeepEqual(exp, result) { t.Fatalf("Expected %v but got %v", exp, result) @@ -4012,7 +4005,7 @@ func TestStatusV1MetricsWithSystemAuthzPolicy(t *testing.T) { txn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) authzPolicy := `package system.authz import rego.v1 - + default allow = false allow if { input.path = ["v1", "status"] @@ -4830,7 +4823,7 @@ func TestAuthorization(t *testing.T) { txn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) authzPolicy := `package system.authz - + import rego.v1 import input.identity @@ -4883,7 +4876,7 @@ func TestAuthorization(t *testing.T) { // Reverse the policy. update := identifier.SetIdentity(newReqV1(http.MethodPut, "/policies/test", ` package system.authz - + import rego.v1 import input.identity diff --git a/v1/tester/runner.go b/v1/tester/runner.go index 4469840ebf..23350a9e4a 100644 --- a/v1/tester/runner.go +++ b/v1/tester/runner.go @@ -10,7 +10,6 @@ import ( "context" "fmt" "regexp" - "sort" "strings" "testing" "time" @@ -24,6 +23,7 @@ import ( "github.com/open-policy-agent/opa/v1/storage" "github.com/open-policy-agent/opa/v1/storage/inmem" "github.com/open-policy-agent/opa/v1/topdown" + "github.com/open-policy-agent/opa/v1/util" ) // TestPrefix declares the prefix for all test rules. @@ -349,13 +349,7 @@ func (r *Runner) runTests(ctx context.Context, txn storage.Transaction, enablePr } } - filenames := make([]string, 0, len(r.compiler.Modules)) - for name := range r.compiler.Modules { - filenames = append(filenames, name) - } - - sort.Strings(filenames) - + filenames := util.KeysSorted(r.compiler.Modules) ch := make(chan *Result) go func() { diff --git a/v1/util/compare.go b/v1/util/compare.go index 8ae7753690..8775a603dd 100644 --- a/v1/util/compare.go +++ b/v1/util/compare.go @@ -8,7 +8,6 @@ import ( "encoding/json" "fmt" "math/big" - "sort" ) // Compare returns 0 if a equals b, -1 if a is less than b, and 1 if b is than a. @@ -99,16 +98,8 @@ func Compare(a, b interface{}) int { case map[string]interface{}: switch b := b.(type) { case map[string]interface{}: - var aKeys []string - for k := range a { - aKeys = append(aKeys, k) - } - var bKeys []string - for k := range b { - bKeys = append(bKeys, k) - } - sort.Strings(aKeys) - sort.Strings(bKeys) + aKeys := KeysSorted(a) + bKeys := KeysSorted(b) aLen := len(aKeys) bLen := len(bKeys) minLen := aLen diff --git a/v1/util/maps.go b/v1/util/maps.go index d943b4d0a8..c56fbe98ac 100644 --- a/v1/util/maps.go +++ b/v1/util/maps.go @@ -1,5 +1,29 @@ package util +import ( + "cmp" + "slices" +) + +// Keys returns a slice of keys from any map. +func Keys[M ~map[K]V, K comparable, V any](m M) []K { + r := make([]K, 0, len(m)) + for k := range m { + r = append(r, k) + } + return r +} + +// KeysSorted returns a slice of keys from any map, sorted in ascending order. +func KeysSorted[M ~map[K]V, K cmp.Ordered, V any](m M) []K { + r := make([]K, 0, len(m)) + for k := range m { + r = append(r, k) + } + slices.Sort(r) + return r +} + // Values returns a slice of values from any map. Copied from golang.org/x/exp/maps. func Values[M ~map[K]V, K comparable, V any](m M) []V { r := make([]V, 0, len(m))