diff --git a/src/csg/probabilistic_csg.jl b/src/csg/probabilistic_csg.jl index 6c0d636..8612250 100644 --- a/src/csg/probabilistic_csg.jl +++ b/src/csg/probabilistic_csg.jl @@ -37,13 +37,10 @@ function expr2pcsgrammar(ex::Expr)::ContextSensitiveGrammar end alltypes = collect(keys(bytype)) # Normalize probabilities for each type - for t ∈ alltypes - total_prob = sum(probabilities[i] for i ∈ bytype[t]) - if !(total_prob ≈ 1) - @warn "The probabilities for type $t don't add up to 1, so they will be normalized." - for i ∈ bytype[t] - probabilities[i] /= total_prob - end + total_prob = sum(probabilities) + if !(total_prob ≈ 1) + for i ∈ eachindex(probabilities) + probabilities[i] /= total_prob end end @@ -54,7 +51,7 @@ function expr2pcsgrammar(ex::Expr)::ContextSensitiveGrammar domains = Dict(type => BitArray(r ∈ bytype[type] for r ∈ 1:length(rules)) for type ∈ alltypes) bychildtypes = [BitVector([childtypes[i1] == childtypes[i2] for i2 ∈ 1:length(rules)]) for i1 ∈ 1:length(rules)] - normalize!(ContextSensitiveGrammar(rules, types, is_terminal, is_eval, bytype, domains, childtypes, bychildtypes, log_probabilities)) + ContextSensitiveGrammar(rules, types, is_terminal, is_eval, bytype, domains, childtypes, bychildtypes, log_probabilities) end """