Skip to content

Commit

Permalink
Tweaks to reduce number of allocations in regal lint hot path
Browse files Browse the repository at this point in the history
Concluding my quest to reduce the number of allocations in the
hot path for `regal lint` for this time around. This PR mainly
does so by reusing pointers to boolean and integer terms where
these are determined not to be mutated later.

The result is another ~4 million allocations reduced when
linting Regal against its own bundle. These improvements should
however help reduce allocations in pretty much any evaluation.

**opa main**
```
BenchmarkRegalLintingItself-10    1	3195257584 ns/op	6496097784 B/op	120108808 allocs/op
```

**PR branch**
```
BenchmarkRegalLintingItself-10    1	3132126333 ns/op	6376318224 B/op	116163318 allocs/op
```

Signed-off-by: Anders Eknert <[email protected]>
  • Loading branch information
anderseknert authored and srenatus committed Nov 14, 2024
1 parent d0dfc04 commit d3f5102
Show file tree
Hide file tree
Showing 28 changed files with 828 additions and 176 deletions.
561 changes: 561 additions & 0 deletions ast/interning.go

Large diffs are not rendered by default.

18 changes: 12 additions & 6 deletions ast/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,11 @@ func (p *Parser) parseTermInfixCallInList() *Term {
return p.parseTermIn(nil, false, p.s.loc.Offset)
}

// use static references to avoid allocations, and
// copy them to the call term only when needed
var memberWithKeyRef = MemberWithKey.Ref()
var memberRef = Member.Ref()

func (p *Parser) parseTermIn(lhs *Term, keyVal bool, offset int) *Term {
// NOTE(sr): `in` is a bit special: besides `lhs in rhs`, it also
// supports `key, val in rhs`, so it can have an optional second lhs.
Expand All @@ -1365,7 +1370,8 @@ func (p *Parser) parseTermIn(lhs *Term, keyVal bool, offset int) *Term {
s := p.save()
p.scan()
if mhs := p.parseTermRelation(nil, offset); mhs != nil {
if op := p.parseTermOpName(MemberWithKey.Ref(), tokens.In); op != nil {

if op := p.parseTermOpName(memberWithKeyRef, tokens.In); op != nil {
if rhs := p.parseTermRelation(nil, p.s.loc.Offset); rhs != nil {
call := p.setLoc(CallTerm(op, lhs, mhs, rhs), lhs.Location, offset, p.s.lastEnd)
switch p.s.tok {
Expand All @@ -1379,7 +1385,7 @@ func (p *Parser) parseTermIn(lhs *Term, keyVal bool, offset int) *Term {
}
p.restore(s)
}
if op := p.parseTermOpName(Member.Ref(), tokens.In); op != nil {
if op := p.parseTermOpName(memberRef, tokens.In); op != nil {
if rhs := p.parseTermRelation(nil, p.s.loc.Offset); rhs != nil {
call := p.setLoc(CallTerm(op, lhs, rhs), lhs.Location, offset, p.s.lastEnd)
switch p.s.tok {
Expand Down Expand Up @@ -1616,8 +1622,7 @@ func (p *Parser) parseNumber() *Term {

// Note: Use the original string, do *not* round trip from
// the big.Float as it can cause precision loss.
r := NumberTerm(json.Number(s)).SetLocation(loc)
return r
return NumberTerm(json.Number(s)).SetLocation(loc)
}

func (p *Parser) parseString() *Term {
Expand Down Expand Up @@ -2055,10 +2060,11 @@ func (p *Parser) parseTermOp(values ...tokens.Token) *Term {
func (p *Parser) parseTermOpName(ref Ref, values ...tokens.Token) *Term {
for i := range values {
if p.s.tok == values[i] {
for _, r := range ref {
cp := ref.Copy()
for _, r := range cp {
r.SetLocation(p.s.Loc())
}
t := RefTerm(ref...)
t := RefTerm(cp...)
t.SetLocation(p.s.Loc())
p.scan()
return t
Expand Down
4 changes: 2 additions & 2 deletions ast/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ var RootDocumentRefs = NewSet(

// SystemDocumentKey is the name of the top-level key that identifies the system
// document.
var SystemDocumentKey = String("system")
const SystemDocumentKey = String("system")

// ReservedVars is the set of names that refer to implicitly ground vars.
var ReservedVars = NewVarSet(
Expand All @@ -97,7 +97,7 @@ var Wildcard = &Term{Value: Var("_")}

// WildcardPrefix is the special character that all wildcard variables are
// prefixed with when the statement they are contained in is parsed.
var WildcardPrefix = "$"
const WildcardPrefix = "$"

// Keywords contains strings that map to language keywords.
var Keywords = KeywordsForRegoVersion(DefaultRegoVersion)
Expand Down
101 changes: 73 additions & 28 deletions ast/term.go
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,12 @@ func FloatNumberTerm(f float64) *Term {
func (num Number) Equal(other Value) bool {
switch other := other.(type) {
case Number:
n1, ok1 := num.Int64()
n2, ok2 := other.Int64()
if ok1 && ok2 && n1 == n2 {
return true
}

return Compare(num, other) == 0
default:
return false
Expand Down Expand Up @@ -1108,26 +1114,46 @@ func IsVarCompatibleString(s string) bool {
return varRegexp.MatchString(s)
}

var sbPool = sync.Pool{
New: func() any {
return &strings.Builder{}
},
}

func (ref Ref) String() string {
if len(ref) == 0 {
return ""
}
buf := []string{ref[0].Value.String()}
path := ref[1:]
for _, p := range path {

sb := sbPool.Get().(*strings.Builder)
sb.Reset()

defer sbPool.Put(sb)

sb.Grow(10 * len(ref))

sb.WriteString(ref[0].Value.String())

for _, p := range ref[1:] {
switch p := p.Value.(type) {
case String:
str := string(p)
if varRegexp.MatchString(str) && len(buf) > 0 && !IsKeyword(str) {
buf = append(buf, "."+str)
if varRegexp.MatchString(str) && !IsKeyword(str) {
sb.WriteByte('.')
sb.WriteString(str)
} else {
buf = append(buf, "["+p.String()+"]")
sb.WriteString(`["`)
sb.WriteString(str)
sb.WriteString(`"]`)
}
default:
buf = append(buf, "["+p.String()+"]")
sb.WriteByte('[')
sb.WriteString(p.String())
sb.WriteByte(']')
}
}
return strings.Join(buf, "")

return sb.String()
}

// OutputVars returns a VarSet containing variables that would be bound by evaluating
Expand Down Expand Up @@ -1271,16 +1297,22 @@ func (arr *Array) MarshalJSON() ([]byte, error) {
}

func (arr *Array) String() string {
var b strings.Builder
b.WriteRune('[')
sb := sbPool.Get().(*strings.Builder)
sb.Reset()
sb.Grow(len(arr.elems) * 16)

defer sbPool.Put(sb)

sb.WriteRune('[')
for i, e := range arr.elems {
if i > 0 {
b.WriteString(", ")
sb.WriteString(", ")
}
b.WriteString(e.String())
sb.WriteString(e.String())
}
b.WriteRune(']')
return b.String()
sb.WriteRune(']')

return sb.String()
}

// Len returns the number of elements in the array.
Expand Down Expand Up @@ -1460,16 +1492,23 @@ func (s *set) String() string {
if s.Len() == 0 {
return "set()"
}
var b strings.Builder
b.WriteRune('{')

sb := sbPool.Get().(*strings.Builder)
sb.Reset()
sb.Grow(s.Len() * 16)

defer sbPool.Put(sb)

sb.WriteRune('{')
for i := range s.sortedKeys() {
if i > 0 {
b.WriteString(", ")
sb.WriteString(", ")
}
b.WriteString(s.keys[i].Value.String())
sb.WriteString(s.keys[i].Value.String())
}
b.WriteRune('}')
return b.String()
sb.WriteRune('}')

return sb.String()
}

func (s *set) sortedKeys() []*Term {
Expand Down Expand Up @@ -2367,19 +2406,25 @@ func (obj object) Len() int {
}

func (obj object) String() string {
var b strings.Builder
b.WriteRune('{')
sb := sbPool.Get().(*strings.Builder)
sb.Reset()
sb.Grow(obj.Len() * 32)

defer sbPool.Put(sb)

sb.WriteRune('{')

for i, elem := range obj.sortedKeys() {
if i > 0 {
b.WriteString(", ")
sb.WriteString(", ")
}
b.WriteString(elem.key.String())
b.WriteString(": ")
b.WriteString(elem.value.String())
sb.WriteString(elem.key.String())
sb.WriteString(": ")
sb.WriteString(elem.value.String())
}
b.WriteRune('}')
return b.String()
sb.WriteRune('}')

return sb.String()
}

func (obj *object) get(k *Term) *objectElem {
Expand Down
5 changes: 2 additions & 3 deletions internal/edittree/edittree.go
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ func (e *EditTree) Unfold(path ast.Ref) (*EditTree, error) {

// Fall back to looking up the key in e.value.
// Extend the tree if key is present. Error otherwise.
if v, err := x.Find(ast.Ref{ast.IntNumberTerm(idx)}); err == nil {
if v, err := x.Find(ast.Ref{ast.InternedIntNumberTerm(idx)}); err == nil {
// TODO: Consider a more efficient "Replace" function that special-cases this for arrays instead?
_, err := e.Delete(ast.IntNumberTerm(idx))
if err != nil {
Expand Down Expand Up @@ -1026,8 +1026,7 @@ func (e *EditTree) Exists(path ast.Ref) bool {
}
// Fallback if child lookup failed.
// We have to ensure that the lookup term is a number here, or Find will fail.
k := ast.Ref{ast.IntNumberTerm(idx)}.Concat(path[1:])
_, err = x.Find(k)
_, err = x.Find(ast.Ref{ast.InternedIntNumberTerm(idx)}.Concat(path[1:]))
return err == nil
default:
// Catch all primitive types.
Expand Down
55 changes: 24 additions & 31 deletions topdown/aggregates.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ import (
func builtinCount(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
switch a := operands[0].Value.(type) {
case *ast.Array:
return iter(ast.IntNumberTerm(a.Len()))
return iter(ast.InternedIntNumberTerm(a.Len()))
case ast.Object:
return iter(ast.IntNumberTerm(a.Len()))
return iter(ast.InternedIntNumberTerm(a.Len()))
case ast.Set:
return iter(ast.IntNumberTerm(a.Len()))
return iter(ast.InternedIntNumberTerm(a.Len()))
case ast.String:
return iter(ast.IntNumberTerm(len([]rune(a))))
return iter(ast.InternedIntNumberTerm(len([]rune(a))))
}
return builtins.NewOperandTypeErr(1, operands[0].Value, "array", "object", "set", "string")
}
Expand Down Expand Up @@ -178,26 +178,26 @@ func builtinAll(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) err
switch val := operands[0].Value.(type) {
case ast.Set:
res := true
match := ast.BooleanTerm(true)
match := ast.InternedBooleanTerm(true)
val.Until(func(term *ast.Term) bool {
if !match.Equal(term) {
res = false
return true
}
return false
})
return iter(ast.BooleanTerm(res))
return iter(ast.InternedBooleanTerm(res))
case *ast.Array:
res := true
match := ast.BooleanTerm(true)
match := ast.InternedBooleanTerm(true)
val.Until(func(term *ast.Term) bool {
if !match.Equal(term) {
res = false
return true
}
return false
})
return iter(ast.BooleanTerm(res))
return iter(ast.InternedBooleanTerm(res))
default:
return builtins.NewOperandTypeErr(1, operands[0].Value, "array", "set")
}
Expand All @@ -206,19 +206,19 @@ func builtinAll(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) err
func builtinAny(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
switch val := operands[0].Value.(type) {
case ast.Set:
res := val.Len() > 0 && val.Contains(ast.BooleanTerm(true))
return iter(ast.BooleanTerm(res))
res := val.Len() > 0 && val.Contains(ast.InternedBooleanTerm(true))
return iter(ast.InternedBooleanTerm(res))
case *ast.Array:
res := false
match := ast.BooleanTerm(true)
match := ast.InternedBooleanTerm(true)
val.Until(func(term *ast.Term) bool {
if match.Equal(term) {
res = true
return true
}
return false
})
return iter(ast.BooleanTerm(res))
return iter(ast.InternedBooleanTerm(res))
default:
return builtins.NewOperandTypeErr(1, operands[0].Value, "array", "set")
}
Expand All @@ -228,27 +228,20 @@ func builtinMember(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term)
containee := operands[0]
switch c := operands[1].Value.(type) {
case ast.Set:
return iter(ast.BooleanTerm(c.Contains(containee)))
return iter(ast.InternedBooleanTerm(c.Contains(containee)))
case *ast.Array:
ret := false
c.Until(func(v *ast.Term) bool {
if v.Value.Compare(containee.Value) == 0 {
ret = true
for i := 0; i < c.Len(); i++ {
if c.Elem(i).Value.Compare(containee.Value) == 0 {
return iter(ast.InternedBooleanTerm(true))
}
return ret
})
return iter(ast.BooleanTerm(ret))
}
return iter(ast.InternedBooleanTerm(false))
case ast.Object:
ret := false
c.Until(func(_, v *ast.Term) bool {
if v.Value.Compare(containee.Value) == 0 {
ret = true
}
return ret
})
return iter(ast.BooleanTerm(ret))
return iter(ast.InternedBooleanTerm(c.Until(func(_, v *ast.Term) bool {
return v.Value.Compare(containee.Value) == 0
})))
}
return iter(ast.BooleanTerm(false))
return iter(ast.InternedBooleanTerm(false))
}

func builtinMemberWithKey(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
Expand All @@ -259,9 +252,9 @@ func builtinMemberWithKey(_ BuiltinContext, operands []*ast.Term, iter func(*ast
if act := c.Get(key); act != nil {
ret = act.Value.Compare(val.Value) == 0
}
return iter(ast.BooleanTerm(ret))
return iter(ast.InternedBooleanTerm(ret))
}
return iter(ast.BooleanTerm(false))
return iter(ast.InternedBooleanTerm(false))
}

func init() {
Expand Down
Loading

0 comments on commit d3f5102

Please sign in to comment.