Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add convenience forms for update and regenerate with optional args and argdiffs #236

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
37 changes: 34 additions & 3 deletions src/gen_fn_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap)
end

"""
(new_trace, weight, retdiff, discard) = update(trace, args::Tuple, argdiffs::Tuple,
constraints::ChoiceMap)
(new_trace, weight, retdiff, discard) = update(
trace, args::Tuple, argdiffs::Tuple, constraints::ChoiceMap)

Update a trace by changing the arguments and/or providing new values for some
existing random choice(s) and values for some newly introduced random choice(s).
Expand All @@ -272,10 +272,25 @@ that if the original `trace` was generated using non-default argument values,
then for each optional argument that is omitted, the old value will be
over-written by the default argument value in the updated trace.
"""
function update(trace, args::Tuple, argdiffs::Tuple, ::ChoiceMap)
function update(trace, args::Tuple, argdiffs::Tuple, constraints::ChoiceMap)
error("Not implemented")
end

"""
update(trace;
args::Tuple=get_args(trace),
argdiffs::Tuple=map((_) -> NoChange(), args),
constraints::ChoiceMap=EmptyChoiceMap())

Form of `update` with keyword arguments providing common defaults.
"""
function update(trace;
args::Tuple=get_args(trace),
argdiffs::Tuple=map((_) -> NoChange(), args),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the default be NoChange or UnknownChange?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now there are two convenience forms:

If you provide args, then argdiffs is an optional keyword argument that defaults to UnknownChange. (This is so you don't need to construct the unknown change argdiffs).

If you do not provide args then argdiffs is automatically set to NoChanges. (This is a common case e.g. for top-level generative functions on which we are doing MCMC or optimization).

constraints::ChoiceMap=EmptyChoiceMap())
update(trace, args, argdiffs, constraints)
end

"""
(new_trace, weight, retdiff) = regenerate(trace, args::Tuple, argdiffs::Tuple,
selection::Selection)
Expand Down Expand Up @@ -307,6 +322,22 @@ function regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection)
error("Not implemented")
end

"""
regenerate(trace;
args::Tuple=get_args(trace),
argdiffs::Tuple=map((_) -> NoChange(), args),
selection::Selection=EmptySelection())

Form of `regenerate` with keyword arguments providing common defaults.
"""
function regenerate(trace;
args::Tuple=get_args(trace),
argdiffs::Tuple=map((_) -> NoChange(), args),
selection::Selection=EmptySelection())
regenerate(trace, args, argdiffs, selection)
end


"""
arg_grads = accumulate_param_gradients!(trace, retgrad=nothing, scale_factor=1.)

Expand Down
6 changes: 2 additions & 4 deletions src/inference/elliptical_slice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ Also takes the mean vector and covariance matrix of the prior.
"""
function elliptical_slice(
trace, addr, mu, cov; check=false, observations=EmptyChoiceMap())
args = get_args(trace)
argdiffs = map((_) -> NoChange(), args)

# sample nu
nu = mvnormal(zeros(length(mu)), cov)
Expand All @@ -29,7 +27,7 @@ function elliptical_slice(
f = trace[addr] .- mu

new_f = f * cos(theta) + nu * sin(theta)
new_trace, weight = update(trace, args, argdiffs, choicemap((addr, new_f .+ mu)))
new_trace, weight = update(trace; constraints=choicemap((addr, new_f .+ mu)))
while weight <= log(u)
if theta < 0
theta_min = theta
Expand All @@ -38,7 +36,7 @@ function elliptical_slice(
end
theta = uniform(theta_min, theta_max)
new_f = f * cos(theta) + nu * sin(theta)
new_trace, weight = update(trace, args, argdiffs, choicemap((addr, new_f .+ mu)))
new_trace, weight = update(trace; constraints=choicemap((addr, new_f .+ mu)))
end
check && check_observations(get_choices(new_trace), observations)
return new_trace
Expand Down
4 changes: 1 addition & 3 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ function hmc(
trace::U, selection::Selection; L=10, eps=0.1,
check=false, observations=EmptyChoiceMap()) where {T,U}
prev_model_score = get_score(trace)
args = get_args(trace)
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing
argdiffs = map((_) -> NoChange(), args)

# run leapfrog dynamics
new_trace = trace
Expand All @@ -46,7 +44,7 @@ function hmc(

# get new gradient
values_trie = from_array(values_trie, values)
(new_trace, _, _) = update(new_trace, args, argdiffs, values_trie)
(new_trace, _, _) = update(new_trace; constraints=values_trie)
(_, _, gradient_trie) = choice_gradients(new_trace, selection, retval_grad)
gradient = to_array(gradient_trie, Float64)

Expand Down
2 changes: 1 addition & 1 deletion src/inference/involution_dsl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ function apply_involution(involution::InvolutionDSLProgram, trace, u, proposal_a

# update model trace
(new_trace, model_weight, _, discard) = update(
trace, get_args(trace), map((_) -> NoChange(), get_args(trace)), first_pass_state.constraints)
trace; constraints=first_pass_state.constraints)

# create input array and mappings input addresses that are needed for Jacobian
# exclude addresses that were moved to another address
Expand Down
5 changes: 1 addition & 4 deletions src/inference/mala.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ Apply a Metropolis-Adjusted Langevin Algorithm (MALA) update.
function mala(
trace, selection::Selection, tau::Real;
check=false, observations=EmptyChoiceMap())
args = get_args(trace)
argdiffs = map((_) -> NoChange(), args)
std = sqrt(2 * tau)
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing

Expand All @@ -30,8 +28,7 @@ function mala(

# evaluate model weight
constraints = from_array(values_trie, proposed_values)
(new_trace, weight, _, discard) = update(trace,
args, argdiffs, constraints)
(new_trace, weight, _, discard) = update(trace; constraints=constraints)
check && check_observations(get_choices(new_trace), observations)

# backward proposal
Expand Down
4 changes: 1 addition & 3 deletions src/inference/map_optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ Selected random choices must have support on the entire real line.
"""
function map_optimize(trace, selection::Selection;
max_step_size=0.1, tau=0.5, min_step_size=1e-16, verbose=false)
args = get_args(trace)
argdiffs = map((_) -> NoChange(), args)
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing

(_, values, gradient) = choice_gradients(trace, selection, retval_grad)
Expand All @@ -21,7 +19,7 @@ function map_optimize(trace, selection::Selection;
new_values_vec = values_vec + gradient_vec * step_size
values = from_array(values, new_values_vec)
# TODO discard and weight are not actually needed, there should be a more specialized variant
(new_trace, _, _, discard) = update(trace, args, argdiffs, values)
(new_trace, _, _, discard) = update(trace; constraints=values)
new_score = get_score(new_trace)
change = new_score - score
if verbose
Expand Down
9 changes: 2 additions & 7 deletions src/inference/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ Perform a Metropolis-Hastings update that proposes new values for the selected a
function metropolis_hastings(
trace, selection::Selection;
check=false, observations=EmptyChoiceMap())
args = get_args(trace)
argdiffs = map((_) -> NoChange(), args)
(new_trace, weight) = regenerate(trace, args, argdiffs, selection)
(new_trace, weight) = regenerate(trace; selection=selection)
check && check_observations(get_choices(new_trace), observations)
if log(rand()) < weight
# accept
Expand All @@ -41,12 +39,9 @@ If the proposal modifies addresses that determine the control flow in the model,
function metropolis_hastings(
trace, proposal::GenerativeFunction, proposal_args::Tuple;
check=false, observations=EmptyChoiceMap())
model_args = get_args(trace)
argdiffs = map((_) -> NoChange(), model_args)
proposal_args_forward = (trace, proposal_args...,)
(fwd_choices, fwd_weight, _) = propose(proposal, proposal_args_forward)
(new_trace, weight, _, discard) = update(trace,
model_args, argdiffs, fwd_choices)
(new_trace, weight, _, discard) = update(trace; constraints=fwd_choices)
proposal_args_backward = (new_trace, proposal_args...,)
(bwd_weight, _) = assess(proposal, proposal_args_backward, discard)
alpha = weight - fwd_weight + bwd_weight
Expand Down
5 changes: 3 additions & 2 deletions src/inference/particle_filter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ function particle_filter_step!(state::ParticleFilterState{U}, new_args::Tuple, a
for i=1:num_particles
(prop_choices, prop_weight, _) = propose(proposal, (state.traces[i], proposal_args...))
constraints = merge(observations, prop_choices)
(state.new_traces[i], up_weight, _, disc) = update(state.traces[i], new_args, argdiffs, constraints)
(state.new_traces[i], up_weight, _, disc) = update(
state.traces[i]; args=new_args, argdiffs=argdiffs, constraints=constraints)
@assert isempty(disc)
state.log_weights[i] += up_weight - prop_weight
end
Expand All @@ -166,7 +167,7 @@ function particle_filter_step!(state::ParticleFilterState{U}, new_args::Tuple, a
num_particles = length(state.traces)
for i=1:num_particles
(state.new_traces[i], increment, _, discard) = update(
state.traces[i], new_args, argdiffs, observations)
state.traces[i]; args=new_args, argdiffs=argdiffs, constraints=observations)
if !isempty(discard)
error("Choices were updated or deleted inside particle filter step: $discard")
end
Expand Down