Skip to content

Commit

Permalink
Report percentage of divergent transitions in progress bar (#297)
Browse files Browse the repository at this point in the history
* Update hamiltonian.jl

* Update src/hamiltonian.jl

* some tweaks + print percentage of numerical error per samples

* report percentage of divergent transitions in progress bar

* add warning msg if percentage of divergent transitions is above a threshhold

* minor fixes

* minor fixes

* add missing numerical error check for static HMC

* Apply suggestions from code review

* Seprate divergent transitions counting for warmup

* Fixed typos.

* Some minor tweaks + bugfixes.

* rm redundant is_adapt msg.

* Update Project.toml

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Fixed formatting.

* Update src/trajectory.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
yebai and github-actions[bot] authored Jan 27, 2023
1 parent c6bb734 commit 967d8a1
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 19 deletions.
34 changes: 34 additions & 0 deletions .github/workflows/Format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Format

on:
pull_request:
push:
branches:
- master
- main

concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
format:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@latest
with:
version: 1
- name: Format code
run: |
using Pkg
Pkg.add(; name="JuliaFormatter", uuid="98e50ef6-434e-11e9-1051-2b60c6c9e899")
using JuliaFormatter
format("."; verbose=true)
shell: julia --color=yes {0}
- uses: reviewdog/action-suggester@v1
with:
tool_name: JuliaFormatter
fail_on_error: true
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.4.2"
version = "0.4.3"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
34 changes: 32 additions & 2 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,16 @@ struct HMCProgressCallback{P}
progress::Bool
"If `progress` is not specified and this is `true` some information will be logged upon completion of adaptation."
verbose::Bool
"Number of divergent transitions fo far."
num_divergent_transitions::Ref{Int}
num_divergent_transitions_during_adaption::Ref{Int}
end

function HMCProgressCallback(n_samples; progress = true, verbose = false)
pm =
progress ? ProgressMeter.Progress(n_samples, desc = "Sampling", barlen = 31) :
nothing
HMCProgressCallback(pm, progress, verbose)
HMCProgressCallback(pm, progress, verbose, Ref(0), Ref(0))
end

function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; nadapts = 0, kwargs...)
Expand All @@ -266,11 +269,38 @@ function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; nadapts = 0, kw
κ = state.κ
tstat = t.stat
isadapted = tstat.is_adapt
if isadapted
cb.num_divergent_transitions_during_adaption[] += tstat.numerical_error
else
cb.num_divergent_transitions[] += tstat.numerical_error
end

# Update progress meter
if progress
percentage_divergent_transitions = cb.num_divergent_transitions[] / i
percentage_divergent_transitions_during_adaption =
cb.num_divergent_transitions_during_adaption[] / i
if percentage_divergent_transitions > 0.25
@warn "The level of numerical errors is high. Please check the model carefully." maxlog =
3
end
# Do include current iteration and mass matrix
pm_next!(pm, (iterations = i, tstat..., mass_matrix = metric))
pm_next!(
pm,
(
iterations = i,
ratio_divergent_transitions = round(
percentage_divergent_transitions;
digits = 2,
),
ratio_divergent_transitions_during_adaption = round(
percentage_divergent_transitions_during_adaption;
digits = 2,
),
tstat...,
mass_matrix = metric,
),
)
# Report finish of adapation
elseif verbose && isadapted && i == nadapts
@info "Finished $nadapts adapation steps" adaptor κ.τ.integrator metric
Expand Down
11 changes: 5 additions & 6 deletions src/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,7 @@ struct PhasePoint{T<:AbstractVecOrMat{<:AbstractFloat},V<:DualValue}
function PhasePoint::T, r::T, ℓπ::V, ℓκ::V) where {T,V}
@argcheck length(θ) == length(r) == length(ℓπ.gradient) == length(ℓπ.gradient)
if any(isfinite.((θ, r, ℓπ, ℓκ)) .== false)
@warn "The current proposal will be rejected due to numerical error(s)." isfinite.((
θ,
r,
ℓπ,
ℓκ,
))
# @warn "The current proposal will be rejected due to numerical error(s)." isfinite.((θ, r, ℓπ, ℓκ))
# NOTE eltype has to be inlined to avoid type stability issue; see #267
ℓπ = DualValue(
map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓπ.value),
Expand Down Expand Up @@ -176,8 +171,12 @@ refresh(
) = phasepoint(h, z.θ, rand(rng, h.metric, h.kinetic))

"""
$(TYPEDEF)
Partial momentum refreshment with refresh rate `α`.
# Fields
$(TYPEDFIELDS)
See equation (5.19) [1]
r' = α⋅r + sqrt(1-α²)⋅G
Expand Down
31 changes: 30 additions & 1 deletion src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ function sample(
# Prepare containers to store sampling results
n_keep = n_samples - (drop_warmup ? n_adapts : 0)
θs, stats = Vector{T}(undef, n_keep), Vector{NamedTuple}(undef, n_keep)
num_divergent_transitions = 0
num_divergent_transitions_during_adaption = 0
# Initial sampling
h, t = sample_init(rng, h, θ)
# Progress meter
Expand All @@ -184,11 +186,38 @@ function sample(
tstat = stat(t)
h, κ, isadapted =
adapt!(h, κ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate)
if isadapted
num_divergent_transitions_during_adaption += tstat.numerical_error
else
num_divergent_transitions += tstat.numerical_error
end
tstat = merge(tstat, (is_adapt = isadapted,))
# Update progress meter
if progress
percentage_divergent_transitions = num_divergent_transitions / i
percentage_divergent_transitions_during_adaption =
num_divergent_transitions_during_adaption / i
if percentage_divergent_transitions > 0.25
@warn "The level of numerical errors is high. Please check the model carefully." maxlog =
3
end
# Do include current iteration and mass matrix
pm_next!(pm, (iterations = i, tstat..., mass_matrix = h.metric))
pm_next!(
pm,
(
iterations = i,
ratio_divergent_transitions = round(
percentage_divergent_transitions;
digits = 2,
),
ratio_divergent_transitions_during_adaption = round(
percentage_divergent_transitions_during_adaption;
digits = 2,
),
tstat...,
mass_matrix = h.metric,
),
)
# Report finish of adapation
elseif verbose && isadapted && i == n_adapts
@info "Finished $n_adapts adapation steps" adaptor κ.τ.integrator h.metric
Expand Down
5 changes: 4 additions & 1 deletion src/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ function transition(
z = accept_phasepoint!(z, z′, is_accept) # NOTE: this function changes `z′` in place in matrix-parallel mode
# Reverse momentum variable to preserve reversibility
z = PhasePoint(z.θ, -z.r, z.ℓπ, z.ℓκ)
H = energy(z)
# Get cached hamiltonian energy
H, H′ = energy(z), energy(z′)
tstat = merge(
(
n_steps = nsteps(τ),
Expand All @@ -257,6 +258,8 @@ function transition(
log_density = z.ℓπ.value,
hamiltonian_energy = H,
hamiltonian_energy_error = H - H0,
# check numerical error in proposed phase point.
numerical_error = isfinite(H′),
),
stat.integrator),
)
Expand Down
11 changes: 3 additions & 8 deletions test/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,9 @@ end
DualValue(zero(T), [zero(T)]),
)

@test_logs (
:warn,
"The current proposal will be rejected due to numerical error(s).",
) init_z1()
@test_logs (
:warn,
"The current proposal will be rejected due to numerical error(s).",
) init_z2()
# (HongGe) we no longer throw warning messages for numerical errors.
# @test_logs (:warn, "The current proposal will be rejected due to numerical error(s).") init_z1()
# @test_logs (:warn, "The current proposal will be rejected due to numerical error(s).") init_z2()

z1 = init_z1()
z2 = init_z2()
Expand Down

0 comments on commit 967d8a1

Please sign in to comment.