Skip to content

Commit

Permalink
Merge branch 'master' into torfjelde/abstractmcmc-initial-step-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai authored Dec 10, 2024
2 parents 56775b2 + 646202c commit c0d3b71
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.6.3"
version = "0.6.4"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
13 changes: 5 additions & 8 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,11 @@ MassMatrixAdaptor(m::DiagEuclideanMetric{T}) where {T} =
MassMatrixAdaptor(m::DenseEuclideanMetric{T}) where {T} =
WelfordCov{T}(size(m); cov = copy(m.M⁻¹))

MassMatrixAdaptor(m::Type{TM}, sz::Tuple{Vararg{Int}} = (2,)) where {TM<:AbstractMetric} =
MassMatrixAdaptor(Float64, m, sz)

MassMatrixAdaptor(
::Type{T},
::Type{TM},
sz::Tuple{Vararg{Int}} = (2,),
) where {T,TM<:AbstractMetric} = MassMatrixAdaptor(TM(T, sz))
MassMatrixAdaptor(::Type{TM}, sz::Dims = (2,)) where {TM<:AbstractMetric} =
MassMatrixAdaptor(Float64, TM, sz)

MassMatrixAdaptor(::Type{T}, ::Type{TM}, sz::Dims = (2,)) where {T,TM<:AbstractMetric} =
MassMatrixAdaptor(TM(T, sz))

# Deprecations

Expand Down
51 changes: 32 additions & 19 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,12 @@ function make_initial_params(
initial_params,
)
T = sampler_eltype(spl)
if initial_params == nothing
if initial_params === nothing
d = LogDensityProblems.dimension(logdensity)
initial_params = randn(rng, d)
return randn(rng, T, d)
else
return T.(initial_params)
end
return T.(initial_params)
end

#########
Expand Down Expand Up @@ -345,10 +346,10 @@ end
function make_step_size(
rng::Random.AbstractRNG,
integrator::AbstractIntegrator,
T::Type,
::Type{T},
hamiltonian::Hamiltonian,
initial_params,
)
) where {T}
if integrator.ϵ > 0
ϵ = integrator.ϵ
else
Expand All @@ -361,10 +362,10 @@ end
function make_step_size(
rng::Random.AbstractRNG,
integrator::Symbol,
T::Type,
::Type{T},
hamiltonian::Hamiltonian,
initial_params,
)
) where {T}
ϵ = find_good_stepsize(rng, hamiltonian, initial_params)
@info string("Found initial step size ", ϵ)
return T(ϵ)
Expand All @@ -373,21 +374,33 @@ end
make_integrator(spl::HMCSampler, ϵ::Real) = spl.κ.τ.integrator
make_integrator(spl::AbstractHMCSampler, ϵ::Real) = make_integrator(spl.integrator, ϵ)
make_integrator(i::AbstractIntegrator, ϵ::Real) = i
make_integrator(i::Symbol, ϵ::Real) = make_integrator(Val(i), ϵ)
make_integrator(@nospecialize(i), ::Real) = error("Integrator $i not supported.")
make_integrator(i::Val{:leapfrog}, ϵ::Real) = Leapfrog(ϵ)
make_integrator(i::Val{:jitteredleapfrog}, ϵ::T) where {T<:Real} =
JitteredLeapfrog(ϵ, T(0.1ϵ))
make_integrator(i::Val{:temperedleapfrog}, ϵ::T) where {T<:Real} = TemperedLeapfrog(ϵ, T(1))
function make_integrator(i::Symbol, ϵ::Real)
float_ϵ = AbstractFloat(ϵ)
if i === :leapfrog
return Leapfrog(float_ϵ)
elseif i === :jitteredleapfrog
return JitteredLeapfrog(float_ϵ, float_ϵ / 10)
elseif i === :temperedleapfrog
return TemperedLeapfrog(float_ϵ, oneunit(float_ϵ))
else
error("Integrator $i not supported.")
end
end

#########

make_metric(@nospecialize(i), T::Type, d::Int) = error("Metric $(typeof(i)) not supported.")
make_metric(i::Symbol, T::Type, d::Int) = make_metric(Val(i), T, d)
make_metric(i::AbstractMetric, T::Type, d::Int) = i
make_metric(i::Val{:diagonal}, T::Type, d::Int) = DiagEuclideanMetric(T, d)
make_metric(i::Val{:unit}, T::Type, d::Int) = UnitEuclideanMetric(T, d)
make_metric(i::Val{:dense}, T::Type, d::Int) = DenseEuclideanMetric(T, d)
make_metric(i::AbstractMetric, ::Type, ::Int) = i
function make_metric(i::Symbol, ::Type{T}, d::Int) where {T}
if i === :diagonal
return DiagEuclideanMetric(T, d)
elseif i === :unit
return UnitEuclideanMetric(T, d)
elseif i === :dense
return DenseEuclideanMetric(T, d)
else
error("Metric $i not supported.")
end
end

function make_metric(spl::AbstractHMCSampler, logdensity)
d = LogDensityProblems.dimension(logdensity)
Expand Down
6 changes: 0 additions & 6 deletions src/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,6 @@ Base.size(e::DenseEuclideanMetric, dim...) = size(e._temp, dim...)
Base.show(io::IO, dem::DenseEuclideanMetric) =
print(io, "DenseEuclideanMetric(diag=$(_string_M⁻¹(dem.M⁻¹)))")

# getname functions
for T in (UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric)
@eval getname(::Type{<:$T}) = $T
end
getname(m::T) where {T<:AbstractMetric} = getname(T)

# `rand` functions for `metric` types.

function _rand(
Expand Down

0 comments on commit c0d3b71

Please sign in to comment.