Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transpilation of pure WinBUGS code when reimplementing Prior and Posterior Prediction #2148

Closed
CMoebus opened this issue Dec 19, 2023 · 6 comments

Comments

@CMoebus
Copy link

CMoebus commented Dec 19, 2023

Hi,
I am a newbee to Turing.jl. So I try to reimplement the WinBUGS scripts of Lee & Wagenmakers' book BAYESIAN COGNITIVE MODELING. Now, I am stuck with the problem 'Prior and Posterior Prediction' in ch3.4. I want to stay close to the WinBUGS code with only 5 code lines. I tried to not export any values out of the model macro. But my chain and the plots generate infos which are completely misleading. I looked around and found two guys semihcanaktepe; quangtiencs doing similar reimplementations. But the circumvented the pure WinBUGS equivalent and wrote more verbose code.
For documentation purpose I attach the code below:

begin  # cf. BAYESIAN COGNITIVE MODELING 
    #     Lee & Wagenmakers, 2013, ch.3.4, Prior and Posterior Prediction
    #---------------------------------------------------------------------------
    using Turing, MCMCChains
    using LaTeXStrings
    using StatsPlots, Random
    #---------------------------------------------------------------------------
    @model function priorPosteriorPredictive(n; k=missing)
        #----------------------------------------------------
        # prior on rate θ
        θ ~ Beta(1, 1)
        #----------------------------------------------------
        # likelihood of observed data
        k ~ Binomial(n, θ)
        #----------------------------------------------------
        # prior predictive
        θPriorPred ~ Beta(1, 1)
        kPriorPred ~ Binomial(n, θPriorPred)
        #----------------------------------------------------
        # posterior predictive
        return kPostPred ~ Binomial(n, θ)
        #----------------------------------------------------
    end # function priorPosteriorPredictive
    #---------------------------------------------------------------------------
    modelPriorPredictive = let k = 1
        n = 15
        priorPosteriorPredictive(n)
    end # let
    #---------------------------------------------------------------------------
    chainPriorPredictive =                          # chain is ok
        let iterations = 3000
            sampler = Prior()
            sample(modelPriorPredictive, sampler, iterations)
        end # let       
    #---------------------------------------------------------------------------
    describe(chainPriorPredictive)                  # results are ok
    #---------------------------------------------------------------------------
    plot(chainPriorPredictive; normalize=true)      # plots are ok
    #---------------------------------------------------------------------------
    modelPosteriorPredictive = let k = 1
        datum = k
        n = 15
        # priorPosteriorPredictive(n)          ,# prior predictive without datum
        priorPosteriorPredictive(n; k=datum)    # posterior predictive including datum
    end # let
    #---------------------------------------------------------------------------
    chainPosteriorPredictive =                      # completely misleading
        let iterations = 3000
            nBurnIn = 1000
            δ = 0.65
            init_ϵ = 0.3
            sampler = NUTS(nBurnIn, δ; init_ϵ=init_ϵ)
            sample(modelPosteriorPredictive, sampler, iterations)
        end # let
    #---------------------------------------------------------------------------
    describe(chainPosteriorPredictive)              # completely misleading
    #---------------------------------------------------------------------------
    plot(chainPosteriorPredictive; normalize=true)  # completely misleading
    #---------------------------------------------------------------------------
end # begin
@sunxd3
Copy link
Member

sunxd3 commented Dec 20, 2023

I think the issue here is that θPriorPred, kPriorPred, and kPostPred throw NUTS off quite a bit.

If the model is written as

using Turing, MCMCChains, StatsPlots

@model function priorPosteriorPredictive(n; k=missing)
    #----------------------------------------------------
    # prior on rate θ
    θ ~ Beta(1, 1)
    #----------------------------------------------------
    # likelihood of observed data
    k ~ Binomial(n, θ)
    #----------------------------------------------------
    # prior predictive
    # θPriorPred ~ Beta(1, 1)
    # kPriorPred ~ Binomial(n, θPriorPred)
    #----------------------------------------------------
    # posterior predictive
    # return kPostPred ~ Binomial(n, θ)
    #----------------------------------------------------
end # function priorPosteriorPredictive

modelPosteriorPredictive = let k = 1
    datum = k
    n = 15
    # priorPosteriorPredictive(n),          # prior predictive without datum
    priorPosteriorPredictive(n; k=datum)    # posterior predictive including datum
end # let

chainPosteriorPredictive =                      # completely misleading
    let iterations = 3000
        nBurnIn = 1000
        δ = 0.65
        init_ϵ = 0.3
        sampler = NUTS(nBurnIn, δ; init_ϵ=init_ϵ)
        sample(modelPosteriorPredictive, sampler, iterations)
    end # let

plot(chainPosteriorPredictive)

(where I commented the three predictive variables)
the plot looks like
plot_31
much better.

Alternatively, importance-sampling based samplers will likely perform better here

@model function priorPosteriorPredictive(n)
    #----------------------------------------------------
    # prior on rate θ
    θ ~ Beta(1, 1)
    #----------------------------------------------------
    # likelihood of observed data
    k ~ Binomial(n, θ)
    #----------------------------------------------------
    # prior predictive
    θPriorPred ~ Beta(1, 1)
    kPriorPred ~ Binomial(n, θPriorPred)
    #----------------------------------------------------
    # posterior predictive
    return kPostPred ~ Binomial(n, θ)
    #----------------------------------------------------
end # function priorPosteriorPredictive

modelPosteriorPredictive = let k = 1
    datum = k
    n = 15
    # priorPosteriorPredictive(n),          # prior predictive without datum
    priorPosteriorPredictive(n)  | (; k=datum)  # posterior predictive including datum
end # let

chainPosteriorPredictive = sample(modelPosteriorPredictive, PG(10), 3000)

plot(chainPosteriorPredictive)

(PG doesn't support models with keyword argument, so I did a simple rewrite with the DynamicPPL.condition syntax)
plot_32

@CMoebus
Copy link
Author

CMoebus commented Dec 21, 2023

Hi sunxd3,
thank you for the advice to use the sampler 'PG' and the condition bar '|'. I always wondered what the error comment "...does not support keyword arguments" meant. Now, with the '|' this error diappeared. Furthermore 'PG' solves the sampling problem. Before your comment I read (https://turinglang.org/v0.30/tutorials/04-hidden-markov-model/) that one possibility is to sample continuous variables with 'HMC' and discrete with 'PG'. So I tried the combined sampler'Gibbs(HMC...), PG(...))'. But I had some difficulties to get it run. Do you have experience with e.g. 'Gibbs(HMC(0.01, 50, ...), PG(120, ...))'
All the best, C.

@sunxd3
Copy link
Member

sunxd3 commented Dec 22, 2023

You can specify which samplers are in charge of which variable(s) like

@model function priorPosteriorPredictive(n)
   θ ~ Beta(1, 1)
   k ~ Binomial(n, θ)
   θPriorPred ~ Beta(1, 1)
   kPriorPred ~ Binomial(n, θPriorPred)
   kPostPred ~ Binomial(n, θ)
   return θPriorPred, kPriorPred, kPostPred
end

model = priorPosteriorPredictive(15) | (; k=1) # creating the model

chn = sample(model, Gibbs(HMC(0.05, 10, ), NUTS(-1, 0.65, :θPriorPred), PG(100, :k, :kPriorPred, :kPostPred)), 1000) # use HMC for `θ`, NUTS for `θPriorPred`, and PG for the rest.

gives

Chains MCMC chain (1000×5×1 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 51.93 seconds
Compute duration  = 51.93 seconds
parameters        = θ, θPriorPred, kPriorPred, kPostPred
internals         = lp

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

           θ    0.1121    0.0758    0.0068   100.5133   141.6987    0.9997        1.9354
  θPriorPred    0.5095    0.2863    0.0687    19.3442   164.4584    1.0147        0.3725
  kPriorPred    7.6210    4.6185    1.0676    19.8468        NaN    1.0166        0.3822
   kPostPred    1.6550    1.5743    0.1143   169.2209   195.1175    1.0025        3.2584

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           θ    0.0170    0.0537    0.0989    0.1555    0.2934
  θPriorPred    0.0377    0.2466    0.5016    0.7709    0.9697
  kPriorPred    0.0000    3.0000    8.0000   12.0000   15.0000
   kPostPred    0.0000    0.0000    1.0000    2.0000    5.0000

@sunxd3
Copy link
Member

sunxd3 commented Dec 22, 2023

@CMoebus we also have a package within Turing ecosystem that supports BUGS language directly, https://github.com/TuringLang/JuliaBUGS.jl, but currently in development and not feature complete.

We'll appreciate it if you give it a try and report issues as there are definitely a lot, but we'll try to fix them ASAP.

@CMoebus
Copy link
Author

CMoebus commented Dec 22, 2023

@sunxd3: Thank you again. Thank you for inviting me to become a JuliaBUGS.jl tester. Just a few weeks ago I started transpiling WinBugs scripts into pure Turing.jl. I liked the declarative, math-oriented style of BUGS. But at the same time, it is tedious if you need some calculations outside the BUGS language. A few years ago I switched to WebPPL. I liked its functional style. But Turing.jl and its embedding in Julia seem to be more promising.

@sunxd3
Copy link
Member

sunxd3 commented Jan 8, 2024

@CMoebus sorry for the late reply.

I liked the declarative, math-oriented style of BUGS. But at the same time, it is tedious if you need some calculations outside the BUGS language.

One of the goal of the JuliaBUGS project is to make this much easier and give user access to other Julia packages.

But Turing.jl and its embedding in Julia seem to be more promising.

As a maintainer and user of Turing, thanks for the support!

@sunxd3 sunxd3 closed this as completed Jan 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants