Skip to content

Commit

Permalink
Add normalizing fix
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolaefilat committed Jul 17, 2024
1 parent 21cb2e9 commit 074672d
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions src/csg/probabilistic_csg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,10 @@ function expr2pcsgrammar(ex::Expr)::ContextSensitiveGrammar
end
alltypes = collect(keys(bytype))
# Normalize probabilities for each type
for t alltypes
total_prob = sum(probabilities[i] for i bytype[t])
if !(total_prob 1)
@warn "The probabilities for type $t don't add up to 1, so they will be normalized."
for i bytype[t]
probabilities[i] /= total_prob
end
total_prob = sum(probabilities)
if !(total_prob 1)
for i eachindex(probabilities)
probabilities[i] /= total_prob
end
end

Expand All @@ -54,7 +51,7 @@ function expr2pcsgrammar(ex::Expr)::ContextSensitiveGrammar
domains = Dict(type => BitArray(r bytype[type] for r 1:length(rules)) for type alltypes)
bychildtypes = [BitVector([childtypes[i1] == childtypes[i2] for i2 1:length(rules)]) for i1 1:length(rules)]

normalize!(ContextSensitiveGrammar(rules, types, is_terminal, is_eval, bytype, domains, childtypes, bychildtypes, log_probabilities))
ContextSensitiveGrammar(rules, types, is_terminal, is_eval, bytype, domains, childtypes, bychildtypes, log_probabilities)
end

"""
Expand Down

0 comments on commit 074672d

Please sign in to comment.