Skip to content

Commit

Permalink
Merge pull request #98 from Herb-AI/expr2RuleNode
Browse files Browse the repository at this point in the history
Expr2 rule node
  • Loading branch information
ReubenJ authored Nov 27, 2024
2 parents 68b1edf + b133e0b commit 2224ae3
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/HerbGrammar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export
clearconstraints!,
addconstraint!,
merge_grammars!,
expr2rulenode,

@pcfgrammar,

Expand Down
170 changes: 170 additions & 0 deletions src/rulenode_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,176 @@ function _rulenode2expr(typ::Symbol, rulenode::AbstractRuleNode, grammar::Abstra
retval, j
end

# ---------------------------------------------
# expr2rulenode and associated functions
# ---------------------------------------------

function grammar_map_right_to_left(grammar::AbstractGrammar)
tags = Dict{Any,Any}()
for (l, r) in zip(grammar.types, grammar.rules)
tags[r] = l
end
return tags
end

function _expr2rulenode(expr::Expr, grammar::AbstractGrammar, tags::Dict{Any,Any})
if expr.head == :call

if !haskey(tags, expr)

parameters = [_expr2rulenode(expr.args[i], grammar, tags) for i in (2:length(expr.args))]
pl = map( x -> x[1], parameters)
pr = map( x -> x[2], parameters)

temp = [expr.args[1] ;pl]
newexpr = Expr(:call, temp...)
rule = findfirst(==(newexpr), grammar.rules)


oldpl = copy(pl)
oldpr = copy(pr)
pnr = length(pl)

while isnothing(rule)

updatedrule = findfirst(==(pl[pnr]), grammar.rules)

if isnothing(updatedrule)
pl[pnr] = oldpl[pnr]
pr[pnr] = oldpr[pnr]
pnr = pnr - 1
continue
end

pl[pnr] = tags[pl[pnr]]
pr[pnr] = RuleNode(updatedrule, [pr[pnr]])

temp = [expr.args[1] ;pl]
newexpr = Expr(:call, temp...)
rule = findfirst(==(newexpr), grammar.rules)

pnr = length(pl)
end
return (tags[newexpr], RuleNode(rule, pr))
else
rule = findfirst(==(expr), grammar.rules)
return (tags[expr], RuleNode(rule, []))
end
elseif expr.head == :block
(l1, r1) = _expr2rulenode( expr.args[1], grammar, tags)
(l2, r2) = _expr2rulenode( expr.args[3], grammar, tags)

temp = (l1, l2)

newexpr = Expr(:block, temp...)
rule = findfirst(==(newexpr), grammar.rules)

pl = [l1, l2]
pr = [r1, r2]

oldpl = copy(pl)
oldpr = copy(pr)
pnr = length(pl)

while isnothing(rule)

updatedrule = findfirst(==(pl[pnr]), grammar.rules)

if isnothing(updatedrule)
pl[pnr] = oldpl[pnr]
pr[pnr] = oldpr[pnr]
pnr = pnr - 1
continue
end

pl[pnr] = tags[pl[pnr]]
pr[pnr] = RuleNode(updatedrule, [pr[pnr]])

temp = (pl[1], pl[2])
newexpr = Expr(:block, temp...)
rule = findfirst(==(newexpr), grammar.rules)

pnr = length(pl)
end
return (tags[newexpr], RuleNode(rule, pr))

elseif expr.head == :quote
return _expr2rulenode(expr.args[1], grammar, tags)
else
error("Only call and block expressions are supported")
end
end

function _expr2rulenode(expr::Any, grammar::AbstractGrammar, tags::Dict{Any,Any})
rule = findfirst(==(expr), grammar.rules)
return (tags[expr], RuleNode(rule, []))
end

"""
expr2rulenode(expr::Expr, grammar::AbstractGrammar, startSymbol::Symbol)
Converts an expression into a [`AbstractRuleNode`](@ref) corresponding to the rule definitions in the grammar.
"""
function expr2rulenode(expr::Expr, grammar::AbstractGrammar, startSymbol::Symbol)
tags = grammar_map_right_to_left(grammar)
(s, rn) = _expr2rulenode(expr, grammar, tags)
while s != startSymbol

updatedrule = findfirst(==(s), grammar.rules)

if isnothing(updatedrule)
error("INVALID STARTING SYMBOL")
end

s = tags[s]
rn = RuleNode(updatedrule, [rn])
end
return rn
end

"""
expr2rulenode(expr::Expr, grammar::AbstractGrammar)
Converts an expression into a [`AbstractRuleNode`](@ref) corresponding to the rule definitions in the grammar.
"""
function expr2rulenode(expr::Expr, grammar::AbstractGrammar)
tags = grammar_map_right_to_left(grammar)
(s, rn) = _expr2rulenode(expr, grammar, tags)
return rn
end

"""
expr2rulenode(expr::Symbol, grammar::AbstractGrammar, startSymbol::Symbol)
Converts an expression into a [`AbstractRuleNode`](@ref) corresponding to the rule definitions in the grammar.
"""
function expr2rulenode(expr::Symbol, grammar::AbstractGrammar, startSymbol::Symbol)
tags = get_tags(grammar)
(s, rn) = expr2rulenode(expr, grammar, tags)
while s != startSymbol

updatedrule = findfirst(==(s), grammar.rules)

if isnothing(updatedrule)
error("INVALID STARTING SYMBOL")
end

s = tags[s]
rn = RuleNode(updatedrule, [rn])
end
return rn
end

"""
expr2rulenode(expr::Symbol, grammar::AbstractGrammar)
Converts an expression into a [`AbstractRuleNode`](@ref) corresponding to the rule definitions in the grammar.
"""
function expr2rulenode(expr::Symbol, grammar::AbstractGrammar)
tags = get_tags(grammar)
(s, rn) = expr2rulenode(expr, grammar, tags)
return rn
end

"""
Calculates the log probability associated with a rulenode in a probabilistic grammar.
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ using HerbCore
using HerbGrammar
using Test


@testset "HerbGrammar.jl" verbose=true begin
include("test_csg.jl")
include("test_rulenode_operators.jl")
include("test_rulenode2expr.jl")
include("test_expr2rulenode.jl")
end
47 changes: 47 additions & 0 deletions test/test_expr2rulenode.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
@testset verbose=true "expr2rulenode" begin

g1 = @cfgrammar begin
Number = |(1:2)
Number = x
Number = Number + Number
Number = Number * Number
Number = DiffNumber
DiffNumber = |(3:4)
end

expr1 = :(1 + 2)
expr2 = :((x * (1 + 3)) + (4 * x))
@test expr2rulenode(expr1, g1) == RuleNode(4, [RuleNode(1, []), RuleNode(2, [])])
@test expr2rulenode(expr2, g1) == RuleNode(4, [ RuleNode(5, [RuleNode(3, []), RuleNode(4, [RuleNode(1, []), RuleNode(6,[RuleNode(7, [])])])]), RuleNode(5,[RuleNode(6, [RuleNode(8, [])]), RuleNode(3, [])])])


g2 = @csgrammar begin
Start = Sequence #1

Sequence = Operation #2
Sequence = (Operation; Sequence) #3
Operation = Transformation #4
Operation = ControlStatement #5

Transformation = moveRight() | moveDown() | moveLeft() | moveUp() | drop() | grab() #6
ControlStatement = IF(Condition, Sequence, Sequence) #12
ControlStatement = WHILE(Condition, Sequence) #13

Condition = atTop() | atBottom() | atLeft() | atRight() | notAtTop() | notAtBottom() | notAtLeft() | notAtRight() #14
end

expr3 = :(moveUp())
expr4 = :(moveUp(); (moveRight()))
expr5 = :(IF(atTop(), ((moveUp(); (moveRight()))), moveRight()))

@test expr2rulenode(expr3, g2) == RuleNode(9, [])
@test expr2rulenode(expr3, g2, :Start) == RuleNode(1, [RuleNode(2, [RuleNode(4, [RuleNode(9, [])])])])

@test expr2rulenode(expr4, g2) == RuleNode(3, [RuleNode(4, [RuleNode(9, [])]) , RuleNode(2, [RuleNode(4, [RuleNode(6, [])])])])
@test expr2rulenode(expr4, g2, :Start) == RuleNode(1, [RuleNode(3, [RuleNode(4, [RuleNode(9, [])]) , RuleNode(2, [RuleNode(4, [RuleNode(6, [])])])])])

@test expr2rulenode(expr5, g2) == RuleNode(12, [RuleNode(14, []), RuleNode(3, [RuleNode(4, [RuleNode(9, [])]) , RuleNode(2, [RuleNode(4, [RuleNode(6, [])])])]), RuleNode(2, [RuleNode(4, [RuleNode(6, [])])])])
@test expr2rulenode(expr5, g2, :Start) == RuleNode(1, [RuleNode(2, [RuleNode(5, [RuleNode(12, [RuleNode(14, []), RuleNode(3, [RuleNode(4, [RuleNode(9, [])]) , RuleNode(2, [RuleNode(4, [RuleNode(6, [])])])]), RuleNode(2, [RuleNode(4, [RuleNode(6, [])])])])])])])

end

0 comments on commit 2224ae3

Please sign in to comment.