diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 17f2bdac..42d52767 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -26,6 +26,8 @@ using AbstractMCMC: LogDensityModel import StatsBase: sample +const DEFAULT_FLOAT_TYPE = typeof(float(0)) + include("utilities.jl") # Notations diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index b6a58131..ff0cf2a9 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -251,13 +251,6 @@ end ############# ### Utils ### ############# - -function sampler_eltype(::AbstractHMCSampler{T}) where {T<:Real} - return T -end - -######### - function make_init_params(spl::AbstractHMCSampler, logdensity, init_params) T = sampler_eltype(spl) if init_params == nothing diff --git a/src/constructors.jl b/src/constructors.jl index f2238224..546887b9 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -1,4 +1,34 @@ -abstract type AbstractHMCSampler{T<:Real} <: AbstractMCMC.AbstractSampler end +""" + determine_sampler_eltype(xs...) + +Determine the element type to use for the given arguments. + +Symbols are either resolved to the default float type or simply dropped +in favour of determined types from the other arguments. +""" +determine_sampler_eltype(xs...) = float(_determine_sampler_eltype(xs...)) +# NOTE: We want to defer conversion to `float` until the very "end" of the +# process, so as to allow `promote_type` to do it's job properly. +# For example, in the scenario `determine_sampler_eltype(::Int64, ::Float32)` +# we want to return `Float32`, not `Float64`. The latter would occur +# if we did `float(eltype(x))` instead of just `eltype(x)`. +_determine_sampler_eltype(x) = eltype(x) +_determine_sampler_eltype(x::AbstractIntegrator) = integrator_eltype(x) +_determine_sampler_eltype(::Symbol) = DEFAULT_FLOAT_TYPE +function _determine_sampler_eltype(xs...) + xs_not_symbol = filter(!Base.Fix2(isa, Symbol), xs) + isempty(xs_not_symbol) && return DEFAULT_FLOAT_TYPE + return promote_type(map(_determine_sampler_eltype, xs_not_symbol)...) +end + +abstract type AbstractHMCSampler <: AbstractMCMC.AbstractSampler end + +""" + sampler_eltype(sampler) + +Return the element type of the sampler. +""" +function sampler_eltype end ############## ### Custom ### @@ -21,19 +51,17 @@ and `adaptor` after sampling. To access the updated fields use the resulting [`HMCState`](@ref). """ -struct HMCSampler{T<:Real} <: AbstractHMCSampler{T} +struct HMCSampler{K<:AbstractMCMCKernel,M<:AbstractMetric,A<:AbstractAdaptor} <: + AbstractHMCSampler "[`AbstractMCMCKernel`](@ref)." - κ::AbstractMCMCKernel + κ::K "Choice of initial metric [`AbstractMetric`](@ref). The metric type will be preserved during adaption." - metric::AbstractMetric + metric::M "[`AbstractAdaptor`](@ref)." - adaptor::AbstractAdaptor + adaptor::A end -function HMCSampler(κ, metric, adaptor) - T = collect(typeof(metric).parameters)[1] - return HMCSampler{T}(κ, metric, adaptor) -end +sampler_eltype(sampler::HMCSampler) = eltype(sampler.metric) ############ ### NUTS ### @@ -53,7 +81,8 @@ $(FIELDS) NUTS(δ=0.65) # Use target accept ratio 0.65. ``` """ -struct NUTS{T<:Real} <: AbstractHMCSampler{T} +struct NUTS{T<:Real,I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <: + AbstractHMCSampler "Target acceptance rate for dual averaging." δ::T "Maximum doubling tree depth." @@ -61,16 +90,18 @@ struct NUTS{T<:Real} <: AbstractHMCSampler{T} "Maximum divergence during doubling tree." Δ_max::T "Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)" - integrator::Union{Symbol,AbstractIntegrator} + integrator::I "Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption." - metric::Union{Symbol,AbstractMetric} + metric::M end function NUTS(δ; max_depth = 10, Δ_max = 1000.0, integrator = :leapfrog, metric = :diagonal) - T = typeof(δ) - return NUTS(δ, max_depth, T(Δ_max), integrator, metric) + T = determine_sampler_eltype(δ, integrator, metric) + return NUTS(T(δ), max_depth, T(Δ_max), integrator, metric) end +sampler_eltype(::NUTS{T}) where {T} = T + ########### ### HMC ### ########### @@ -89,23 +120,20 @@ $(FIELDS) HMC(10, integrator = Leapfrog(0.05), metric = :diagonal) ``` """ -struct HMC{T<:Real} <: AbstractHMCSampler{T} +struct HMC{I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <: + AbstractHMCSampler "Number of leapfrog steps." n_leapfrog::Int "Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)" - integrator::Union{Symbol,AbstractIntegrator} + integrator::I "Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption." - metric::Union{Symbol,AbstractMetric} + metric::M end -function HMC(n_leapfrog; integrator = :leapfrog, metric = :diagonal) - if integrator isa Symbol - T = typeof(0.0) # current default float type - else - T = integrator_eltype(integrator) - end - return HMC{T}(n_leapfrog, integrator, metric) -end +HMC(n_leapfrog; integrator = :leapfrog, metric = :diagonal) = + HMC(n_leapfrog, integrator, metric) + +sampler_eltype(sampler::HMC) = determine_sampler_eltype(sampler.metric, sampler.integrator) ############# ### HMCDA ### @@ -131,7 +159,7 @@ For more information, please view the following paper ([arXiv link](https://arxi setting path lengths in Hamiltonian Monte Carlo." Journal of Machine Learning Research 15, no. 1 (2014): 1593-1623. """ -struct HMCDA{T<:Real} <: AbstractHMCSampler{T} +struct HMCDA{T<:Real} <: AbstractHMCSampler "Target acceptance rate for dual averaging." δ::T "Target leapfrog length." @@ -142,8 +170,9 @@ struct HMCDA{T<:Real} <: AbstractHMCSampler{T} metric::Union{Symbol,AbstractMetric} end -function HMCDA(δ, λ; init_ϵ = 0, integrator = :leapfrog, metric = :diagonal) - δ, λ = promote(δ, λ) - T = typeof(δ) - return HMCDA(δ, T(λ), integrator, metric) +function HMCDA(δ, λ; integrator = :leapfrog, metric = :diagonal) + T = determine_sampler_eltype(δ, λ, integrator, metric) + return HMCDA(T(δ), T(λ), integrator, metric) end + +sampler_eltype(::HMCDA{T}) where {T} = T diff --git a/src/metric.jl b/src/metric.jl index a7976f40..5660c140 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -23,6 +23,7 @@ UnitEuclideanMetric(dim::Int) = UnitEuclideanMetric(Float64, (dim,)) renew(ue::UnitEuclideanMetric, M⁻¹) = UnitEuclideanMetric(M⁻¹, ue.size) +Base.eltype(::UnitEuclideanMetric{T}) where {T} = T Base.size(e::UnitEuclideanMetric) = e.size Base.size(e::UnitEuclideanMetric, dim::Int) = e.size[dim] Base.show(io::IO, uem::UnitEuclideanMetric) = @@ -47,6 +48,7 @@ DiagEuclideanMetric(dim::Int) = DiagEuclideanMetric(Float64, dim) renew(ue::DiagEuclideanMetric, M⁻¹) = DiagEuclideanMetric(M⁻¹) +Base.eltype(::DiagEuclideanMetric{T}) where {T} = T Base.size(e::DiagEuclideanMetric, dim...) = size(e.M⁻¹, dim...) Base.show(io::IO, dem::DiagEuclideanMetric) = print(io, "DiagEuclideanMetric($(_string_M⁻¹(dem.M⁻¹)))") @@ -80,6 +82,7 @@ DenseEuclideanMetric(sz::Tuple{Int}) = DenseEuclideanMetric(Float64, sz) renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹) +Base.eltype(::DenseEuclideanMetric{T}) where {T} = T Base.size(e::DenseEuclideanMetric, dim...) = size(e._temp, dim...) Base.show(io::IO, dem::DenseEuclideanMetric) = print(io, "DenseEuclideanMetric(diag=$(_string_M⁻¹(dem.M⁻¹)))") diff --git a/test/constructors.jl b/test/constructors.jl index ea0be912..3d0594d9 100644 --- a/test/constructors.jl +++ b/test/constructors.jl @@ -2,7 +2,8 @@ using AdvancedHMC, AbstractMCMC, Random include("common.jl") @testset "Constructors" begin - θ_init = randn(2) + d = 2 + θ_init = randn(d) model = AbstractMCMC.LogDensityModel(ℓπ_gdemo) @testset "$T" for T in [Float32, Float64] @@ -15,6 +16,38 @@ include("common.jl") integrator_type = Leapfrog{T}, ), ), + ( + HMC(25, metric = DiagEuclideanMetric(ones(T, 2))), + ( + adaptor_type = NoAdaptation, + metric_type = DiagEuclideanMetric{T}, + integrator_type = Leapfrog{T}, + ), + ), + ( + HMC(25, integrator = Leapfrog(T(0.1)), metric = :unit), + ( + adaptor_type = NoAdaptation, + metric_type = UnitEuclideanMetric{T}, + integrator_type = Leapfrog{T}, + ), + ), + ( + HMC(25, integrator = Leapfrog(T(0.1)), metric = :dense), + ( + adaptor_type = NoAdaptation, + metric_type = DenseEuclideanMetric{T}, + integrator_type = Leapfrog{T}, + ), + ), + ( + HMCDA(T(0.8), one(T), integrator = Leapfrog(T(0.1))), + ( + adaptor_type = NesterovDualAveraging, + metric_type = DiagEuclideanMetric{T}, + integrator_type = Leapfrog{T}, + ), + ), # This should perform the correct promotion for the 2nd argument. ( HMCDA(T(0.8), 1, integrator = Leapfrog(T(0.1))),