Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use setparams!! rather than reset_state!!
Browse files Browse the repository at this point in the history
mhauru committed Nov 1, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 15ee270 commit d52af52
Showing 2 changed files with 70 additions and 140 deletions.
205 changes: 67 additions & 138 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
@@ -50,9 +50,16 @@ function DynamicPPL.setchildcontext(context::GibbsContext, childcontext)
)
end

get_global_varinfo(context::GibbsContext) = context.global_varinfo[]

function set_global_varinfo!(context::GibbsContext, new_global_varinfo)
context.global_varinfo[] = new_global_varinfo
return nothing

Check warning on line 57 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L55-L57

Added lines #L55 - L57 were not covered by tests
end

# has and get
function has_conditioned_gibbs(context::GibbsContext, vn::VarName)
return DynamicPPL.haskey(context.global_varinfo[], vn)
return DynamicPPL.haskey(get_global_varinfo(context), vn)
end
function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName})
num_conditioned = count(Iterators.map(Base.Fix1(has_conditioned_gibbs, context), vns))
@@ -66,7 +73,7 @@ function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarNa
end

function get_conditioned_gibbs(context::GibbsContext, vn::VarName)
return context.global_varinfo[][vn]
return get_global_varinfo(context)[vn]
end
function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName})
return map(Base.Fix1(get_conditioned_gibbs, context), vns)

Check warning on line 79 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L78-L79

Added lines #L78 - L79 were not covered by tests
@@ -110,9 +117,9 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
DynamicPPL.SampleFromPrior(),
right,
vn,
context.global_varinfo[],
get_global_varinfo(context),
)
context.global_varinfo[] = new_global_vi
set_global_varinfo!(context, new_global_vi)
return value, lp, vi

Check warning on line 123 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L122-L123

Added lines #L122 - L123 were not covered by tests
end
end
@@ -137,9 +144,9 @@ function DynamicPPL.tilde_assume(
DynamicPPL.SampleFromPrior(),
right,
vn,
context.global_varinfo[],
get_global_varinfo(context),
)
context.global_varinfo[] = new_global_vi
set_global_varinfo!(context, new_global_vi)
return value, lp, vi

Check warning on line 150 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L149-L150

Added lines #L149 - L150 were not covered by tests
end
end
@@ -181,9 +188,9 @@ function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi
right,
left,
vns,
context.global_varinfo[],
get_global_varinfo(context),
)
context.global_varinfo[] = new_global_vi
set_global_varinfo!(context, new_global_vi)
return value, lp, vi

Check warning on line 194 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L193-L194

Added lines #L193 - L194 were not covered by tests
end
end
@@ -209,9 +216,9 @@ function DynamicPPL.dot_tilde_assume(
right,
left,
vns,
context.global_varinfo[],
get_global_varinfo(context),
)
context.global_varinfo[] = new_global_vi
set_global_varinfo!(context, new_global_vi)
return value, lp, vi

Check warning on line 222 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L221-L222

Added lines #L221 - L222 were not covered by tests
end
end
@@ -468,139 +475,71 @@ function DynamicPPL.setlogp!!(state::TuringState, logp)
return TuringState(setlogp!!(state.state, logp), logp)

Check warning on line 475 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L474-L475

Added lines #L474 - L475 were not covered by tests
end

# TODO(mhauru) In the general case, which arguments are really needed for reset_state!!?
# The current list is a guess, and I think some are unnecessary.
"""
reset_state!!(rng, model, sampler, state, varinfo, sampler_previous, state_previous)
Return an updated state for a Gibbs component sampler.
This takes into account changes caused by other Gibbs components. The default implementation
is to try to set the `vi` field of `state` to `varinfo`. If this is not the right thing to
do, a method should be implemented for the specific type of `state`.
# Some samplers use a VarInfo directly as the state. In that case, there's little to do in
# `setparams!!`.
function AbstractMCMC.setparams!!(state::VarInfo, params::AbstractVector)
return DynamicPPL.unflatten(state, params)
end

# Arguments
- `model::DynamicPPL.Model`: The model as seen by this component sampler. Variables not
sampled by this component sampler have been conditioned with a `GibbsContext`.
- `sampler::DynamicPPL.Sampler`: The current component sampler.
- `state`: The state of this component sampler from its previous iteration.
- `varinfo::DynamicPPL.AbstractVarInfo`: The current `VarInfo`, subsetted to the variables
sampled by this component sampler.
- `sampler_previous::DynamicPPL.Sampler`: The previous sampler in the Gibbs chain.
- `state_previous`: The state returned by the previous sampler.
function AbstractMCMC.setparams!!(state::VarInfo, params::AbstractVarInfo)
return params

Check warning on line 485 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L484-L485

Added lines #L484 - L485 were not covered by tests
end

# Returns
An updated state of the same type as `state`. It should have variables set to the values in
`varinfo`, and any other relevant updates done.
"""
function reset_state!!(
model, sampler, state, varinfo::AbstractVarInfo, sampler_previous, state_previous
function AbstractMCMC.setparams!!(

Check warning on line 488 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L488

Added line #L488 was not covered by tests
model::DynamicPPL.Model,
state::TuringState,
params::Union{AbstractVector,AbstractVarInfo},
)
# In the fallback implementation we guess that `state` has a field called `vi` we can
# set. Fingers crossed!
try
return Accessors.set(state, Accessors.PropertyLens{:vi}(), varinfo)
catch
error(
"Unable to set `state.vi` for a $(typeof(state)). " *
"Consider writing a method for reset_state!! for this type.",
)
end
new_inner_state = AbstractMCMC.setparams!!(model, state.state, params)
return TuringState(new_inner_state, state.logdensity)

Check warning on line 494 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L493-L494

Added lines #L493 - L494 were not covered by tests
end

# Some samplers use a VarInfo directly as the state. In that case, there's little to do in
# `reset_state!!`.
function reset_state!!(
model,
sampler,
state::AbstractVarInfo,
varinfo::AbstractVarInfo,
sampler_previous,
state_previous,
)
return varinfo
# Unless some other treatment has been specified for this state type, just flatten the
# AbstractVarInfo. This method exists because some sampler types need to override this
# behavior.
function AbstractMCMC.setparams!!(model::DynamicPPL.Model, state, params::AbstractVarInfo)
return AbstractMCMC.setparams!!(model, state, params[:])
end

function reset_state!!(
model,
sampler,
state::TuringState,
varinfo::AbstractVarInfo,
sampler_previous,
state_previous,
function AbstractMCMC.setparams!!(

Check warning on line 504 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L504

Added line #L504 was not covered by tests
model::DynamicPPL.Model, state::HMCState, params::AbstractVarInfo
)
new_inner_state = reset_state!!(
model, sampler, state.state, varinfo, sampler_previous, state_previous
θ_new = params[:]
hamiltonian = get_hamiltonian(model, state.sampler, params, state, length(θ_new))

Check warning on line 508 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L507-L508

Added lines #L507 - L508 were not covered by tests

# Update the parameter values in `state.z`.
# TODO: Avoid mutation
z = state.z
resize!(z.θ, length(θ_new))
z.θ .= θ_new
return HMCState(

Check warning on line 515 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L512-L515

Added lines #L512 - L515 were not covered by tests
params, state.i, state.kernel, hamiltonian, z, state.adaptor, state.sampler
)
return TuringState(new_inner_state, state.logdensity)
end

function reset_state!!(
model,
sampler,
state::HMCState,
varinfo::AbstractVarInfo,
sampler_previous,
state_previous,
function AbstractMCMC.setparams!!(

Check warning on line 520 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L520

Added line #L520 was not covered by tests
model::DynamicPPL.Model, state::HMCState, params::AbstractVector
)
θ_new = varinfo[:]
hamiltonian = get_hamiltonian(model, sampler, varinfo, state, length(θ_new))
θ_new = params
vi = DynamicPPL.unflatten(state.vi, params)
hamiltonian = get_hamiltonian(model, state.sampler, vi, state, length(θ_new))

Check warning on line 525 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L523-L525

Added lines #L523 - L525 were not covered by tests

# Update the parameter values in `state.z`.
# TODO: Avoid mutation
z = state.z
resize!(z.θ, length(θ_new))
z.θ .= θ_new
return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor)
return HMCState(vi, state.i, state.kernel, hamiltonian, z, state.adaptor, state.sampler)

Check warning on line 532 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L529-L532

Added lines #L529 - L532 were not covered by tests
end

function reset_state!!(
model,
sampler,
state::AdvancedHMC.HMCState,
varinfo::AbstractVarInfo,
sampler_previous,
state_previous,
function AbstractMCMC.setparams!!(
model::DynamicPPL.Model, state::PGState, params::AbstractVarInfo
)
hamiltonian = AdvancedHMC.Hamiltonian(
state.metric, DynamicPPL.LogDensityFunction(model)
)
θ_new = varinfo[:]
# Set the momentum to some arbitrary value, making sure it has the right number of
# components. We could try to do something clever here to only reset momenta related to
# new variables, but it'll be resampled in the next iteration anyway.
# TODO(mhauru) Would prefer to set it to zeros rather than ones, but that makes
# ForwardDiff crash for some reason. Should investigate and report as a ForwardDiff bug.
momenta_old = state.transition.z.r
momenta_new = ones(eltype(momenta_old), length(θ_new))
return Accessors.@set state.transition.z = AdvancedHMC.phasepoint(
hamiltonian, θ_new, momenta_new
)
return PGState(params, state.rng)
end

function reset_state!!(
model,
sampler,
state::AdvancedMH.Transition,
varinfo::AbstractVarInfo,
sampler_previous,
state_previous,
)
# TODO(mhauru) Setting the last argument like this seems a bit suspect, since the
# current values for the parameters might not have come from this sampler at all.
# I don't see a better way though.
return AdvancedMH.Transition(varinfo[:], varinfo.logp[], state.accepted)
end

function reset_state!!(
model,
sampler,
state::PGState,
varinfo::AbstractVarInfo,
sampler_previous,
state_previous,
)
return PGState(varinfo, state.rng)
function AbstractMCMC.setparams!!(state::PGState, params::AbstractVector)
return PGState(DynamicPPL.unflatten(state.vi, params), state.rng)

Check warning on line 542 in src/mcmc/gibbs.jl

Codecov / codecov/patch

src/mcmc/gibbs.jl#L541-L542

Added lines #L541 - L542 were not covered by tests
end

function gibbs_step_inner(
@@ -609,7 +548,7 @@ function gibbs_step_inner(
varnames,
samplers,
states,
vi,
global_vi,
index;
kwargs...,
)
@@ -618,23 +557,16 @@ function gibbs_step_inner(
varnames_local = _maybevec(varnames[index])

# Construct the conditional model and the varinfo that this sampler should use.
model_local, context_local = make_conditional(model, varnames_local, vi)
varinfo_local = subset(vi, varnames_local)
model_local, context_local = make_conditional(model, varnames_local, global_vi)
varinfo_local = subset(global_vi, varnames_local)

# Extract the previous sampler and state.
sampler_previous = samplers[index == 1 ? length(samplers) : index - 1]
state_previous = states[index == 1 ? length(states) : index - 1]

# Set the state of the current sampler, accounting for any changes made by other
# samplers.
state_local = reset_state!!(
model_local,
sampler_local,
state_local,
varinfo_local,
sampler_previous,
state_previous,
)
state_local = AbstractMCMC.setparams!!(model_local, state_local, varinfo_local)
if gibbs_requires_recompute_logprob(
model_local, sampler_local, sampler_previous, state_local, state_previous
)
@@ -647,11 +579,8 @@ function gibbs_step_inner(
)

new_vi_local = varinfo(new_state_local)
# This merges in any new variables that were introduced during the step, but that
# were not in the domain of the current sampler.
new_vi = merge(vi, context_local.global_varinfo[])
# This merges the latest values for all the variables in the current sampler.
new_vi = merge(new_vi, new_vi_local)
new_vi = setlogp!!(new_vi, new_vi_local.logp[])
return new_vi, new_state_local
# Merge the latest values for all the variables in the current sampler.
new_global_vi = merge(get_global_varinfo(context_local), new_vi_local)
new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local))
return new_global_vi, new_state_local
end
5 changes: 3 additions & 2 deletions src/mcmc/hmc.jl
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@ struct HMCState{
hamiltonian::THam
z::PhType
adaptor::TAdapt
sampler::Sampler{<:Hamiltonian}
end

###
@@ -229,7 +230,7 @@ function DynamicPPL.initialstep(
end

transition = Transition(model, vi, t)
state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor)
state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor, spl)

return transition, state
end
@@ -275,7 +276,7 @@ function AbstractMCMC.step(

# Compute next transition and state.
transition = Transition(model, vi, t)
newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor)
newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor, spl)

return transition, newstate
end

0 comments on commit d52af52

Please sign in to comment.