Skip to content

Commit

Permalink
Merge pull request #40 from Herb-AI/dev
Browse files Browse the repository at this point in the history
Update add_rule and rulenode2expr
  • Loading branch information
THinnerichs authored Jan 16, 2024
2 parents 27f0618 + b75ad77 commit 0a6a3c6
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
39 changes: 32 additions & 7 deletions src/cfg/cfg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,39 @@ end
parse_rule!(v::Vector{Any}, r) = push!(v, r)

function parse_rule!(v::Vector{Any}, ex::Expr)
if ex.head == :call && ex.args[1] == :|
terms = length(ex.args) == 2 ?
collect(eval(ex.args[2])) : #|(a:c) case
ex.args[2:end] #a|b|c case
for t in terms
parse_rule!(v, t)
# Strips `LineNumberNode`s from the expression
Base.remove_linenums!(ex)

if ex.head == :call && ex.args[1] == :|
terms = _expand_shorthand(ex.args)

for t in terms
parse_rule!(v, t)
end
else
push!(v, ex)
end
end

function _expand_shorthand(args::Vector{Any})
# expand a rule using the `|` symbol:
# `X = |(1:3)`, `X = 1|2|3`, `X = |([1,2,3])`
# these should all be equivalent and should expand to
# the following 3 rules: `X = 1`, `X = 2`, and `X = 3`
if args[1] != :|
throw(ArgumentError("Tried to parse: $ex as a shorthand rule, but it is not a shorthand rule."))
end

if length(args) == 2
to_expand = args[2]
if to_expand.args[1] == :(:)
expanded = collect(to_expand.args[2]:to_expand.args[3]) # (1:3) case
else
expanded = to_expand.args # ([1,2,3]) case
end
elseif length(args) == 3
expanded = args[2:end] # 1|2|3 case
else
push!(v, ex)
throw(ArgumentError("Failed to parse shorthand for rule: $ex"))
end
end
2 changes: 1 addition & 1 deletion src/nodelocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ function Base.insert!(root::RuleNode, loc::NodeLoc, rulenode::RuleNode)
root.children = rulenode.children
end
return root
end
end
11 changes: 11 additions & 0 deletions src/rulenode_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,23 @@ function rulenode2expr(rulenode::RuleNode, grammar::Grammar)
end


function _rulenode2expr(rulenode::Hole, grammar::Grammar)
# Find the index of the first element that is true
index = findfirst(==(true), rulenode.domain)
return isnothing(index) ? :Nothing : grammar.types[index]
end
rulenode2expr(rulenode::Hole, grammar::Grammar) = _rulenode2expr(rulenode::Hole, grammar::Grammar)

function _rulenode2expr(expr::Expr, rulenode::RuleNode, grammar::Grammar, j=0)
for (k,arg) in enumerate(expr.args)
if isa(arg, Expr)
expr.args[k],j = _rulenode2expr(arg, rulenode, grammar, j)
elseif haskey(grammar.bytype, arg)
child = rulenode.children[j+=1]
if isa(child, Hole)
expr.args[k] = _rulenode2expr(child, grammar)
continue
end
expr.args[k] = (child._val !== nothing) ?
child._val : deepcopy(grammar.rules[child.ind])
if !isterminal(grammar, child)
Expand Down

0 comments on commit 0a6a3c6

Please sign in to comment.