From 75962f58b9e84b56705fbc6cd5b50133c5b72d05 Mon Sep 17 00:00:00 2001 From: Anders Eknert Date: Mon, 20 Jan 2025 17:41:39 +0100 Subject: [PATCH] Perf: improvements to terms and built-in functions (#7284) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Having worked on performance improvements in OPA on the side for almost a month now, there's a lot of code piling up 😅 So much that a single PR would be way too much to review. Instead, I'm splitting the work into chunks, and will submit the next PR as soon as this one is merged. Using the same benchmark as before — Regal linting itself, these new changes in total reduce the number of allocations by ~13 million, and quite a substantial amount of evaluation time saved as well. This first PR is isolated to improvements to terms, values and built-ins, and saves ~3M allocations. The details can be found below for each change, and of course in the code :) **BenchmarkRegalLintingItself-10 Before** ``` 1885978209 ns/op 3497157312 B/op 69064779 allocs/op ``` **BenchmarkRegalLintingItself-10 After** ``` 1796255084 ns/op 3452379408 B/op 66126623 allocs/op ``` **Terms** - Use pointer receivers consistently for object and set types. This allows changing the sortGuard once lock from a pointer to a non-pointer type, which is really the biggest win performance-wise in this PR. - Comparisons happen all the time, so make sure these take the shortest path possible whenever, possible, such as when one type is compared to another value of the same type. Built-in functions: **Arrays** - Both `array.concat` and `array.slice` will now return the operand on operations where the result isn't different from the input operand (like when concatenating an empty array) instead of allocating a new term/value. **Strings** - Return operand on unchanged result rather than allocating new term/value. - Where applicable, have functions take a cheaper path when string is ASCII and we can avoid the cost of rune conversion. **Crypto** - Hashing functions now optimized, spending less than half the time compared to previously. **Objects** - Avoid heap allocating result boolean escaping its scope, and instead use the return value of the `Until` function. **HTTP** - Use interned terms for keys in configuration object, avoding allocating these each time `http.send` is invoked. **Globs** - Use read/write lock to avoid contention. Use package level vars for "constant" values, avoiding them to escape to the heap each invocation. **Not directly/only related to built-in functions** - Add `ValueName` function replacing the previous `TypeName` functions for getting the name of Value's without paying for `any` interface allocations. - Add a few more interned terms. Signed-off-by: Anders Eknert --- internal/edittree/edittree_test.go | 2 +- internal/planner/planner.go | 2 +- v1/ast/compare.go | 66 ++++++-- v1/ast/compile.go | 2 +- v1/ast/interning.go | 6 + v1/ast/map.go | 2 +- v1/ast/parser.go | 10 +- v1/ast/parser_ext.go | 14 +- v1/ast/strings.go | 36 +++++ v1/ast/strings_bench_test.go | 29 ++++ v1/ast/term.go | 241 ++++++++++++++++++----------- v1/ast/term_test.go | 8 +- v1/format/format.go | 2 +- v1/rego/rego.go | 2 +- v1/topdown/aggregates.go | 8 +- v1/topdown/array.go | 11 ++ v1/topdown/builtins/builtins.go | 12 +- v1/topdown/crypto.go | 46 +++--- v1/topdown/crypto_test.go | 26 ++++ v1/topdown/glob.go | 21 +-- v1/topdown/http.go | 127 ++++++++------- v1/topdown/json.go | 4 +- v1/topdown/jsonschema.go | 2 +- v1/topdown/object.go | 2 +- v1/topdown/print.go | 2 +- v1/topdown/runtime.go | 6 +- v1/topdown/strings.go | 185 +++++++++++++++++----- v1/topdown/subset.go | 47 ++---- v1/util/performance.go | 17 +- 29 files changed, 634 insertions(+), 304 deletions(-) create mode 100644 v1/ast/strings_bench_test.go diff --git a/internal/edittree/edittree_test.go b/internal/edittree/edittree_test.go index 9956930377..ec354eac0a 100644 --- a/internal/edittree/edittree_test.go +++ b/internal/edittree/edittree_test.go @@ -885,7 +885,7 @@ func parsePath(path *ast.Term) (ast.Ref, error) { pathSegments = append(pathSegments, term) }) default: - return nil, builtins.NewOperandErr(2, "must be one of {set, array} containing string paths or array of path segments but got %v", ast.TypeName(p)) + return nil, builtins.NewOperandErr(2, "must be one of {set, array} containing string paths or array of path segments but got %v", ast.ValueName(p)) } return pathSegments, nil diff --git a/internal/planner/planner.go b/internal/planner/planner.go index d6b3020413..160775c0e9 100644 --- a/internal/planner/planner.go +++ b/internal/planner/planner.go @@ -1519,7 +1519,7 @@ func (p *Planner) planValue(t ast.Value, loc *ast.Location, iter planiter) error p.loc = loc return p.planObjectComprehension(v, iter) default: - return fmt.Errorf("%v term not implemented", ast.TypeName(v)) + return fmt.Errorf("%v term not implemented", ast.ValueName(v)) } } diff --git a/v1/ast/compare.go b/v1/ast/compare.go index 3bb6f2a75d..24e61712e7 100644 --- a/v1/ast/compare.go +++ b/v1/ast/compare.go @@ -151,14 +151,7 @@ func Compare(a, b interface{}) int { } return 1 case Var: - b := b.(Var) - if a.Equal(b) { - return 0 - } - if a < b { - return -1 - } - return 1 + return VarCompare(a, b.(Var)) case Ref: b := b.(Ref) return termSliceCompare(a, b) @@ -181,7 +174,7 @@ func Compare(a, b interface{}) int { if cmp := Compare(a.Term, b.Term); cmp != 0 { return cmp } - return Compare(a.Body, b.Body) + return a.Body.Compare(b.Body) case *ObjectComprehension: b := b.(*ObjectComprehension) if cmp := Compare(a.Key, b.Key); cmp != 0 { @@ -190,13 +183,13 @@ func Compare(a, b interface{}) int { if cmp := Compare(a.Value, b.Value); cmp != 0 { return cmp } - return Compare(a.Body, b.Body) + return a.Body.Compare(b.Body) case *SetComprehension: b := b.(*SetComprehension) if cmp := Compare(a.Term, b.Term); cmp != 0 { return cmp } - return Compare(a.Body, b.Body) + return a.Body.Compare(b.Body) case Call: b := b.(Call) return termSliceCompare(a, b) @@ -394,3 +387,54 @@ func withSliceCompare(a, b []*With) int { } return 0 } + +func VarCompare(a, b Var) int { + if a == b { + return 0 + } + if a < b { + return -1 + } + return 1 +} + +func TermValueCompare(a, b *Term) int { + return a.Value.Compare(b.Value) +} + +func ValueEqual(a, b Value) bool { + // TODO(ae): why doesn't this work the same? + // + // case interface{ Equal(Value) bool }: + // return v.Equal(b) + // + // When put on top, golangci-lint even flags the other cases as unreachable.. + // but TestTopdownVirtualCache will have failing test cases when we replace + // the other cases with the above one.. 🤔 + switch v := a.(type) { + case Null: + return v.Equal(b) + case Boolean: + return v.Equal(b) + case Number: + return v.Equal(b) + case String: + return v.Equal(b) + case Var: + return v.Equal(b) + case Ref: + return v.Equal(b) + case *Array: + return v.Equal(b) + } + + return a.Compare(b) == 0 +} + +func RefCompare(a, b Ref) int { + return termSliceCompare(a, b) +} + +func RefEqual(a, b Ref) bool { + return termSliceEqual(a, b) +} diff --git a/v1/ast/compile.go b/v1/ast/compile.go index a238d454af..76b3c51bda 100644 --- a/v1/ast/compile.go +++ b/v1/ast/compile.go @@ -5500,7 +5500,7 @@ func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, e return true } } - errs = append(errs, NewError(CompileErr, t.Location, "cannot assign to %v", TypeName(t.Value))) + errs = append(errs, NewError(CompileErr, t.Location, "cannot assign to %v", ValueName(t.Value))) return true } diff --git a/v1/ast/interning.go b/v1/ast/interning.go index f521af9661..17b10231b7 100644 --- a/v1/ast/interning.go +++ b/v1/ast/interning.go @@ -15,6 +15,8 @@ var ( // since this is by far the most common negative number minusOneTerm = &Term{Value: Number("-1")} + + InternedNullTerm = &Term{Value: Null{}} ) // InternedBooleanTerm returns an interned term with the given boolean value. @@ -1090,3 +1092,7 @@ var intNumberTerms = [...]*Term{ {Value: Number("511")}, {Value: Number("512")}, } + +var InternedEmptyString = StringTerm("") + +var InternedEmptyObject = ObjectTerm() diff --git a/v1/ast/map.go b/v1/ast/map.go index c22d279a68..5a64f32505 100644 --- a/v1/ast/map.go +++ b/v1/ast/map.go @@ -31,7 +31,7 @@ func (vs *ValueMap) MarshalJSON() ([]byte, error) { vs.Iter(func(k Value, v Value) bool { tmp = append(tmp, map[string]interface{}{ "name": k.String(), - "type": TypeName(v), + "type": ValueName(v), "value": v, }) return false diff --git a/v1/ast/parser.go b/v1/ast/parser.go index 6639ca990b..fef9575132 100644 --- a/v1/ast/parser.go +++ b/v1/ast/parser.go @@ -591,7 +591,7 @@ func (p *Parser) parsePackage() *Package { pkg.Path[0] = DefaultRootDocument.Copy().SetLocation(v[0].Location) first, ok := v[0].Value.(Var) if !ok { - p.errorf(v[0].Location, "unexpected %v token: expecting var", TypeName(v[0].Value)) + p.errorf(v[0].Location, "unexpected %v token: expecting var", ValueName(v[0].Value)) return nil } pkg.Path[1] = StringTerm(string(first)).SetLocation(v[0].Location) @@ -600,7 +600,7 @@ func (p *Parser) parsePackage() *Package { case String: pkg.Path[i] = v[i-1] default: - p.errorf(v[i-1].Location, "unexpected %v token: expecting string", TypeName(v[i-1].Value)) + p.errorf(v[i-1].Location, "unexpected %v token: expecting string", ValueName(v[i-1].Value)) return nil } } @@ -643,7 +643,7 @@ func (p *Parser) parseImport() *Import { case Ref: for i := 1; i < len(v); i++ { if _, ok := v[i].Value.(String); !ok { - p.errorf(v[i].Location, "unexpected %v token: expecting string", TypeName(v[i].Value)) + p.errorf(v[i].Location, "unexpected %v token: expecting string", ValueName(v[i].Value)) return nil } } @@ -1717,7 +1717,7 @@ func (p *Parser) parseRef(head *Term, offset int) (term *Term) { case Var, *Array, Object, Set, *ArrayComprehension, *ObjectComprehension, *SetComprehension, Call: // ok default: - p.errorf(loc, "illegal ref (head cannot be %v)", TypeName(h)) + p.errorf(loc, "illegal ref (head cannot be %v)", ValueName(h)) } ref := []*Term{head} @@ -2318,7 +2318,7 @@ func (p *Parser) validateDefaultRuleArgs(rule *Rule) bool { switch v := x.Value.(type) { case Var: // do nothing default: - p.error(rule.Loc(), fmt.Sprintf("illegal default rule (arguments cannot contain %v)", TypeName(v))) + p.error(rule.Loc(), fmt.Sprintf("illegal default rule (arguments cannot contain %v)", ValueName(v))) valid = false return true } diff --git a/v1/ast/parser_ext.go b/v1/ast/parser_ext.go index f08c112a72..db1c3caedc 100644 --- a/v1/ast/parser_ext.go +++ b/v1/ast/parser_ext.go @@ -186,7 +186,7 @@ func ParseRuleFromExpr(module *Module, expr *Expr) (*Rule, error) { } return ParsePartialSetDocRuleFromTerm(module, term) default: - return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(v)) + return nil, fmt.Errorf("%v cannot be used for rule name", ValueName(v)) } } @@ -277,7 +277,7 @@ func ParseCompleteDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, erro return nil, fmt.Errorf("ref not ground") } } else { - return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(lhs.Value)) + return nil, fmt.Errorf("%v cannot be used for rule name", ValueName(lhs.Value)) } head.Value = rhs head.Location = lhs.Location @@ -299,7 +299,7 @@ func ParseCompleteDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, erro func ParseCompleteDocRuleWithDotsFromTerm(module *Module, term *Term) (*Rule, error) { ref, ok := term.Value.(Ref) if !ok { - return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(term.Value)) + return nil, fmt.Errorf("%v cannot be used for rule name", ValueName(term.Value)) } if _, ok := ref[0].Value.(Var); !ok { @@ -328,7 +328,7 @@ func ParseCompleteDocRuleWithDotsFromTerm(module *Module, term *Term) (*Rule, er func ParsePartialObjectDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { ref, ok := lhs.Value.(Ref) if !ok { - return nil, fmt.Errorf("%v cannot be used as rule name", TypeName(lhs.Value)) + return nil, fmt.Errorf("%v cannot be used as rule name", ValueName(lhs.Value)) } if _, ok := ref[0].Value.(Var); !ok { @@ -363,7 +363,7 @@ func ParsePartialSetDocRuleFromTerm(module *Module, term *Term) (*Rule, error) { ref, ok := term.Value.(Ref) if !ok || len(ref) == 1 { - return nil, fmt.Errorf("%vs cannot be used for rule head", TypeName(term.Value)) + return nil, fmt.Errorf("%vs cannot be used for rule head", ValueName(term.Value)) } if _, ok := ref[0].Value.(Var); !ok { return nil, fmt.Errorf("invalid rule head: %v", ref) @@ -373,7 +373,7 @@ func ParsePartialSetDocRuleFromTerm(module *Module, term *Term) (*Rule, error) { if len(ref) == 2 { v, ok := ref[0].Value.(Var) if !ok { - return nil, fmt.Errorf("%vs cannot be used for rule head", TypeName(term.Value)) + return nil, fmt.Errorf("%vs cannot be used for rule head", ValueName(term.Value)) } // Modify the code to add the location to the head ref // and set the head ref's jsonOptions. @@ -408,7 +408,7 @@ func ParseRuleFromCallEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { ref, ok := call[0].Value.(Ref) if !ok { - return nil, fmt.Errorf("%vs cannot be used in function signature", TypeName(call[0].Value)) + return nil, fmt.Errorf("%vs cannot be used in function signature", ValueName(call[0].Value)) } if _, ok := ref[0].Value.(Var); !ok { return nil, fmt.Errorf("invalid rule head: %v", ref) diff --git a/v1/ast/strings.go b/v1/ast/strings.go index e489f6977c..40d66753f5 100644 --- a/v1/ast/strings.go +++ b/v1/ast/strings.go @@ -16,3 +16,39 @@ func TypeName(x interface{}) string { } return strings.ToLower(reflect.Indirect(reflect.ValueOf(x)).Type().Name()) } + +// ValueName returns a human readable name for the AST Value type. +// This is preferrable over calling TypeName when the argument is known to be +// a Value, as this doesn't require reflection (= heap allocations). +func ValueName(x Value) string { + switch x.(type) { + case String: + return "string" + case Boolean: + return "boolean" + case Number: + return "number" + case Null: + return "null" + case Var: + return "var" + case Object: + return "object" + case Set: + return "set" + case Ref: + return "ref" + case Call: + return "call" + case *Array: + return "array" + case *ArrayComprehension: + return "arraycomprehension" + case *ObjectComprehension: + return "objectcomprehension" + case *SetComprehension: + return "setcomprehension" + } + + return TypeName(x) +} diff --git a/v1/ast/strings_bench_test.go b/v1/ast/strings_bench_test.go new file mode 100644 index 0000000000..c7cce82bc2 --- /dev/null +++ b/v1/ast/strings_bench_test.go @@ -0,0 +1,29 @@ +package ast + +import "testing" + +// BenchmarkTypeName-10 32207775 38.93 ns/op 8 B/op 1 allocs/op +func BenchmarkTypeName(b *testing.B) { + term := StringTerm("foo") + b.ResetTimer() + + for i := 0; i < b.N; i++ { + name := TypeName(term.Value) + if name != "string" { + b.Fatalf("expected string but got %v", name) + } + } +} + +// BenchmarkValueName-10 508312227 2.374 ns/op 0 B/op 0 allocs/op +func BenchmarkValueName(b *testing.B) { + term := StringTerm("foo") + b.ResetTimer() + + for i := 0; i < b.N; i++ { + name := ValueName(term.Value) + if name != "string" { + b.Fatalf("expected string but got %v", name) + } + } +} diff --git a/v1/ast/term.go b/v1/ast/term.go index d79f4418bd..1350150f1a 100644 --- a/v1/ast/term.go +++ b/v1/ast/term.go @@ -14,7 +14,7 @@ import ( "math/big" "net/url" "regexp" - "sort" + "slices" "strconv" "strings" "sync" @@ -56,10 +56,16 @@ type Value interface { func InterfaceToValue(x interface{}) (Value, error) { switch x := x.(type) { case nil: - return Null{}, nil + return NullValue, nil case bool: - return Boolean(x), nil + if x { + return InternedBooleanTerm(true).Value, nil + } + return InternedBooleanTerm(false).Value, nil case json.Number: + if interned := InternedIntNumberTermFromString(string(x)); interned != nil { + return interned.Value, nil + } return Number(x), nil case int64: return int64Number(x), nil @@ -85,11 +91,7 @@ func InterfaceToValue(x interface{}) (Value, error) { kvs := util.NewPtrSlice[Term](len(x) * 2) idx := 0 for k, v := range x { - k, err := InterfaceToValue(k) - if err != nil { - return nil, err - } - kvs[idx].Value = k + kvs[idx].Value = String(k) v, err := InterfaceToValue(v) if err != nil { return nil, err @@ -105,15 +107,7 @@ func InterfaceToValue(x interface{}) (Value, error) { case map[string]string: r := newobject(len(x)) for k, v := range x { - k, err := InterfaceToValue(k) - if err != nil { - return nil, err - } - v, err := InterfaceToValue(v) - if err != nil { - return nil, err - } - r.Insert(NewTerm(k), NewTerm(v)) + r.Insert(StringTerm(k), StringTerm(v)) } return r, nil default: @@ -136,7 +130,7 @@ func ValueFromReader(r io.Reader) (Value, error) { // As converts v into a Go native type referred to by x. func As(v Value, x interface{}) error { - return util.NewJSONDecoder(bytes.NewBufferString(v.String())).Decode(x) + return util.NewJSONDecoder(strings.NewReader(v.String())).Decode(x) } // Resolver defines the interface for resolving references to native Go values. @@ -363,7 +357,7 @@ func (term *Term) Copy() *Term { } // Equal returns true if this term equals the other term. Equality is -// defined for each kind of term. +// defined for each kind of term, and does not compare the Location. func (term *Term) Equal(other *Term) bool { if term == nil && other != nil { return false @@ -375,28 +369,7 @@ func (term *Term) Equal(other *Term) bool { return true } - // TODO(tsandall): This early-exit avoids allocations for types that have - // Equal() functions that just use == underneath. We should revisit the - // other types and implement Equal() functions that do not require - // allocations. - switch v := term.Value.(type) { - case Null: - return v.Equal(other.Value) - case Boolean: - return v.Equal(other.Value) - case Number: - return v.Equal(other.Value) - case String: - return v.Equal(other.Value) - case Var: - return v.Equal(other.Value) - case Ref: - return v.Equal(other.Value) - case *Array: - return v.Equal(other.Value) - } - - return term.Value.Compare(other.Value) == 0 + return ValueEqual(term.Value, other.Value) } // Get returns a value referred to by name from the term. @@ -441,7 +414,7 @@ func (term *Term) setJSONOptions(opts astJSON.Options) { // Specialized marshalling logic is required to include a type hint for Value. func (term *Term) MarshalJSON() ([]byte, error) { d := map[string]interface{}{ - "type": TypeName(term.Value), + "type": ValueName(term.Value), "value": term.Value, } if term.jsonOptions.MarshalOptions.IncludeLocation.Term { @@ -553,13 +526,7 @@ func ContainsClosures(v interface{}) bool { // IsScalar returns true if the AST value is a scalar. func IsScalar(v Value) bool { switch v.(type) { - case String: - return true - case Number: - return true - case Boolean: - return true - case Null: + case String, Number, Boolean, Null: return true } return false @@ -568,9 +535,11 @@ func IsScalar(v Value) bool { // Null represents the null value defined by JSON. type Null struct{} +var NullValue Value = Null{} + // NullTerm creates a new Term with a Null value. func NullTerm() *Term { - return &Term{Value: Null{}} + return &Term{Value: NullValue} } // Equal returns true if the other term Value is also Null. @@ -586,13 +555,16 @@ func (null Null) Equal(other Value) bool { // Compare compares null to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (null Null) Compare(other Value) int { - return Compare(null, other) + if _, ok := other.(Null); ok { + return 0 + } + return -1 } // Find returns the current value or a not found error. func (null Null) Find(path Ref) (Value, error) { if len(path) == 0 { - return null, nil + return NullValue, nil } return nil, errFindNotFound } @@ -616,7 +588,10 @@ type Boolean bool // BooleanTerm creates a new Term with a Boolean value. func BooleanTerm(b bool) *Term { - return &Term{Value: Boolean(b)} + if b { + return &Term{Value: InternedBooleanTerm(true).Value} + } + return &Term{Value: InternedBooleanTerm(false).Value} } // Equal returns true if the other Value is a Boolean and is equal. @@ -632,13 +607,29 @@ func (bol Boolean) Equal(other Value) bool { // Compare compares bol to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (bol Boolean) Compare(other Value) int { - return Compare(bol, other) + switch other := other.(type) { + case Boolean: + if bol == other { + return 0 + } + if !bol { + return -1 + } + return 1 + case Null: + return 1 + } + + return -1 } // Find returns the current value or a not found error. func (bol Boolean) Find(path Ref) (Value, error) { if len(path) == 0 { - return bol, nil + if bol { + return InternedBooleanTerm(true).Value, nil + } + return InternedBooleanTerm(false).Value, nil } return nil, errFindNotFound } @@ -688,13 +679,14 @@ func FloatNumberTerm(f float64) *Term { func (num Number) Equal(other Value) bool { switch other := other.(type) { case Number: - n1, ok1 := num.Int64() - n2, ok2 := other.Int64() - if ok1 && ok2 && n1 == n2 { - return true + if n1, ok1 := num.Int64(); ok1 { + n2, ok2 := other.Int64() + if ok1 && ok2 && n1 == n2 { + return true + } } - return Compare(num, other) == 0 + return num.Compare(other) == 0 default: return false } @@ -703,6 +695,21 @@ func (num Number) Equal(other Value) bool { // Compare compares num to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (num Number) Compare(other Value) int { + // Optimize for the common case, as calling Compare allocates on heap. + if otherNum, yes := other.(Number); yes { + if ai, ok := num.Int64(); ok { + if bi, ok := otherNum.Int64(); ok { + if ai == bi { + return 0 + } + if ai < bi { + return -1 + } + return 1 + } + } + } + return Compare(num, other) } @@ -800,6 +807,19 @@ func (str String) Equal(other Value) bool { // Compare compares str to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (str String) Compare(other Value) int { + // Optimize for the common case of one string being compared to another by + // using a direct comparison of values. This avoids the allocation performed + // when calling Compare and its interface{} argument conversion. + if otherStr, ok := other.(String); ok { + if str == otherStr { + return 0 + } + if str < otherStr { + return -1 + } + return 1 + } + return Compare(str, other) } @@ -848,6 +868,9 @@ func (v Var) Equal(other Value) bool { // Compare compares v to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (v Var) Compare(other Value) int { + if otherVar, ok := other.(Var); ok { + return strings.Compare(string(v), string(otherVar)) + } return Compare(v, other) } @@ -1020,6 +1043,10 @@ func (ref Ref) Equal(other Value) bool { // Compare compares ref to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (ref Ref) Compare(other Value) int { + if o, ok := other.(Ref); ok { + return termSliceCompare(ref, o) + } + return Compare(ref, other) } @@ -1051,32 +1078,32 @@ func (ref Ref) HasPrefix(other Ref) bool { // ConstantPrefix returns the constant portion of the ref starting from the head. func (ref Ref) ConstantPrefix() Ref { - ref = ref.Copy() - i := ref.Dynamic() if i < 0 { - return ref + return ref.Copy() } - return ref[:i] + return ref[:i].Copy() } func (ref Ref) StringPrefix() Ref { - r := ref.Copy() - for i := 1; i < len(ref); i++ { - switch r[i].Value.(type) { + switch ref[i].Value.(type) { case String: // pass default: // cut off - return r[:i] + return ref[:i].Copy() } } - return r + return ref.Copy() } // GroundPrefix returns the ground portion of the ref starting from the head. By // definition, the head of the reference is always ground. func (ref Ref) GroundPrefix() Ref { + if ref.IsGround() { + return ref + } + prefix := make(Ref, 0, len(ref)) for i, x := range ref { @@ -1260,6 +1287,19 @@ func (arr *Array) Equal(other Value) bool { // Compare compares arr to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (arr *Array) Compare(other Value) int { + if b, ok := other.(*Array); ok { + return termSliceCompare(arr.elems, b.elems) + } + + sortA := sortOrder(arr) + sortB := sortOrder(other) + + if sortA < sortB { + return -1 + } else if sortB < sortA { + return 1 + } + return Compare(arr, other) } @@ -1307,7 +1347,9 @@ func (arr *Array) Sorted() *Array { for i := range cpy { cpy[i] = arr.elems[i] } - sort.Sort(termSlice(cpy)) + + slices.SortFunc(cpy, TermValueCompare) + a := NewArray(cpy...) a.hashs = arr.hashs return a @@ -1480,7 +1522,7 @@ func newset(n int) *set { keys: keys, hash: 0, ground: true, - sortGuard: new(sync.Once), + sortGuard: sync.Once{}, } } @@ -1493,11 +1535,15 @@ func SetTerm(t ...*Term) *Term { } type set struct { - elems map[int]*Term - keys []*Term - hash int - ground bool - sortGuard *sync.Once // Prevents race condition around sorting. + elems map[int]*Term + keys []*Term + hash int + ground bool + // Prevents race condition around sorting. + // We can avoid (the allocation cost of) using a pointer here as all + // methods of `set` use a pointer receiver, and the `sync.Once` value + // is never copied. + sortGuard sync.Once } // Copy returns a deep copy of s. @@ -1547,7 +1593,7 @@ func (s *set) String() string { func (s *set) sortedKeys() []*Term { s.sortGuard.Do(func() { - sort.Sort(termSlice(s.keys)) + slices.SortFunc(s.keys, TermValueCompare) }) return s.keys } @@ -1717,7 +1763,7 @@ func (s *set) clear() { s.keys = s.keys[:0] s.hash = 0 s.ground = true - s.sortGuard = new(sync.Once) + s.sortGuard = sync.Once{} } func (s *set) insertNoGuard(x *Term) { @@ -1825,7 +1871,7 @@ func (s *set) insert(x *Term, resetSortGuard bool) { // Note that this will always be the case when external code calls insert via // Add, or otherwise. Internal code may however benefit from not having to // re-create this pointer when it's known not to be needed. - s.sortGuard = new(sync.Once) + s.sortGuard = sync.Once{} } s.hash += hash @@ -2094,7 +2140,8 @@ func (l *lazyObj) Keys() []*Term { for k := range l.native { ret = append(ret, StringTerm(k)) } - sort.Sort(termSlice(ret)) + slices.SortFunc(ret, TermValueCompare) + return ret } @@ -2148,7 +2195,7 @@ type object struct { ground int // number of key and value grounds. Counting is // required to support insert's key-value replace. hash int - sortGuard *sync.Once // Prevents race condition around sorting. + sortGuard sync.Once // Prevents race condition around sorting. } func newobject(n int) *object { @@ -2161,7 +2208,7 @@ func newobject(n int) *object { keys: keys, ground: 0, hash: 0, - sortGuard: new(sync.Once), + sortGuard: sync.Once{}, } } @@ -2185,7 +2232,9 @@ func Item(key, value *Term) [2]*Term { func (obj *object) sortedKeys() objectElemSlice { obj.sortGuard.Do(func() { - sort.Sort(obj.keys) + slices.SortFunc(obj.keys, func(a, b *objectElem) int { + return a.key.Value.Compare(b.key.Value) + }) }) return obj.keys } @@ -2376,7 +2425,7 @@ func (obj *object) MarshalJSON() ([]byte, error) { // overlapping keys between obj and other, the values of associated with the keys are merged. Only // objects can be merged with other objects. If the values cannot be merged, the second turn value // will be false. -func (obj object) Merge(other Object) (Object, bool) { +func (obj *object) Merge(other Object) (Object, bool) { return obj.MergeWith(other, func(v1, v2 *Term) (*Term, bool) { obj1, ok1 := v1.Value.(Object) obj2, ok2 := v2.Value.(Object) @@ -2395,7 +2444,7 @@ func (obj object) Merge(other Object) (Object, bool) { // If there are overlapping keys between obj and other, the conflictResolver // is called. The conflictResolver can return a merged value and a boolean // indicating if the merge has failed and should stop. -func (obj object) MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) { +func (obj *object) MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) { result := NewObject() stop := obj.Until(func(k, v *Term) bool { v2 := other.Get(k) @@ -2438,11 +2487,11 @@ func (obj *object) Filter(filter Object) (Object, error) { } // Len returns the number of elements in the object. -func (obj object) Len() int { +func (obj *object) Len() int { return len(obj.keys) } -func (obj object) String() string { +func (obj *object) String() string { sb := sbPool.Get().(*strings.Builder) sb.Reset() sb.Grow(obj.Len() * 32) @@ -2667,8 +2716,8 @@ func (obj *object) insert(k, v *Term, resetSortGuard bool) { // See https://github.com/golang/go/issues/25955 for why we do it this way. // Note that this will always be the case when external code calls insert via // Add, or otherwise. Internal code may however benefit from not having to - // re-create this pointer when it's known not to be needed. - obj.sortGuard = new(sync.Once) + // re-create this when it's known not to be needed. + obj.sortGuard = sync.Once{} } obj.hash += hash + v.Hash() @@ -2695,7 +2744,7 @@ func (obj *object) rehash() { } func filterObject(o Value, filter Value) (Value, error) { - if filter.Compare(Null{}) == 0 { + if (Null{}).Equal(filter) { return o, nil } @@ -3013,12 +3062,16 @@ func (c Call) String() string { func termSliceCopy(a []*Term) []*Term { cpy := make([]*Term, len(a)) - for i := range a { - cpy[i] = a[i].Copy() - } + termSliceCopyTo(a, cpy) return cpy } +func termSliceCopyTo(src, dst []*Term) { + for i := range src { + dst[i] = src[i].Copy() + } +} + func termSliceEqual(a, b []*Term) bool { if len(a) == len(b) { for i := range a { @@ -3243,7 +3296,7 @@ func unmarshalValue(d map[string]interface{}) (Value, error) { v := d["value"] switch d["type"] { case "null": - return Null{}, nil + return NullValue, nil case "boolean": if b, ok := v.(bool); ok { return Boolean(b), nil diff --git a/v1/ast/term_test.go b/v1/ast/term_test.go index 63d4e7e6ca..2d28a05895 100644 --- a/v1/ast/term_test.go +++ b/v1/ast/term_test.go @@ -277,7 +277,7 @@ func TestTermBadJSON(t *testing.T) { term := Term{} err := util.UnmarshalJSON([]byte(input), &term) expected := fmt.Errorf("ast: unable to unmarshal term") - if !reflect.DeepEqual(expected, err) { + if expected.Error() != err.Error() { t.Errorf("Expected %v but got: %v", expected, err) } } @@ -756,7 +756,7 @@ func TestSetMap(t *testing.T) { return nil, fmt.Errorf("oops") }) - if !reflect.DeepEqual(err, fmt.Errorf("oops")) { + if err.Error() != "oops" { t.Fatalf("Expected oops to be returned but got: %v, %v", result, err) } } @@ -1418,7 +1418,7 @@ func TestLazyObjectKeys(t *testing.T) { }) act := x.Keys() exp := []*Term{StringTerm("a"), StringTerm("b"), StringTerm("c")} - if !reflect.DeepEqual(exp, act) { + if !termSliceEqual(exp, act) { t.Errorf("expected Keys() %v, got %v", exp, act) } assertForced(t, x, false) @@ -1436,7 +1436,7 @@ func TestLazyObjectKeysIterator(t *testing.T) { act = append(act, k) } exp := []*Term{StringTerm("a"), StringTerm("b"), StringTerm("c")} - if !reflect.DeepEqual(exp, act) { + if !termSliceEqual(exp, act) { t.Errorf("expected Keys() %v, got %v", exp, act) } assertForced(t, x, false) diff --git a/v1/format/format.go b/v1/format/format.go index 56c30171dd..e86964d1b4 100644 --- a/v1/format/format.go +++ b/v1/format/format.go @@ -1637,7 +1637,7 @@ func ArityFormatMismatchError(operands []*ast.Term, operator string, loc *ast.Lo have := make([]string, len(operands)) for i := 0; i < len(operands); i++ { - have[i] = ast.TypeName(operands[i].Value) + have[i] = ast.ValueName(operands[i].Value) } err := ast.NewError(ast.TypeErr, loc, "%s: %s", operator, "arity mismatch") err.Details = &ArityFormatErrDetail{ diff --git a/v1/rego/rego.go b/v1/rego/rego.go index 1b7ea47bdd..ede02439dd 100644 --- a/v1/rego/rego.go +++ b/v1/rego/rego.go @@ -2598,7 +2598,7 @@ func (r *Rego) rewriteQueryForPartialEval(_ ast.QueryCompiler, query ast.Body) ( ref, ok := term.Value.(ast.Ref) if !ok { - return nil, fmt.Errorf("partial evaluation requires ref (not %v)", ast.TypeName(term.Value)) + return nil, fmt.Errorf("partial evaluation requires ref (not %v)", ast.ValueName(term.Value)) } if !ref.IsGround() { diff --git a/v1/topdown/aggregates.go b/v1/topdown/aggregates.go index e7d0578224..02425d2411 100644 --- a/v1/topdown/aggregates.go +++ b/v1/topdown/aggregates.go @@ -99,7 +99,7 @@ func builtinMax(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) err if a.Len() == 0 { return nil } - var max = ast.Value(ast.Null{}) + max := ast.InternedNullTerm.Value a.Foreach(func(x *ast.Term) { if ast.Compare(max, x.Value) <= 0 { max = x.Value @@ -110,7 +110,7 @@ func builtinMax(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) err if a.Len() == 0 { return nil } - max, err := a.Reduce(ast.NullTerm(), func(max *ast.Term, elem *ast.Term) (*ast.Term, error) { + max, err := a.Reduce(ast.InternedNullTerm, func(max *ast.Term, elem *ast.Term) (*ast.Term, error) { if ast.Compare(max, elem) <= 0 { return elem, nil } @@ -142,11 +142,11 @@ func builtinMin(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) err if a.Len() == 0 { return nil } - min, err := a.Reduce(ast.NullTerm(), func(min *ast.Term, elem *ast.Term) (*ast.Term, error) { + min, err := a.Reduce(ast.InternedNullTerm, func(min *ast.Term, elem *ast.Term) (*ast.Term, error) { // The null term is considered to be less than any other term, // so in order for min of a set to make sense, we need to check // for it. - if min.Value.Compare(ast.Null{}) == 0 { + if min.Value.Compare(ast.InternedNullTerm.Value) == 0 { return elem, nil } diff --git a/v1/topdown/array.go b/v1/topdown/array.go index d37204bef0..4a2a2ed148 100644 --- a/v1/topdown/array.go +++ b/v1/topdown/array.go @@ -20,6 +20,13 @@ func builtinArrayConcat(_ BuiltinContext, operands []*ast.Term, iter func(*ast.T return err } + if arrA.Len() == 0 { + return iter(operands[1]) + } + if arrB.Len() == 0 { + return iter(operands[0]) + } + arrC := make([]*ast.Term, arrA.Len()+arrB.Len()) i := 0 @@ -68,6 +75,10 @@ func builtinArraySlice(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te startIndex = stopIndex } + if startIndex == 0 && stopIndex >= arr.Len() { + return iter(operands[0]) + } + return iter(ast.NewTerm(arr.Slice(startIndex, stopIndex))) } diff --git a/v1/topdown/builtins/builtins.go b/v1/topdown/builtins/builtins.go index c788cf2536..45a0b88408 100644 --- a/v1/topdown/builtins/builtins.go +++ b/v1/topdown/builtins/builtins.go @@ -128,23 +128,23 @@ func NewOperandErr(pos int, f string, a ...interface{}) error { func NewOperandTypeErr(pos int, got ast.Value, expected ...string) error { if len(expected) == 1 { - return NewOperandErr(pos, "must be %v but got %v", expected[0], ast.TypeName(got)) + return NewOperandErr(pos, "must be %v but got %v", expected[0], ast.ValueName(got)) } - return NewOperandErr(pos, "must be one of {%v} but got %v", strings.Join(expected, ", "), ast.TypeName(got)) + return NewOperandErr(pos, "must be one of {%v} but got %v", strings.Join(expected, ", "), ast.ValueName(got)) } // NewOperandElementErr returns an operand error indicating an element in the // composite operand was wrong. func NewOperandElementErr(pos int, composite ast.Value, got ast.Value, expected ...string) error { - tpe := ast.TypeName(composite) + tpe := ast.ValueName(composite) if len(expected) == 1 { - return NewOperandErr(pos, "must be %v of %vs but got %v containing %v", tpe, expected[0], tpe, ast.TypeName(got)) + return NewOperandErr(pos, "must be %v of %vs but got %v containing %v", tpe, expected[0], tpe, ast.ValueName(got)) } - return NewOperandErr(pos, "must be %v of (any of) {%v} but got %v containing %v", tpe, strings.Join(expected, ", "), tpe, ast.TypeName(got)) + return NewOperandErr(pos, "must be %v of (any of) {%v} but got %v containing %v", tpe, strings.Join(expected, ", "), tpe, ast.ValueName(got)) } // NewOperandEnumErr returns an operand error indicating a value was wrong. @@ -233,7 +233,7 @@ func ObjectOperand(x ast.Value, pos int) (ast.Object, error) { func ArrayOperand(x ast.Value, pos int) (*ast.Array, error) { a, ok := x.(*ast.Array) if !ok { - return ast.NewArray(), NewOperandTypeErr(pos, x, "array") + return nil, NewOperandTypeErr(pos, x, "array") } return a, nil } diff --git a/v1/topdown/crypto.go b/v1/topdown/crypto.go index ff53550748..ab499e3e8f 100644 --- a/v1/topdown/crypto.go +++ b/v1/topdown/crypto.go @@ -15,6 +15,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/base64" + "encoding/hex" "encoding/json" "encoding/pem" "fmt" @@ -373,7 +374,7 @@ func builtinCryptoJWKFromPrivateKey(_ BuiltinContext, operands []*ast.Term, iter } if len(rawKeys) == 0 { - return iter(ast.NullTerm()) + return iter(ast.InternedNullTerm) } key, err := jwk.New(rawKeys[0]) @@ -407,7 +408,7 @@ func builtinCryptoParsePrivateKeys(_ BuiltinContext, operands []*ast.Term, iter } if string(input) == "" { - return iter(ast.NullTerm()) + return iter(ast.InternedNullTerm) } // get the raw private key @@ -417,7 +418,7 @@ func builtinCryptoParsePrivateKeys(_ BuiltinContext, operands []*ast.Term, iter } if len(rawKeys) == 0 { - return iter(ast.NewTerm(ast.NewArray())) + return iter(emptyArr) } bs, err := json.Marshal(rawKeys) @@ -438,36 +439,43 @@ func builtinCryptoParsePrivateKeys(_ BuiltinContext, operands []*ast.Term, iter return iter(ast.NewTerm(value)) } -func hashHelper(a ast.Value, h func(ast.String) string) (ast.Value, error) { - s, err := builtins.StringOperand(a, 1) - if err != nil { - return nil, err - } - return ast.String(h(s)), nil +func toHexEncodedString(src []byte) string { + dst := make([]byte, hex.EncodedLen(len(src))) + hex.Encode(dst, src) + return util.ByteSliceToString(dst) } func builtinCryptoMd5(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { - res, err := hashHelper(operands[0].Value, func(s ast.String) string { return fmt.Sprintf("%x", md5.Sum([]byte(s))) }) + s, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err } - return iter(ast.NewTerm(res)) + + md5sum := md5.Sum([]byte(s)) + + return iter(ast.StringTerm(toHexEncodedString(md5sum[:]))) } func builtinCryptoSha1(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { - res, err := hashHelper(operands[0].Value, func(s ast.String) string { return fmt.Sprintf("%x", sha1.Sum([]byte(s))) }) + s, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err } - return iter(ast.NewTerm(res)) + + sha1sum := sha1.Sum([]byte(s)) + + return iter(ast.StringTerm(toHexEncodedString(sha1sum[:]))) } func builtinCryptoSha256(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { - res, err := hashHelper(operands[0].Value, func(s ast.String) string { return fmt.Sprintf("%x", sha256.Sum256([]byte(s))) }) + s, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err } - return iter(ast.NewTerm(res)) + + sha256sum := sha256.Sum256([]byte(s)) + + return iter(ast.StringTerm(toHexEncodedString(sha256sum[:]))) } func hmacHelper(operands []*ast.Term, iter func(*ast.Term) error, h func() hash.Hash) error { @@ -724,9 +732,11 @@ func readCertFromFile(localCertFile string) ([]byte, error) { return certPEM, nil } +var beginPrefix = []byte("-----BEGIN ") + func getTLSx509KeyPairFromString(certPemBlock []byte, keyPemBlock []byte) (*tls.Certificate, error) { - if !strings.HasPrefix(string(certPemBlock), "-----BEGIN") { + if !bytes.HasPrefix(certPemBlock, beginPrefix) { s, err := base64.StdEncoding.DecodeString(string(certPemBlock)) if err != nil { return nil, err @@ -734,7 +744,7 @@ func getTLSx509KeyPairFromString(certPemBlock []byte, keyPemBlock []byte) (*tls. certPemBlock = s } - if !strings.HasPrefix(string(keyPemBlock), "-----BEGIN") { + if !bytes.HasPrefix(keyPemBlock, beginPrefix) { s, err := base64.StdEncoding.DecodeString(string(keyPemBlock)) if err != nil { return nil, err @@ -743,7 +753,7 @@ func getTLSx509KeyPairFromString(certPemBlock []byte, keyPemBlock []byte) (*tls. } // we assume it a DER certificate and try to convert it to a PEM. - if !bytes.HasPrefix(certPemBlock, []byte("-----BEGIN")) { + if !bytes.HasPrefix(certPemBlock, beginPrefix) { pemBlock := &pem.Block{ Type: "CERTIFICATE", diff --git a/v1/topdown/crypto_test.go b/v1/topdown/crypto_test.go index 839f03cf95..b5bd82b1a0 100644 --- a/v1/topdown/crypto_test.go +++ b/v1/topdown/crypto_test.go @@ -865,3 +865,29 @@ func TestExtractX509VerifyOptions(t *testing.T) { } } } + +// Before/after replacing sprintf("%x", ...) with hex.EncodeToString(...), and using +// util.ByteSliceToString to convert the resulting byte slice: +// BenchmarkMd5-10 3294998 435.2 ns/op 128 B/op 5 allocs/op +// BenchmarkMd5-10 6193455 180.9 ns/op 96 B/op 3 allocs/op +// ... +func BenchmarkMd5(b *testing.B) { + bctx := BuiltinContext{} + operands := []*ast.Term{ast.StringTerm("hello")} + expect := ast.String("5d41402abc4b2a76b9719d911017c592") + iter := func(result *ast.Term) error { + if !expect.Equal(result.Value) { + return fmt.Errorf("unexpected result: %v", result.Value) + } + return nil + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + err := builtinCryptoMd5(bctx, operands, iter) + if err != nil { + b.Fatalf("unexpected error: %v", err) + } + } +} diff --git a/v1/topdown/glob.go b/v1/topdown/glob.go index cda17f3827..4de17d06c5 100644 --- a/v1/topdown/glob.go +++ b/v1/topdown/glob.go @@ -13,8 +13,10 @@ import ( const globCacheMaxSize = 100 const globInterQueryValueCacheHits = "rego_builtin_glob_interquery_value_cache_hits" -var globCacheLock = sync.Mutex{} -var globCache map[string]glob.Glob +var noDelimiters = []rune{} +var dotDelimiters = []rune{'.'} +var globCacheLock = sync.RWMutex{} +var globCache = map[string]glob.Glob{} func builtinGlobMatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { pattern, err := builtins.StringOperand(operands[0].Value, 1) @@ -25,14 +27,14 @@ func builtinGlobMatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast. var delimiters []rune switch operands[1].Value.(type) { case ast.Null: - delimiters = []rune{} + delimiters = noDelimiters case *ast.Array: delimiters, err = builtins.RuneSliceOperand(operands[1].Value, 2) if err != nil { return err } if len(delimiters) == 0 { - delimiters = []rune{'.'} + delimiters = dotDelimiters } default: return builtins.NewOperandTypeErr(2, operands[1].Value, "array", "null") @@ -86,14 +88,15 @@ func globCompileAndMatch(bctx BuiltinContext, id, pattern, match string, delimit return res.Match(match), nil } - globCacheLock.Lock() - defer globCacheLock.Unlock() + globCacheLock.RLock() p, ok := globCache[id] + globCacheLock.RUnlock() if !ok { var err error if p, err = glob.Compile(pattern, delimiters...); err != nil { return false, err } + globCacheLock.Lock() if len(globCache) >= globCacheMaxSize { // Delete a (semi-)random key to make room for the new one. for k := range globCache { @@ -102,9 +105,10 @@ func globCompileAndMatch(bctx BuiltinContext, id, pattern, match string, delimit } } globCache[id] = p + globCacheLock.Unlock() } - out := p.Match(match) - return out, nil + + return p.Match(match), nil } func builtinGlobQuoteMeta(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -117,7 +121,6 @@ func builtinGlobQuoteMeta(_ BuiltinContext, operands []*ast.Term, iter func(*ast } func init() { - globCache = map[string]glob.Glob{} RegisterBuiltinFunc(ast.GlobMatch.Name, builtinGlobMatch) RegisterBuiltinFunc(ast.GlobQuoteMeta.Name, builtinGlobQuoteMeta) } diff --git a/v1/topdown/http.go b/v1/topdown/http.go index 20fea2d7a6..71c7c7d9eb 100644 --- a/v1/topdown/http.go +++ b/v1/topdown/http.go @@ -86,11 +86,24 @@ var cacheableHTTPStatusCodes = [...]int{ http.StatusNotImplemented, } +var ( + codeTerm = ast.StringTerm("code") + messageTerm = ast.StringTerm("message") + statusCodeTerm = ast.StringTerm("status_code") + errorTerm = ast.StringTerm("error") + methodTerm = ast.StringTerm("method") + urlTerm = ast.StringTerm("url") + + httpSendNetworkErrTerm = ast.StringTerm(HTTPSendNetworkErr) + httpSendInternalErrTerm = ast.StringTerm(HTTPSendInternalErr) +) + var ( allowedKeys = ast.NewSet() + keyCache = make(map[string]*ast.Term, len(allowedKeyNames)) cacheableCodes = ast.NewSet() - requiredKeys = ast.NewSet(ast.StringTerm("method"), ast.StringTerm("url")) - httpSendLatencyMetricKey = "rego_builtin_" + strings.ReplaceAll(ast.HTTPSend.Name, ".", "_") + requiredKeys = ast.NewSet(methodTerm, urlTerm) + httpSendLatencyMetricKey = "rego_builtin_http_send" httpSendInterQueryCacheHits = httpSendLatencyMetricKey + "_interquery_cache_hits" ) @@ -151,22 +164,24 @@ func builtinHTTPSend(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.T } func generateRaiseErrorResult(err error) *ast.Term { - obj := ast.NewObject() - obj.Insert(ast.StringTerm("status_code"), ast.InternedIntNumberTerm(0)) - - errObj := ast.NewObject() - + var errObj ast.Object switch err.(type) { case *url.Error: - errObj.Insert(ast.StringTerm("code"), ast.StringTerm(HTTPSendNetworkErr)) + errObj = ast.NewObject( + ast.Item(codeTerm, httpSendNetworkErrTerm), + ast.Item(messageTerm, ast.StringTerm(err.Error())), + ) default: - errObj.Insert(ast.StringTerm("code"), ast.StringTerm(HTTPSendInternalErr)) + errObj = ast.NewObject( + ast.Item(codeTerm, httpSendInternalErrTerm), + ast.Item(messageTerm, ast.StringTerm(err.Error())), + ) } - errObj.Insert(ast.StringTerm("message"), ast.StringTerm(err.Error())) - obj.Insert(ast.StringTerm("error"), ast.NewTerm(errObj)) - - return ast.NewTerm(obj) + return ast.NewTerm(ast.NewObject( + ast.Item(statusCodeTerm, ast.InternedIntNumberTerm(0)), + ast.Item(errorTerm, ast.NewTerm(errObj)), + )) } func getHTTPResponse(bctx BuiltinContext, req ast.Object) (*ast.Term, error) { @@ -212,21 +227,21 @@ func getHTTPResponse(bctx BuiltinContext, req ast.Object) (*ast.Term, error) { func getKeyFromRequest(req ast.Object) (ast.Object, error) { // deep copy so changes to key do not reflect in the request object key := req.Copy() - cacheIgnoredHeadersTerm := req.Get(ast.StringTerm("cache_ignored_headers")) + cacheIgnoredHeadersTerm := req.Get(keyCache["cache_ignored_headers"]) allHeadersTerm := req.Get(ast.StringTerm("headers")) // skip because no headers to delete if cacheIgnoredHeadersTerm == nil || allHeadersTerm == nil { // need to explicitly set cache_ignored_headers to null // equivalent requests might have different sets of exclusion lists - key.Insert(ast.StringTerm("cache_ignored_headers"), ast.NullTerm()) + key.Insert(ast.StringTerm("cache_ignored_headers"), ast.InternedNullTerm) return key, nil } var cacheIgnoredHeaders []string - var allHeaders map[string]interface{} err := ast.As(cacheIgnoredHeadersTerm.Value, &cacheIgnoredHeaders) if err != nil { return nil, err } + var allHeaders map[string]interface{} err = ast.As(allHeadersTerm.Value, &allHeaders) if err != nil { return nil, err @@ -238,14 +253,14 @@ func getKeyFromRequest(req ast.Object) (ast.Object, error) { if err != nil { return nil, err } - key.Insert(ast.StringTerm("headers"), ast.NewTerm(val)) + key.Insert(keyCache["headers"], ast.NewTerm(val)) // remove cache_ignored_headers key - key.Insert(ast.StringTerm("cache_ignored_headers"), ast.NullTerm()) + key.Insert(keyCache["cache_ignored_headers"], ast.InternedNullTerm) return key, nil } func init() { - createAllowedKeys() + createKeys() createCacheableHTTPStatusCodes() initDefaults() RegisterBuiltinFunc(ast.HTTPSend.Name, builtinHTTPSend) @@ -389,33 +404,24 @@ func verifyURLHost(bctx BuiltinContext, unverifiedURL string) error { } func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *http.Client, error) { - var url string - var method string - - // Additional CA certificates loading options. - var tlsCaCert []byte - var tlsCaCertEnvVar string - var tlsCaCertFile string - - // Client TLS certificate and key options. Each input source - // comes in a matched pair. - var tlsClientCert []byte - var tlsClientKey []byte - - var tlsClientCertEnvVar string - var tlsClientKeyEnvVar string - - var tlsClientCertFile string - var tlsClientKeyFile string - - var tlsServerName string - var body *bytes.Buffer - var rawBody *bytes.Buffer - var enableRedirect bool - var tlsUseSystemCerts *bool - var tlsConfig tls.Config - var customHeaders map[string]interface{} - var tlsInsecureSkipVerify bool + var ( + url, method string + // Additional CA certificates loading options. + tlsCaCert []byte + tlsCaCertEnvVar, tlsCaCertFile string + // Client TLS certificate and key options. Each input source + // comes in a matched pair. + tlsClientCert, tlsClientKey []byte + tlsClientCertEnvVar, tlsClientKeyEnvVar string + tlsClientCertFile, tlsClientKeyFile, tlsServerName string + + body, rawBody *bytes.Buffer + enableRedirect, tlsInsecureSkipVerify bool + tlsUseSystemCerts *bool + tlsConfig tls.Config + customHeaders map[string]interface{} + ) + timeout := defaultHTTPRequestTimeout for _, val := range obj.Keys() { @@ -724,7 +730,7 @@ func executeHTTPRequest(req *http.Request, client *http.Client, inputReqObj ast. var err error var retry int - retry, err = getNumberValFromReqObj(inputReqObj, ast.StringTerm("max_retry_attempts")) + retry, err = getNumberValFromReqObj(inputReqObj, keyCache["max_retry_attempts"]) if err != nil { return nil, err } @@ -1009,9 +1015,12 @@ func insertIntoHTTPSendInterQueryCache(bctx BuiltinContext, key ast.Value, resp return nil } -func createAllowedKeys() { +func createKeys() { for _, element := range allowedKeyNames { - allowedKeys.Add(ast.StringTerm(element)) + term := ast.StringTerm(element) + + allowedKeys.Add(term) + keyCache[element] = term } } @@ -1045,7 +1054,7 @@ func parseTimeout(timeoutVal ast.Value) (time.Duration, error) { } return timeout, nil default: - return timeout, builtins.NewOperandErr(1, "'timeout' must be one of {string, number} but got %s", ast.TypeName(t)) + return timeout, builtins.NewOperandErr(1, "'timeout' must be one of {string, number} but got %s", ast.ValueName(t)) } } @@ -1078,7 +1087,7 @@ func getNumberValFromReqObj(req ast.Object, key *ast.Term) (int, error) { } func getCachingMode(req ast.Object) (cachingMode, error) { - key := ast.StringTerm("caching_mode") + key := keyCache["caching_mode"] var s ast.String var ok bool if v := req.Get(key); v != nil { @@ -1477,11 +1486,11 @@ func (c *interQueryCache) CheckCache() (ast.Value, error) { return resp, nil } - c.forceJSONDecode, err = getBoolValFromReqObj(c.key, ast.StringTerm("force_json_decode")) + c.forceJSONDecode, err = getBoolValFromReqObj(c.key, keyCache["force_json_decode"]) if err != nil { return nil, handleHTTPSendErr(c.bctx, err) } - c.forceYAMLDecode, err = getBoolValFromReqObj(c.key, ast.StringTerm("force_yaml_decode")) + c.forceYAMLDecode, err = getBoolValFromReqObj(c.key, keyCache["force_yaml_decode"]) if err != nil { return nil, handleHTTPSendErr(c.bctx, err) } @@ -1545,11 +1554,11 @@ func (c *intraQueryCache) CheckCache() (ast.Value, error) { // InsertIntoCache inserts the key set on this object into the cache with the given value func (c *intraQueryCache) InsertIntoCache(value *http.Response) (ast.Value, error) { - forceJSONDecode, err := getBoolValFromReqObj(c.key, ast.StringTerm("force_json_decode")) + forceJSONDecode, err := getBoolValFromReqObj(c.key, keyCache["force_json_decode"]) if err != nil { return nil, handleHTTPSendErr(c.bctx, err) } - forceYAMLDecode, err := getBoolValFromReqObj(c.key, ast.StringTerm("force_yaml_decode")) + forceYAMLDecode, err := getBoolValFromReqObj(c.key, keyCache["force_yaml_decode"]) if err != nil { return nil, handleHTTPSendErr(c.bctx, err) } @@ -1580,12 +1589,12 @@ func (c *intraQueryCache) ExecuteHTTPRequest() (*http.Response, error) { } func useInterQueryCache(req ast.Object) (bool, *forceCacheParams, error) { - value, err := getBoolValFromReqObj(req, ast.StringTerm("cache")) + value, err := getBoolValFromReqObj(req, keyCache["cache"]) if err != nil { return false, nil, err } - valueForceCache, err := getBoolValFromReqObj(req, ast.StringTerm("force_cache")) + valueForceCache, err := getBoolValFromReqObj(req, keyCache["force_cache"]) if err != nil { return false, nil, err } @@ -1603,7 +1612,7 @@ type forceCacheParams struct { } func newForceCacheParams(req ast.Object) (*forceCacheParams, error) { - term := req.Get(ast.StringTerm("force_cache_duration_seconds")) + term := req.Get(keyCache["force_cache_duration_seconds"]) if term == nil { return nil, fmt.Errorf("'force_cache' set but 'force_cache_duration_seconds' parameter is missing") } @@ -1621,7 +1630,7 @@ func newForceCacheParams(req ast.Object) (*forceCacheParams, error) { func getRaiseErrorValue(req ast.Object) (bool, error) { result := ast.Boolean(true) var ok bool - if v := req.Get(ast.StringTerm("raise_error")); v != nil { + if v := req.Get(keyCache["raise_error"]); v != nil { if result, ok = v.Value.(ast.Boolean); !ok { return false, fmt.Errorf("invalid value for raise_error field") } diff --git a/v1/topdown/json.go b/v1/topdown/json.go index 57e079d2e0..5b7c414e40 100644 --- a/v1/topdown/json.go +++ b/v1/topdown/json.go @@ -189,7 +189,7 @@ func parsePath(path *ast.Term) (ast.Ref, error) { pathSegments = append(pathSegments, term) }) default: - return nil, builtins.NewOperandErr(2, "must be one of {set, array} containing string paths or array of path segments but got %v", ast.TypeName(p)) + return nil, builtins.NewOperandErr(2, "must be one of {set, array} containing string paths or array of path segments but got %v", ast.ValueName(p)) } return pathSegments, nil @@ -231,7 +231,7 @@ func pathsToObject(paths []ast.Ref) ast.Object { } if !done { - node.Insert(path[len(path)-1], ast.NullTerm()) + node.Insert(path[len(path)-1], ast.InternedNullTerm) } } diff --git a/v1/topdown/jsonschema.go b/v1/topdown/jsonschema.go index 588b7ec4ce..b1609fb044 100644 --- a/v1/topdown/jsonschema.go +++ b/v1/topdown/jsonschema.go @@ -61,7 +61,7 @@ func builtinJSONSchemaVerify(_ BuiltinContext, operands []*ast.Term, iter func(* return iter(newResultTerm(false, ast.StringTerm("jsonschema: "+err.Error()))) } - return iter(newResultTerm(true, ast.NullTerm())) + return iter(newResultTerm(true, ast.InternedNullTerm)) } // builtinJSONMatchSchema accepts 2 arguments both can be string or object and verifies if the document matches the JSON schema. diff --git a/v1/topdown/object.go b/v1/topdown/object.go index 11671da5f3..4db8fa8272 100644 --- a/v1/topdown/object.go +++ b/v1/topdown/object.go @@ -92,7 +92,7 @@ func builtinObjectFilter(_ BuiltinContext, operands []*ast.Term, iter func(*ast. filterObj := ast.NewObject() keys.Foreach(func(key *ast.Term) { - filterObj.Insert(key, ast.NullTerm()) + filterObj.Insert(key, ast.InternedNullTerm) }) // Actually do the filtering diff --git a/v1/topdown/print.go b/v1/topdown/print.go index 2d16c2baab..f852f3e320 100644 --- a/v1/topdown/print.go +++ b/v1/topdown/print.go @@ -62,7 +62,7 @@ func builtinPrintCrossProductOperands(bctx BuiltinContext, buf []string, operand xs, ok := operands.Elem(i).Value.(ast.Set) if !ok { - return Halt{Err: internalErr(bctx.Location, fmt.Sprintf("illegal argument type: %v", ast.TypeName(operands.Elem(i).Value)))} + return Halt{Err: internalErr(bctx.Location, fmt.Sprintf("illegal argument type: %v", ast.ValueName(operands.Elem(i).Value)))} } if xs.Len() == 0 { diff --git a/v1/topdown/runtime.go b/v1/topdown/runtime.go index f892f1751e..9323225832 100644 --- a/v1/topdown/runtime.go +++ b/v1/topdown/runtime.go @@ -12,14 +12,16 @@ import ( var configStringTerm = ast.StringTerm("config") +var nothingResolver ast.Resolver = illegalResolver{} + func builtinOPARuntime(bctx BuiltinContext, _ []*ast.Term, iter func(*ast.Term) error) error { if bctx.Runtime == nil { - return iter(ast.ObjectTerm()) + return iter(ast.InternedEmptyObject) } if bctx.Runtime.Get(configStringTerm) != nil { - iface, err := ast.ValueToInterface(bctx.Runtime.Value, illegalResolver{}) + iface, err := ast.ValueToInterface(bctx.Runtime.Value, nothingResolver) if err != nil { return err } diff --git a/v1/topdown/strings.go b/v1/topdown/strings.go index 8d6c753e6d..929a18ea0a 100644 --- a/v1/topdown/strings.go +++ b/v1/topdown/strings.go @@ -10,6 +10,8 @@ import ( "sort" "strconv" "strings" + "unicode" + "unicode/utf8" "github.com/tchap/go-patricia/v2/patricia" @@ -153,33 +155,48 @@ func builtinConcat(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) return err } - strs := []string{} + var strs []string switch b := operands[1].Value.(type) { case *ast.Array: - err := b.Iter(func(x *ast.Term) error { - s, ok := x.Value.(ast.String) + var l int + for i := 0; i < b.Len(); i++ { + s, ok := b.Elem(i).Value.(ast.String) if !ok { - return builtins.NewOperandElementErr(2, operands[1].Value, x.Value, "string") + return builtins.NewOperandElementErr(2, operands[1].Value, b.Elem(i).Value, "string") } - strs = append(strs, string(s)) - return nil - }) - if err != nil { - return err + l += len(string(s)) + } + + if b.Len() == 1 { + return iter(b.Elem(0)) } + + strs = make([]string, 0, l) + for i := 0; i < b.Len(); i++ { + strs = append(strs, string(b.Elem(i).Value.(ast.String))) + } + case ast.Set: - err := b.Iter(func(x *ast.Term) error { - s, ok := x.Value.(ast.String) + var l int + terms := b.Slice() + for i := 0; i < len(terms); i++ { + s, ok := terms[i].Value.(ast.String) if !ok { - return builtins.NewOperandElementErr(2, operands[1].Value, x.Value, "string") + return builtins.NewOperandElementErr(2, operands[1].Value, terms[i].Value, "string") } - strs = append(strs, string(s)) - return nil - }) - if err != nil { - return err + l += len(string(s)) + } + + if b.Len() == 1 { + return iter(b.Slice()[0]) + } + + strs = make([]string, 0, l) + for i := 0; i < b.Len(); i++ { + strs = append(strs, string(terms[i].Value.(ast.String))) } + default: return builtins.NewOperandTypeErr(2, operands[1].Value, "set", "array") } @@ -213,6 +230,10 @@ func builtinIndexOf(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) return fmt.Errorf("empty search character") } + if isASCII(string(base)) && isASCII(string(search)) { + return iter(ast.InternedIntNumberTerm(strings.Index(string(base), string(search)))) + } + baseRunes := []rune(string(base)) searchRunes := []rune(string(search)) searchLen := len(searchRunes) @@ -268,15 +289,10 @@ func builtinSubstring(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter if err != nil { return err } - runes := []rune(base) startIndex, err := builtins.IntOperand(operands[1].Value, 2) if err != nil { return err - } else if startIndex >= len(runes) { - return iter(ast.StringTerm("")) - } else if startIndex < 0 { - return fmt.Errorf("negative offset") } length, err := builtins.IntOperand(operands[2].Value, 3) @@ -284,18 +300,60 @@ func builtinSubstring(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter return err } - var s ast.String + if startIndex < 0 { + return fmt.Errorf("negative offset") + } + + sbase := string(base) + if sbase == "" { + return iter(ast.InternedEmptyString) + } + + // Optimized path for the likely common case of ASCII strings. + // This allocates less memory and runs in about 1/3 the time. + if isASCII(sbase) { + if startIndex >= len(sbase) { + return iter(ast.InternedEmptyString) + } + + if length < 0 { + return iter(ast.StringTerm(sbase[startIndex:])) + } + + upto := startIndex + length + if len(sbase) < upto { + upto = len(sbase) + } + return iter(ast.StringTerm(sbase[startIndex:upto])) + } + + runes := []rune(base) + + if startIndex >= len(runes) { + return iter(ast.InternedEmptyString) + } + + var s string if length < 0 { - s = ast.String(runes[startIndex:]) + s = string(runes[startIndex:]) } else { upto := startIndex + length if len(runes) < upto { upto = len(runes) } - s = ast.String(runes[startIndex:upto]) + s = string(runes[startIndex:upto]) } - return iter(ast.NewTerm(s)) + return iter(ast.StringTerm(s)) +} + +func isASCII(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] > unicode.MaxASCII { + return false + } + } + return true } func builtinContains(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -325,7 +383,6 @@ func builtinStringCount(_ BuiltinContext, operands []*ast.Term, iter func(*ast.T baseTerm := string(s) searchTerm := string(substr) - count := strings.Count(baseTerm, searchTerm) return iter(ast.InternedIntNumberTerm(count)) @@ -382,15 +439,22 @@ func builtinSplit(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) e if err != nil { return err } + d, err := builtins.StringOperand(operands[1].Value, 2) if err != nil { return err } + + if !strings.Contains(string(s), string(d)) { + return iter(ast.ArrayTerm(operands[0])) + } + elems := strings.Split(string(s), string(d)) arr := util.NewPtrSlice[ast.Term](len(elems)) for i := range elems { arr[i].Value = ast.String(elems[i]) } + return iter(ast.ArrayTerm(arr...)) } @@ -410,7 +474,12 @@ func builtinReplace(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) return err } - return iter(ast.StringTerm(strings.Replace(string(s), string(old), string(n), -1))) + replaced := strings.Replace(string(s), string(old), string(n), -1) + if replaced == string(s) { + return iter(operands[0]) + } + + return iter(ast.StringTerm(replaced)) } func builtinReplaceN(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -454,6 +523,11 @@ func builtinTrim(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) er return err } + trimmed := strings.Trim(string(s), string(c)) + if trimmed == string(s) { + return iter(operands[0]) + } + return iter(ast.StringTerm(strings.Trim(string(s), string(c)))) } @@ -468,7 +542,12 @@ func builtinTrimLeft(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term return err } - return iter(ast.StringTerm(strings.TrimLeft(string(s), string(c)))) + trimmed := strings.TrimLeft(string(s), string(c)) + if trimmed == string(s) { + return iter(operands[0]) + } + + return iter(ast.StringTerm(trimmed)) } func builtinTrimPrefix(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -482,7 +561,12 @@ func builtinTrimPrefix(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te return err } - return iter(ast.StringTerm(strings.TrimPrefix(string(s), string(pre)))) + trimmed := strings.TrimPrefix(string(s), string(pre)) + if trimmed == string(s) { + return iter(operands[0]) + } + + return iter(ast.StringTerm(trimmed)) } func builtinTrimRight(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -496,7 +580,12 @@ func builtinTrimRight(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter return err } - return iter(ast.StringTerm(strings.TrimRight(string(s), string(c)))) + trimmed := strings.TrimRight(string(s), string(c)) + if trimmed == string(s) { + return iter(operands[0]) + } + + return iter(ast.StringTerm(trimmed)) } func builtinTrimSuffix(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -510,7 +599,12 @@ func builtinTrimSuffix(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te return err } - return iter(ast.StringTerm(strings.TrimSuffix(string(s), string(suf)))) + trimmed := strings.TrimSuffix(string(s), string(suf)) + if trimmed == string(s) { + return iter(operands[0]) + } + + return iter(ast.StringTerm(trimmed)) } func builtinTrimSpace(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -519,7 +613,12 @@ func builtinTrimSpace(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter return err } - return iter(ast.StringTerm(strings.TrimSpace(string(s)))) + trimmed := strings.TrimSpace(string(s)) + if trimmed == string(s) { + return iter(operands[0]) + } + + return iter(ast.StringTerm(trimmed)) } func builtinSprintf(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -577,15 +676,23 @@ func builtinReverse(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) } func reverseString(str string) string { - sRunes := []rune(str) - length := len(sRunes) - reversedRunes := make([]rune, length) + var buf []byte + var arr [255]byte + size := len(str) + + if size < 255 { + buf = arr[:size:size] + } else { + buf = make([]byte, size) + } - for index, r := range sRunes { - reversedRunes[length-index-1] = r + for start := 0; start < size; { + r, n := utf8.DecodeRuneInString(str[start:]) + start += n + utf8.EncodeRune(buf[size-start:], r) } - return string(reversedRunes) + return string(buf) } func init() { diff --git a/v1/topdown/subset.go b/v1/topdown/subset.go index 08bdc8db45..29354d9730 100644 --- a/v1/topdown/subset.go +++ b/v1/topdown/subset.go @@ -88,9 +88,8 @@ func arraySet(t1, t2 *ast.Term) (bool, *ast.Array, ast.Set) { // associated with a key. func objectSubset(super ast.Object, sub ast.Object) bool { var superTerm *ast.Term - isSubset := true - sub.Until(func(key, subTerm *ast.Term) bool { + notSubset := sub.Until(func(key, subTerm *ast.Term) bool { // This really wants to be a for loop, hence the somewhat // weird internal structure. However, using Until() in this // was is a performance optimization, as it avoids performing @@ -98,10 +97,9 @@ func objectSubset(super ast.Object, sub ast.Object) bool { superTerm = super.Get(key) - // subTerm is can't be nil because we got it from Until(), so + // subTerm can't be nil because we got it from Until(), so // we only need to verify that super is non-nil. if superTerm == nil { - isSubset = false return true // break, not a subset } @@ -114,58 +112,39 @@ func objectSubset(super ast.Object, sub ast.Object) bool { // them normally. If only one term is an object, then we // do a normal comparison which will come up false. if ok, superObj, subObj := bothObjects(superTerm, subTerm); ok { - if !objectSubset(superObj, subObj) { - isSubset = false - return true // break, not a subset - } - - return false // continue + return !objectSubset(superObj, subObj) } if ok, superSet, subSet := bothSets(superTerm, subTerm); ok { - if !setSubset(superSet, subSet) { - isSubset = false - return true // break, not a subset - } - - return false // continue + return !setSubset(superSet, subSet) } if ok, superArray, subArray := bothArrays(superTerm, subTerm); ok { - if !arraySubset(superArray, subArray) { - isSubset = false - return true // break, not a subset - } - - return false // continue + return !arraySubset(superArray, subArray) } // We have already checked for exact equality, as well as for // all of the types of nested subsets we care about, so if we // get here it means this isn't a subset. - isSubset = false return true // break, not a subset }) - return isSubset + return !notSubset } // setSubset implements the subset operation on sets. // // Unlike in the object case, this is not recursive, we just compare values -// using ast.Set.Contains() because we have no well defined way to "match up" +// using ast.Set.Contains() because we have no well-defined way to "match up" // objects that are in different sets. func setSubset(super ast.Set, sub ast.Set) bool { - isSubset := true - sub.Until(func(t *ast.Term) bool { - if !super.Contains(t) { - isSubset = false - return true + for _, elem := range sub.Slice() { + if !super.Contains(elem) { + return false } - return false - }) + } - return isSubset + return true } // arraySubset implements the subset operation on arrays. @@ -197,12 +176,12 @@ func arraySubset(super, sub *ast.Array) bool { return false } - subElem := sub.Elem(subCursor) superElem := super.Elem(superCursor + subCursor) if superElem == nil { return false } + subElem := sub.Elem(subCursor) if superElem.Value.Compare(subElem.Value) == 0 { subCursor++ } else { diff --git a/v1/util/performance.go b/v1/util/performance.go index 03dc7d0601..b7222b23cb 100644 --- a/v1/util/performance.go +++ b/v1/util/performance.go @@ -1,6 +1,9 @@ package util -import "slices" +import ( + "slices" + "unsafe" +) // NewPtrSlice returns a slice of pointers to T with length n, // with only 2 allocations performed no matter the size of n. @@ -22,3 +25,15 @@ func GrowPtrSlice[T any](s []*T, n int) []*T { } return s } + +// Allocation free conversion from []byte to string (unsafe) +// Note that the byte slice must not be modified after conversion +func ByteSliceToString(bs []byte) string { + return unsafe.String(unsafe.SliceData(bs), len(bs)) +} + +// Allocation free conversion from ~string to []byte (unsafe) +// Note that the byte slice must not be modified after conversion +func StringToByteSlice[T ~string](s T) []byte { + return unsafe.Slice(unsafe.StringData(string(s)), len(s)) +}