Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
efmanu committed Feb 26, 2021
1 parent 433b9c7 commit 6823736
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 38 deletions.
8 changes: 4 additions & 4 deletions docs/src/compare.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ chm = sample(mdl, spl, 100000; param_names=["μ", "σ"], chain_type=Chains)
#define MCMC sampling algorithm
alg = [MH()]
sample_alg =Dict(
1 => [1, Normal(2.0,3.0)],
2 => [1, Normal(3.0,3.0)]
1 => [1, 1, Normal(2.0,3.0)],
2 => [1, 1, Normal(3.0,3.0)]
)

# Sample from the posterior using Gibbs sampler.
Expand Down Expand Up @@ -84,8 +84,8 @@ chm = sample(mdl1, spl1, 100000; param_names=["μ", "σ"], chain_type=Chains)
#define MCMC sampling algorithm
alg = [MH()]
sample_alg =Dict(
1 => [1, Normal(1.0,5.0)],
2 => [1, Normal(0.0,5.0)]
1 => [1, 1, Normal(1.0,5.0)],
2 => [1, 1, Normal(0.0,5.0)]
)

# Sample from the posterior using Gibbs sampler.
Expand Down
12 changes: 6 additions & 6 deletions docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ function logJoint(params)
end
alg = [MH()]
sample_alg =Dict(
1 => [1, Normal(2.0,3.0)],
2 => [1, Normal(3.0,3.0)]
1 => [1, 1, Normal(2.0,3.0)],
2 => [1, 1, Normal(3.0,3.0)]
)
# Sample from the posterior using Gibbs sampler.
chn = GibbsSampler.gibbs(alg, sample_alg, logJoint;itr = 10000, chain_type = :mcmcchain)
Expand All @@ -51,8 +51,8 @@ The `adNUTS()` struct defined with [GibbsSampler.jl](https://github.com/efmanu/G
#select MCMC sampler as vector with adNUTS() struct with same length of proposal distribution
alg = [adNUTS()]
sample_alg =Dict(
1 => [1],
2 => [1]
1 => [1,1],
2 => [1,1]
)

# Sample from the posterior using Gibbs sampler.
Expand All @@ -65,8 +65,8 @@ chn = GibbsSampler.gibbs(alg, sample_alg, logJoint;itr = 10000, chain_type = :mc
#select MCMC sampler as vector the same length of proposal distribution
alg = [adNUTS(), MH()]
sample_alg =Dict(
1 => [1],
2 => [2, Normal(0.0,1.0)]
1 => [1,1],
2 => [2, 1, Normal(0.0,1.0)]
)

# Sample from the posterior using Gibbs sampler.
Expand Down
46 changes: 22 additions & 24 deletions src/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,64 +8,62 @@
To generate posterior samples using Gibbs sampling algorithm
# Inputs
- alg :MCMC algorithms based on structs defined in this package as vector eg: alg = [MH()]
- sample_alg :A dictionary maps `alg` to parameter groups index and it contains proposal distribution
- `alg` :MCMC algorithms based on structs defined in this package as vector eg: `alg = [MH()]`
- `sample_alg` :A dictionary maps `alg` to parameter groups index and it contains proposal distribution
if required by the sampling algorithm.
Eg: sample_alg =Dict(
1 => [1, Normal(2.0,3.0)],
2 => [1, Normal(3.0,3.0)]
)
Eg: `sample_alg =Dict(1 => [1, Normal(2.0,3.0)],2 => [1, Normal(3.0,3.0)])`
Here key is the paramter group and first index in the value (Vector) maps to `alg` index. Second index in the `sample_alg`
is the proposal distribution. This not mandatory.
- logJoint :Log PDF as a function
- `logJoint` :Log PDF as a function
# Keyword Arguments
- itr :Number of iterations
- burn_in :Burn in from samples
- chain_type :Sample chain type. default value is `:default`. Samples chains formated using `MCMCChain.jl`
- `itr` :Number of iterations
- `burn_in` :Burn in from samples
- `chain_type` :Sample chain type. default value is `:default`. Samples chains formated using `MCMCChain.jl`
by choosing `chain_type` as `:mcmcchain`
- progress :To show the sampling progress. Default value is `true`.
- `progress` :To show the sampling progress. Default value is `true`.
# Output
- chn :Generated samples
- `chn ` :Generated samples
"""
function gibbs(alg, sample_alg, logJoint::Function;
revt = [reverse_transform for _ in 1:length(sample_alg)],
itr = 100, burn_in = Int(round(itr*0.2)),
chain_type=:default, progress = true
) where {T <: Distribution}
states = Dict()
lens = length(sample_alg)
param_val = copy(rand(lens))
if progress
prog = Progress(itr, dt=0.5,
barglyphs=BarGlyphs('|','', ['' ,'' ,'' ,'' ,'' ,'', ''],' ','|',),
barlen=50)
end
val = check_sample_alg(alg, sample_alg)
states = Dict()
lens = length(sample_alg)
param_val = copy(rand.(val))
states["itr_$(0)"] = copy(param_val)
for i in 1:itr
if progress
ProgressMeter.next!(prog, showvalues = [(:iter,i), (:samples, param_val)])
end
states["itr_$i"] = copy(param_val)
end
for idx in 1:lens
itr_loc = lens*(i-1)+idx
function step_wrapper(new_param)
nw_param_val = [param_val[1:idx-1]..., revt[idx](new_param), param_val[idx+1:end]...]
return logJoint(nw_param_val)
end
initial_θ = states["itr_$(itr_loc-1)"][idx]

if i == 1
initial_θ = rand(val[idx])
else
initial_θ = states["itr_$(i-1)"][idx]
end
param_val[idx] = revt[idx](proposal_sampling(step_wrapper, initial_θ, val[idx], alg[sample_alg[idx][1]]))

states["itr_$i"][idx] = param_val[idx]
states["itr_$(itr_loc)"] = copy(param_val)
end
end
if progress
ProgressMeter.finish!(prog)
end
return format_chain(states, burn_in, itr, chain_type=chain_type)
delete!(states, 0)
return format_chain(states, Int(round(itr*lens*0.2)), itr*lens, chain_type=chain_type)
end

function proposal_sampling(step_wrapper::Function, initial_θ,
Expand Down
12 changes: 8 additions & 4 deletions src/libs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@ function check_sample_alg(alg, sample_alg)
len_s = length(sample_alg)
val = Array{Any}(undef,len_s)
for loc in 1:len_s
if (!isassigned(sample_alg[loc],2)) && (alg[sample_alg[loc][1]] isa MH)
if (!isassigned(sample_alg[loc],3)) && (alg[sample_alg[loc][1]] isa MH)
throw("Error: MH sampler requires a proposal distribution")
elseif ((alg[sample_alg[loc][1]] isa adHMC) || (alg[sample_alg[loc][1]] isa adNUTS)) && ((!isassigned(sample_alg[loc],2)))
val[loc] = Normal(0.0,1.0)
elseif ((alg[sample_alg[loc][1]] isa adHMC) || (alg[sample_alg[loc][1]] isa adNUTS)) && ((!isassigned(sample_alg[loc],3)))
if sample_alg[loc][2] == 1
val[loc] = Normal(0.0,1.0)
else
val[loc] = MvNormal(zeros(sample_alg[loc][2]),1.0)
end
else
val[loc] = sample_alg[loc][2]
val[loc] = sample_alg[loc][3]
end
end
return val
Expand Down

0 comments on commit 6823736

Please sign in to comment.