Skip to content

Commit

Permalink
Remove selector stuff from SGHMC
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Jan 8, 2025
1 parent 934ebef commit 02f251d
Showing 1 changed file with 13 additions and 22 deletions.
35 changes: 13 additions & 22 deletions src/mcmc/sghmc.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
SGHMC{AD,space}
SGHMC{AD}
Stochastic Gradient Hamiltonian Monte Carlo (SGHMC) sampler.e
Stochastic Gradient Hamiltonian Monte Carlo (SGHMC) sampler.
# Fields
$(TYPEDFIELDS)
Expand All @@ -12,15 +12,14 @@ Tianqi Chen, Emily Fox, & Carlos Guestrin (2014). Stochastic Gradient Hamiltonia
Carlo. In: Proceedings of the 31st International Conference on Machine Learning
(pp. 1683–1691).
"""
struct SGHMC{AD,space,T<:Real} <: StaticHamiltonian
struct SGHMC{AD,T<:Real} <: StaticHamiltonian
learning_rate::T
momentum_decay::T
adtype::AD
end

"""
SGHMC(
space::Symbol...;
SGHMC(;
learning_rate::Real,
momentum_decay::Real,
adtype::ADTypes.AbstractADType = AutoForwardDiff(),
Expand All @@ -37,21 +36,18 @@ Tianqi Chen, Emily Fox, & Carlos Guestrin (2014). Stochastic Gradient Hamiltonia
Carlo. In: Proceedings of the 31st International Conference on Machine Learning
(pp. 1683–1691).
"""
function SGHMC(
space::Symbol...;
function SGHMC(;
learning_rate::Real,
momentum_decay::Real,
adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE,
)
_learning_rate, _momentum_decay = promote(learning_rate, momentum_decay)
return SGHMC{typeof(adtype),space,typeof(_learning_rate)}(
return SGHMC{typeof(adtype),typeof(_learning_rate)}(
_learning_rate, _momentum_decay, adtype
)
end

function drop_space(alg::SGHMC{AD,space,T}) where {AD,space,T}
return SGHMC{AD,(),T}(alg.learning_rate, alg.momentum_decay, alg.adtype)
end
drop_space(alg::SGHMC) = alg

Check warning on line 50 in src/mcmc/sghmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/sghmc.jl#L50

Added line #L50 was not covered by tests

struct SGHMCState{L,V<:AbstractVarInfo,T<:AbstractVector{<:Real}}
logdensity::L
Expand Down Expand Up @@ -128,15 +124,13 @@ Max Welling & Yee Whye Teh (2011). Bayesian Learning via Stochastic Gradient Lan
Dynamics. In: Proceedings of the 28th International Conference on Machine Learning
(pp. 681–688).
"""
struct SGLD{AD,space,S} <: StaticHamiltonian
struct SGLD{AD,S} <: StaticHamiltonian
"Step size function."
stepsize::S
adtype::AD
end

function drop_space(alg::SGLD{AD,space,S}) where {AD,space,S}
return SGLD{AD,(),S}(alg.stepsize, alg.adtype)
end
drop_space(alg::SGLD) = alg

Check warning on line 133 in src/mcmc/sghmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/sghmc.jl#L133

Added line #L133 was not covered by tests

struct PolynomialStepsize{T<:Real}
"Constant scale factor of the step size."
Expand Down Expand Up @@ -172,8 +166,7 @@ end
(f::PolynomialStepsize)(t::Int) = f.a / (t + f.b)^f.γ

"""
SGLD(
space::Symbol...;
SGLD(;
stepsize = PolynomialStepsize(0.01),
adtype::ADTypes.AbstractADType = AutoForwardDiff(),
)
Expand All @@ -193,12 +186,10 @@ Dynamics. In: Proceedings of the 28th International Conference on Machine Learni
See also: [`PolynomialStepsize`](@ref)
"""
function SGLD(
space::Symbol...;
stepsize=PolynomialStepsize(0.01),
adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE,
function SGLD(;
stepsize=PolynomialStepsize(0.01), adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE
)
return SGLD{typeof(adtype),space,typeof(stepsize)}(stepsize, adtype)
return SGLD{typeof(adtype),typeof(stepsize)}(stepsize, adtype)
end

struct SGLDTransition{T,F<:Real} <: AbstractTransition
Expand Down

0 comments on commit 02f251d

Please sign in to comment.