diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index d3a8fe98..67d4ec9e 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -34,14 +34,12 @@ function AbstractMCMC.getparams(state::HMCState) return state.transition.z.θ end -# Using @set to update state.transition.z.θ can lead to inconsistencies: -# - It retains cached log-joint and gradient computations that become invalid -# - This can cause incorrect evaluations in subsequent steps (e.g. MH) -# -# TODO: adopt https://github.com/TuringLang/MCMCTempering.jl/blob/deb96684496f3fbd011b9f70f28c49a161def23f/ext/MCMCTemperingAdvancedHMCExt.jl#L10-L17 -# if in the future the interface provides access to the log density function -function AbstractMCMC.setparams!!(state::HMCState, params) - return @set state.transition.z.θ = params +function AbstractMCMC.setparams!!(model, state::HMCState, params) + hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) + return Setfield.@set state.transition.z = AdvancedHMC.phasepoint( + hamiltonian, params, state.transition.z.r; + ℓκ=state.transition.z.ℓκ + ) end """