Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

variable naming / destructuring #2465

Merged
merged 2 commits into from
Jan 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 33 additions & 33 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,49 +429,49 @@
and a tuple of initial states for all component samplers.
"""
function gibbs_initialstep_recursive(
rng, model, varnames, samplers, vi, states=(); initial_params=nothing, kwargs...
rng, model, varname_vecs, samplers, vi, states=(); initial_params=nothing, kwargs...
)
# End recursion
if isempty(varnames) && isempty(samplers)
if isempty(varname_vecs) && isempty(samplers)

Check warning on line 435 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L435

Added line #L435 was not covered by tests
return vi, states
end

varnames_local = first(varnames)
sampler_local = first(samplers)
varnames, varname_vecs_tail... = varname_vecs
sampler, samplers_tail... = samplers

Check warning on line 440 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L439-L440

Added lines #L439 - L440 were not covered by tests

# Get the initial values for this component sampler.
initial_params_local = if initial_params === nothing
nothing
else
DynamicPPL.subset(vi, varnames_local)[:]
DynamicPPL.subset(vi, varnames)[:]

Check warning on line 446 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L446

Added line #L446 was not covered by tests
end

# Construct the conditioned model.
model_local, context_local = make_conditional(model, varnames_local, vi)
conditioned_model, context = make_conditional(model, varnames, vi)

Check warning on line 450 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L450

Added line #L450 was not covered by tests

# Take initial step.
_, new_state_local = AbstractMCMC.step(
# Take initial step with the current sampler.
_, new_state = AbstractMCMC.step(

Check warning on line 453 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L453

Added line #L453 was not covered by tests
rng,
model_local,
sampler_local;
conditioned_model,
sampler;
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
initial_params=initial_params_local,
kwargs...,
)
new_vi_local = varinfo(new_state_local)
new_vi_local = varinfo(new_state)

Check warning on line 462 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L462

Added line #L462 was not covered by tests
# Merge in any new variables that were introduced during the step, but that
# were not in the domain of the current sampler.
vi = merge(vi, get_global_varinfo(context_local))
vi = merge(vi, get_global_varinfo(context))

Check warning on line 465 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L465

Added line #L465 was not covered by tests
# Merge the new values for all the variables sampled by the current sampler.
vi = merge(vi, new_vi_local)

states = (states..., new_state_local)
states = (states..., new_state)

Check warning on line 469 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L469

Added line #L469 was not covered by tests
return gibbs_initialstep_recursive(
rng,
model,
varnames[2:end],
samplers[2:end],
varname_vecs_tail,
samplers_tail,
vi,
states;
initial_params=initial_params,
Expand Down Expand Up @@ -624,26 +624,26 @@
function gibbs_step_recursive(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
varnames,
varname_vecs,
samplers,
states,
global_vi,
new_states=();
kwargs...,
)
# End recursion.
if isempty(varnames) && isempty(samplers) && isempty(states)
if isempty(varname_vecs) && isempty(samplers) && isempty(states)

Check warning on line 635 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L635

Added line #L635 was not covered by tests
return global_vi, new_states
end

varnames_local = first(varnames)
sampler_local = first(samplers)
state_local = first(states)
varnames, varname_vecs_tail... = varname_vecs
sampler, samplers_tail... = samplers
state, states_tail... = states

Check warning on line 641 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L639-L641

Added lines #L639 - L641 were not covered by tests

# Construct the conditional model and the varinfo that this sampler should use.
model_local, context_local = make_conditional(model, varnames_local, global_vi)
varinfo_local = subset(global_vi, varnames_local)
varinfo_local = match_linking!!(varinfo_local, state_local, model)
conditioned_model, context = make_conditional(model, varnames, global_vi)
vi = subset(global_vi, varnames)
vi = match_linking!!(vi, state, model)

Check warning on line 646 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L644-L646

Added lines #L644 - L646 were not covered by tests

# TODO(mhauru) The below may be overkill. If the varnames for this sampler are not
# sampled by other samplers, we don't need to `setparams`, but could rather simply
Expand All @@ -654,27 +654,27 @@
# going to be a significant expense anyway.
# Set the state of the current sampler, accounting for any changes made by other
# samplers.
state_local = setparams_varinfo!!(
model_local, sampler_local, state_local, varinfo_local
state = setparams_varinfo!!(

Check warning on line 657 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L657

Added line #L657 was not covered by tests
conditioned_model, sampler, state, vi
)

# Take a step with the local sampler.
new_state_local = last(
AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...)
new_state = last(

Check warning on line 662 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L662

Added line #L662 was not covered by tests
AbstractMCMC.step(rng, conditioned_model, sampler, state; kwargs...)
)

new_vi_local = varinfo(new_state_local)
new_vi_local = varinfo(new_state)

Check warning on line 666 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L666

Added line #L666 was not covered by tests
# Merge the latest values for all the variables in the current sampler.
new_global_vi = merge(get_global_varinfo(context_local), new_vi_local)
new_global_vi = merge(get_global_varinfo(context), new_vi_local)

Check warning on line 668 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L668

Added line #L668 was not covered by tests
new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local))

new_states = (new_states..., new_state_local)
new_states = (new_states..., new_state)

Check warning on line 671 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L671

Added line #L671 was not covered by tests
return gibbs_step_recursive(
rng,
model,
varnames[2:end],
samplers[2:end],
states[2:end],
varname_vecs_tail,
samplers_tail,
states_tail,
new_global_vi,
new_states;
kwargs...,
Expand Down
Loading