Skip to content

Commit

Permalink
Move optional features to extensions (#320)
Browse files Browse the repository at this point in the history
* Move optional features to extensions

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Add optional dependencies to `[extras]` section

* Load `stat`

* Restrict definition of `step` fallback

* Update src/integrator.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Run tests with Julia 1.6

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
devmotion and github-actions[bot] authored Mar 13, 2023
1 parent 5eb1dd9 commit ab0b078
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 92 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jobs:
strategy:
matrix:
version:
- '1.6'
- '1'
- 'nightly'
os:
Expand Down
19 changes: 17 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.4.4"
version = "0.4.5"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -19,20 +19,35 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"

[extensions]
AdvancedHMCCUDAExt = "CUDA"
AdvancedHMCMCMCChainsExt = "MCMCChains"
AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"

[compat]
AbstractMCMC = "4.2"
ArgCheck = "1, 2"
CUDA = "3, 4"
DocStringExtensions = "0.8, 0.9"
InplaceOps = "0.3"
LogDensityProblems = "2"
LogDensityProblemsAD = "1"
MCMCChains = "5, 6"
OrdinaryDiffEq = "6"
ProgressMeter = "1"
Requires = "0.5, 1"
Setfield = "0.7, 0.8, 1"
SimpleUnPack = "1.1"
StatsBase = "0.31, 0.32, 0.33"
StatsFuns = "0.8, 0.9, 1"
julia = "1"
julia = "1.6"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
30 changes: 20 additions & 10 deletions src/contrib/cuda.jl → ext/AdvancedHMCCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
import .CUDA
module AdvancedHMCCUDAExt

CUDA.allowscalar(false)
if isdefined(Base, :get_extension)
import AdvancedHMC
import CUDA
import Random
else
import ..AdvancedHMC
import ..CUDA
import ..Random
end

function refresh(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
::FullMomentumRefreshment,
h::Hamiltonian,
z::PhasePoint{TA},
function AdvancedHMC.refresh(
rng::Union{Random.AbstractRNG,AbstractVector{<:Random.AbstractRNG}},
::AdvancedHMC.FullMomentumRefreshment,
h::AdvancedHMC.Hamiltonian,
z::AdvancedHMC.PhasePoint{TA},
) where {T<:AbstractFloat,TA<:CUDA.CuArray{<:T}}
r = CUDA.CuArray{T,2}(undef, size(h.metric)...)
CUDA.CURAND.randn!(r)
return phasepoint(h, z.θ, r)
return AdvancedHMC.phasepoint(h, z.θ, r)
end

# TODO: Ideally this should be merged with the CPU version. The function is
# essentially the same but sampling requires a custom call to CUDA.
function mh_accept_ratio(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
function AdvancedHMC.mh_accept_ratio(
rng::Union{Random.AbstractRNG,AbstractVector{<:Random.AbstractRNG}},
Horiginal::TA,
Hproposal::TA,
) where {T<:AbstractFloat,TA<:CUDA.CuArray{<:T}}
Expand All @@ -31,3 +39,5 @@ function mh_accept_ratio(
accept = r .< α
return accept, α
end

end # module
12 changes: 11 additions & 1 deletion src/mcmcchains-connect.jl → ext/AdvancedHMCMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
import .MCMCChains: Chains
module AdvancedHMCMCMCChainsExt

if isdefined(Base, :get_extension)
using AdvancedHMC: AbstractMCMC, Transition, stat
using MCMCChains: Chains
else
using ..AdvancedHMC: AbstractMCMC, Transition, stat
using ..MCMCChains: Chains
end

# A basic chains constructor that works with the Transition struct we defined.
function AbstractMCMC.bundle_samples(
Expand Down Expand Up @@ -36,3 +44,5 @@ function AbstractMCMC.bundle_samples(
thin = thinning,
)
end

end # module
33 changes: 18 additions & 15 deletions src/contrib/diffeq.jl → ext/AdvancedHMCOrdinaryDiffEqExt.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,32 @@
import .OrdinaryDiffEq
module AdvancedHMCOrdinaryDiffEqExt

struct DiffEqIntegrator{T<:AbstractScalarOrVec{<:AbstractFloat},DiffEqSolver} <:
AbstractLeapfrog{T}
ϵ::T
solver::DiffEqSolver
if isdefined(Base, :get_extension)
import AdvancedHMC
import OrdinaryDiffEq
else
import ..AdvancedHMC
import ..OrdinaryDiffEq
end

function step(
integrator::DiffEqIntegrator,
h::Hamiltonian,
function AdvancedHMC.step(
integrator::AdvancedHMC.DiffEqIntegrator,
h::AdvancedHMC.Hamiltonian,
z::P,
n_steps::Int = 1;
fwd::Bool = n_steps > 0, # simulate hamiltonian backward when n_steps < 0
res::Union{Vector{P},P} = z,
) where {P<:PhasePoint}

@unpack θ, r = z
) where {P<:AdvancedHMC.PhasePoint}

AdvancedHMC.@unpack θ, r = z
# For DynamicalODEProblem `u` is `θ` and `v` is `r`
# f1 is dr/dt RHS function
# f2 is dθ/dt RHS function
v0, u0 = r, θ

f1(v, u, p, t) = -∂H∂θ(h, u).gradient
f2(v, u, p, t) = ∂H∂r(h, v)
f1(v, u, p, t) = -AdvancedHMC.∂H∂θ(h, u).gradient
f2(v, u, p, t) = AdvancedHMC.∂H∂r(h, v)

ϵ = fwd ? step_size(integrator) : -step_size(integrator)
ϵ = fwd ? AdvancedHMC.step_size(integrator) : -AdvancedHMC.step_size(integrator)
tspan = (0.0, sign(n_steps))
problem = OrdinaryDiffEq.DynamicalODEProblem(f1, f2, v0, u0, tspan)
diffeq_integrator = OrdinaryDiffEq.init(
Expand All @@ -40,7 +41,7 @@ function step(
for i = 1:abs(n_steps)
OrdinaryDiffEq.step!(diffeq_integrator)
solution = diffeq_integrator.u.x # (r, θ) at the proposed step
z = phasepoint(h, solution[2], solution[1])
z = AdvancedHMC.phasepoint(h, solution[2], solution[1])
!isfinite(z) && break
if res isa Vector
res[i] = z
Expand All @@ -50,3 +51,5 @@ function step(
end
return res
end

end # module
44 changes: 33 additions & 11 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,20 +224,42 @@ end

### Init

using Requires
struct DiffEqIntegrator{T<:AbstractScalarOrVec{<:AbstractFloat},DiffEqSolver} <:
AbstractLeapfrog{T}
ϵ::T
solver::DiffEqSolver
end
export DiffEqIntegrator

if !isdefined(Base, :get_extension)
using Requires
end
function __init__()
@require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" begin
export DiffEqIntegrator
include("contrib/diffeq.jl")
# Better error message if users forgot to load OrdinaryDiffEq
Base.Experimental.register_error_hint(MethodError) do io, exc, arg_types, kwargs
n = length(arg_types)
if exc.f === step &&
(n == 3 || n == 4) &&
arg_types[1] <: DiffEqIntegrator &&
arg_types[2] <: Hamiltonian &&
arg_types[3] <: PhasePoint &&
(n == 3 || arg_types[4] === Int)

print(io, "\\nDid you forget to load OrdinaryDiffEq?")
end
end

@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" begin
include("contrib/cuda.jl")
end

@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" begin
include("mcmcchains-connect.jl")
@static if !isdefined(Base, :get_extension)
@require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" begin
include("../ext/AdvancedHMCOrdinaryDiffEqExt.jl")
end

@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" begin
include("../ext/AdvancedHMCCUDAExt.jl")
end

@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" begin
include("../ext/AdvancedHMCMCMCChainsExt.jl")
end
end
end

Expand Down
110 changes: 57 additions & 53 deletions src/integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,59 +56,6 @@ stat(lf::AbstractLeapfrog) = (step_size = step_size(lf), nom_step_size = nom_ste

update_nom_step_size(lf::AbstractLeapfrog, ϵ) = @set lf.ϵ = ϵ

function step(
lf::AbstractLeapfrog{T},
h::Hamiltonian,
z::P,
n_steps::Int = 1;
fwd::Bool = n_steps > 0, # simulate hamiltonian backward when n_steps < 0
full_trajectory::Val{FullTraj} = Val(false),
) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj}
n_steps = abs(n_steps) # to support `n_steps < 0` cases

ϵ = fwd ? step_size(lf) : -step_size(lf)
ϵ = ϵ'

res = if FullTraj
Vector{P}(undef, n_steps)
else
z
end

@unpack θ, r = z
@unpack value, gradient = z.ℓπ
for i = 1:n_steps
# Tempering
r = temper(lf, r, (i = i, is_half = true), n_steps)
# Take a half leapfrog step for momentum variable
r = r - ϵ / 2 .* gradient
# Take a full leapfrog step for position variable
∇r = ∂H∂r(h, r)
θ = θ + ϵ .* ∇r
# Take a half leapfrog step for momentum variable
@unpack value, gradient = ∂H∂θ(h, θ)
r = r - ϵ / 2 .* gradient
# Tempering
r = temper(lf, r, (i = i, is_half = false), n_steps)
# Create a new phase point by caching the logdensity and gradient
z = phasepoint(h, θ, r; ℓπ = DualValue(value, gradient))
# Update result
if FullTraj
res[i] = z
else
res = z
end
if !isfinite(z)
# Remove undef
if FullTraj
res = res[isassigned.(Ref(res), 1:n_steps)]
end
break
end
end
return res
end

"""
$(TYPEDEF)
Expand Down Expand Up @@ -243,3 +190,60 @@ function temper(
i_temper = 2(step.i - 1) + 1 + !step.is_half # counter for half temper steps
return i_temper <= n_steps ? r * sqrt(lf.α) : r / sqrt(lf.α)
end

# `step` method for integrators above
# method for `DiffEqIntegrator` is defined in the OrdinaryDiffEq extension
const DefaultLeapfrog{FT<:AbstractFloat,T<:AbstractScalarOrVec{FT}} =
Union{Leapfrog{T},JitteredLeapfrog{FT,T},TemperedLeapfrog{FT,T}}
function step(
lf::DefaultLeapfrog{FT,T},
h::Hamiltonian,
z::P,
n_steps::Int = 1;
fwd::Bool = n_steps > 0, # simulate hamiltonian backward when n_steps < 0
full_trajectory::Val{FullTraj} = Val(false),
) where {FT<:AbstractFloat,T<:AbstractScalarOrVec{FT},P<:PhasePoint,FullTraj}
n_steps = abs(n_steps) # to support `n_steps < 0` cases

ϵ = fwd ? step_size(lf) : -step_size(lf)
ϵ = ϵ'

res = if FullTraj
Vector{P}(undef, n_steps)
else
z
end

@unpack θ, r = z
@unpack value, gradient = z.ℓπ
for i = 1:n_steps
# Tempering
r = temper(lf, r, (i = i, is_half = true), n_steps)
# Take a half leapfrog step for momentum variable
r = r - ϵ / 2 .* gradient
# Take a full leapfrog step for position variable
∇r = ∂H∂r(h, r)
θ = θ + ϵ .* ∇r
# Take a half leapfrog step for momentum variable
@unpack value, gradient = ∂H∂θ(h, θ)
r = r - ϵ / 2 .* gradient
# Tempering
r = temper(lf, r, (i = i, is_half = false), n_steps)
# Create a new phase point by caching the logdensity and gradient
z = phasepoint(h, θ, r; ℓπ = DualValue(value, gradient))
# Update result
if FullTraj
res[i] = z
else
res = z
end
if !isfinite(z)
# Remove undef
if FullTraj
res = res[isassigned.(Ref(res), 1:n_steps)]
end
break
end
end
return res
end

2 comments on commit ab0b078

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/79481

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.5 -m "<description of version>" ab0b078cff023db85ca96ddd8a3ccdd76c92ca82
git push origin v0.4.5

Please sign in to comment.