Skip to content

Commit

Permalink
Merge pull request #56 from Herb-AI/HerbCore-0.2
Browse files Browse the repository at this point in the history
Grammar -> AbstractGrammar according to HerbCore 0.2
  • Loading branch information
THinnerichs authored Feb 26, 2024
2 parents 7ccaa2e + e4664e9 commit d9f62f8
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 92 deletions.
10 changes: 5 additions & 5 deletions src/csg/csg.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
ContextSensitiveGrammar <: Grammar
ContextSensitiveGrammar <: AbstractGrammar
Represents a context-sensitive grammar.
Extends [`Grammar`](@ref) with constraints.
Extends [`AbstractGrammar`](@ref) with constraints.
Consists of:
Expand All @@ -22,7 +22,7 @@ Consists of:
Use the [`@csgrammar`](@ref) macro to create a [`ContextSensitiveGrammar`](@ref) object.
Use the [`@pcsgrammar`](@ref) macro to create a [`ContextSensitiveGrammar`](@ref) object with probabilities.
"""
mutable struct ContextSensitiveGrammar <: Grammar
mutable struct ContextSensitiveGrammar <: AbstractGrammar
rules::Vector{Any}
types::Vector{Union{Symbol, Nothing}}
isterminal::BitVector
Expand Down Expand Up @@ -190,11 +190,11 @@ function Base.display(rulenode::RuleNode, grammar::ContextSensitiveGrammar)
end

"""
merge_grammars!(merge_to::Grammar, merge_from::Grammar)
merge_grammars!(merge_to::AbstractGrammar, merge_from::AbstractGrammar)
Adds all rules and constraints from `merge_from` to `merge_to`.
"""
function merge_grammars!(merge_to::Grammar, merge_from::Grammar)
function merge_grammars!(merge_to::AbstractGrammar, merge_from::AbstractGrammar)
for i in eachindex(merge_from.rules)
expression = :($(merge_from.types[i]) = $(merge_from.rules[i]))
add_rule!(merge_to, expression)
Expand Down
78 changes: 39 additions & 39 deletions src/grammar_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,94 +45,94 @@ function get_childtypes(rule::Any, types::AbstractVector{Symbol})
return retval
end

Base.getindex(grammar::Grammar, typ::Symbol) = grammar.bytype[typ]
Base.getindex(grammar::AbstractGrammar, typ::Symbol) = grammar.bytype[typ]

"""
nonterminals(grammar::Grammar)::Vector{Symbol}
nonterminals(grammar::AbstractGrammar)::Vector{Symbol}
Returns a list of the nonterminals or types in the [`Grammar`](@ref).
Returns a list of the nonterminals or types in the [`AbstractGrammar`](@ref).
"""
nonterminals(grammar::Grammar)::Vector{Symbol} = collect(keys(grammar.bytype))
nonterminals(grammar::AbstractGrammar)::Vector{Symbol} = collect(keys(grammar.bytype))


"""
return_type(grammar::Grammar, rule_index::Int)::Symbol
return_type(grammar::AbstractGrammar, rule_index::Int)::Symbol
Returns the type of the production rule at `rule_index`.
"""
return_type(grammar::Grammar, rule_index::Int) = grammar.types[rule_index]
return_type(grammar::AbstractGrammar, rule_index::Int) = grammar.types[rule_index]


"""
child_types(grammar::Grammar, rule_index::Int)
child_types(grammar::AbstractGrammar, rule_index::Int)
Returns the types of the children (nonterminals) of the production rule at `rule_index`.
"""
child_types(grammar::Grammar, rule_index::Int) = grammar.childtypes[rule_index]
child_types(grammar::AbstractGrammar, rule_index::Int) = grammar.childtypes[rule_index]


"""
get_domain(g::Grammar, type::Symbol)::BitVector
get_domain(g::AbstractGrammar, type::Symbol)::BitVector
Returns the domain for the hole of a certain type as a `BitVector` of the same length as the number of
rules in the grammar. Bit `i` is set to `true` iff rule `i` is of type `type`.
!!! info
Since this function can be intensively used when exploring a program space defined by a grammar,
the outcomes of this function are precomputed and stored in the `domains` field in a [`Grammar`](@ref).
the outcomes of this function are precomputed and stored in the `domains` field in a [`AbstractGrammar`](@ref).
"""
get_domain(g::Grammar, type::Symbol)::BitVector = deepcopy(g.domains[type])
get_domain(g::AbstractGrammar, type::Symbol)::BitVector = deepcopy(g.domains[type])


"""
get_domain(g::Grammar, rules::Vector{Int})::BitVector
get_domain(g::AbstractGrammar, rules::Vector{Int})::BitVector
Takes a domain `rules` defined as a vector of ints and converts it to a domain defined as a `BitVector`.
"""
get_domain(g::Grammar, rules::Vector{Int})::BitVector = BitArray(r rules for r 1:length(g.rules))
get_domain(g::AbstractGrammar, rules::Vector{Int})::BitVector = BitArray(r rules for r 1:length(g.rules))


"""
isterminal(grammar::Grammar, rule_index::Int)::Bool
isterminal(grammar::AbstractGrammar, rule_index::Int)::Bool
Returns true if the production rule at `rule_index` is terminal, i.e., does not contain any nonterminal symbols.
"""
isterminal(grammar::Grammar, rule_index::Int)::Bool = grammar.isterminal[rule_index]
isterminal(grammar::AbstractGrammar, rule_index::Int)::Bool = grammar.isterminal[rule_index]


"""
iseval(grammar::Grammar)::Bool
iseval(grammar::AbstractGrammar)::Bool
Returns true if any production rules in grammar contain the special _() eval function.
!!! compat
evaluate immediately functionality is not yet supported by most of Herb.jl
"""
iseval(grammar::Grammar)::Bool = any(grammar.iseval)
iseval(grammar::AbstractGrammar)::Bool = any(grammar.iseval)


"""
iseval(grammar::Grammar, index::Int)::Bool
iseval(grammar::AbstractGrammar, index::Int)::Bool
Returns true if the production rule at rule_index contains the special _() eval function.
!!! compat
evaluate immediately functionality is not yet supported by most of Herb.jl
"""
iseval(grammar::Grammar, index::Int)::Bool = grammar.iseval[index]
iseval(grammar::AbstractGrammar, index::Int)::Bool = grammar.iseval[index]


"""
log_probability(grammar::Grammar, index::Int)::Real
log_probability(grammar::AbstractGrammar, index::Int)::Real
Returns the log probability for the rule at `index` in the grammar.
!!! warning
If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed.
"""
function log_probability(grammar::Grammar, index::Int)::Real
function log_probability(grammar::AbstractGrammar, index::Int)::Real
if !isprobabilistic(grammar)
@warn "Requesting probability in a non-probabilistic grammar.\nUniform distribution is assumed."
# Assume uniform probability
Expand All @@ -142,15 +142,15 @@ function log_probability(grammar::Grammar, index::Int)::Real
end

"""
probability(grammar::Grammar, index::Int)::Real
probability(grammar::AbstractGrammar, index::Int)::Real
Return the probability for a rule in the grammar.
Use [`log_probability`](@ref) whenever possible.
!!! warning
If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed.
"""
function probability(grammar::Grammar, index::Int)::Real
function probability(grammar::AbstractGrammar, index::Int)::Real
if !isprobabilistic(grammar)
@warn "Requesting probability in a non-probabilistic grammar.\nUniform distribution is assumed."
# Assume uniform probability
Expand All @@ -160,38 +160,38 @@ function probability(grammar::Grammar, index::Int)::Real
end

"""
isprobabilistic(grammar::Grammar)::Bool
isprobabilistic(grammar::AbstractGrammar)::Bool
Function returns whether a [`Grammar`](@ref) is probabilistic.
Function returns whether a [`AbstractGrammar`](@ref) is probabilistic.
"""
isprobabilistic(grammar::Grammar)::Bool = !(grammar.log_probabilities nothing)
isprobabilistic(grammar::AbstractGrammar)::Bool = !(grammar.log_probabilities nothing)


"""
nchildren(grammar::Grammar, rule_index::Int)::Int
nchildren(grammar::AbstractGrammar, rule_index::Int)::Int
Returns the number of children (nonterminals) of the production rule at `rule_index`.
"""
nchildren(grammar::Grammar, rule_index::Int)::Int = length(grammar.childtypes[rule_index])
nchildren(grammar::AbstractGrammar, rule_index::Int)::Int = length(grammar.childtypes[rule_index])


"""
max_arity(grammar::Grammar)::Int
max_arity(grammar::AbstractGrammar)::Int
Returns the maximum arity (number of children) over all production rules in the [`Grammar`](@ref).
Returns the maximum arity (number of children) over all production rules in the [`AbstractGrammar`](@ref).
"""
max_arity(grammar::Grammar)::Int = maximum(length(cs) for cs in grammar.childtypes)
max_arity(grammar::AbstractGrammar)::Int = maximum(length(cs) for cs in grammar.childtypes)


function Base.show(io::IO, grammar::Grammar)
function Base.show(io::IO, grammar::AbstractGrammar)
for i in eachindex(grammar.rules)
println(io, i, ": ", grammar.types[i], " = ", grammar.rules[i])
end
end


"""
add_rule!(g::Grammar, e::Expr)
add_rule!(g::AbstractGrammar, e::Expr)
Adds a rule to the grammar.
Expand All @@ -204,7 +204,7 @@ The syntax is identical to the syntax of [`@csgrammar`](@ref) and [`@cfgrammar`]
!!! warning
Calls to this function are ignored if a rule is already in the grammar.
"""
function add_rule!(g::Grammar, e::Expr)
function add_rule!(g::AbstractGrammar, e::Expr)
if e.head == :(=) && typeof(e.args[1]) == Symbol
s = e.args[1] # Name of return type
rule = e.args[2] # expression?
Expand Down Expand Up @@ -237,7 +237,7 @@ end
"""
Adds a probabilistic derivation rule.
"""
function add_rule!(g::Grammar, p::Real, e::Expr)
function add_rule!(g::AbstractGrammar, p::Real, e::Expr)
isprobabilistic(g) || throw(ArgumentError("adding a probabilistic rule to a non-probabilistic grammar"))
len₀ = length(g.rules)
add_rule!(g, e)
Expand All @@ -248,13 +248,13 @@ function add_rule!(g::Grammar, p::Real, e::Expr)
end

"""
remove_rule!(g::Grammar, idx::Int)
remove_rule!(g::AbstractGrammar, idx::Int)
Removes the rule corresponding to `idx` from the grammar.
In order to avoid shifting indices, the rule is replaced with `nothing`,
and all other data structures are updated accordingly.
"""
function remove_rule!(g::Grammar, idx::Int)
function remove_rule!(g::AbstractGrammar, idx::Int)
type = g.types[idx]
g.rules[idx] = nothing
g.iseval[idx] = false
Expand All @@ -275,7 +275,7 @@ end


"""
cleanup_removed_rules!(g::Grammar)
cleanup_removed_rules!(g::AbstractGrammar)
Removes any placeholders for previously deleted rules.
This means that indices get shifted.
Expand All @@ -285,7 +285,7 @@ This means that indices get shifted.
[`AbstractRuleNode`](@ref) trees created before the call to this function.
These trees become meaningless.
"""
function cleanup_removed_rules!(g::Grammar)
function cleanup_removed_rules!(g::AbstractGrammar)
rules_to_cleanup = findall(isequal(nothing), g.rules)
# highest indices are removed first, otherwise their index will have shifted
for v [g.rules, g.types, g.isterminal, g.iseval, g.childtypes]
Expand Down
Loading

0 comments on commit d9f62f8

Please sign in to comment.