Skip to content

Commit

Permalink
Further improvements to recently introduced constructors (#340)
Browse files Browse the repository at this point in the history
* remove type parameter from AbstractHMCSampler, and added eltype for metrics

* Update src/constructors.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added determine_sampler_eltype to unify handling of different argument types

* fixed issues with conversion of arguments

* added test for type promotion in the case of HMCDA

* removed unnecessary float calls

* make sampler types concretely typed

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* remove now-redundant type parameter from HMC

* removed unused argument to HMCDA

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/constructors.jl

Co-authored-by: Hong Ge <[email protected]>

---------

Co-authored-by: Jaime RZ <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
4 people authored Jul 27, 2023
1 parent 37481ac commit 8429077
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 38 deletions.
2 changes: 2 additions & 0 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ using AbstractMCMC: LogDensityModel

import StatsBase: sample

const DEFAULT_FLOAT_TYPE = typeof(float(0))

include("utilities.jl")

# Notations
Expand Down
7 changes: 0 additions & 7 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,6 @@ end
#############
### Utils ###
#############

function sampler_eltype(::AbstractHMCSampler{T}) where {T<:Real}
return T
end

#########

function make_init_params(spl::AbstractHMCSampler, logdensity, init_params)
T = sampler_eltype(spl)
if init_params == nothing
Expand Down
89 changes: 59 additions & 30 deletions src/constructors.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,34 @@
abstract type AbstractHMCSampler{T<:Real} <: AbstractMCMC.AbstractSampler end
"""
determine_sampler_eltype(xs...)
Determine the element type to use for the given arguments.
Symbols are either resolved to the default float type or simply dropped
in favour of determined types from the other arguments.
"""
determine_sampler_eltype(xs...) = float(_determine_sampler_eltype(xs...))
# NOTE: We want to defer conversion to `float` until the very "end" of the
# process, so as to allow `promote_type` to do it's job properly.
# For example, in the scenario `determine_sampler_eltype(::Int64, ::Float32)`
# we want to return `Float32`, not `Float64`. The latter would occur
# if we did `float(eltype(x))` instead of just `eltype(x)`.
_determine_sampler_eltype(x) = eltype(x)
_determine_sampler_eltype(x::AbstractIntegrator) = integrator_eltype(x)
_determine_sampler_eltype(::Symbol) = DEFAULT_FLOAT_TYPE
function _determine_sampler_eltype(xs...)
xs_not_symbol = filter(!Base.Fix2(isa, Symbol), xs)
isempty(xs_not_symbol) && return DEFAULT_FLOAT_TYPE
return promote_type(map(_determine_sampler_eltype, xs_not_symbol)...)
end

abstract type AbstractHMCSampler <: AbstractMCMC.AbstractSampler end

"""
sampler_eltype(sampler)
Return the element type of the sampler.
"""
function sampler_eltype end

##############
### Custom ###
Expand All @@ -21,19 +51,17 @@ and `adaptor` after sampling.
To access the updated fields use the resulting [`HMCState`](@ref).
"""
struct HMCSampler{T<:Real} <: AbstractHMCSampler{T}
struct HMCSampler{K<:AbstractMCMCKernel,M<:AbstractMetric,A<:AbstractAdaptor} <:
AbstractHMCSampler
"[`AbstractMCMCKernel`](@ref)."
κ::AbstractMCMCKernel
κ::K
"Choice of initial metric [`AbstractMetric`](@ref). The metric type will be preserved during adaption."
metric::AbstractMetric
metric::M
"[`AbstractAdaptor`](@ref)."
adaptor::AbstractAdaptor
adaptor::A
end

function HMCSampler(κ, metric, adaptor)
T = collect(typeof(metric).parameters)[1]
return HMCSampler{T}(κ, metric, adaptor)
end
sampler_eltype(sampler::HMCSampler) = eltype(sampler.metric)

############
### NUTS ###
Expand All @@ -53,24 +81,27 @@ $(FIELDS)
NUTS(δ=0.65) # Use target accept ratio 0.65.
```
"""
struct NUTS{T<:Real} <: AbstractHMCSampler{T}
struct NUTS{T<:Real,I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <:
AbstractHMCSampler
"Target acceptance rate for dual averaging."
δ::T
"Maximum doubling tree depth."
max_depth::Int
"Maximum divergence during doubling tree."
Δ_max::T
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
integrator::Union{Symbol,AbstractIntegrator}
integrator::I
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
metric::Union{Symbol,AbstractMetric}
metric::M
end

function NUTS(δ; max_depth = 10, Δ_max = 1000.0, integrator = :leapfrog, metric = :diagonal)
T = typeof)
return NUTS(δ, max_depth, T(Δ_max), integrator, metric)
T = determine_sampler_eltype(δ, integrator, metric)
return NUTS(T(δ), max_depth, T(Δ_max), integrator, metric)
end

sampler_eltype(::NUTS{T}) where {T} = T

###########
### HMC ###
###########
Expand All @@ -89,23 +120,20 @@ $(FIELDS)
HMC(10, integrator = Leapfrog(0.05), metric = :diagonal)
```
"""
struct HMC{T<:Real} <: AbstractHMCSampler{T}
struct HMC{I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <:
AbstractHMCSampler
"Number of leapfrog steps."
n_leapfrog::Int
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
integrator::Union{Symbol,AbstractIntegrator}
integrator::I
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
metric::Union{Symbol,AbstractMetric}
metric::M
end

function HMC(n_leapfrog; integrator = :leapfrog, metric = :diagonal)
if integrator isa Symbol
T = typeof(0.0) # current default float type
else
T = integrator_eltype(integrator)
end
return HMC{T}(n_leapfrog, integrator, metric)
end
HMC(n_leapfrog; integrator = :leapfrog, metric = :diagonal) =
HMC(n_leapfrog, integrator, metric)

sampler_eltype(sampler::HMC) = determine_sampler_eltype(sampler.metric, sampler.integrator)

#############
### HMCDA ###
Expand All @@ -131,7 +159,7 @@ For more information, please view the following paper ([arXiv link](https://arxi
setting path lengths in Hamiltonian Monte Carlo." Journal of Machine Learning
Research 15, no. 1 (2014): 1593-1623.
"""
struct HMCDA{T<:Real} <: AbstractHMCSampler{T}
struct HMCDA{T<:Real} <: AbstractHMCSampler
"Target acceptance rate for dual averaging."
δ::T
"Target leapfrog length."
Expand All @@ -142,8 +170,9 @@ struct HMCDA{T<:Real} <: AbstractHMCSampler{T}
metric::Union{Symbol,AbstractMetric}
end

function HMCDA(δ, λ; init_ϵ = 0, integrator = :leapfrog, metric = :diagonal)
δ, λ = promote(δ, λ)
T = typeof(δ)
return HMCDA(δ, T(λ), integrator, metric)
function HMCDA(δ, λ; integrator = :leapfrog, metric = :diagonal)
T = determine_sampler_eltype(δ, λ, integrator, metric)
return HMCDA(T(δ), T(λ), integrator, metric)
end

sampler_eltype(::HMCDA{T}) where {T} = T
3 changes: 3 additions & 0 deletions src/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ UnitEuclideanMetric(dim::Int) = UnitEuclideanMetric(Float64, (dim,))

renew(ue::UnitEuclideanMetric, M⁻¹) = UnitEuclideanMetric(M⁻¹, ue.size)

Base.eltype(::UnitEuclideanMetric{T}) where {T} = T
Base.size(e::UnitEuclideanMetric) = e.size
Base.size(e::UnitEuclideanMetric, dim::Int) = e.size[dim]
Base.show(io::IO, uem::UnitEuclideanMetric) =
Expand All @@ -47,6 +48,7 @@ DiagEuclideanMetric(dim::Int) = DiagEuclideanMetric(Float64, dim)

renew(ue::DiagEuclideanMetric, M⁻¹) = DiagEuclideanMetric(M⁻¹)

Base.eltype(::DiagEuclideanMetric{T}) where {T} = T
Base.size(e::DiagEuclideanMetric, dim...) = size(e.M⁻¹, dim...)
Base.show(io::IO, dem::DiagEuclideanMetric) =
print(io, "DiagEuclideanMetric($(_string_M⁻¹(dem.M⁻¹)))")
Expand Down Expand Up @@ -80,6 +82,7 @@ DenseEuclideanMetric(sz::Tuple{Int}) = DenseEuclideanMetric(Float64, sz)

renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹)

Base.eltype(::DenseEuclideanMetric{T}) where {T} = T
Base.size(e::DenseEuclideanMetric, dim...) = size(e._temp, dim...)
Base.show(io::IO, dem::DenseEuclideanMetric) =
print(io, "DenseEuclideanMetric(diag=$(_string_M⁻¹(dem.M⁻¹)))")
Expand Down
35 changes: 34 additions & 1 deletion test/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ using AdvancedHMC, AbstractMCMC, Random
include("common.jl")

@testset "Constructors" begin
θ_init = randn(2)
d = 2
θ_init = randn(d)
model = AbstractMCMC.LogDensityModel(ℓπ_gdemo)

@testset "$T" for T in [Float32, Float64]
Expand All @@ -15,6 +16,38 @@ include("common.jl")
integrator_type = Leapfrog{T},
),
),
(
HMC(25, metric = DiagEuclideanMetric(ones(T, 2))),
(
adaptor_type = NoAdaptation,
metric_type = DiagEuclideanMetric{T},
integrator_type = Leapfrog{T},
),
),
(
HMC(25, integrator = Leapfrog(T(0.1)), metric = :unit),
(
adaptor_type = NoAdaptation,
metric_type = UnitEuclideanMetric{T},
integrator_type = Leapfrog{T},
),
),
(
HMC(25, integrator = Leapfrog(T(0.1)), metric = :dense),
(
adaptor_type = NoAdaptation,
metric_type = DenseEuclideanMetric{T},
integrator_type = Leapfrog{T},
),
),
(
HMCDA(T(0.8), one(T), integrator = Leapfrog(T(0.1))),
(
adaptor_type = NesterovDualAveraging,
metric_type = DiagEuclideanMetric{T},
integrator_type = Leapfrog{T},
),
),
# This should perform the correct promotion for the 2nd argument.
(
HMCDA(T(0.8), 1, integrator = Leapfrog(T(0.1))),
Expand Down

2 comments on commit 8429077

@yebai
Copy link
Member

@yebai yebai commented on 8429077 Jul 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JaimeRZP, you can comment on a commit "@JuliaRegistrator register" to trigger a release.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/88461

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.0 -m "<description of version>" 84290771ada04536c5247ac226cb1baed043dbc4
git push origin v0.5.0

Please sign in to comment.