diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index a5c6f99f7..0f4cd2ecf 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -429,49 +429,49 @@ recursively on the remaining samplers, until no samplers remain. Return the glob and a tuple of initial states for all component samplers. """ function gibbs_initialstep_recursive( - rng, model, varnames, samplers, vi, states=(); initial_params=nothing, kwargs... + rng, model, varname_vecs, samplers, vi, states=(); initial_params=nothing, kwargs... ) # End recursion - if isempty(varnames) && isempty(samplers) + if isempty(varname_vecs) && isempty(samplers) return vi, states end - varnames_local = first(varnames) - sampler_local = first(samplers) + varnames, varname_vecs_tail... = varname_vecs + sampler, samplers_tail... = samplers # Get the initial values for this component sampler. initial_params_local = if initial_params === nothing nothing else - DynamicPPL.subset(vi, varnames_local)[:] + DynamicPPL.subset(vi, varnames)[:] end # Construct the conditioned model. - model_local, context_local = make_conditional(model, varnames_local, vi) + conditioned_model, context = make_conditional(model, varnames, vi) - # Take initial step. - _, new_state_local = AbstractMCMC.step( + # Take initial step with the current sampler. + _, new_state = AbstractMCMC.step( rng, - model_local, - sampler_local; + conditioned_model, + sampler; # FIXME: This will cause issues if the sampler expects initial params in unconstrained space. # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. initial_params=initial_params_local, kwargs..., ) - new_vi_local = varinfo(new_state_local) + new_vi_local = varinfo(new_state) # Merge in any new variables that were introduced during the step, but that # were not in the domain of the current sampler. - vi = merge(vi, get_global_varinfo(context_local)) + vi = merge(vi, get_global_varinfo(context)) # Merge the new values for all the variables sampled by the current sampler. vi = merge(vi, new_vi_local) - states = (states..., new_state_local) + states = (states..., new_state) return gibbs_initialstep_recursive( rng, model, - varnames[2:end], - samplers[2:end], + varname_vecs_tail, + samplers_tail, vi, states; initial_params=initial_params, @@ -624,7 +624,7 @@ function on the tail, until there are no more samplers left. function gibbs_step_recursive( rng::Random.AbstractRNG, model::DynamicPPL.Model, - varnames, + varname_vecs, samplers, states, global_vi, @@ -632,18 +632,18 @@ function gibbs_step_recursive( kwargs..., ) # End recursion. - if isempty(varnames) && isempty(samplers) && isempty(states) + if isempty(varname_vecs) && isempty(samplers) && isempty(states) return global_vi, new_states end - varnames_local = first(varnames) - sampler_local = first(samplers) - state_local = first(states) + varnames, varname_vecs_tail... = varname_vecs + sampler, samplers_tail... = samplers + state, states_tail... = states # Construct the conditional model and the varinfo that this sampler should use. - model_local, context_local = make_conditional(model, varnames_local, global_vi) - varinfo_local = subset(global_vi, varnames_local) - varinfo_local = match_linking!!(varinfo_local, state_local, model) + conditioned_model, context = make_conditional(model, varnames, global_vi) + vi = subset(global_vi, varnames) + vi = match_linking!!(vi, state, model) # TODO(mhauru) The below may be overkill. If the varnames for this sampler are not # sampled by other samplers, we don't need to `setparams`, but could rather simply @@ -654,27 +654,27 @@ function gibbs_step_recursive( # going to be a significant expense anyway. # Set the state of the current sampler, accounting for any changes made by other # samplers. - state_local = setparams_varinfo!!( - model_local, sampler_local, state_local, varinfo_local + state = setparams_varinfo!!( + conditioned_model, sampler, state, vi ) # Take a step with the local sampler. - new_state_local = last( - AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...) + new_state = last( + AbstractMCMC.step(rng, conditioned_model, sampler, state; kwargs...) ) - new_vi_local = varinfo(new_state_local) + new_vi_local = varinfo(new_state) # Merge the latest values for all the variables in the current sampler. - new_global_vi = merge(get_global_varinfo(context_local), new_vi_local) + new_global_vi = merge(get_global_varinfo(context), new_vi_local) new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local)) - new_states = (new_states..., new_state_local) + new_states = (new_states..., new_state) return gibbs_step_recursive( rng, model, - varnames[2:end], - samplers[2:end], - states[2:end], + varname_vecs_tail, + samplers_tail, + states_tail, new_global_vi, new_states; kwargs...,