Skip to content

Commit

Permalink
Using MCMCDebugging for Geweke test (#208)
Browse files Browse the repository at this point in the history
* using MCMCDebugging for Geweke test

* add MCMCDebugging to test dep

* add Plots to test dep

* add Plots to test dep

* only enable Geweke test for GitHub Action

* bugfix for GitHub Action yml

* log ENV

* bugfix for ENV check

* Geweke not on 1.0

* Geweke not on <=1.1

* Geweke not on <=1.2

* excluding all 32bit machine due to OOM in CI

* turn off progress for Geweke test
  • Loading branch information
xukai92 authored Aug 5, 2020
1 parent fb5632a commit 07236bf
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 6 deletions.
14 changes: 10 additions & 4 deletions .github/workflows/AHMC-CI.yml → .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: AdvancedHMC-CI
name: CI

on:
push:
Expand All @@ -13,8 +13,7 @@ jobs:
strategy:
matrix:
version:
- '1.0'
- '1'
- '1.3'
- 'nightly'
os:
- ubuntu-latest
Expand All @@ -24,13 +23,20 @@ jobs:
- x86
- x64
exclude:
- os: ubuntu-latest
arch: x86
- os: macOS-latest
arch: x86
- os: windows-latest
arch: x86
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-runtest@latest
- name: Run tests
uses: julia-actions/julia-runtest@latest
env:
GEWEKE_TEST: 1
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.vscode
.history
Manifest.toml
test/Project.toml
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCDebugging = "6d524b87-5f90-4494-b601-374a5b87a94b"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Distributed", "Distributions", "ForwardDiff", "Test", "Turing", "UnicodePlots", "Bijectors", "OrdinaryDiffEq", "Zygote"]
test = ["Distributed", "Distributions", "ForwardDiff", "Plots", "MCMCDebugging", "Test", "Turing", "UnicodePlots", "Bijectors", "OrdinaryDiffEq", "Zygote"]
39 changes: 39 additions & 0 deletions test/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,42 @@ function ℓπ_gdemo(θ)
loglikelihood = logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0)
return logprior + loglikelihood
end

using Distributions: MvNormal
import Turing

Turing.@model function mvntest(θ, x)
θ ~ MvNormal(zeros(D), 2)
x ~ Normal(sum(θ), 1)
return θ, x
end

function get_primitives(x, modelgen)
spl_prior = Turing.SampleFromPrior()
function ℓπ(θ)
vi = Turing.VarInfo(model)
vi[spl_prior] = θ
model(vi, spl_prior)
Turing.getlogp(vi)
end
adbackend = Turing.Core.ForwardDiffAD{40}
alg_ad = Turing.HMC{adbackend}(0.1, 1)
model = modelgen(missing, x)
vi = Turing.VarInfo(model)
spl = Turing.Sampler(alg_ad, model)
Turing.Core.link!(vi, spl)
∂ℓπ∂θ = θ -> Turing.Core.gradient_logp(adbackend(), θ, vi, model, spl)
θ₀ = Turing.VarInfo(model)[Turing.SampleFromPrior()]
return ℓπ, ∂ℓπ∂θ, θ₀
end

function rand_θ_given(x, modelgen, metric, τ; n_samples=20)
ℓπ, ∂ℓπ∂θ, θ₀ = get_primitives(x, modelgen)
h = Hamiltonian(metric, ℓπ, ∂ℓπ∂θ)
samples, stats = sample(h, τ, θ₀, n_samples; verbose=false, progress=false)
s = samples[end]
return length(s) == 1 ? s[1] : s
end

# Test function
g(θ, x) = cat(θ, x; dims=1)
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using Distributed, Test

println("Envronment variables for testing")
println(ENV)

@testset "AdvancedHMC" begin
tests = [
"metric",
Expand Down
9 changes: 8 additions & 1 deletion test/sampler.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Allow pass --progress when running this script individually to turn on progress meter
const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false

using Test, AdvancedHMC, LinearAlgebra, Random
using Test, AdvancedHMC, LinearAlgebra, Random, MCMCDebugging, Plots
using Parameters: reconstruct
using Statistics: mean, var, cov
unicodeplots()
include("common.jl")

θ_init = rand(MersenneTwister(1), D)
Expand Down Expand Up @@ -59,6 +60,12 @@ end
Random.seed!(1)
samples, stats = sample(h, τ, θ_init, n_samples; verbose=false, progress=PROGRESS)
@test mean(samples[n_adapts+1:end]) zeros(D) atol=RNDATOL
if "GEWEKE_TEST" in keys(ENV) && ENV["GEWEKE_TEST"] == "1"
res = perform(GewekeTest(5_000), mvntest, x -> rand_θ_given(x, mvntest, metric, τ), g; progress=false)
p = plot(res, mvntest)
display(p)
println()
end
end

# Skip adaptation tests with tempering
Expand Down

0 comments on commit 07236bf

Please sign in to comment.