Skip to content

Commit

Permalink
ast/visit: add SomeDecl to visitor walks (open-policy-agent#5515)
Browse files Browse the repository at this point in the history
These had been overlooked, it seems. Adjusted the test element counts
for the new nodes getting recorded.

It seems like there's no test coverage for the deprecated ast.Walk() and
ast.WalkBeforeAndAfter() methods.

Fixes open-policy-agent#5480.

Signed-off-by: Stephan Renatus <[email protected]>
  • Loading branch information
srenatus authored Jan 2, 2023
1 parent ed455fb commit 2b8cbda
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 269 deletions.
16 changes: 16 additions & 0 deletions ast/visit.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ func walk(v Visitor, x interface{}) {
Walk(w, x.Value)
Walk(w, x.Domain)
Walk(w, x.Body)
case *SomeDecl:
for i := range x.Symbols {
Walk(w, x.Symbols[i])
}
}
}

Expand Down Expand Up @@ -383,6 +387,10 @@ func (vis *GenericVisitor) Walk(x interface{}) {
vis.Walk(x.Value)
vis.Walk(x.Domain)
vis.Walk(x.Body)
case *SomeDecl:
for i := range x.Symbols {
vis.Walk(x.Symbols[i])
}
}
}

Expand Down Expand Up @@ -519,6 +527,10 @@ func (vis *BeforeAfterVisitor) Walk(x interface{}) {
vis.Walk(x.Value)
vis.Walk(x.Domain)
vis.Walk(x.Body)
case *SomeDecl:
for i := range x.Symbols {
vis.Walk(x.Symbols[i])
}
}
}

Expand Down Expand Up @@ -759,5 +771,9 @@ func (vis *VarVisitor) Walk(x interface{}) {
vis.Walk(x.Value)
vis.Walk(x.Domain)
vis.Walk(x.Body)
case *SomeDecl:
for i := range x.Symbols {
vis.Walk(x.Symbols[i])
}
}
}
290 changes: 21 additions & 269 deletions ast/visit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ t[x] = y {
y = [[x, z] | x = "x"; z = "z"]
z = {"foo": [x, z] | x = "x"; z = "z"}
s = {1 | a[i] = "foo"}
some x0, y0, z0
count({1, 2, 3}, n) with input.foo.bar as x
}
Expand All @@ -37,264 +38,10 @@ p { false } else { false } else { true }
fn([x, y]) = z { json.unmarshal(x, z); z > y }
`)
vis := &testVis{}

NewGenericVisitor(vis.Visit).Walk(rule)

/*
mod
package
data.a.b
term
data
term
a
term
b
import
term
input.x.y
term
input
term
x
term
y
z
rule
head
t
args
term
x
term
y
body
expr1
term
ref
term
=
term
ref1
term
p
term
x
term
object1
term
"foo"
term
array
term
y
term
2
term
object2
term
"bar"
term
3
expr2
term
ref2
term
q
term
x
expr3
term
ref
term
=
term
y
term
compr
term
array
term
x
term
z
body
expr4
term
ref
term
=
term
x
term
"x"
expr5
term
ref
term
=
term
z
term
"z"
expr4
term
ref
term
=
term
z
term
compr
key
term
"foo"
value
array
term
x
term
z
body
expr1
term
ref
term
=
term
x
term
"x"
expr2
term
ref
term
=
term
z
term
"z"
expr5
term
ref
term
=
term
s
term
compr
term
1
body
expr1
term
ref
term
=
term
ref
term
a
term
i
term
"foo"
expr6
term
ref
term
count
term
set
term
1
term
2
term
3
term
n
with
term
input.foo.bar
term
input
term
foo
term
bar
term
baz
rule
head
p
args
<nil> # not counted
term
true
body
expr
term
false
rule
head
p
args
<nil> # not counted
term
true
body
expr
term
false
rule
head
p
args
<nil> # not counted
term
true
body
expr
term
true
func
head
fn
args
term
array
term
x
term
y
term
z
body
expr1
term
ref
term
json
term
unmarshal
term
x
term
z
expr2
term
ref
term
>
term
z
term
y
*/
if len(vis.elems) != 246 {
t.Errorf("Expected exactly 246 elements in AST but got %d: %v", len(vis.elems), vis.elems)
if exp, act := 254, len(vis.elems); exp != act {
t.Errorf("Expected exactly %d elements in AST but got %d: %v", exp, act, vis.elems)
}
}

Expand Down Expand Up @@ -344,6 +91,7 @@ t[x] = y {
y = [[x, z] | x = "x"; z = "z"]
z = {"foo": [x, z] | x = "x"; z = "z"}
s = {1 | a[i] = "foo"}
some x0, y0, z0
count({1, 2, 3}, n) with input.foo.bar as x
}
Expand All @@ -359,7 +107,7 @@ fn([x, y]) = z { json.unmarshal(x, z); z > y }
})
vis.Walk(rule)

if len(elems) != 246 {
if len(elems) != 254 {
t.Errorf("Expected exactly 246 elements in AST but got %d: %v", len(elems), elems)
}
}
Expand All @@ -375,6 +123,7 @@ t[x] = y {
y = [[x, z] | x = "x"; z = "z"]
z = {"foo": [x, z] | x = "x"; z = "z"}
s = {1 | a[i] = "foo"}
some x0, y0, z0
count({1, 2, 3}, n) with input.foo.bar as x
}
Expand All @@ -393,11 +142,11 @@ fn([x, y]) = z { json.unmarshal(x, z); z > y }
})
vis.Walk(rule)

if exp, act := 256, len(before); exp != act {
if exp, act := 264, len(before); exp != act {
t.Errorf("Expected exactly %d before elements in AST but got %d: %v", exp, act, before)
}

if exp, act := 256, len(before); exp != act {
if exp, act := 264, len(before); exp != act {
t.Errorf("Expected exactly %d after elements in AST but got %d: %v", exp, act, after)
}
}
Expand All @@ -414,22 +163,25 @@ func TestVarVisitor(t *testing.T) {
{"data.foo[x] = bar.baz[y]", VarVisitorParams{SkipRefHead: true}, "[x, y]"},
{`foo = [x | data.a[i] = x]`, VarVisitorParams{SkipClosures: true}, "[foo, eq]"},
{`x = 1; y = 2; z = x + y; count([x, y, z], z)`, VarVisitorParams{}, "[x, y, z, eq, plus, count]"},
{"some x, y", VarVisitorParams{}, "[x, y]"},
}

for _, tc := range tests {
stmt := MustParseStatement(tc.stmt)
t.Run(tc.stmt, func(t *testing.T) {
stmt := MustParseStatement(tc.stmt)

expected := NewVarSet()
MustParseTerm(tc.expected).Value.(*Array).Foreach(func(x *Term) {
expected.Add(x.Value.(Var))
})
expected := NewVarSet()
MustParseTerm(tc.expected).Value.(*Array).Foreach(func(x *Term) {
expected.Add(x.Value.(Var))
})

vis := NewVarVisitor().WithParams(tc.params)
vis.Walk(stmt)
vis := NewVarVisitor().WithParams(tc.params)
vis.Walk(stmt)

if !vis.Vars().Equal(expected) {
t.Errorf("For %v w/ %v expected %v but got: %v", stmt, tc.params, expected, vis.Vars())
}
if !vis.Vars().Equal(expected) {
t.Errorf("Params %#v expected %v but got: %v", tc.params, expected, vis.Vars())
}
})
}
}

Expand Down

0 comments on commit 2b8cbda

Please sign in to comment.