Skip to content

Commit

Permalink
targets rework
Browse files Browse the repository at this point in the history
  • Loading branch information
JaimeRZP committed Jan 15, 2024
1 parent ab72fc8 commit 870d9a7
Showing 1 changed file with 24 additions and 44 deletions.
68 changes: 24 additions & 44 deletions src/targets.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,3 @@
NoTransform(x) = x

mutable struct Target
d::Int
vsyms::Any
h::Hamiltonian
transform::Function
inv_transform::Function
prior_draw::Function
end

#=
mutable struct TuringTarget <: Target
model::DynamicPPL.Model
Expand Down Expand Up @@ -88,58 +77,49 @@ TuringTarget(model; kwargs...) = begin
end
=#

CustomTarget(nlogp, grad_nlogp, priors; kwargs...) = begin
d = length(priors)
vsyms = [DynamicPPL.VarName(Symbol("d_", i)) for i = 1:d]
NoTransform(x) = x

function prior_draw()
x = [rand(dist) for dist in priors]
xt = transform(x)
return xt
mutable struct Target
d::Int
h::Hamiltonian
transform::Function
inv_transform::Function
θ_start::Vector{Float64}
θ_names::Vector{String}
end

function CustomTarget(nlogp, grad_nlogp, θ_start::Vector{Float64};
θ_names=nothing,
transform=NoTransform,
inv_transform=NoTransform)
d = length(θ_start)
if θ_names==nothing
θ_names = [String("θ_", i) for i=1:d]
end
hamiltonian = Hamiltonian(nlogp, grad_nlogp)
Target(d, hamiltonian, NoTransform, NoTransform, prior_draw)
return Target(d, Hamiltonian(nlogp, grad_nlogp), transform, inv_transform, θ_start, θ_names)
end

GaussianTarget(_mean::AbstractVector, _cov::AbstractMatrix) = begin
function GaussianTarget(_mean::AbstractVector, _cov::AbstractMatrix)
d = length(_mean)
vsyms = [DynamicPPL.VarName(Symbol("d_", i)) for i = 1:d]

_gaussian = MvNormal(_mean, _cov)
ℓπ::AbstractVector) = logpdf(_gaussian, θ)
∂lπ∂θ::AbstractVector) = (logpdf(_gaussian, θ), gradlogpdf(_gaussian, θ))
hamiltonian = Hamiltonian(ℓπ, ∂lπ∂θ)

function prior_draw()
xt = rand(MvNormal(zeros(d), ones(d)))
return xt
end

Target(d, vsyms, hamiltonian, NoTransform, NoTransform, prior_draw)
θ_start = rand(MvNormal(zeros(d), ones(d)))
return CustomTarget(ℓπ, ∂lπ∂θ, θ_start)
end

RosenbrockTarget(a, b; kwargs...) = begin
function RosenbrockTarget(a, b; kwargs...)
kwargs = Dict(kwargs)
d = kwargs[:d]
vsyms = [DynamicPPL.VarName(Symbol("d_", i)) for i = 1:d]

function ℓπ(x; a = a, b = b)
x1 = x[1:Int(d / 2)]
x2 = x[Int(d / 2)+1:end]
m = @.((a - x1)^2 + b * (x2 - x1^2)^2)
return -0.5 * sum(m)
end

function ∂lπ∂θ(x)
return ℓπ(x), ForwardDiff.gradient(ℓπ, x)
end

hamiltonian = Hamiltonian(ℓπ, ∂lπ∂θ)

function prior_draw()
x = rand(MvNormal(zeros(d), ones(d)))
return x
end

Target(d, vsyms, hamiltonian, NoTransform, NoTransform, prior_draw)
θ_start = rand(MvNormal(zeros(d), ones(d)))
return CustomTarget(ℓπ, ∂lπ∂θ, θ_start)
end

0 comments on commit 870d9a7

Please sign in to comment.