diff --git a/src/Gen.jl b/src/Gen.jl index 29113ff11..2ba70f6a5 100644 --- a/src/Gen.jl +++ b/src/Gen.jl @@ -2,6 +2,8 @@ module Gen +using Random: AbstractRNG, default_rng + """ load_generated_functions(__module__=Main) diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index d83055444..a76c8dbec 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -45,12 +45,15 @@ end accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad -mutable struct GFUntracedState +mutable struct GFUntracedState{R<:AbstractRNG} params::Dict{Symbol,Any} + rng::R end -function (gen_fn::DynamicDSLFunction)(args...) - state = GFUntracedState(gen_fn.params) +(gen_fn::DynamicDSLFunction)(args...) = gen_fn(default_rng(), args...) + +function (gen_fn::DynamicDSLFunction)(rng::AbstractRNG, args...) + state = GFUntracedState(gen_fn.params, rng) gen_fn.julia_function(state, args...) end @@ -82,13 +85,13 @@ end # Defaults for untraced execution @inline traceat(state::GFUntracedState, gen_fn::GenerativeFunction, args, key) = - gen_fn(args...) + gen_fn(state.rng, args...) @inline traceat(state::GFUntracedState, dist::Distribution, args, key) = - random(dist, args...) + random(state.rng, dist, args...) @inline splice(state::GFUntracedState, gen_fn::DynamicDSLFunction, args::Tuple) = - gen_fn(args...) + gen_fn(state.rng, args...) ######################## # trainable parameters # diff --git a/src/dynamic/generate.jl b/src/dynamic/generate.jl index a89e0c352..55a81c193 100644 --- a/src/dynamic/generate.jl +++ b/src/dynamic/generate.jl @@ -1,14 +1,15 @@ -mutable struct GFGenerateState +mutable struct GFGenerateState{R<:AbstractRNG} trace::DynamicDSLTrace constraints::ChoiceMap weight::Float64 visitor::AddressVisitor params::Dict{Symbol,Any} + rng::R end -function GFGenerateState(gen_fn, args, constraints, params) +function GFGenerateState(gen_fn, args, constraints, params, rng::AbstractRNG) trace = DynamicDSLTrace(gen_fn, args) - GFGenerateState(trace, constraints, 0., AddressVisitor(), params) + GFGenerateState(trace, constraints, 0., AddressVisitor(), params, rng) end function traceat(state::GFGenerateState, dist::Distribution{T}, @@ -26,7 +27,7 @@ function traceat(state::GFGenerateState, dist::Distribution{T}, if constrained retval = get_value(state.constraints, key) else - retval = random(dist, args...) + retval = random(state.rng, dist, args...) end # compute logpdf @@ -55,7 +56,7 @@ function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U}, constraints = get_submap(state.constraints, key) # get subtrace - (subtrace, weight) = generate(gen_fn, args, constraints) + (subtrace, weight) = generate(state.rng, gen_fn, args, constraints) # add to the trace add_call!(state.trace, key, subtrace) @@ -78,9 +79,12 @@ function splice(state::GFGenerateState, gen_fn::DynamicDSLFunction, retval end -function generate(gen_fn::DynamicDSLFunction, args::Tuple, - constraints::ChoiceMap) - state = GFGenerateState(gen_fn, args, constraints, gen_fn.params) +generate(gen_fn::DynamicDSLFunction, args::Tuple, constraints::ChoiceMap) = + generate(default_rng(), gen_fn, args, constraints) + +function generate(rng::AbstractRNG, gen_fn::DynamicDSLFunction, args::Tuple, + constraints::ChoiceMap) + state = GFGenerateState(gen_fn, args, constraints, gen_fn.params, rng) retval = exec(gen_fn, state, args) set_retval!(state.trace, retval) (state.trace, state.weight) diff --git a/src/dynamic/propose.jl b/src/dynamic/propose.jl index 7c630cc19..7eda0174e 100644 --- a/src/dynamic/propose.jl +++ b/src/dynamic/propose.jl @@ -1,12 +1,13 @@ -mutable struct GFProposeState +mutable struct GFProposeState{R<:AbstractRNG} choices::DynamicChoiceMap weight::Float64 visitor::AddressVisitor params::Dict{Symbol,Any} + rng::R end -function GFProposeState(params::Dict{Symbol,Any}) - GFProposeState(choicemap(), 0., AddressVisitor(), params) +function GFProposeState(params::Dict{Symbol,Any}, rng::AbstractRNG) + GFProposeState(choicemap(), 0., AddressVisitor(), params, rng) end function traceat(state::GFProposeState, dist::Distribution{T}, @@ -17,7 +18,7 @@ function traceat(state::GFProposeState, dist::Distribution{T}, visit!(state.visitor, key) # sample return value - retval = random(dist, args...) + retval = random(state.rng, dist, args...) # update assignment set_value!(state.choices, key, retval) @@ -36,7 +37,7 @@ function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U}, visit!(state.visitor, key) # get subtrace - (submap, weight, retval) = propose(gen_fn, args) + (submap, weight, retval) = propose(state.rng, gen_fn, args) # update assignment set_submap!(state.choices, key, submap) @@ -55,8 +56,8 @@ function splice(state::GFProposeState, gen_fn::DynamicDSLFunction, args::Tuple) retval end -function propose(gen_fn::DynamicDSLFunction, args::Tuple) - state = GFProposeState(gen_fn.params) +function propose(rng::AbstractRNG, gen_fn::DynamicDSLFunction, args::Tuple) + state = GFProposeState(gen_fn.params, rng) retval = exec(gen_fn, state, args) (state.choices, state.weight, retval) end diff --git a/src/dynamic/regenerate.jl b/src/dynamic/regenerate.jl index 13d14d86f..0264b7ec8 100644 --- a/src/dynamic/regenerate.jl +++ b/src/dynamic/regenerate.jl @@ -1,17 +1,18 @@ -mutable struct GFRegenerateState +mutable struct GFRegenerateState{R<:AbstractRNG} prev_trace::DynamicDSLTrace trace::DynamicDSLTrace selection::Selection weight::Float64 visitor::AddressVisitor params::Dict{Symbol,Any} + rng::R end function GFRegenerateState(gen_fn, args, prev_trace, - selection, params) + selection, params, rng::AbstractRNG) visitor = AddressVisitor() GFRegenerateState(prev_trace, DynamicDSLTrace(gen_fn, args), selection, - 0., visitor, params) + 0., visitor, params, rng) end function traceat(state::GFRegenerateState, dist::Distribution{T}, @@ -35,11 +36,11 @@ function traceat(state::GFRegenerateState, dist::Distribution{T}, # get return value if has_previous && in_selection - retval = random(dist, args...) + retval = random(state.rng, dist, args...) elseif has_previous retval = prev_retval else - retval = random(dist, args...) + retval = random(state.rng, dist, args...) end # compute logpdf @@ -75,9 +76,9 @@ function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U}, prev_subtrace = prev_call.subtrace get_gen_fn(prev_subtrace) === gen_fn || gen_fn_changed_error(key) (subtrace, weight, _) = regenerate( - prev_subtrace, args, map((_) -> UnknownChange(), args), subselection) + state.rng, prev_subtrace, args, map((_) -> UnknownChange(), args), subselection) else - (subtrace, weight) = generate(gen_fn, args, EmptyChoiceMap()) + (subtrace, weight) = generate(state.rng, gen_fn, args, EmptyChoiceMap()) end # update weight @@ -130,10 +131,10 @@ function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, noise end -function regenerate(trace::DynamicDSLTrace, args::Tuple, argdiffs::Tuple, - selection::Selection) +function regenerate(rng::AbstractRNG, trace::DynamicDSLTrace, args::Tuple, + argdiffs::Tuple, selection::Selection) gen_fn = trace.gen_fn - state = GFRegenerateState(gen_fn, args, trace, selection, gen_fn.params) + state = GFRegenerateState(gen_fn, args, trace, selection, gen_fn.params, rng) retval = exec(gen_fn, state, args) set_retval!(state.trace, retval) visited = state.visitor.visited diff --git a/src/dynamic/simulate.jl b/src/dynamic/simulate.jl index 7db1a213a..dac2627f2 100644 --- a/src/dynamic/simulate.jl +++ b/src/dynamic/simulate.jl @@ -1,12 +1,13 @@ -mutable struct GFSimulateState +mutable struct GFSimulateState{R<:AbstractRNG} trace::DynamicDSLTrace visitor::AddressVisitor params::Dict{Symbol,Any} + rng::R end -function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, params) +function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, params, rng::AbstractRNG) trace = DynamicDSLTrace(gen_fn, args) - GFSimulateState(trace, AddressVisitor(), params) + GFSimulateState(trace, AddressVisitor(), params, rng) end function traceat(state::GFSimulateState, dist::Distribution{T}, @@ -16,7 +17,7 @@ function traceat(state::GFSimulateState, dist::Distribution{T}, # check that key was not already visited, and mark it as visited visit!(state.visitor, key) - retval = random(dist, args...) + retval = random(state.rng, dist, args...) # compute logpdf score = logpdf(dist, retval, args...) @@ -36,7 +37,7 @@ function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U}, visit!(state.visitor, key) # get subtrace - subtrace = simulate(gen_fn, args) + subtrace = simulate(state.rng, gen_fn, args) # add to the trace add_call!(state.trace, key, subtrace) @@ -56,8 +57,8 @@ function splice(state::GFSimulateState, gen_fn::DynamicDSLFunction, retval end -function simulate(gen_fn::DynamicDSLFunction, args::Tuple) - state = GFSimulateState(gen_fn, args, gen_fn.params) +function simulate(rng::AbstractRNG, gen_fn::DynamicDSLFunction, args::Tuple) + state = GFSimulateState(gen_fn, args, gen_fn.params, rng) retval = exec(gen_fn, state, args) set_retval!(state.trace, retval) state.trace diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index 3e3605f59..15af33137 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -1,4 +1,4 @@ -mutable struct GFUpdateState +mutable struct GFUpdateState{R<:AbstractRNG} prev_trace::DynamicDSLTrace trace::DynamicDSLTrace constraints::Any @@ -6,14 +6,15 @@ mutable struct GFUpdateState visitor::AddressVisitor params::Dict{Symbol,Any} discard::DynamicChoiceMap + rng::R end -function GFUpdateState(gen_fn, args, prev_trace, constraints, params) +function GFUpdateState(gen_fn, args, prev_trace, constraints, params, rng::AbstractRNG) visitor = AddressVisitor() discard = choicemap() trace = DynamicDSLTrace(gen_fn, args) GFUpdateState(prev_trace, trace, constraints, - 0., visitor, params, discard) + 0., visitor, params, discard, rng) end function traceat(state::GFUpdateState, dist::Distribution{T}, @@ -48,7 +49,7 @@ function traceat(state::GFUpdateState, dist::Distribution{T}, elseif has_previous retval = prev_retval else - retval = random(dist, args...) + retval = random(state.rng, dist, args...) end # compute logpdf @@ -87,10 +88,10 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, prev_call = get_call(state.prev_trace, key) prev_subtrace = prev_call.subtrace get_gen_fn(prev_subtrace) == gen_fn || gen_fn_changed_error(key) - (subtrace, weight, _, discard) = update(prev_subtrace, + (subtrace, weight, _, discard) = update(state.rng, prev_subtrace, args, map((_) -> UnknownChange(), args), constraints) else - (subtrace, weight) = generate(gen_fn, args, constraints) + (subtrace, weight) = generate(state.rng, gen_fn, args, constraints) end # update the weight @@ -184,10 +185,10 @@ function add_unvisited_to_discard!(discard::DynamicChoiceMap, end end -function update(trace::DynamicDSLTrace, arg_values::Tuple, arg_diffs::Tuple, +function update(rng::AbstractRNG, trace::DynamicDSLTrace, arg_values::Tuple, arg_diffs::Tuple, constraints::ChoiceMap) gen_fn = trace.gen_fn - state = GFUpdateState(gen_fn, arg_values, trace, constraints, gen_fn.params) + state = GFUpdateState(gen_fn, arg_values, trace, constraints, gen_fn.params, rng) retval = exec(gen_fn, state, arg_values) set_retval!(state.trace, retval) visited = get_visited(state.visitor) diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index b9ae77632..332746cc3 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -135,7 +135,7 @@ Return an iterable over the trainable parameters of the generative function. get_params(::GenerativeFunction) = () """ - trace = simulate(gen_fn, args) + trace = simulate([rng::AbstractRNG], gen_fn, args) Execute the generative function and return the trace. @@ -143,18 +143,26 @@ Given arguments (`args`), sample \$(r, t) \\sim p(\\cdot; x)\$ and return a trac If `gen_fn` has optional trailing arguments (i.e., default values are provided), the optional arguments can be omitted from the `args` tuple. The generated trace - will have default values filled in. +will have default values filled in. + +The RNG state can be optionally supplied as the first argument. If no RNG is supplied, `Random.default_rng()` +will be used by default. """ -function simulate(::GenerativeFunction, ::Tuple) - error("Not implemented") +function simulate(::AbstractRNG, ::GenerativeFunction, ::Tuple) + # TODO: For backwards compatibility only. Remove in next breaking version. + @warn "Missing concrete implementation of `simulate(::AbstractRNG, ::$(typeof(gen_fn)), ::Tuple), `" * + "falling back to `simulate(::$(typeof(gen_fn)), ::Tuple)`." + return simulate(gen_fn, args) end +simulate(gen_fn::GenerativeFunction, args::Tuple) = simulate(default_rng(), gen_fn, args) + """ - (trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple) + (trace::U, weight) = generate([rng::AbstractRNG], gen_fn::GenerativeFunction{T,U}, args::Tuple) Return a trace of a generative function. - (trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple, + (trace::U, weight) = generate(rng, gen_fn::GenerativeFunction{T,U}, args::Tuple, constraints::ChoiceMap) Return a trace of a generative function that is consistent with the given @@ -167,6 +175,9 @@ Also return the weight (`weight`): \\log \\frac{p(r, t; x)}{q(t; u, x) q(r; x, t)} ``` +The RNG state can be optionally supplied as the first argument. If no RNG is supplied, `Random.default_rng()` +will be used by default. + If `gen_fn` has optional trailing arguments (i.e., default values are provided), the optional arguments can be omitted from the `args` tuple. The generated trace will have default values filled in. @@ -181,14 +192,21 @@ Example with constraint that address `:z` takes value `true`. (trace, weight) = generate(foo, (2, 4), choicemap((:z, true)) ``` """ -function generate(::GenerativeFunction, ::Tuple, ::ChoiceMap) - error("Not implemented") +function generate(::AbstractRNG, ::GenerativeFunction, ::Tuple, ::ChoiceMap) + # TODO: For backwards compatibility only. Remove in next breaking version. + @warn "Missing concrete implementation of `generate(::AbstractRNG, ::$(typeof(gen_fn)), ::Tuple), `" * + "falling back to `generate(::$(typeof(gen_fn)), ::Tuple)`." + return generate(gen_fn, args) end -function generate(gen_fn::GenerativeFunction, args::Tuple) - generate(gen_fn, args, EmptyChoiceMap()) +generate(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) = generate(default_rng(), gen_fn, args, choices) + +function generate(rng::AbstractRNG, gen_fn::GenerativeFunction, args::Tuple) + generate(rng, gen_fn, args, EmptyChoiceMap()) end +generate(gen_fn::GenerativeFunction, args::Tuple) = generate(default_rng(), gen_fn, args) + """ weight = project(trace::U, selection::Selection) @@ -207,7 +225,7 @@ function project(trace, selection::Selection) end """ - (choices, weight, retval) = propose(gen_fn::GenerativeFunction, args::Tuple) + (choices, weight, retval) = propose([rng::AbstractRNG], gen_fn::GenerativeFunction, args::Tuple) Sample an assignment and compute the probability of proposing that assignment. @@ -217,13 +235,18 @@ t)\$, and return \$t\$ ```math \\log \\frac{p(r, t; x)}{q(r; x, t)} ``` + +The RNG state can be optionally supplied as the first argument. If no RNG is supplied, `Random.default_rng()` +will be used by default. """ -function propose(gen_fn::GenerativeFunction, args::Tuple) - trace = simulate(gen_fn, args) +function propose(rng::AbstractRNG, gen_fn::GenerativeFunction, args::Tuple) + trace = simulate(rng, gen_fn, args) weight = get_score(trace) (get_choices(trace), weight, get_retval(trace)) end +propose(gen_fn::GenerativeFunction, args::Tuple) = propose(default_rng(), gen_fn, args) + """ (weight, retval) = assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) @@ -243,8 +266,8 @@ function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) end """ - (new_trace, weight, retdiff, discard) = update(trace, args::Tuple, argdiffs::Tuple, - constraints::ChoiceMap) + (new_trace, weight, retdiff, discard) = update([rng::AbstractRNG], trace, args::Tuple, + argdiffs::Tuple, constraints::ChoiceMap) Update a trace by changing the arguments and/or providing new values for some existing random choice(s) and values for some newly introduced random choice(s). @@ -271,26 +294,37 @@ then these arguments can be omitted from `args` and `argdiffs`. Note that if the original `trace` was generated using non-default argument values, then for each optional argument that is omitted, the old value will be over-written by the default argument value in the updated trace. + +The RNG state can be optionally supplied as the first argument. If no RNG is supplied, `Random.default_rng()` +will be used by default. """ -function update(trace, args::Tuple, argdiffs::Tuple, ::ChoiceMap) - error("Not implemented") +function update(::AbstractRNG, trace, ::Tuple, ::Tuple, ::ChoiceMap) + # TODO: For backwards compatibility only. Remove in next breaking version. + @warn "Missing concrete implementation of `update(::AbstractRNG, ::$(typeof(gen_fn)), ::Tuple), `" * + "falling back to `update(::$(typeof(gen_fn)), ::Tuple)`." + return update(gen_fn, args) end +update(trace, args::Tuple, argdiffs::Tuple, choices::ChoiceMap) = + update(default_rng(), trace, args, argdiffs, choices) + """ - (new_trace, weight, retdiff, discard) = update(trace, constraints::ChoiceMap) + (new_trace, weight, retdiff, discard) = update([rng::AbstractRNG], trace, constraints::ChoiceMap) Shorthand variant of [`update`](@ref update(::Any, ::Tuple, ::Tuple, ::ChoiceMap)) which assumes the arguments are unchanged. """ -function update(trace, constraints::ChoiceMap) +function update(rng::AbstractRNG, trace, constraints::ChoiceMap) args = get_args(trace) argdiffs = Tuple(NoChange() for _ in args) - return update(trace, args, argdiffs, constraints) + return update(rng, trace, args, argdiffs, constraints) end +update(trace, constraints::ChoiceMap) = update(default_rng(), trace, constraints) + """ - (new_trace, weight, retdiff) = regenerate(trace, args::Tuple, argdiffs::Tuple, + (new_trace, weight, retdiff) = regenerate([rng::AbstractRNG], trace, args::Tuple, argdiffs::Tuple, selection::Selection) Update a trace by changing the arguments and/or randomly sampling new values @@ -316,24 +350,35 @@ then these arguments can be omitted from `args` and `argdiffs`. Note that if the original `trace` was generated using non-default argument values, then for each optional argument that is omitted, the old value will be over-written by the default argument value in the regenerated trace. + +The RNG state can be optionally supplied as the first argument. If no RNG is supplied, +`Random.default_rng()` will be used by default. """ -function regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection) - error("Not implemented") +function regenerate(::AbstractRNG, trace, ::Tuple, ::Tuple, ::Selection) + # TODO: For backwards compatibility only. Remove in next breaking version. + @warn "Missing concrete implementation of `regenerate(::AbstractRNG, ::$(typeof(gen_fn)), ::Tuple), `" * + "falling back to `regenerate(::$(typeof(gen_fn)), ::Tuple)`." + return regenerate(gen_fn, args) end +regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection) = + regenerate(default_rng(), trace, args, argdiffs, selection) + """ - (new_trace, weight, retdiff) = regenerate(trace, selection::Selection) + (new_trace, weight, retdiff) = regenerate([rng::AbstractRNG], trace, selection::Selection) Shorthand variant of [`regenerate`](@ref regenerate(::Any, ::Tuple, ::Tuple, ::Selection)) which assumes the arguments are unchanged. """ -function regenerate(trace, selection::Selection) +function regenerate(rng::AbstractRNG, trace, selection::Selection) args = get_args(trace) argdiffs = Tuple(NoChange() for _ in args) - return regenerate(trace, args, argdiffs, selection) + return regenerate(rng, trace, args, argdiffs, selection) end +regenerate(trace, selection::Selection) = regenerate(default_rng(), trace, selection) + """ arg_grads = accumulate_param_gradients!(trace, retgrad=nothing, scale_factor=1.) diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 95a6fdeeb..10c904bb0 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -1,5 +1,5 @@ -function sample_momenta(n::Int) - Float64[random(normal, 0, 1) for _=1:n] +function sample_momenta(rng::AbstractRNG, n::Int) + Float64[random(rng, normal, 0, 1) for _=1:n] end function assess_momenta(momenta) @@ -13,7 +13,8 @@ end """ (new_trace, accepted) = hmc( trace, selection::Selection; L=10, eps=0.1, - check=false, observations=EmptyChoiceMap()) + check=false, observations=EmptyChoiceMap(), + rng::Random.AbstractRNG=Random.default_rng()) Apply a Hamiltonian Monte Carlo (HMC) update that proposes new values for the selected addresses, returning the new trace (which is equal to the previous trace if the move was not accepted) and a `Bool` indicating whether the move was accepted or not. @@ -23,8 +24,10 @@ Hamilton's equations are numerically integrated using leapfrog integration with Neal, Radford M. (2011), "MCMC Using Hamiltonian Dynamics", Handbook of Markov Chain Monte Carlo, pp. 113-162. URL: http://www.mcmchandbook.net/HandbookChapter5.pdf """ function hmc( - trace::Trace, selection::Selection; L=10, eps=0.1, - check=false, observations=EmptyChoiceMap()) + trace::Trace, selection::Selection; L=10, eps=0.1, + check=false, observations=EmptyChoiceMap(), + rng::Random.AbstractRNG=Random.default_rng() +) prev_model_score = get_score(trace) args = get_args(trace) retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing @@ -35,7 +38,7 @@ function hmc( (_, values_trie, gradient_trie) = choice_gradients(new_trace, selection, retval_grad) values = to_array(values_trie, Float64) gradient = to_array(gradient_trie, Float64) - momenta = sample_momenta(length(values)) + momenta = sample_momenta(rng, length(values)) prev_momenta_score = assess_momenta(momenta) for step=1:L @@ -47,7 +50,7 @@ function hmc( # get new gradient values_trie = from_array(values_trie, values) - (new_trace, _, _) = update(new_trace, args, argdiffs, values_trie) + (new_trace, _, _) = update(rng, new_trace, args, argdiffs, values_trie) (_, _, gradient_trie) = choice_gradients(new_trace, selection, retval_grad) gradient = to_array(gradient_trie, Float64) @@ -64,7 +67,7 @@ function hmc( # accept or reject alpha = new_model_score - prev_model_score + new_momenta_score - prev_momenta_score - if log(rand()) < alpha + if log(rand(rng)) < alpha (new_trace, true) else (trace, false) diff --git a/src/inference/mala.jl b/src/inference/mala.jl index 033a45a7e..1be82d556 100644 --- a/src/inference/mala.jl +++ b/src/inference/mala.jl @@ -2,15 +2,18 @@ """ (new_trace, accepted) = mala( trace, selection::Selection, tau::Real; - check=false, observations=EmptyChoiceMap()) + check=false, observations=EmptyChoiceMap(), + rng::Random.AbstractRNG=Random.default_rng()) Apply a Metropolis-Adjusted Langevin Algorithm (MALA) update. [Reference URL](https://en.wikipedia.org/wiki/Metropolis-adjusted_Langevin_algorithm) """ function mala( - trace, selection::Selection, tau::Real; - check=false, observations=EmptyChoiceMap()) + trace, selection::Selection, tau::Real; + check=false, observations=EmptyChoiceMap(), + rng::Random.AbstractRNG=Random.default_rng() +) args = get_args(trace) argdiffs = map((_) -> NoChange(), args) std = sqrt(2 * tau) @@ -24,13 +27,13 @@ function mala( forward_score = 0. proposed_values = Vector{Float64}(undef, length(values)) for i=1:length(values) - proposed_values[i] = random(normal, forward_mu[i], std) + proposed_values[i] = random(rng, normal, forward_mu[i], std) forward_score += logpdf(normal, proposed_values[i], forward_mu[i], std) end # evaluate model weight constraints = from_array(values_trie, proposed_values) - (new_trace, weight, _, discard) = update(trace, + (new_trace, weight, _, discard) = update(rng, trace, args, argdiffs, constraints) check && check_observations(get_choices(new_trace), observations) @@ -46,7 +49,7 @@ function mala( # accept or reject alpha = weight - forward_score + backward_score - if log(rand()) < alpha + if log(rand(rng)) < alpha (new_trace, true) else (trace, false) diff --git a/src/inference/mh.jl b/src/inference/mh.jl index bdb89fd73..092a6784f 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -7,16 +7,18 @@ reversal(::typeof(metropolis_hastings)) = metropolis_hastings """ (new_trace, accepted) = metropolis_hastings( trace, selection::Selection; - check=false, observations=EmptyChoiceMap()) + check=false, observations=EmptyChoiceMap(), + rng::Random.AbstractRNG=Random.default_rng()) Perform a Metropolis-Hastings update that proposes new values for the selected addresses from the internal proposal (often using ancestral sampling), returning the new trace (which is equal to the previous trace if the move was not accepted) and a Bool indicating whether the move was accepted or not. """ function metropolis_hastings( trace, selection::Selection; - check=false, observations=EmptyChoiceMap()) - (new_trace, weight) = regenerate(trace, selection) + check=false, observations=EmptyChoiceMap(), + rng::Random.AbstractRNG=Random.default_rng()) + (new_trace, weight) = regenerate(rng, trace, selection) check && check_observations(get_choices(new_trace), observations) - if log(rand()) < weight + if log(rand(rng)) < weight # accept return (new_trace, true) else @@ -38,19 +40,20 @@ If the proposal modifies addresses that determine the control flow in the model, """ function metropolis_hastings( trace, proposal::GenerativeFunction, proposal_args::Tuple; - check=false, observations=EmptyChoiceMap()) + check=false, observations=EmptyChoiceMap(), + rng::Random.AbstractRNG=Random.default_rng()) # TODO add a round trip check model_args = get_args(trace) argdiffs = map((_) -> NoChange(), model_args) proposal_args_forward = (trace, proposal_args...,) - (fwd_choices, fwd_weight, _) = propose(proposal, proposal_args_forward) - (new_trace, weight, _, discard) = update(trace, + (fwd_choices, fwd_weight, _) = propose(rng, proposal, proposal_args_forward) + (new_trace, weight, _, discard) = update(rng, trace, model_args, argdiffs, fwd_choices) proposal_args_backward = (new_trace, proposal_args...,) (bwd_weight, _) = assess(proposal, proposal_args_backward, discard) alpha = weight - fwd_weight + bwd_weight check && check_observations(get_choices(new_trace), observations) - if log(rand()) < alpha + if log(rand(rng)) < alpha # accept return (new_trace, true) else @@ -83,10 +86,11 @@ The `check` keyword argument to the involution can be used to enable or disable function metropolis_hastings( trace, proposal::GenerativeFunction, proposal_args::Tuple, involution::Union{TraceTransformDSLProgram,Function}; - check=false, observations=EmptyChoiceMap()) + check=false, observations=EmptyChoiceMap(), + rng::Random.AbstractRNG=Random.default_rng()) trace_translator = SymmetricTraceTranslator(proposal, proposal_args, involution) (new_trace, log_weight) = trace_translator(trace; check=check, observations=observations) - if log(rand()) < log_weight + if log(rand(rng)) < log_weight # accept (new_trace, true) else diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index 0c5b997bc..899c87eef 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -76,24 +76,24 @@ function assess(gen_fn::CallAtCombinator, args::Tuple, choices::ChoiceMap) assess(gen_fn.kernel, kernel_args, submap) end -function propose(gen_fn::CallAtCombinator, args::Tuple) +function propose(rng::AbstractRNG, gen_fn::CallAtCombinator, args::Tuple) (key, kernel_args) = unpack_call_at_args(args) - (submap, weight, retval) = propose(gen_fn.kernel, kernel_args) + (submap, weight, retval) = propose(rng, gen_fn.kernel, kernel_args) choices = CallAtChoiceMap(key, submap) (choices, weight, retval) end -function simulate(gen_fn::CallAtCombinator, args::Tuple) +function simulate(rng::AbstractRNG, gen_fn::CallAtCombinator, args::Tuple) (key, kernel_args) = unpack_call_at_args(args) - subtrace = simulate(gen_fn.kernel, kernel_args) + subtrace = simulate(rng, gen_fn.kernel, kernel_args) CallAtTrace(gen_fn, subtrace, key) end -function generate(gen_fn::CallAtCombinator{T,U,K}, args::Tuple, +function generate(rng::AbstractRNG, gen_fn::CallAtCombinator{T,U,K}, args::Tuple, choices::ChoiceMap) where {T,U,K} (key, kernel_args) = unpack_call_at_args(args) submap = get_submap(choices, key) - (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap) + (subtrace, weight) = generate(rng, gen_fn.kernel, kernel_args, submap) trace = CallAtTrace(gen_fn, subtrace, key) (trace, weight) end @@ -103,26 +103,26 @@ function project(trace::CallAtTrace, selection::Selection) project(trace.subtrace, subselection) end -function update(trace::CallAtTrace, args::Tuple, argdiffs::Tuple, - choices::ChoiceMap) +function update(rng::AbstractRNG, trace::CallAtTrace, args::Tuple, + argdiffs::Tuple, choices::ChoiceMap) (key, kernel_args) = unpack_call_at_args(args) key_changed = (key != trace.key) submap = get_submap(choices, key) if key_changed - (subtrace, weight) = generate(trace.gen_fn.kernel, kernel_args, submap) + (subtrace, weight) = generate(rng, trace.gen_fn.kernel, kernel_args, submap) weight -= get_score(trace.subtrace) discard = get_choices(trace) retdiff = UnknownChange() else (subtrace, weight, retdiff, subdiscard) = update( - trace.subtrace, kernel_args, argdiffs[1:end-1], submap) + rng, trace.subtrace, kernel_args, argdiffs[1:end-1], submap) discard = CallAtChoiceMap(key, subdiscard) end new_trace = CallAtTrace(trace.gen_fn, subtrace, key) (new_trace, weight, retdiff, discard) end -function regenerate(trace::CallAtTrace, args::Tuple, argdiffs::Tuple, +function regenerate(rng::AbstractRNG, trace::CallAtTrace, args::Tuple, argdiffs::Tuple, selection::Selection) (key, kernel_args) = unpack_call_at_args(args) key_changed = (key != trace.key) @@ -131,12 +131,12 @@ function regenerate(trace::CallAtTrace, args::Tuple, argdiffs::Tuple, if !isempty(subselection) error("Cannot select addresses under new key $key in regenerate") end - (subtrace, weight) = generate(trace.gen_fn.kernel, kernel_args, EmptyChoiceMap()) + (subtrace, weight) = generate(rng, trace.gen_fn.kernel, kernel_args, EmptyChoiceMap()) weight -= project(trace.subtrace, EmptySelection()) retdiff = UnknownChange() else (subtrace, weight, retdiff) = regenerate( - trace.subtrace, kernel_args, argdiffs[1:end-1], subselection) + rng, trace.subtrace, kernel_args, argdiffs[1:end-1], subselection) end new_trace = CallAtTrace(trace.gen_fn, subtrace, key) (new_trace, weight, retdiff) diff --git a/src/modeling_library/choice_at/choice_at.jl b/src/modeling_library/choice_at/choice_at.jl index 09cb922fa..0348cf596 100644 --- a/src/modeling_library/choice_at/choice_at.jl +++ b/src/modeling_library/choice_at/choice_at.jl @@ -59,29 +59,35 @@ function assess(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple, choices::ChoiceMap (weight, value) end -function propose(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple) where {T,K} +propose(gen_fn::ChoiceAtCombinator, args::Tuple) = propose(default_rng(), gen_fn, args) + +function propose(rng::AbstractRNG, gen_fn::ChoiceAtCombinator{T,K}, args::Tuple) where {T,K} local key::K local value::T (key, kernel_args) = unpack_choice_at_args(args) - value = random(gen_fn.dist, kernel_args...) + value = random(rng, gen_fn.dist, kernel_args...) score = logpdf(gen_fn.dist, value, kernel_args...) choices = ChoiceAtChoiceMap(key, value) (choices, score, value) end -function simulate(gen_fn::ChoiceAtCombinator, args::Tuple) +simulate(gen_fn::ChoiceAtCombinator, args::Tuple) = simulate(default_rng(), gen_fn, args) + +function simulate(rng::AbstractRNG, gen_fn::ChoiceAtCombinator, args::Tuple) (key, kernel_args) = unpack_choice_at_args(args) - value = random(gen_fn.dist, kernel_args...) + value = random(rng, gen_fn.dist, kernel_args...) score = logpdf(gen_fn.dist, value, kernel_args...) ChoiceAtTrace(gen_fn, value, key, kernel_args, score) end -function generate(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple, choices::ChoiceMap) where {T,K} +generate(gen_fn::ChoiceAtCombinator, args::Tuple, choices::ChoiceMap) = generate(default_rng(), gen_fn, args, choices) + +function generate(rng::AbstractRNG, gen_fn::ChoiceAtCombinator{T,K}, args::Tuple, choices::ChoiceMap) where {T,K} local key::K local value::T (key, kernel_args) = unpack_choice_at_args(args) constrained = has_value(choices, key) - value = constrained ? get_value(choices, key) : random(gen_fn.dist, kernel_args...) + value = constrained ? get_value(choices, key) : random(rng, gen_fn.dist, kernel_args...) score = logpdf(gen_fn.dist, value, kernel_args...) trace = ChoiceAtTrace(gen_fn, value, key, kernel_args, score) weight = constrained ? score : 0. @@ -115,17 +121,20 @@ function update(trace::ChoiceAtTrace, args::Tuple, argdiffs::Tuple, (new_trace, weight, UnknownChange(), discard) end -function regenerate(trace::ChoiceAtTrace, args::Tuple, argdiffs::Tuple, - selection::Selection) +regenerate(trace::ChoiceAtTrace, args::Tuple, argdiffs::Tuple, selection::Selection) = + regenerate(default_rng(), trace, args, argdiffs, selection) + +function regenerate(rng::AbstractRNG, trace::ChoiceAtTrace, args::Tuple, + argdiffs::Tuple, selection::Selection) (key, kernel_args) = unpack_choice_at_args(args) key_changed = (key != trace.key) selected = key in selection if !key_changed && selected - new_value = random(trace.gen_fn.dist, kernel_args...) + new_value = random(rng, trace.gen_fn.dist, kernel_args...) elseif !key_changed && !selected new_value = trace.value elseif key_changed && !selected - new_value = random(trace.gen_fn.dist, kernel_args...) + new_value = random(rng, trace.gen_fn.dist, kernel_args...) else error("Cannot select new address $key in regenerate") end diff --git a/src/modeling_library/custom_determ.jl b/src/modeling_library/custom_determ.jl index 24d6d90f2..1058fc771 100644 --- a/src/modeling_library/custom_determ.jl +++ b/src/modeling_library/custom_determ.jl @@ -93,11 +93,15 @@ function accumulate_param_gradients_determ!( gradient_with_state(gen_fn, state, args, retgrad) end +simulate(::AbstractRNG, gen_fn::CustomDetermGF, args::Tuple) = simulate(gen_fn, args) + function simulate(gen_fn::CustomDetermGF{T,S}, args::Tuple) where {T,S} retval, state = apply_with_state(gen_fn, args) CustomDetermGFTrace{T,S}(retval, state, args, gen_fn) end +generate(::AbstractRNG, gen_fn::CustomDetermGF, args::Tuple, choices::ChoiceMap) = generate(gen_fn, args, choices) + function generate(gen_fn::CustomDetermGF{T,S}, args::Tuple, choices::ChoiceMap) where {T,S} if !isempty(choices) error("Deterministic generative function makes no random choices") @@ -107,6 +111,8 @@ function generate(gen_fn::CustomDetermGF{T,S}, args::Tuple, choices::ChoiceMap) trace, 0. end +update(::AbstractRNG, trace::CustomDetermGFTrace, args::Tuple, argdiffs::Tuple, choices::ChoiceMap) = update(trace, args, argdiffs, choices) + function update(trace::CustomDetermGFTrace{T,S}, args::Tuple, argdiffs::Tuple, choices::ChoiceMap) where {T,S} if !isempty(choices) error("Deterministic generative function makes no random choices") @@ -116,6 +122,8 @@ function update(trace::CustomDetermGFTrace{T,S}, args::Tuple, argdiffs::Tuple, c (new_trace, 0., retdiff, choicemap()) end +regenerate(::AbstractRNG, trace::CustomDetermGFTrace, args::Tuple, argdiffs::Tuple, selection::Selection) = regenerate(trace, args, argdiffs, selection) + function regenerate(trace::CustomDetermGFTrace, args::Tuple, argdiffs::Tuple, selection::Selection) update(trace, args, argdiffs, EmptyChoiceMap()) end diff --git a/src/modeling_library/dist_dsl/dist_dsl.jl b/src/modeling_library/dist_dsl/dist_dsl.jl index 3f3fe6b11..02e731ece 100644 --- a/src/modeling_library/dist_dsl/dist_dsl.jl +++ b/src/modeling_library/dist_dsl/dist_dsl.jl @@ -159,14 +159,15 @@ function logpdf_grad(d::CompiledDistWithArgs{T}, x::T, args...) where T return (self_output_grad, self_arg_grads...) end -function random(d::CompiledDistWithArgs{T}, args...)::T where T +function random(rng::AbstractRNG, d::CompiledDistWithArgs{T}, args...)::T where T concrete_args = [eval_arg(arg, args) for arg in d.arglist] - random(d.base, concrete_args...) + random(rng, d.base, concrete_args...) end is_discrete(d::CompiledDistWithArgs{T}) where T = is_discrete(d.base) -(d::CompiledDistWithArgs{T})(args...) where T = random(d, args...) +(d::CompiledDistWithArgs)(args...) = d(default_rng(), args...) +(d::CompiledDistWithArgs{T})(rng::AbstractRNG, args...) where T = random(rng, d, args...) function has_output_grad(d::CompiledDistWithArgs{T}) where T has_output_grad(d.base) diff --git a/src/modeling_library/dist_dsl/relabeled_distribution.jl b/src/modeling_library/dist_dsl/relabeled_distribution.jl index 19c5a31a1..5844bd42c 100644 --- a/src/modeling_library/dist_dsl/relabeled_distribution.jl +++ b/src/modeling_library/dist_dsl/relabeled_distribution.jl @@ -37,13 +37,14 @@ function logpdf_grad(d::WithLabelArg{T, U}, x::T, collection, base_args...) wher (nothing, nothing, base_arg_grads...) end -function random(d::WithLabelArg{T, U}, collection, base_args...)::T where {T, U} - collection[random(d.base, base_args...)] +function random(rng::AbstractRNG, d::WithLabelArg{T, U}, collection, base_args...)::T where {T, U} + collection[random(rng, d.base, base_args...)] end is_discrete(d::WithLabelArg{T, U}) where {T, U} = true -(d::WithLabelArg{T, U})(collection, base_args...) where {T, U} = random(d, collection, base_args...) +(d::WithLabelArg)(collection, base_args...) = d(default_rng(), collection, base_args...) +(d::WithLabelArg{T, U})(rng::AbstractRNG, collection, base_args...) where {T, U} = random(rng, d, collection, base_args...) function has_output_grad(d::WithLabelArg{T, U}) where {T, U} false @@ -91,13 +92,14 @@ function logpdf_grad(d::RelabeledDistribution{T, U}, x::T, base_args...) where { (nothing, base_arg_grads...) end -function random(d::RelabeledDistribution{T, U}, base_args...)::T where {T, U} - d.collection[random(d.base, base_args...)] +function random(rng::AbstractRNG, d::RelabeledDistribution{T, U}, base_args...)::T where {T, U} + d.collection[random(rng, d.base, base_args...)] end is_discrete(d::RelabeledDistribution{T, U}) where {T, U} = true -(d::RelabeledDistribution{T, U})(base_args...) where {T, U} = random(d, base_args...) +(d::RelabeledDistribution)(base_args...) = d(default_rng(), base_args...) +(d::RelabeledDistribution{T, U})(rng::AbstractRNG, base_args...) where {T, U} = random(rng, d, base_args...) function has_output_grad(d::RelabeledDistribution{T, U}) where {T, U} false diff --git a/src/modeling_library/dist_dsl/transformed_distribution.jl b/src/modeling_library/dist_dsl/transformed_distribution.jl index afcc08b30..65259c255 100644 --- a/src/modeling_library/dist_dsl/transformed_distribution.jl +++ b/src/modeling_library/dist_dsl/transformed_distribution.jl @@ -16,8 +16,8 @@ struct TransformedDistribution{T, U} <: Distribution{T} backward_grad :: Function end -function random(d::TransformedDistribution{T, U}, args...)::T where {T, U} - d.forward(random(d.base, args[d.nArgs+1:end]...), args[1:d.nArgs]...) +function random(rng::AbstractRNG, d::TransformedDistribution{T, U}, args...)::T where {T, U} + d.forward(random(rng, d.base, args[d.nArgs+1:end]...), args[1:d.nArgs]...) end function logpdf_correction(d::TransformedDistribution{T, U}, x, args) where {T, U} @@ -54,7 +54,8 @@ end is_discrete(d::TransformedDistribution{T, U}) where {T, U} = is_discrete(d.base) -(d::TransformedDistribution{T, U})(args...) where {T, U} = random(d, args...) +(d::TransformedDistribution)(args...) = d(default_rng(), args...) +(d::TransformedDistribution{T, U})(rng::AbstractRNG, args...) where {T, U} = random(rng, d, args...) function has_output_grad(d::TransformedDistribution{T, U}) where {T, U} has_output_grad(d.base) diff --git a/src/modeling_library/distributions/bernoulli.jl b/src/modeling_library/distributions/bernoulli.jl index 495528de1..20de0b336 100644 --- a/src/modeling_library/distributions/bernoulli.jl +++ b/src/modeling_library/distributions/bernoulli.jl @@ -16,11 +16,12 @@ function logpdf_grad(::Bernoulli, x::Bool, prob::Real) (nothing, prob_grad) end -random(::Bernoulli, prob::Real) = rand() < prob +random(rng::AbstractRNG, ::Bernoulli, prob::Real) = rand(rng) < prob is_discrete(::Bernoulli) = true -(::Bernoulli)(prob) = random(Bernoulli(), prob) +(dist::Bernoulli)(prob) = dist(default_rng(), prob) +(::Bernoulli)(rng::AbstractRNG, prob) = random(rng, Bernoulli(), prob) has_output_grad(::Bernoulli) = false has_argument_grads(::Bernoulli) = (true,) diff --git a/src/modeling_library/distributions/beta.jl b/src/modeling_library/distributions/beta.jl index 1503f6d71..9c9d4a97f 100644 --- a/src/modeling_library/distributions/beta.jl +++ b/src/modeling_library/distributions/beta.jl @@ -34,12 +34,13 @@ function logpdf_grad(::Beta, x::Real, alpha::Real, beta::Real) (deriv_x, deriv_alpha, deriv_beta) end -function random(::Beta, alpha::Real, beta::Real) - rand(Distributions.Beta(alpha, beta)) +function random(rng::AbstractRNG, ::Beta, alpha::Real, beta::Real) + rand(rng, Distributions.Beta(alpha, beta)) end is_discrete(::Beta) = false -(::Beta)(alpha, beta) = random(Beta(), alpha, beta) +(dist::Beta)(alpha, beta) = dist(default_rng(), alpha, beta) +(::Beta)(rng::AbstractRNG, alpha, beta) = random(rng, Beta(), alpha, beta) has_output_grad(::Beta) = true has_argument_grads(::Beta) = (true, true) diff --git a/src/modeling_library/distributions/beta_uniform.jl b/src/modeling_library/distributions/beta_uniform.jl index 9cb213da4..d724e8fce 100644 --- a/src/modeling_library/distributions/beta_uniform.jl +++ b/src/modeling_library/distributions/beta_uniform.jl @@ -33,15 +33,16 @@ function logpdf_grad(::BetaUniformMixture, x::Real, theta::Real, alpha::Real, be (x_deriv, theta_deriv, alpha_deriv, beta_deriv) end -function random(::BetaUniformMixture, theta::Real, alpha::Real, beta::Real) +function random(rng::AbstractRNG, ::BetaUniformMixture, theta::Real, alpha::Real, beta::Real) if bernoulli(theta) - random(Beta(), alpha, beta) + random(rng, Beta(), alpha, beta) else - random(uniform_continuous, 0., 1.) + random(rng, uniform_continuous, 0., 1.) end end -(::BetaUniformMixture)(theta, alpha, beta) = random(BetaUniformMixture(), theta, alpha, beta) +(dist::BetaUniformMixture)(theta, alpha, beta) = dist(default_rng(), theta, alpha, beta) +(::BetaUniformMixture)(rng::AbstractRNG, theta, alpha, beta) = random(rng, BetaUniformMixture(), theta, alpha, beta) is_discrete(::BetaUniformMixture) = false diff --git a/src/modeling_library/distributions/binom.jl b/src/modeling_library/distributions/binom.jl index 87afdce66..52ba315aa 100644 --- a/src/modeling_library/distributions/binom.jl +++ b/src/modeling_library/distributions/binom.jl @@ -15,11 +15,12 @@ function logpdf_grad(::Binomial, x::Integer, n::Integer, p::Real) (nothing, nothing, (x / p - (n - x) / (1 - p))) end -function random(::Binomial, n::Integer, p::Real) - rand(Distributions.Binomial(n, p)) +function random(rng::AbstractRNG, ::Binomial, n::Integer, p::Real) + rand(rng, Distributions.Binomial(n, p)) end -(::Binomial)(n, p) = random(Binomial(), n, p) +(dist::Binomial)(n, p) = dist(default_rng(), n, p) +(::Binomial)(rng::AbstractRNG, n, p) = random(rng, Binomial(), n, p) has_output_grad(::Binomial) = false has_argument_grads(::Binomial) = (false, true) diff --git a/src/modeling_library/distributions/categorical.jl b/src/modeling_library/distributions/categorical.jl index dd21728c0..c9ce5cf7b 100644 --- a/src/modeling_library/distributions/categorical.jl +++ b/src/modeling_library/distributions/categorical.jl @@ -17,12 +17,13 @@ function logpdf_grad(::Categorical, x::Int, probs::AbstractArray{U,1}) where {U (nothing, grad) end -function random(::Categorical, probs::AbstractArray{U,1}) where {U <: Real} - rand(Distributions.Categorical(probs)) +function random(rng::AbstractRNG, ::Categorical, probs::AbstractArray{U,1}) where {U <: Real} + rand(rng, Distributions.Categorical(probs)) end is_discrete(::Categorical) = true -(::Categorical)(probs) = random(Categorical(), probs) +(dist::Categorical)(probs) = dist(default_rng(), probs) +(::Categorical)(rng::AbstractRNG, probs) = random(rng, Categorical(), probs) has_output_grad(::Categorical) = false has_argument_grads(::Categorical) = (true,) diff --git a/src/modeling_library/distributions/cauchy.jl b/src/modeling_library/distributions/cauchy.jl index 34e592763..e4d1493da 100644 --- a/src/modeling_library/distributions/cauchy.jl +++ b/src/modeling_library/distributions/cauchy.jl @@ -23,9 +23,10 @@ end is_discrete(::Cauchy) = false -random(::Cauchy, x0::Real, gamma::Real) = rand(Distributions.Cauchy(x0, gamma)) +random(rng::AbstractRNG, ::Cauchy, x0::Real, gamma::Real) = rand(rng, Distributions.Cauchy(x0, gamma)) -(::Cauchy)(x0::Real, gamma::Real) = random(Cauchy(), x0, gamma) +(dist::Cauchy)(x0::Real, gamma::Real) = dist(default_rng(), x0, gamma) +(::Cauchy)(rng::AbstractRNG, x0::Real, gamma::Real) = random(rng, Cauchy(), x0, gamma) has_output_grad(::Cauchy) = true has_argument_grads(::Cauchy) = (true, true) diff --git a/src/modeling_library/distributions/dirichlet.jl b/src/modeling_library/distributions/dirichlet.jl index 63b5aef28..16ef37aca 100644 --- a/src/modeling_library/distributions/dirichlet.jl +++ b/src/modeling_library/distributions/dirichlet.jl @@ -27,13 +27,14 @@ function logpdf_grad(::Dirichlet, x::AbstractVector{T}, alpha::AbstractVector{U} end end -function random(::Dirichlet, alpha::AbstractVector{T}) where {T <: Real} - rand(Distributions.Dirichlet(alpha)) +function random(rng::AbstractRNG, ::Dirichlet, alpha::AbstractVector{T}) where {T <: Real} + rand(rng, Distributions.Dirichlet(alpha)) end is_discrete(::Dirichlet) = false -(::Dirichlet)(alpha) = random(Dirichlet(), alpha) +(dist::Dirichlet)(alpha) = dist(default_rng(), alpha) +(::Dirichlet)(rng::AbstractRNG, alpha) = random(rng, Dirichlet(), alpha) has_output_grad(::Dirichlet) = true has_argument_grads(::Dirichlet) = (true,) diff --git a/src/modeling_library/distributions/exponential.jl b/src/modeling_library/distributions/exponential.jl index de933a4b0..697e300b0 100644 --- a/src/modeling_library/distributions/exponential.jl +++ b/src/modeling_library/distributions/exponential.jl @@ -19,14 +19,15 @@ function Gen.logpdf_grad(::Exponential, x::Real, rate::Real) (x_grad, rate_grad) end -function Gen.random(::Exponential, rate::Real) +function Gen.random(rng::AbstractRNG, ::Exponential, rate::Real) scale = 1/rate - rand(Distributions.Exponential(scale)) + rand(rng, Distributions.Exponential(scale)) end is_discrete(::Exponential) = false -(::Exponential)(rate) = random(Exponential(), rate) +(dist::Exponential)(rate) = dist(default_rng(), rate) +(::Exponential)(rng::AbstractRNG, rate) = random(rng, Exponential(), rate) Gen.has_output_grad(::Exponential) = true Gen.has_argument_grads(::Exponential) = (true,) diff --git a/src/modeling_library/distributions/gamma.jl b/src/modeling_library/distributions/gamma.jl index fb9370f60..8c0b5b4c6 100644 --- a/src/modeling_library/distributions/gamma.jl +++ b/src/modeling_library/distributions/gamma.jl @@ -26,13 +26,14 @@ function logpdf_grad(::Gamma, x::Real, shape::Real, scale::Real) end end -function random(::Gamma, shape::Real, scale::Real) - rand(Distributions.Gamma(shape, scale)) +function random(rng::AbstractRNG, ::Gamma, shape::Real, scale::Real) + rand(rng, Distributions.Gamma(shape, scale)) end is_discrete(::Gamma) = false -(::Gamma)(shape, scale) = random(Gamma(), shape, scale) +(dist::Gamma)(shape, scale) = dist(default_rng(), shape, scale) +(::Gamma)(rng::AbstractRNG, shape, scale) = random(rng, Gamma(), shape, scale) has_output_grad(::Gamma) = true has_argument_grads(::Gamma) = (true, true) diff --git a/src/modeling_library/distributions/geometric.jl b/src/modeling_library/distributions/geometric.jl index fc7fad3cc..9958cc952 100644 --- a/src/modeling_library/distributions/geometric.jl +++ b/src/modeling_library/distributions/geometric.jl @@ -20,13 +20,14 @@ function Gen.logpdf_grad(::Geometric, x::Int, p::Real) (nothing, p_grad) end -function Gen.random(::Geometric, p::Real) - rand(Distributions.Geometric(p)) +function Gen.random(rng::AbstractRNG, ::Geometric, p::Real) + rand(rng, Distributions.Geometric(p)) end is_discrete(::Geometric) = true -(::Geometric)(p) = random(Geometric(), p) +(dist::Geometric)(p) = dist(default_rng(), p) +(::Geometric)(rng::AbstractRNG, p) = random(rng, Geometric(), p) Gen.has_output_grad(::Geometric) = false Gen.has_argument_grads(::Geometric) = (true,) diff --git a/src/modeling_library/distributions/inv_gamma.jl b/src/modeling_library/distributions/inv_gamma.jl index 1861b8a7f..1187dcb9c 100644 --- a/src/modeling_library/distributions/inv_gamma.jl +++ b/src/modeling_library/distributions/inv_gamma.jl @@ -25,13 +25,14 @@ function logpdf_grad(::InverseGamma, x::Real, shape::Real, scale::Real) end -function random(::InverseGamma, shape::Real, scale::Real) - rand(Distributions.InverseGamma(shape, scale)) +function random(rng::AbstractRNG, ::InverseGamma, shape::Real, scale::Real) + rand(rng, Distributions.InverseGamma(shape, scale)) end is_discrete(::InverseGamma) = false -(::InverseGamma)(shape, scale) = random(InverseGamma(), shape, scale) +(dist::InverseGamma)(shape, scale) = dist(default_rng(), shape, scale) +(::InverseGamma)(rng::AbstractRNG, shape, scale) = random(rng, InverseGamma(), shape, scale) has_output_grad(::InverseGamma) = true has_argument_grads(::InverseGamma) = (true, true) diff --git a/src/modeling_library/distributions/laplace.jl b/src/modeling_library/distributions/laplace.jl index 71586c6fc..4e3de9855 100644 --- a/src/modeling_library/distributions/laplace.jl +++ b/src/modeling_library/distributions/laplace.jl @@ -25,13 +25,14 @@ function logpdf_grad(::Laplace, x::Real, loc::Real, scale::Real) (deriv_x, deriv_loc, deriv_scale) end -function random(::Laplace, loc::Real, scale::Real) - rand(Distributions.Laplace(loc, scale)) +function random(rng::AbstractRNG, ::Laplace, loc::Real, scale::Real) + rand(rng, Distributions.Laplace(loc, scale)) end is_discrete(::Laplace) = false -(::Laplace)(loc, scale) = random(Laplace(), loc, scale) +(dist::Laplace)(loc, scale) = dist(default_rng(), loc, scale) +(::Laplace)(rng::AbstractRNG, loc, scale) = random(rng, Laplace(), loc, scale) has_output_grad(::Laplace) = true has_argument_grads(::Laplace) = (true, true) diff --git a/src/modeling_library/distributions/mvnormal.jl b/src/modeling_library/distributions/mvnormal.jl index 942320020..10df70bac 100644 --- a/src/modeling_library/distributions/mvnormal.jl +++ b/src/modeling_library/distributions/mvnormal.jl @@ -27,12 +27,13 @@ function logpdf_grad(::MultivariateNormal, x::AbstractVector{T}, mu::AbstractVec (x_deriv, mu_deriv, cov_deriv) end -function random(::MultivariateNormal, mu::AbstractVector{U}, +function random(rng::AbstractRNG, ::MultivariateNormal, mu::AbstractVector{U}, cov::AbstractMatrix{V}) where {U <: Real, V <: Real} - rand(Distributions.MvNormal(mu, LinearAlgebra.Symmetric(cov))) + rand(rng, Distributions.MvNormal(mu, LinearAlgebra.Symmetric(cov))) end -(::MultivariateNormal)(mu, cov) = random(MultivariateNormal(), mu, cov) +(dist::MultivariateNormal)(mu, cov) = dist(default_rng(), mu, cov) +(::MultivariateNormal)(rng::AbstractRNG, mu, cov) = random(rng, MultivariateNormal(), mu, cov) has_output_grad(::MultivariateNormal) = true has_argument_grads(::MultivariateNormal) = (true, true) diff --git a/src/modeling_library/distributions/neg_binom.jl b/src/modeling_library/distributions/neg_binom.jl index baf63d4a7..3bfc0f5ac 100644 --- a/src/modeling_library/distributions/neg_binom.jl +++ b/src/modeling_library/distributions/neg_binom.jl @@ -19,13 +19,14 @@ function logpdf_grad(::NegativeBinomial, x::Int, r::Real, p::Real) return (nothing, r_grad, p_grad) end -function random(::NegativeBinomial, r::Real, p::Real) - rand(Distributions.NegativeBinomial(r, p)) +function random(rng::AbstractRNG, ::NegativeBinomial, r::Real, p::Real) + rand(rng, Distributions.NegativeBinomial(r, p)) end is_discrete(::NegativeBinomial) = true -(::NegativeBinomial)(r, p) = random(NegativeBinomial(), r, p) +(dist::NegativeBinomial)(r, p) = dist(default_rng(), r, p) +(::NegativeBinomial)(rng::AbstractRNG, r, p) = random(rng, NegativeBinomial(), r, p) has_output_grad(::NegativeBinomial) = false has_argument_grads(::NegativeBinomial) = (true, true) diff --git a/src/modeling_library/distributions/normal.jl b/src/modeling_library/distributions/normal.jl index d128f97d5..d4fca34ed 100644 --- a/src/modeling_library/distributions/normal.jl +++ b/src/modeling_library/distributions/normal.jl @@ -125,18 +125,22 @@ function _unbroadcast_to_shape(target_shape::NTuple{target_ndims, Int}, dims=Dims(target_ndims + 1 : full_ndims)) end -random(::Normal, mu::Real, std::Real) = mu + std * randn() +random(rng::AbstractRNG, ::Normal, mu::Real, std::Real) = mu + std * randn(rng) + is_discrete(::Normal) = false -function random(::BroadcastedNormal, +function random(rng::AbstractRNG, ::BroadcastedNormal, mu::Union{AbstractArray{<:Real}, Real}, std::Union{AbstractArray{<:Real}, Real}) broadcast_shape = broadcast_shapes_or_crash(mu, std) - mu .+ std .* randn(broadcast_shape) + mu .+ std .* randn(rng, broadcast_shape) end -(::Normal)(mu, std) = random(Normal(), mu, std) -(::BroadcastedNormal)(mu, std) = random(BroadcastedNormal(), mu, std) +(dist::Normal)(mu, std) = dist(default_rng(), mu, std) +(::Normal)(rng::AbstractRNG, mu, std) = random(rng, Normal(), mu, std) + +(dist::BroadcastedNormal)(mu, std) = dist(default_rng(), mu, std) +(::BroadcastedNormal)(rng::AbstractRNG, mu, std) = random(rng, BroadcastedNormal(), mu, std) has_output_grad(::Normal) = true has_argument_grads(::Normal) = (true, true) diff --git a/src/modeling_library/distributions/piecewise_uniform.jl b/src/modeling_library/distributions/piecewise_uniform.jl index e4eb89d12..cbb630720 100644 --- a/src/modeling_library/distributions/piecewise_uniform.jl +++ b/src/modeling_library/distributions/piecewise_uniform.jl @@ -42,13 +42,14 @@ function logpdf(::PiecewiseUniform, x::Real, bounds::AbstractVector{T}, end end -function random(::PiecewiseUniform, bounds::Vector{T}, +function random(rng::AbstractRNG, ::PiecewiseUniform, bounds::Vector{T}, probs::Vector{U}) where {T <: Real, U <: Real} bin = categorical(probs) - uniform_continuous(bounds[bin], bounds[bin+1]) + uniform_continuous(rng, bounds[bin], bounds[bin+1]) end -(::PiecewiseUniform)(bounds, probs) = random(PiecewiseUniform(), bounds, probs) +(dist::PiecewiseUniform)(bounds, probs) = dist(default_rng(), bounds, probs) +(::PiecewiseUniform)(rng::AbstractRNG, bounds, probs) = random(rng, PiecewiseUniform(), bounds, probs) function logpdf_grad(::PiecewiseUniform, x::Real, bounds, probs) check_dims(piecewise_uniform, bounds, probs) diff --git a/src/modeling_library/distributions/poisson.jl b/src/modeling_library/distributions/poisson.jl index efb723671..b99f6cd04 100644 --- a/src/modeling_library/distributions/poisson.jl +++ b/src/modeling_library/distributions/poisson.jl @@ -16,11 +16,13 @@ function logpdf_grad(::Poisson, x::Int, lambda::Real) end -function random(::Poisson, lambda::Real) - rand(Distributions.Poisson(lambda)) +function random(rng::AbstractRNG, ::Poisson, lambda::Real) + rand(rng, Distributions.Poisson(lambda)) end -(::Poisson)(lambda) = random(Poisson(), lambda) +(dist::Poisson)(lambda) = dist(default_rng(), lambda) +(::Poisson)(rng::AbstractRNG, lambda) = random(rng, Poisson(), lambda) + is_discrete(::Poisson) = true has_output_grad(::Poisson) = false diff --git a/src/modeling_library/distributions/uniform_continuous.jl b/src/modeling_library/distributions/uniform_continuous.jl index 919d9a399..9e0cb3cc4 100644 --- a/src/modeling_library/distributions/uniform_continuous.jl +++ b/src/modeling_library/distributions/uniform_continuous.jl @@ -18,11 +18,13 @@ function logpdf_grad(::UniformContinuous, x::Real, low::Real, high::Real) (0., inv_diff, -inv_diff) end -function random(::UniformContinuous, low::Real, high::Real) - rand() * (high - low) + low +function random(rng::AbstractRNG, ::UniformContinuous, low::Real, high::Real) + rand(rng) * (high - low) + low end -(::UniformContinuous)(low, high) = random(UniformContinuous(), low, high) +(dist::UniformContinuous)(low, high) = dist(default_rng(), low, high) +(::UniformContinuous)(rng::AbstractRNG, low, high) = random(rng, UniformContinuous(), low, high) + is_discrete(::UniformContinuous) = false has_output_grad(::UniformContinuous) = true diff --git a/src/modeling_library/distributions/uniform_discrete.jl b/src/modeling_library/distributions/uniform_discrete.jl index b125b3b2a..1a553dd3a 100644 --- a/src/modeling_library/distributions/uniform_discrete.jl +++ b/src/modeling_library/distributions/uniform_discrete.jl @@ -16,12 +16,13 @@ function logpdf_grad(::UniformDiscrete, x::Int, lower::Integer, high::Integer) (nothing, nothing, nothing) end -function random(::UniformDiscrete, low::Integer, high::Integer) - rand(Distributions.DiscreteUniform(low, high)) +function random(rng::AbstractRNG, ::UniformDiscrete, low::Integer, high::Integer) + rand(rng, Distributions.DiscreteUniform(low, high)) end is_discrete(::UniformDiscrete) = true -(::UniformDiscrete)(low, high) = random(UniformDiscrete(), low, high) +(dist::UniformDiscrete)(low, high) = dist(default_rng(), low, high) +(::UniformDiscrete)(rng::AbstractRNG, low, high) = random(rng, UniformDiscrete(), low, high) has_output_grad(::UniformDiscrete) = false has_argument_grads(::UniformDiscrete) = (false, false) diff --git a/src/modeling_library/map/assess.jl b/src/modeling_library/map/assess.jl index 74538d9d4..bdd74a24b 100644 --- a/src/modeling_library/map/assess.jl +++ b/src/modeling_library/map/assess.jl @@ -3,7 +3,7 @@ mutable struct MapAssessState{T} retvals::Vector{T} end -function process_new!(gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, +function process_new!(rng::AbstractRNG, gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, key::Int, state::MapAssessState{T}) where {T,U} kernel_args = get_args_for_key(args, key) submap = get_submap(choices, key) @@ -16,7 +16,8 @@ function assess(gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap) where {T,U} len = length(args[1]) state = MapAssessState{T}(0., Vector{T}(undef,len)) for key=1:len - process_new!(gen_fn, args, choices, key, state) + # pass default rng to satisfy the interface; note, however, that it will not be used. + process_new!(default_rng(), gen_fn, args, choices, key, state) end (state.weight, PersistentVector{T}(state.retvals)) end diff --git a/src/modeling_library/map/generate.jl b/src/modeling_library/map/generate.jl index 8a5c2e8da..ccab65776 100644 --- a/src/modeling_library/map/generate.jl +++ b/src/modeling_library/map/generate.jl @@ -7,13 +7,13 @@ mutable struct MapGenerateState{T,U} num_nonempty::Int end -function process!(gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, +function process!(rng::AbstractRNG, gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, key::Int, state::MapGenerateState{T,U}) where {T,U} local subtrace::U local retval::T kernel_args = get_args_for_key(args, key) submap = get_submap(choices, key) - (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap) + (subtrace, weight) = generate(rng, gen_fn.kernel, kernel_args, submap) state.weight += weight state.noise += project(subtrace, EmptySelection()) state.num_nonempty += (isempty(get_choices(subtrace)) ? 0 : 1) @@ -23,12 +23,12 @@ function process!(gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, state.retval[key] = retval end -function generate(gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap) where {T,U} +function generate(rng::AbstractRNG, gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap) where {T,U} len = length(args[1]) state = MapGenerateState{T,U}(0., 0., 0., Vector{U}(undef,len), Vector{T}(undef,len), 0) # TODO check for keys that aren't valid constraints for key=1:len - process!(gen_fn, args, choices, key, state) + process!(rng, gen_fn, args, choices, key, state) end trace = VectorTrace{MapType,T,U}(gen_fn, PersistentVector{U}(state.subtraces), PersistentVector{T}(state.retval), diff --git a/src/modeling_library/map/generic_update.jl b/src/modeling_library/map/generic_update.jl index 84e130afd..39425ef7f 100644 --- a/src/modeling_library/map/generic_update.jl +++ b/src/modeling_library/map/generic_update.jl @@ -17,7 +17,7 @@ function get_kernel_argdiffs(argdiffs::Tuple) kernel_argdiffs end -function process_all_retained!(gen_fn::Map{T,U}, args::Tuple, argdiffs::Tuple, +function process_all_retained!(rng::AbstractRNG, gen_fn::Map{T,U}, args::Tuple, argdiffs::Tuple, choices_or_selection, prev_length::Int, new_length::Int, retained_and_targeted::Set{Int}, state) where {T,U} kernel_no_change_argdiffs = map((_) -> NoChange(), args) @@ -28,7 +28,7 @@ function process_all_retained!(gen_fn::Map{T,U}, args::Tuple, argdiffs::Tuple, # only visit retained applications that were targeted for key in retained_and_targeted @assert key <= min(new_length, prev_length) - process_retained!(gen_fn, args, choices_or_selection, key, kernel_no_change_argdiffs, state) + process_retained!(rng, gen_fn, args, choices_or_selection, key, kernel_no_change_argdiffs, state) end elseif any(diff == UnknownChange() for diff in argdiffs) @@ -36,7 +36,7 @@ function process_all_retained!(gen_fn::Map{T,U}, args::Tuple, argdiffs::Tuple, # visit every retained application for key in 1:min(prev_length, new_length) @assert key <= min(new_length, prev_length) - process_retained!(gen_fn, args, choices_or_selection, key, kernel_unknown_change_argdiffs, state) + process_retained!(rng, gen_fn, args, choices_or_selection, key, kernel_unknown_change_argdiffs, state) end else @@ -51,7 +51,7 @@ function process_all_retained!(gen_fn::Map{T,U}, args::Tuple, argdiffs::Tuple, else kernel_argdiffs = kernel_no_change_argdiffs end - process_retained!(gen_fn, args, choices_or_selection, key, kernel_argdiffs, state) + process_retained!(rng, gen_fn, args, choices_or_selection, key, kernel_argdiffs, state) end end @@ -60,9 +60,9 @@ end """ Process all new applications. """ -function process_all_new!(gen_fn::Map{T,U}, args::Tuple, choices_or_selection, +function process_all_new!(rng::AbstractRNG, gen_fn::Map{T,U}, args::Tuple, choices_or_selection, prev_len::Int, new_len::Int, state) where {T,U} for key=prev_len+1:new_len - process_new!(gen_fn, args, choices_or_selection, key, state) + process_new!(rng, gen_fn, args, choices_or_selection, key, state) end end diff --git a/src/modeling_library/map/map.jl b/src/modeling_library/map/map.jl index 1bb695ff7..8d58d9670 100644 --- a/src/modeling_library/map/map.jl +++ b/src/modeling_library/map/map.jl @@ -27,8 +27,9 @@ export Map has_argument_grads(map_gf::Map) = has_argument_grads(map_gf.kernel) accepts_output_grad(map_gf::Map) = accepts_output_grad(map_gf.kernel) -function (gen_fn::Map)(args...) - (_, _, retval) = propose(gen_fn, args) +(gen_fn::Map)(args...) = gen_fn(default_rng(), args...) +function (gen_fn::Map)(rng::AbstractRNG, args...) + (_, _, retval) = propose(rng, gen_fn, args) retval end diff --git a/src/modeling_library/map/propose.jl b/src/modeling_library/map/propose.jl index c0ca14330..1bd04a41d 100644 --- a/src/modeling_library/map/propose.jl +++ b/src/modeling_library/map/propose.jl @@ -4,7 +4,7 @@ mutable struct MapProposeState{T} retvals::Vector{T} end -function process_new!(gen_fn::Map{T,U}, args::Tuple, key::Int, +function process_new!(rng::AbstractRNG, gen_fn::Map{T,U}, args::Tuple, key::Int, state::MapProposeState{T}) where {T,U} local subtrace::U kernel_args = get_args_for_key(args, key) @@ -14,12 +14,12 @@ function process_new!(gen_fn::Map{T,U}, args::Tuple, key::Int, state.retvals[key] = retval end -function propose(gen_fn::Map{T,U}, args::Tuple) where {T,U} +function propose(rng::AbstractRNG, gen_fn::Map{T,U}, args::Tuple) where {T,U} len = length(args[1]) choices = choicemap() state = MapProposeState{T}(choices, 0., Vector{T}(undef,len)) for key=1:len - process_new!(gen_fn, args, key, state) + process_new!(rng, gen_fn, args, key, state) end (state.choices, state.weight, PersistentVector{T}(state.retvals)) end diff --git a/src/modeling_library/map/regenerate.jl b/src/modeling_library/map/regenerate.jl index 5e80ff476..84a7e8686 100644 --- a/src/modeling_library/map/regenerate.jl +++ b/src/modeling_library/map/regenerate.jl @@ -8,7 +8,7 @@ mutable struct MapRegenerateState{T,U} updated_retdiffs::Dict{Int,Diff} end -function process_retained!(gen_fn::Map{T,U}, args::Tuple, +function process_retained!(rng::AbstractRNG, gen_fn::Map{T,U}, args::Tuple, selection::Selection, key::Int, kernel_argdiffs::Tuple, state::MapRegenerateState{T,U}) where {T,U} local subtrace::U @@ -21,7 +21,7 @@ function process_retained!(gen_fn::Map{T,U}, args::Tuple, # get new subtrace with recursive call to regenerate() prev_subtrace = state.subtraces[key] (subtrace, weight, retdiff) = regenerate( - prev_subtrace, kernel_args, kernel_argdiffs, subselection) + rng, prev_subtrace, kernel_args, kernel_argdiffs, subselection) # retrieve retdiff if retdiff != NoChange() @@ -44,7 +44,7 @@ function process_retained!(gen_fn::Map{T,U}, args::Tuple, end end -function process_new!(gen_fn::Map{T,U}, args::Tuple, selection::Selection, key::Int, +function process_new!(rng::AbstractRNG, gen_fn::Map{T,U}, args::Tuple, selection::Selection, key::Int, state::MapRegenerateState{T,U}) where {T,U} local subtrace::U local retval::T @@ -54,7 +54,7 @@ function process_new!(gen_fn::Map{T,U}, args::Tuple, selection::Selection, key:: kernel_args = get_args_for_key(args, key) # get subtrace and weight - (subtrace, weight) = generate(gen_fn.kernel, kernel_args, EmptyChoiceMap()) + (subtrace, weight) = generate(rng, gen_fn.kernel, kernel_args, EmptyChoiceMap()) # update state state.weight += weight @@ -70,7 +70,7 @@ function process_new!(gen_fn::Map{T,U}, args::Tuple, selection::Selection, key:: end -function regenerate(trace::VectorTrace{MapType,T,U}, args::Tuple, argdiffs::Tuple, +function regenerate(rng::AbstractRNG, trace::VectorTrace{MapType,T,U}, args::Tuple, argdiffs::Tuple, selection::Selection) where {T,U} gen_fn = trace.gen_fn (new_length, prev_length) = get_prev_and_new_lengths(args, trace) @@ -88,9 +88,9 @@ function regenerate(trace::VectorTrace{MapType,T,U}, args::Tuple, argdiffs::Tupl state = MapRegenerateState{T,U}(-noise_decrement, score, noise, subtraces, retval, num_nonempty, Dict{Int,Diff}()) - process_all_retained!(gen_fn, args, argdiffs, selection, + process_all_retained!(rng, gen_fn, args, argdiffs, selection, prev_length, new_length, retained_and_selected, state) - process_all_new!(gen_fn, args, selection, prev_length, new_length, state) + process_all_new!(rng, gen_fn, args, selection, prev_length, new_length, state) # retdiff retdiff = vector_compute_retdiff(state.updated_retdiffs, new_length, prev_length) diff --git a/src/modeling_library/map/simulate.jl b/src/modeling_library/map/simulate.jl index 216e140c8..c629799d7 100644 --- a/src/modeling_library/map/simulate.jl +++ b/src/modeling_library/map/simulate.jl @@ -6,12 +6,12 @@ mutable struct MapSimulateState{T,U} num_nonempty::Int end -function process!(gen_fn::Map{T,U}, args::Tuple, +function process!(rng::AbstractRNG, gen_fn::Map{T,U}, args::Tuple, key::Int, state::MapSimulateState{T,U}) where {T,U} local subtrace::U local retval::T kernel_args = get_args_for_key(args, key) - subtrace = simulate(gen_fn.kernel, kernel_args) + subtrace = simulate(rng, gen_fn.kernel, kernel_args) state.noise += project(subtrace, EmptySelection()) state.num_nonempty += (isempty(get_choices(subtrace)) ? 0 : 1) state.score += get_score(subtrace) @@ -20,11 +20,11 @@ function process!(gen_fn::Map{T,U}, args::Tuple, state.retval[key] = retval end -function simulate(gen_fn::Map{T,U}, args::Tuple) where {T,U} +function simulate(rng::AbstractRNG, gen_fn::Map{T,U}, args::Tuple) where {T,U} len = length(args[1]) state = MapSimulateState{T,U}(0., 0., Vector{U}(undef,len), Vector{T}(undef,len), 0) for key=1:len - process!(gen_fn, args, key, state) + process!(rng, gen_fn, args, key, state) end VectorTrace{MapType,T,U}(gen_fn, PersistentVector{U}(state.subtraces), PersistentVector{T}(state.retval), diff --git a/src/modeling_library/map/update.jl b/src/modeling_library/map/update.jl index 8e6dc9825..50ade2beb 100644 --- a/src/modeling_library/map/update.jl +++ b/src/modeling_library/map/update.jl @@ -9,7 +9,7 @@ mutable struct MapUpdateState{T,U} updated_retdiffs::Dict{Int,Diff} end -function process_retained!(gen_fn::Map{T,U}, args::Tuple, +function process_retained!(rng::AbstractRNG, gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, key::Int, kernel_argdiffs::Tuple, state::MapUpdateState{T,U}) where {T,U} local subtrace::U @@ -22,7 +22,7 @@ function process_retained!(gen_fn::Map{T,U}, args::Tuple, # get new subtrace with recursive call to update() prev_subtrace = state.subtraces[key] (subtrace, weight, retdiff, discard) = update( - prev_subtrace, kernel_args, kernel_argdiffs, submap) + rng, prev_subtrace, kernel_args, kernel_argdiffs, submap) # retrieve retdiff if retdiff != NoChange() @@ -46,7 +46,7 @@ function process_retained!(gen_fn::Map{T,U}, args::Tuple, end end -function process_new!(gen_fn::Map{T,U}, args::Tuple, choices, key::Int, +function process_new!(rng::AbstractRNG, gen_fn::Map{T,U}, args::Tuple, choices, key::Int, state::MapUpdateState{T,U}) where {T,U} local subtrace::U local retval::T @@ -55,7 +55,7 @@ function process_new!(gen_fn::Map{T,U}, args::Tuple, choices, key::Int, kernel_args = get_args_for_key(args, key) # get subtrace and weight - (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap) + (subtrace, weight) = generate(rng, gen_fn.kernel, kernel_args, submap) # update state state.weight += weight @@ -71,7 +71,7 @@ function process_new!(gen_fn::Map{T,U}, args::Tuple, choices, key::Int, end -function update(trace::VectorTrace{MapType,T,U}, args::Tuple, argdiffs::Tuple, +function update(rng::AbstractRNG, trace::VectorTrace{MapType,T,U}, args::Tuple, argdiffs::Tuple, choices::ChoiceMap) where {T,U} gen_fn = trace.gen_fn (new_length, prev_length) = get_prev_and_new_lengths(args, trace) @@ -89,9 +89,9 @@ function update(trace::VectorTrace{MapType,T,U}, args::Tuple, argdiffs::Tuple, state = MapUpdateState{T,U}(-score_decrement, score, noise, subtraces, retval, discard, num_nonempty, Dict{Int,Diff}()) - process_all_retained!(gen_fn, args, argdiffs, choices, prev_length, new_length, + process_all_retained!(rng, gen_fn, args, argdiffs, choices, prev_length, new_length, retained_and_constrained, state) - process_all_new!(gen_fn, args, choices, prev_length, new_length, state) + process_all_new!(rng, gen_fn, args, choices, prev_length, new_length, state) # retdiff retdiff = vector_compute_retdiff(state.updated_retdiffs, new_length, prev_length) diff --git a/src/modeling_library/mixture.jl b/src/modeling_library/mixture.jl index 46763e5c3..2c55892a1 100644 --- a/src/modeling_library/mixture.jl +++ b/src/modeling_library/mixture.jl @@ -57,7 +57,8 @@ struct HomogeneousMixture{T} <: Distribution{T} dims::Vector{Int} end -(dist::HomogeneousMixture)(args...) = random(dist, args...) +(dist::HomogeneousMixture)(args...) = dist(default_rng(), args...) +(dist::HomogeneousMixture)(rng::AbstractRNG, args...) = random(rng, dist, args...) Gen.has_output_grad(dist::HomogeneousMixture) = has_output_grad(dist.base_dist) Gen.has_argument_grads(dist::HomogeneousMixture) = (true, has_argument_grads(dist.base_dist)...) @@ -69,9 +70,9 @@ function args_for_component(dist::HomogeneousMixture, k::Int, args) for (arg, dim) in zip(args, dist.dims)) end -function Gen.random(dist::HomogeneousMixture, weights, args...) +function Gen.random(rng::AbstractRNG, dist::HomogeneousMixture, weights, args...) k = categorical(weights) - return random(dist.base_dist, args_for_component(dist, k, args)...) + return random(rng, dist.base_dist, args_for_component(dist, k, args)...) end function Gen.logpdf(dist::HomogeneousMixture, x, weights, args...) @@ -170,7 +171,8 @@ struct HeterogeneousMixture{T} <: Distribution{T} starting_args::Vector{Int} end -(dist::HeterogeneousMixture)(args...) = random(dist, args...) +(dist::HeterogeneousMixture)(args...) = dist(default_rng(), args...) +(dist::HeterogeneousMixture)(rng::AbstractRNG, args...) = random(rng, dist, args...) Gen.has_output_grad(dist::HeterogeneousMixture) = dist.has_output_grad Gen.has_argument_grads(dist::HeterogeneousMixture) = dist.has_argument_grads @@ -211,10 +213,11 @@ function extract_args_for_component(dist::HeterogeneousMixture, component_args_f return component_args_flat[start_arg:start_arg+n-1] end -function Gen.random(dist::HeterogeneousMixture{T}, weights, component_args_flat...) where {T} +function Gen.random(rng::AbstractRNG, dist::HeterogeneousMixture{T}, weights, component_args_flat...) where {T} (length(weights) != dist.K) && error(MIXTURE_WRONG_NUM_COMPONENTS_ERR) k = categorical(weights) value::T = random( + rng, dist.distributions[k], extract_args_for_component(dist, component_args_flat, k)...) return value diff --git a/src/modeling_library/modeling_library.jl b/src/modeling_library/modeling_library.jl index 13d6e4880..98d986252 100644 --- a/src/modeling_library/modeling_library.jl +++ b/src/modeling_library/modeling_library.jl @@ -3,14 +3,16 @@ ############################# import Distributions + using SpecialFunctions: loggamma, logbeta, digamma abstract type Distribution{T} end """ - val::T = random(dist::Distribution{T}, args...) + val::T = random([rng::AbstractRNG], dist::Distribution{T}, args...) -Sample a random choice from the given distribution with the given arguments. +Sample a random choice from the given distribution with the given arguments. The RNG state can be optionally supplied as the first +argument. If `rng` is not supplied, `Random.default_rng()` will be used by default. """ function random end @@ -40,6 +42,14 @@ Otherwise, this element contains the gradient with respect to the `i`th argument """ function logpdf_grad end +random(dist::Distribution, args...) = random(default_rng(), dist, args...) +function random(rng::AbstractRNG, dist::Distribution, args...) + # TODO: For backwards compatibility only. Remove in next breaking version. + @warn "Missing concrete implementation of `random(::AbstractRNG, ::$(typeof(dist)), args...), `" * + "falling back to `random(::$(typeof(dist)), args...)`." + return random(dist, args) +end + is_discrete(::Distribution) = false # default # NOTE: has_argument_grad is documented and exported in gen_fn_interface.jl diff --git a/src/modeling_library/recurse/recurse.jl b/src/modeling_library/recurse/recurse.jl index 7b4d1d23d..7e561b6c1 100644 --- a/src/modeling_library/recurse/recurse.jl +++ b/src/modeling_library/recurse/recurse.jl @@ -131,8 +131,9 @@ end # TODO accepts_output_grad(::Recurse) = false -function (gen_fn::Recurse)(args...) - (_, _, retval) = propose(gen_fn, args) +(gen_fn::Recurse)(args...) = gen_fn(default_rng(), args...) +function (gen_fn::Recurse)(rng::AbstractRNG, args...) + (_, _, retval) = propose(rng, gen_fn, args) retval end @@ -197,7 +198,7 @@ end # simulate # ############ -function simulate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}) where {S,T,U,V,W,X,Y} +function simulate(rng::AbstractRNG, gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}) where {S,T,U,V,W,X,Y} (root_production_input::U, root_idx::Int) = args production_traces = PersistentHashMap{Int,S}() aggregation_traces = PersistentHashMap{Int,T}() @@ -213,7 +214,7 @@ function simulate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}) where {S,T cur = first(prod_to_visit) delete!(prod_to_visit, cur) input = get_production_input(gen_fn, cur, production_traces, root_idx, root_production_input) - subtrace = simulate(gen_fn.production_kern, (input,)) + subtrace = simulate(rng, gen_fn.production_kern, (input,)) score += get_score(subtrace) production_traces = assoc(production_traces, cur, subtrace) children_inputs::Vector{U} = get_retval(subtrace).children @@ -232,7 +233,7 @@ function simulate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}) where {S,T local subtrace::T local input::Tuple{V,Vector{W}} input = get_aggregation_input(gen_fn, cur, production_traces, aggregation_traces) - subtrace = simulate(gen_fn.aggregation_kern, input) + subtrace = simulate(rng, gen_fn.aggregation_kern, input) score += get_score(subtrace) aggregation_traces = assoc(aggregation_traces, cur, subtrace) if !isempty(get_choices(subtrace)) @@ -249,7 +250,7 @@ end # generate # ############ -function generate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}, +function generate(rng::AbstractRNG, gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}, constraints::ChoiceMap) where {S,T,U,V,W,X,Y} (root_production_input::U, root_idx::Int) = args production_traces = PersistentHashMap{Int,S}() @@ -268,7 +269,7 @@ function generate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}, delete!(prod_to_visit, cur) input = get_production_input(gen_fn, cur, production_traces, root_idx, root_production_input) subconstraints = get_production_constraints(constraints, cur) - (subtrace, subweight) = generate(gen_fn.production_kern, (input,), subconstraints) + (subtrace, subweight) = generate(rng, gen_fn.production_kern, (input,), subconstraints) score += get_score(subtrace) production_traces = assoc(production_traces, cur, subtrace) weight += subweight @@ -289,7 +290,7 @@ function generate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}, local input::Tuple{V,Vector{W}} input = get_aggregation_input(gen_fn, cur, production_traces, aggregation_traces) subconstraints = get_aggregation_constraints(constraints, cur) - (subtrace, subweight) = generate(gen_fn.aggregation_kern, input, subconstraints) + (subtrace, subweight) = generate(rng, gen_fn.aggregation_kern, input, subconstraints) score += get_score(subtrace) aggregation_traces = assoc(aggregation_traces, cur, subtrace) weight += subweight @@ -438,7 +439,8 @@ function get_aggregation_argdiffs(production_retdiffs::Dict{Int,Diff}, (dv, VectorDiff(new_num_children, prev_num_children, dws)) end -function update(trace::RecurseTrace{S,T,U,V,W,X,Y}, +function update(rng::AbstractRNG, + trace::RecurseTrace{S,T,U,V,W,X,Y}, new_args::Tuple{U,Int}, argdiffs::Tuple, constraints::ChoiceMap) where {S,T,U,V,W,X,Y} @@ -504,7 +506,7 @@ function update(trace::RecurseTrace{S,T,U,V,W,X,Y}, # call update on production kernel prev_subtrace = production_traces[cur] (subtrace, subweight, subretdiff, subdiscard) = update( - prev_subtrace, input, (subargdiff,), subconstraints) + rng, prev_subtrace, input, (subargdiff,), subconstraints) prev_num_children = get_num_children(production_traces[cur]) new_num_children = length(get_retval(subtrace).children) idx_to_prev_num_children[cur] = prev_num_children @@ -622,7 +624,7 @@ function update(trace::RecurseTrace{S,T,U,V,W,X,Y}, # call update on aggregation kernel prev_subtrace = aggregation_traces[cur] (subtrace, subweight, subretdiff, subdiscard) = update( - prev_subtrace, input, subargdiffs, subconstraints) + rng, prev_subtrace, input, subargdiffs, subconstraints) # update trace, weight, and score, and discard aggregation_traces = assoc(aggregation_traces, cur, subtrace) @@ -649,7 +651,7 @@ function update(trace::RecurseTrace{S,T,U,V,W,X,Y}, # if the node does not exist (but its children do, since we created them already) else - (subtrace, _) = generate(gen_fn.aggregation_kern, input, subconstraints) + (subtrace, _) = generate(rng, gen_fn.aggregation_kern, input, subconstraints) # update trace, weight, and score aggregation_traces = assoc(aggregation_traces, cur, subtrace) diff --git a/src/modeling_library/switch/generate.jl b/src/modeling_library/switch/generate.jl index 02f0fa0de..9b5ca184b 100644 --- a/src/modeling_library/switch/generate.jl +++ b/src/modeling_library/switch/generate.jl @@ -8,13 +8,14 @@ mutable struct SwitchGenerateState{T} SwitchGenerateState{T}(score::Float64, noise::Float64, weight::Float64) where T = new{T}(score, noise, weight) end -function process!(gen_fn::Switch{C, N, K, T}, +function process!(rng::AbstractRNG, + gen_fn::Switch{C, N, K, T}, index::Int, args::Tuple, choices::ChoiceMap, state::SwitchGenerateState{T}) where {C, N, K, T} - (subtrace, weight) = generate(getindex(gen_fn.branches, index), args, choices) + (subtrace, weight) = generate(rng, getindex(gen_fn.branches, index), args, choices) state.index = index state.subtrace = subtrace state.weight += weight @@ -23,13 +24,14 @@ end @inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchGenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state) -function generate(gen_fn::Switch{C, N, K, T}, +function generate(rng::AbstractRNG, + gen_fn::Switch{C, N, K, T}, args::Tuple, choices::ChoiceMap) where {C, N, K, T} index = args[1] state = SwitchGenerateState{T}(0.0, 0.0, 0.0) - process!(gen_fn, index, args[2 : end], choices, state) + process!(rng, gen_fn, index, args[2 : end], choices, state) return SwitchTrace(gen_fn, state.subtrace, state.retval, args, state.score, state.noise), state.weight diff --git a/src/modeling_library/switch/propose.jl b/src/modeling_library/switch/propose.jl index b4df1d97f..412d73e3e 100644 --- a/src/modeling_library/switch/propose.jl +++ b/src/modeling_library/switch/propose.jl @@ -5,12 +5,13 @@ mutable struct SwitchProposeState{T} SwitchProposeState{T}(choices, weight) where T = new{T}(choices, weight) end -function process!(gen_fn::Switch{C, N, K, T}, +function process!(rng::AbstractRNG, + gen_fn::Switch{C, N, K, T}, index::Int, args::Tuple, state::SwitchProposeState{T}) where {C, N, K, T} - (submap, weight, retval) = propose(getindex(gen_fn.branches, index), args) + (submap, weight, retval) = propose(rng, getindex(gen_fn.branches, index), args) state.choices = submap state.weight += weight state.retval = retval @@ -18,12 +19,13 @@ end @inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchProposeState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state) -function propose(gen_fn::Switch{C, N, K, T}, +function propose(rng::AbstractRNG, + gen_fn::Switch{C, N, K, T}, args::Tuple) where {C, N, K, T} index = args[1] choices = choicemap() state = SwitchProposeState{T}(choices, 0.0) - process!(gen_fn, index, args[2:end], state) + process!(rng, gen_fn, index, args[2:end], state) return state.choices, state.weight, state.retval end diff --git a/src/modeling_library/switch/regenerate.jl b/src/modeling_library/switch/regenerate.jl index db78312ac..394156b24 100644 --- a/src/modeling_library/switch/regenerate.jl +++ b/src/modeling_library/switch/regenerate.jl @@ -9,7 +9,8 @@ mutable struct SwitchRegenerateState{T} SwitchRegenerateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) end -function process!(gen_fn::Switch{C, N, K, T}, +function process!(rng::AbstractRNG, + gen_fn::Switch{C, N, K, T}, index::Int, index_argdiff::Diff, args::Tuple, @@ -18,7 +19,7 @@ function process!(gen_fn::Switch{C, N, K, T}, state::SwitchRegenerateState{T}) where {C, N, K, T} branch_fn = getfield(gen_fn.branches, index) merged = get_selected(get_choices(state.prev_trace), complement(selection)) - new_trace, weight = generate(branch_fn, args, merged) + new_trace, weight = generate(rng, branch_fn, args, merged) retdiff = UnknownChange() weight -= project(state.prev_trace, complement(selection)) weight += (project(new_trace, selection) - project(state.prev_trace, selection)) @@ -30,14 +31,15 @@ function process!(gen_fn::Switch{C, N, K, T}, state.retdiff = retdiff end -function process!(gen_fn::Switch{C, N, K, T}, +function process!(rng::AbstractRNG, + gen_fn::Switch{C, N, K, T}, index::Int, index_argdiff::NoChange, args::Tuple, kernel_argdiffs::Tuple, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, K, T} - new_trace, weight, retdiff = regenerate(getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection) + new_trace, weight, retdiff = regenerate(rng, getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection) state.index = index state.weight = weight state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) @@ -48,14 +50,15 @@ end @inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, selection, state) -function regenerate(trace::SwitchTrace{A, T, U}, +function regenerate(rng::AbstractRNG, + trace::SwitchTrace{A, T, U}, args::Tuple, argdiffs::Tuple, selection::Selection) where {A, T, U} gen_fn = trace.gen_fn index, index_argdiff = args[1], argdiffs[1] state = SwitchRegenerateState{T}(0.0, 0.0, 0.0, trace) - process!(gen_fn, index, index_argdiff, args[2 : end], argdiffs[2 : end], selection, state) + process!(rng, gen_fn, index, index_argdiff, args[2 : end], argdiffs[2 : end], selection, state) return SwitchTrace(gen_fn, state.trace, get_retval(state.trace), args, state.score, state.noise), state.weight, state.retdiff diff --git a/src/modeling_library/switch/simulate.jl b/src/modeling_library/switch/simulate.jl index 74eb4811d..bd54cfb2e 100644 --- a/src/modeling_library/switch/simulate.jl +++ b/src/modeling_library/switch/simulate.jl @@ -7,12 +7,13 @@ mutable struct SwitchSimulateState{T} SwitchSimulateState{T}(score::Float64, noise::Float64) where T = new{T}(score, noise) end -function process!(gen_fn::Switch{C, N, K, T}, +function process!(rng::AbstractRNG, + gen_fn::Switch{C, N, K, T}, index::Int, args::Tuple, state::SwitchSimulateState{T}) where {C, N, K, T} local retval::T - subtrace = simulate(getindex(gen_fn.branches, index), args) + subtrace = simulate(rng, getindex(gen_fn.branches, index), args) state.index = index state.noise += project(subtrace, EmptySelection()) state.subtrace = subtrace @@ -22,12 +23,13 @@ end @inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchSimulateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state) -function simulate(gen_fn::Switch{C, N, K, T}, +function simulate(rng::AbstractRNG, + gen_fn::Switch{C, N, K, T}, args::Tuple) where {C, N, K, T} index = args[1] state = SwitchSimulateState{T}(0.0, 0.0) - process!(gen_fn, index, args[2 : end], state) + process!(rng, gen_fn, index, args[2 : end], state) return SwitchTrace(gen_fn, state.subtrace, state.retval, args, state.score, state.noise) diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index 9ef5752e1..e487d3452 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -19,13 +19,15 @@ has_argument_grads(switch_fn::Switch) = map(zip(map(has_argument_grads, switch_f end accepts_output_grad(switch_fn::Switch) = all(accepts_output_grad, switch_fn.branches) -function (gen_fn::Switch)(index::Int, args...) - (_, _, retval) = propose(gen_fn, (index, args...)) +(gen_fn::Switch)(index, args...) = gen_fn(default_rng(), index, args...) + +function (gen_fn::Switch)(rng::AbstractRNG, index::Int, args...) + (_, _, retval) = propose(rng, gen_fn, (index, args...)) retval end -function (gen_fn::Switch{C})(index::C, args...) where C - (_, _, retval) = propose(gen_fn, (gen_fn.cases[index], args...)) +function (gen_fn::Switch{C})(rng::AbstractRNG, index::C, args...) where C + (_, _, retval) = propose(rng, gen_fn, (gen_fn.cases[index], args...)) retval end diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl index d14af1278..16de72cc2 100644 --- a/src/modeling_library/switch/update.jl +++ b/src/modeling_library/switch/update.jl @@ -82,18 +82,21 @@ Returns choices from previous trace that: @inline update_discard(prev_trace::Trace, choices::ChoiceMap, new_trace::Trace) = update_discard(get_choices(prev_trace), choices, get_choices(new_trace)) -function process!(gen_fn::Switch{C, N, K, T}, - index::Int, - index_argdiff::UnknownChange, - args::Tuple, - kernel_argdiffs::Tuple, - choices::ChoiceMap, - state::SwitchUpdateState{T}) where {C, N, K, T} +function process!( + rng::AbstractRNG, + gen_fn::Switch{C, N, K, T}, + index::Int, + index_argdiff::UnknownChange, + args::Tuple, + kernel_argdiffs::Tuple, + choices::ChoiceMap, + state::SwitchUpdateState{T} +) where {C, N, K, T} # Generate new trace. merged = update_recurse_merge(get_choices(state.prev_trace), choices) branch_fn = getfield(gen_fn.branches, index) - new_trace, weight = generate(branch_fn, args, merged) + new_trace, weight = generate(rng, branch_fn, args, merged) weight -= get_score(state.prev_trace) state.discard = update_discard(state.prev_trace, choices, new_trace) @@ -106,16 +109,19 @@ function process!(gen_fn::Switch{C, N, K, T}, state.updated_retdiff = UnknownChange() end -function process!(gen_fn::Switch{C, N, K, T}, - index::Int, - index_argdiff::NoChange, # TODO: Diffed wrapper? - args::Tuple, - kernel_argdiffs::Tuple, - choices::ChoiceMap, - state::SwitchUpdateState{T}) where {C, N, K, T} +function process!( + rng::AbstractRNG, + gen_fn::Switch{C, N, K, T}, + index::Int, + index_argdiff::NoChange, # TODO: Diffed wrapper? + args::Tuple, + kernel_argdiffs::Tuple, + choices::ChoiceMap, + state::SwitchUpdateState{T} +) where {C, N, K, T} # Update trace. - new_trace, weight, retdiff, discard = update(getfield(state.prev_trace, :branch), args, kernel_argdiffs, choices) + new_trace, weight, retdiff, discard = update(rng, getfield(state.prev_trace, :branch), args, kernel_argdiffs, choices) # Set state. state.index = index @@ -127,16 +133,19 @@ function process!(gen_fn::Switch{C, N, K, T}, state.discard = discard end -@inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, choices::ChoiceMap, state::SwitchUpdateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, choices, state) +@inline process!(rng::AbstractRNG, gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, choices::ChoiceMap, state::SwitchUpdateState{T}) where {C, N, K, T} = process!(rng, gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, choices, state) -function update(trace::SwitchTrace{A, T, U}, - args::Tuple, - argdiffs::Tuple, - choices::ChoiceMap) where {A, T, U} +function update( + rng::AbstractRNG, + trace::SwitchTrace{A, T, U}, + args::Tuple, + argdiffs::Tuple, + choices::ChoiceMap +) where {A, T, U} gen_fn = trace.gen_fn index, index_argdiff = args[1], argdiffs[1] state = SwitchUpdateState{T}(0.0, 0.0, 0.0, trace) - process!(gen_fn, index, index_argdiff, + process!(rng, gen_fn, index, index_argdiff, args[2 : end], argdiffs[2 : end], choices, state) return SwitchTrace(gen_fn, state.trace, get_retval(state.trace), args, diff --git a/src/modeling_library/unfold/assess.jl b/src/modeling_library/unfold/assess.jl index 4199f77da..aa83de3a6 100644 --- a/src/modeling_library/unfold/assess.jl +++ b/src/modeling_library/unfold/assess.jl @@ -4,7 +4,7 @@ mutable struct UnfoldAssessState{T} state::T end -function process_new!(gen_fn::Unfold{T,U}, params::Tuple, choices::ChoiceMap, +function process_new!(rng::AbstractRNG, gen_fn::Unfold{T,U}, params::Tuple, choices::ChoiceMap, key::Int, state::UnfoldAssessState{T}) where {T,U} local new_state::T kernel_args = (key, state.state, params...) @@ -21,7 +21,7 @@ function assess(gen_fn::Unfold{T,U}, args::Tuple, choices::ChoiceMap) where {T,U params = args[3:end] state = UnfoldAssessState{T}(0., Vector{T}(undef,len), init_state) for key=1:len - process_new!(gen_fn, params, choices, key, state) + process_new!(default_rng(), gen_fn, params, choices, key, state) end (state.weight, PersistentVector{T}(state.retvals)) end diff --git a/src/modeling_library/unfold/generate.jl b/src/modeling_library/unfold/generate.jl index 3ef9a78b9..d6c9a4cdd 100644 --- a/src/modeling_library/unfold/generate.jl +++ b/src/modeling_library/unfold/generate.jl @@ -8,13 +8,13 @@ mutable struct UnfoldGenerateState{T,U} state::T end -function process!(gen_fn::Unfold{T,U}, params::Tuple, choices::ChoiceMap, +function process!(rng::AbstractRNG, gen_fn::Unfold{T,U}, params::Tuple, choices::ChoiceMap, key::Int, state::UnfoldGenerateState{T,U}) where {T,U} local subtrace::U local new_state::T kernel_args = (key, state.state, params...) submap = get_submap(choices, key) - (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap) + (subtrace, weight) = generate(rng, gen_fn.kernel, kernel_args, submap) state.weight += weight state.noise += project(subtrace, EmptySelection()) state.num_nonempty += (isempty(get_choices(subtrace)) ? 0 : 1) @@ -25,14 +25,14 @@ function process!(gen_fn::Unfold{T,U}, params::Tuple, choices::ChoiceMap, state.retval[key] = new_state end -function generate(gen_fn::Unfold{T,U}, args::Tuple, choices::ChoiceMap) where {T,U} +function generate(rng::AbstractRNG, gen_fn::Unfold{T,U}, args::Tuple, choices::ChoiceMap) where {T,U} len = args[1] init_state = args[2] params = args[3:end] state = UnfoldGenerateState{T,U}(0., 0., 0., Vector{U}(undef,len), Vector{T}(undef,len), 0, init_state) for key=1:len - process!(gen_fn, params, choices, key, state) + process!(rng, gen_fn, params, choices, key, state) end trace = VectorTrace{UnfoldType,T,U}(gen_fn, PersistentVector{U}(state.subtraces), PersistentVector{T}(state.retval), diff --git a/src/modeling_library/unfold/generic_update.jl b/src/modeling_library/unfold/generic_update.jl index 1d693277e..1f8c93a53 100644 --- a/src/modeling_library/unfold/generic_update.jl +++ b/src/modeling_library/unfold/generic_update.jl @@ -1,4 +1,4 @@ -function process_all_retained!(gen_fn::Unfold{T,U}, params::Tuple, argdiffs::Tuple, +function process_all_retained!(rng::AbstractRNG, gen_fn::Unfold{T,U}, params::Tuple, argdiffs::Tuple, choices_or_selection, prev_length::Int, new_length::Int, retained_and_targeted::Set{Int}, state) where {T,U} @@ -11,7 +11,7 @@ function process_all_retained!(gen_fn::Unfold{T,U}, params::Tuple, argdiffs::Tup # visit every retained kernel application state_diff = init_state_diff for key=1:min(prev_length,new_length) - state_diff = process_retained!(gen_fn, params, choices_or_selection, + state_diff = process_retained!(rng, gen_fn, params, choices_or_selection, key, (NoChange(), state_diff, param_diffs...), state) end @@ -26,7 +26,7 @@ function process_all_retained!(gen_fn::Unfold{T,U}, params::Tuple, argdiffs::Tup key = 1 visit = true while visit && key <= min(prev_length, new_length) - state_diff = process_retained!(gen_fn, params, choices_or_selection, + state_diff = process_retained!(rng, gen_fn, params, choices_or_selection, key, (NoChange(), state_diff, param_diffs...), state) key += 1 visit = (state_diff != NoChange()) @@ -40,7 +40,7 @@ function process_all_retained!(gen_fn::Unfold{T,U}, params::Tuple, argdiffs::Tup key = to_visit[i] visit = true while visit && key <= min(prev_length, new_length) - state_diff = process_retained!(gen_fn, params, choices_or_selection, + state_diff = process_retained!(rng, gen_fn, params, choices_or_selection, key, (NoChange(), state_diff, param_diffs...), state) key += 1 visit = (state_diff != NoChange()) @@ -52,9 +52,9 @@ end """ Process all new applications. """ -function process_all_new!(gen_fn::Unfold{T,U}, params::Tuple, choices_or_selection, +function process_all_new!(rng::AbstractRNG, gen_fn::Unfold{T,U}, params::Tuple, choices_or_selection, prev_len::Int, new_len::Int, state) where {T,U} for key=prev_len+1:new_len - process_new!(gen_fn, params, choices_or_selection, key, state) + process_new!(rng, gen_fn, params, choices_or_selection, key, state) end end diff --git a/src/modeling_library/unfold/propose.jl b/src/modeling_library/unfold/propose.jl index 8863fbd4a..8608aca22 100644 --- a/src/modeling_library/unfold/propose.jl +++ b/src/modeling_library/unfold/propose.jl @@ -5,7 +5,7 @@ mutable struct UnfoldProposeState{T} state::T end -function process_new!(gen_fn::Unfold{T,U}, params::Tuple, key::Int, +function process_new!(rng::AbstractRNG, gen_fn::Unfold{T,U}, params::Tuple, key::Int, state::UnfoldProposeState{T}) where {T,U} local new_state::T kernel_args = (key, state.state, params...) @@ -16,14 +16,14 @@ function process_new!(gen_fn::Unfold{T,U}, params::Tuple, key::Int, state.state = new_state end -function propose(gen_fn::Unfold{T,U}, args::Tuple) where {T,U} +function propose(rng::AbstractRNG, gen_fn::Unfold{T,U}, args::Tuple) where {T,U} len = args[1] init_state = args[2] params = args[3:end] choices = choicemap() state = UnfoldProposeState{T}(choices, 0., Vector{T}(undef,len), init_state) for key=1:len - process_new!(gen_fn, params, key, state) + process_new!(rng, gen_fn, params, key, state) end (state.choices, state.weight, PersistentVector{T}(state.retvals)) end diff --git a/src/modeling_library/unfold/regenerate.jl b/src/modeling_library/unfold/regenerate.jl index 6b4f1eb99..aae2a0c39 100644 --- a/src/modeling_library/unfold/regenerate.jl +++ b/src/modeling_library/unfold/regenerate.jl @@ -9,7 +9,7 @@ mutable struct UnfoldRegenerateState{T,U} updated_retdiffs::Dict{Int,Diff} end -function process_retained!(gen_fn::Unfold{T,U}, params::Tuple, +function process_retained!(rng::AbstractRNG, gen_fn::Unfold{T,U}, params::Tuple, selection::Selection, key::Int, kernel_argdiffs::Tuple, state::UnfoldRegenerateState{T,U}) where {T,U} local subtrace::U @@ -24,7 +24,7 @@ function process_retained!(gen_fn::Unfold{T,U}, params::Tuple, # get new subtrace with recursive call to regenerate() prev_subtrace = state.subtraces[key] (subtrace, weight, retdiff) = regenerate( - prev_subtrace, kernel_args, kernel_argdiffs, subselection) + rng, prev_subtrace, kernel_args, kernel_argdiffs, subselection) # retrieve retdiff if retdiff != NoChange() @@ -49,7 +49,7 @@ function process_retained!(gen_fn::Unfold{T,U}, params::Tuple, retdiff end -function process_new!(gen_fn::Unfold{T,U}, params::Tuple, selection::Selection, key::Int, +function process_new!(rng::AbstractRNG, gen_fn::Unfold{T,U}, params::Tuple, selection::Selection, key::Int, state::UnfoldRegenerateState{T,U}) where {T,U} local subtrace::U local prev_state::T @@ -62,7 +62,7 @@ function process_new!(gen_fn::Unfold{T,U}, params::Tuple, selection::Selection, kernel_args = (key, prev_state, params...) # get subtrace and weight - (subtrace, weight) = generate(gen_fn.kernel, kernel_args, EmptyChoiceMap()) + (subtrace, weight) = generate(rng, gen_fn.kernel, kernel_args, EmptyChoiceMap()) # update state state.weight += weight @@ -77,7 +77,7 @@ function process_new!(gen_fn::Unfold{T,U}, params::Tuple, selection::Selection, end end -function regenerate(trace::VectorTrace{UnfoldType,T,U}, +function regenerate(rng::AbstractRNG, trace::VectorTrace{UnfoldType,T,U}, args::Tuple, argdiffs::Tuple, selection::Selection) where {T,U} gen_fn = trace.gen_fn @@ -98,9 +98,9 @@ function regenerate(trace::VectorTrace{UnfoldType,T,U}, # handle retained and new applications state = UnfoldRegenerateState{T,U}(init_state, -noise_decrement, score, noise, subtraces, retval, num_nonempty, Dict{Int,Diff}()) - process_all_retained!(gen_fn, params, argdiffs, selection, prev_length, new_length, + process_all_retained!(rng, gen_fn, params, argdiffs, selection, prev_length, new_length, retained_and_selected, state) - process_all_new!(gen_fn, params, selection, prev_length, new_length, state) + process_all_new!(rng, gen_fn, params, selection, prev_length, new_length, state) # retdiff retdiff = vector_compute_retdiff(state.updated_retdiffs, new_length, prev_length) diff --git a/src/modeling_library/unfold/simulate.jl b/src/modeling_library/unfold/simulate.jl index e161e64a2..592598048 100644 --- a/src/modeling_library/unfold/simulate.jl +++ b/src/modeling_library/unfold/simulate.jl @@ -7,12 +7,12 @@ mutable struct UnfoldSimulateState{T,U} state::T end -function process!(gen_fn::Unfold{T,U}, params::Tuple, +function process!(rng::AbstractRNG, gen_fn::Unfold{T,U}, params::Tuple, key::Int, state::UnfoldSimulateState{T,U}) where {T,U} local subtrace::U local new_state::T kernel_args = (key, state.state, params...) - subtrace = simulate(gen_fn.kernel, kernel_args) + subtrace = simulate(rng, gen_fn.kernel, kernel_args) state.noise += project(subtrace, EmptySelection()) state.num_nonempty += (isempty(get_choices(subtrace)) ? 0 : 1) state.score += get_score(subtrace) @@ -22,14 +22,14 @@ function process!(gen_fn::Unfold{T,U}, params::Tuple, state.retval[key] = new_state end -function simulate(gen_fn::Unfold{T,U}, args::Tuple) where {T,U} +function simulate(rng::AbstractRNG, gen_fn::Unfold{T,U}, args::Tuple) where {T,U} len = args[1] init_state = args[2] params = args[3:end] state = UnfoldSimulateState{T,U}(0., 0., Vector{U}(undef,len), Vector{T}(undef,len), 0, init_state) for key=1:len - process!(gen_fn, params, key, state) + process!(rng, gen_fn, params, key, state) end VectorTrace{UnfoldType,T,U}(gen_fn, PersistentVector{U}(state.subtraces), PersistentVector{T}(state.retval), diff --git a/src/modeling_library/unfold/unfold.jl b/src/modeling_library/unfold/unfold.jl index 44238e3b7..87c2551e0 100644 --- a/src/modeling_library/unfold/unfold.jl +++ b/src/modeling_library/unfold/unfold.jl @@ -42,8 +42,9 @@ end # TODO accepts_output_grad(gen_fn::Unfold) = false -function (gen_fn::Unfold)(args...) - (_, _, retval) = propose(gen_fn, args) +(gen_fn::Unfold)(args...) = gen_fn(default_rng(), args...) +function (gen_fn::Unfold)(rng::AbstractRNG, args...) + (_, _, retval) = propose(rng, gen_fn, args) retval end diff --git a/src/modeling_library/unfold/update.jl b/src/modeling_library/unfold/update.jl index 0090aaa86..f190509cc 100644 --- a/src/modeling_library/unfold/update.jl +++ b/src/modeling_library/unfold/update.jl @@ -10,7 +10,7 @@ mutable struct UnfoldUpdateState{T,U} updated_retdiffs::Dict{Int,Diff} end -function process_retained!(gen_fn::Unfold{T,U}, params::Tuple, +function process_retained!(rng::AbstractRNG, gen_fn::Unfold{T,U}, params::Tuple, choices::ChoiceMap, key::Int, kernel_argdiffs::Tuple, state::UnfoldUpdateState{T,U}) where {T,U} local subtrace::U @@ -25,7 +25,7 @@ function process_retained!(gen_fn::Unfold{T,U}, params::Tuple, # get new subtrace with recursive call to update() prev_subtrace = state.subtraces[key] (subtrace, weight, retdiff, discard) = update( - prev_subtrace, kernel_args, kernel_argdiffs, submap) + rng, prev_subtrace, kernel_args, kernel_argdiffs, submap) # retrieve retdiff if retdiff != NoChange() @@ -51,7 +51,7 @@ function process_retained!(gen_fn::Unfold{T,U}, params::Tuple, retdiff end -function process_new!(gen_fn::Unfold{T,U}, params::Tuple, choices, key::Int, +function process_new!(rng::AbstractRNG, gen_fn::Unfold{T,U}, params::Tuple, choices, key::Int, state::UnfoldUpdateState{T,U}) where {T,U} local subtrace::U local prev_state::T @@ -62,7 +62,7 @@ function process_new!(gen_fn::Unfold{T,U}, params::Tuple, choices, key::Int, kernel_args = (key, prev_state, params...) # get subtrace and weight - (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap) + (subtrace, weight) = generate(rng, gen_fn.kernel, kernel_args, submap) # update state state.weight += weight @@ -77,7 +77,7 @@ function process_new!(gen_fn::Unfold{T,U}, params::Tuple, choices, key::Int, end end -function update(trace::VectorTrace{UnfoldType,T,U}, +function update(rng::AbstractRNG, trace::VectorTrace{UnfoldType,T,U}, args::Tuple, argdiffs::Tuple, choices::ChoiceMap) where {T,U} gen_fn = trace.gen_fn @@ -98,9 +98,9 @@ function update(trace::VectorTrace{UnfoldType,T,U}, # handle retained and new applications state = UnfoldUpdateState{T,U}(init_state, -score_decrement, score, noise, subtraces, retval, discard, num_nonempty, Dict{Int,Diff}()) - process_all_retained!(gen_fn, params, argdiffs, choices, prev_length, new_length, + process_all_retained!(rng, gen_fn, params, argdiffs, choices, prev_length, new_length, retained_and_constrained, state) - process_all_new!(gen_fn, params, choices, prev_length, new_length, state) + process_all_new!(rng, gen_fn, params, choices, prev_length, new_length, state) # retdiff retdiff = vector_compute_retdiff(state.updated_retdiffs, new_length, prev_length) diff --git a/src/static_ir/generate.jl b/src/static_ir/generate.jl index 643686766..22ffa507e 100644 --- a/src/static_ir/generate.jl +++ b/src/static_ir/generate.jl @@ -33,7 +33,7 @@ function process!(state::StaticIRGenerateState, node::RandomChoiceNode, options) push!(state.stmts, :($incr = $(GlobalRef(Gen, :logpdf))($dist, $(node.name), $(args...)))) push!(state.stmts, :($weight += $incr)) else - push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :random))($dist, $(args...)))) + push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :random))($STATIC_RNG, $dist, $(args...)))) push!(state.stmts, :($incr = $(GlobalRef(Gen, :logpdf))($dist, $(node.name), $(args...)))) end push!(state.stmts, :($(get_value_fieldname(node)) = $(node.name))) @@ -54,9 +54,9 @@ function process!(state::StaticIRGenerateState, node::GenerativeFunctionCallNode subconstraints = gensym("subconstraints") if isa(schema, StaticAddressSchema) && (node.addr in keys(schema)) push!(state.stmts, :($subconstraints = $(GlobalRef(Gen, :static_get_submap))(constraints, Val($addr)))) - push!(state.stmts, :(($subtrace, $incr) = $(GlobalRef(Gen, :generate))($gen_fn, $args_tuple, $subconstraints))) + push!(state.stmts, :(($subtrace, $incr) = $(GlobalRef(Gen, :generate))($STATIC_RNG, $gen_fn, $args_tuple, $subconstraints))) else - push!(state.stmts, :(($subtrace, $incr) = $(GlobalRef(Gen, :generate))($gen_fn, $args_tuple, $(GlobalRef(Gen, :EmptyChoiceMap))()))) + push!(state.stmts, :(($subtrace, $incr) = $(GlobalRef(Gen, :generate))($STATIC_RNG, $gen_fn, $args_tuple, $(GlobalRef(Gen, :EmptyChoiceMap))()))) end push!(state.stmts, :($weight += $incr)) push!(state.stmts, :($num_nonempty_fieldname += !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($subtrace)) ? 1 : 0)) @@ -72,7 +72,7 @@ function codegen_generate(gen_fn_type::Type{T}, args, # convert the constraints to a static assignment if it is not already one if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema)) - return quote $(GlobalRef(Gen, :generate))(gen_fn, args, $(QuoteNode(StaticChoiceMap))(constraints)) end + return quote $(GlobalRef(Gen, :generate))($STATIC_RNG, gen_fn, args, $(QuoteNode(StaticChoiceMap))(constraints)) end end ir = get_ir(gen_fn_type) diff --git a/src/static_ir/simulate.jl b/src/static_ir/simulate.jl index 02fb78800..a1d049f66 100644 --- a/src/static_ir/simulate.jl +++ b/src/static_ir/simulate.jl @@ -25,7 +25,7 @@ function process!(state::StaticIRSimulateState, node::RandomChoiceNode, options) incr = gensym("logpdf") addr = QuoteNode(node.addr) dist = QuoteNode(node.dist) - push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :random))($dist, $(args...)))) + push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :random))($STATIC_RNG, $dist, $(args...)))) push!(state.stmts, :($incr = $(GlobalRef(Gen, :logpdf))($dist, $(node.name), $(args...)))) push!(state.stmts, :($(get_value_fieldname(node)) = $(node.name))) push!(state.stmts, :($(get_score_fieldname(node)) = $incr)) @@ -40,7 +40,7 @@ function process!(state::StaticIRSimulateState, node::GenerativeFunctionCallNode gen_fn = QuoteNode(node.generative_function) subtrace = get_subtrace_fieldname(node) incr = gensym("weight") - push!(state.stmts, :($subtrace = $(QuoteNode(simulate))($gen_fn, $args_tuple))) + push!(state.stmts, :($subtrace = $(QuoteNode(simulate))($STATIC_RNG, $gen_fn, $args_tuple))) push!(state.stmts, :($num_nonempty_fieldname += !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($subtrace)) ? 1 : 0)) push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :get_retval))($subtrace))) push!(state.stmts, :($total_score_fieldname += $(GlobalRef(Gen, :get_score))($subtrace))) diff --git a/src/static_ir/static_ir.jl b/src/static_ir/static_ir.jl index 3e27810f1..cb8efaf8b 100644 --- a/src/static_ir/static_ir.jl +++ b/src/static_ir/static_ir.jl @@ -17,6 +17,9 @@ end # trace code generation include("trace.jl") +"Global reference to the RNG variable for the static modeling language." +const STATIC_RNG = gensym("rng") + """ StaticIRGenerativeFunction{T,U} <: GenerativeFunction{T,U} @@ -63,18 +66,18 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati $(GlobalRef(Gen, :get_options))(::Type{$gen_fn_type_name}) = $(QuoteNode(options)) # Generate GFI definitions (gen_fn::$gen_fn_type_name)(args...) = $(GlobalRef(Gen, :propose))(gen_fn, args)[3] - @generated function $(GlobalRef(Gen, :simulate))(gen_fn::$gen_fn_type_name, args::$(QuoteNode(Tuple))) + @generated function $(GlobalRef(Gen, :simulate))($STATIC_RNG::$AbstractRNG, gen_fn::$gen_fn_type_name, args::$(QuoteNode(Tuple))) $(QuoteNode(codegen_simulate))(gen_fn, args) end - @generated function $(GlobalRef(Gen, :generate))(gen_fn::$gen_fn_type_name, args::$(QuoteNode(Tuple)), + @generated function $(GlobalRef(Gen, :generate))($STATIC_RNG::$AbstractRNG, gen_fn::$gen_fn_type_name, args::$(QuoteNode(Tuple)), constraints::$(QuoteNode(ChoiceMap))) $(QuoteNode(codegen_generate))(gen_fn, args, constraints) end - @generated function $(GlobalRef(Gen, :update))(trace::$trace_type, args::$(QuoteNode(Tuple)), + @generated function $(GlobalRef(Gen, :update))($STATIC_RNG::$AbstractRNG, trace::$trace_type, args::$(QuoteNode(Tuple)), argdiffs::$(QuoteNode(Tuple)), constraints::$(QuoteNode(ChoiceMap))) $(QuoteNode(codegen_update))(trace, args, argdiffs, constraints) end - @generated function $(GlobalRef(Gen, :regenerate))(trace::$trace_type, args::$(QuoteNode(Tuple)), + @generated function $(GlobalRef(Gen, :regenerate))($STATIC_RNG::$AbstractRNG, trace::$trace_type, args::$(QuoteNode(Tuple)), argdiffs::$(QuoteNode(Tuple)), selection::$(QuoteNode(Selection))) $(QuoteNode(codegen_regenerate))(trace, args, argdiffs, selection) end diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index 5768791c1..c4e2f5331 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -256,7 +256,7 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, output_value = Expr(:call, (GlobalRef(Gen, :strip_diff)), node.name) if node in fwd.constrained_or_selected_choices # the choice was selected, it does not contribute to the weight - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))($(GlobalRef(Gen, :random))($dist, $(arg_values...)), UnknownChange()))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))($(GlobalRef(Gen, :random))($STATIC_RNG, $dist, $(arg_values...)), UnknownChange()))) push!(stmts, :($new_logpdf = $(GlobalRef(Gen, :logpdf))($dist, $output_value, $(arg_values...)))) else # the choice was not selected, and the input to the choice changed @@ -282,7 +282,7 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, if node in fwd.constrained_or_selected_choices || node in fwd.input_changed if node in fwd.constrained_or_selected_choices # the choice was selected, it does not contribute to the weight - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :random))($dist, $(arg_values...)))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :random))($STATIC_RNG, $dist, $(arg_values...)))) push!(stmts, :($new_logpdf = $(GlobalRef(Gen, :logpdf))($dist, $(node.name), $(arg_values...)))) else # the choice was not selected, and the input to the choice changed @@ -323,7 +323,7 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, push!(stmts, :($call_constraints = $(GlobalRef(Gen, :EmptyChoiceMap))())) end push!(stmts, :(($subtrace, $call_weight, $(calldiff_var(node)), $(call_discard_var(node))) = - $(GlobalRef(Gen, :update))($prev_subtrace, $(Expr(:tuple, arg_values...)), $(Expr(:tuple, arg_diffs...)), $call_constraints))) + $(GlobalRef(Gen, :update))($STATIC_RNG, $prev_subtrace, $(Expr(:tuple, arg_values...)), $(Expr(:tuple, arg_diffs...)), $call_constraints))) push!(stmts, :($weight += $call_weight)) push!(stmts, :($total_score_fieldname += $(GlobalRef(Gen, :get_score))($subtrace) - $(GlobalRef(Gen, :get_score))($prev_subtrace))) push!(stmts, :($total_noise_fieldname += $(GlobalRef(Gen, :project))($subtrace, $(GlobalRef(Gen, :EmptySelection))()) - $(GlobalRef(Gen, :project))($prev_subtrace, $(GlobalRef(Gen, :EmptySelection))()))) @@ -471,7 +471,7 @@ function codegen_update(trace_type::Type{T}, args_type::Type, argdiffs_type::Typ # convert the constraints to a static assignment if it is not already one if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema)) - return quote $(GlobalRef(Gen, :update))(trace, args, argdiffs, $(QuoteNode(StaticChoiceMap))(constraints)) end + return quote $(GlobalRef(Gen, :update))($STATIC_RNG, trace, args, argdiffs, $(QuoteNode(StaticChoiceMap))(constraints)) end end ir = get_ir(gen_fn_type) @@ -519,7 +519,7 @@ function codegen_regenerate(trace_type::Type{T}, args_type::Type, argdiffs_type: # convert a hierarchical selection to a static selection if it is not alreay one if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema)) - return quote $(GlobalRef(Gen, :regenerate))(trace, args, argdiffs, $(QuoteNode(StaticSelection))(selection)) end + return quote $(GlobalRef(Gen, :regenerate))($STATIC_RNG, trace, args, argdiffs, $(QuoteNode(StaticSelection))(selection)) end end ir = get_ir(gen_fn_type)