Skip to content

Commit

Permalink
first try
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed May 14, 2024
1 parent 48487cc commit 0de9e5c
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.26.1"
version = "0.25.1"


[deps]
Expand Down
55 changes: 38 additions & 17 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,15 @@ function AbstractMCMC.step(
model::Model,
sampler::Union{SampleFromUniform,SampleFromPrior},
state=nothing;
use_simplevarinfo=false,
kwargs...,
)
vi = VarInfo()
model(rng, vi, sampler)
if !use_simplevarinfo
vi = VarInfo()
model(rng, vi, sampler)
else
vi = model(rng, SimpleVarInfo{Float64}(OrderedDict()), sampler)
end
return vi, nothing
end

Expand Down Expand Up @@ -97,23 +102,39 @@ end

# initial step: general interface for resuming and
function AbstractMCMC.step(
rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs...
rng::Random.AbstractRNG,
model::Model,
spl::Sampler;
initial_params=nothing,
use_simplevarinfo=false,
kwargs...,
)
# Sample initial values.
vi = default_varinfo(rng, model, spl)

# Update the parameters if provided.
if initial_params !== nothing
vi = initialize_parameters!!(vi, initial_params, spl, model)

# Update joint log probability.
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
# and https://github.com/TuringLang/Turing.jl/issues/1563
# to avoid that existing variables are resampled
vi = last(evaluate!!(model, vi, DefaultContext()))
end
if !use_simplevarinfo
# Sample initial values.
vi = default_varinfo(rng, model, spl)

# Update the parameters if provided.
if initial_params !== nothing
vi = initialize_parameters!!(vi, initial_params, spl, model)

# Update joint log probability.
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
# and https://github.com/TuringLang/Turing.jl/issues/1563
# to avoid that existing variables are resampled
vi = last(evaluate!!(model, vi, DefaultContext()))
end

return initialstep(rng, model, spl, vi; initial_params, kwargs...)
else
vi = last(DynamicPPL.evaluate!!(model, SimpleVarInfo{Float64}(OrderedDict()), SamplingContext(rng, SampleFromPrior(), DefaultContext())))

if initial_params !== nothing
vi = initialize_parameters!!(vi, initial_params, spl, model)
vi = last(evaluate!!(model, vi, DefaultContext()))
end

return initialstep(rng, model, spl, vi; initial_params, kwargs...)
return initialstep(rng, model, spl, vi; initial_params, kwargs...)
end
end

"""
Expand Down
4 changes: 4 additions & 0 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,10 @@ function dot_assume(
return value, lp, vi
end

function updategid!(vi::SimpleOrThreadSafeSimple, vn::VarName, spl::Sampler)
return nothing
end

# NOTE: We don't implement `settrans!!(vi, trans, vn)`.
function settrans!!(vi::SimpleVarInfo, trans)
return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation())
Expand Down

0 comments on commit 0de9e5c

Please sign in to comment.