Skip to content

Commit

Permalink
update implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Oct 28, 2024
1 parent b9dfa36 commit c7f3163
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,12 @@ function AbstractMCMC.getparams(state::HMCState)
return state.transition.z.θ
end

# Using @set to update state.transition.z.θ can lead to inconsistencies:
# - It retains cached log-joint and gradient computations that become invalid
# - This can cause incorrect evaluations in subsequent steps (e.g. MH)
#
# TODO: adopt https://github.com/TuringLang/MCMCTempering.jl/blob/deb96684496f3fbd011b9f70f28c49a161def23f/ext/MCMCTemperingAdvancedHMCExt.jl#L10-L17
# if in the future the interface provides access to the log density function
function AbstractMCMC.setparams!!(state::HMCState, params)
return @set state.transition.z.θ = params
function AbstractMCMC.setparams!!(model, state::HMCState, params)
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
return Setfield.@set state.transition.z = AdvancedHMC.phasepoint(
hamiltonian, params, state.transition.z.r;
ℓκ=state.transition.z.ℓκ
)
end

"""
Expand Down

0 comments on commit c7f3163

Please sign in to comment.