Skip to content

Commit

Permalink
fix docs
Browse files Browse the repository at this point in the history
marius311 committed Apr 23, 2024
1 parent d7a7ff6 commit 67dac79
Showing 3 changed files with 10 additions and 33 deletions.
9 changes: 1 addition & 8 deletions .github/workflows/tests_and_docs.yml
Original file line number Diff line number Diff line change
@@ -12,21 +12,14 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
julia-version: ['1.7', '1.8', '1.9', '~1.10.0-0']
julia-version: ['1.7', '1.8', '1.9', '1.10']
threads: ['1', '2']
fail-fast: false
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.julia-version }}
- uses: actions/setup-python@v2
with:
python-version: '3.8'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install matplotlib
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
32 changes: 8 additions & 24 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -33,9 +33,8 @@ First, load up the packages we'll need:

```@example 1
using MuseInference, Turing
using AbstractDifferentiation, Dates, LinearAlgebra, Printf, PyPlot, Random, Zygote
using AbstractDifferentiation, Dates, LinearAlgebra, Printf, Plots, Random, Zygote
Turing.setadbackend(:zygote)
PyPlot.ioff() # hide
using Logging # hide
Logging.disable_logging(Logging.Info) # hide
Turing.AdvancedVI.PROGRESS[] = false # hide
@@ -85,7 +84,7 @@ nothing # hide
We next compute the MUSE estimate for the same problem. To reach the same Monte Carlo error as HMC, the number of MUSE simulations should be the same as the effective sample size of the chain we just ran. This is:

```@example 1
nsims = round(Int, ess_rhat(chain)[:θ,:ess])
nsims = round(Int, ess(chain)[:θ,:ess])
```

Running the MUSE estimate,
@@ -97,29 +96,14 @@ muse_result = @time muse(model, 0; nsims, get_covariance=true)
nothing # hide
```

Lets also try mean-field variational inference (MFVI) to compare to another approximate method.
Now let's plot the different estimates. In this case, MUSE gives a nearly perfect answer in a fraction of the time.

```@example 1
Random.seed!(4)
vi(model, ADVI(10, 10)) # warmup # hide
t_vi = @time @elapsed vi_result = vi(model, ADVI(10, 1000))
nothing # hide
```

Now let's plot the different estimates. In this case, MUSE gives a nearly perfect answer at a fraction of the computational cost. MFVI struggles in both speed and accuracy by comparison.

```@example 1
figure(figsize=(6,5)) # hide
axvline(0, c="k", ls="--", alpha=0.5)
hist(collect(chain["θ"][:]), density=true, bins=15, label=@sprintf("HMC (%.1f seconds)", chain.info.stop_time - chain.info.start_time))
histogram(collect(chain["θ"][:]), normalize=:pdf, bins=10, label=@sprintf("HMC (%.1f seconds)", chain.info.stop_time - chain.info.start_time))
θs = range(-0.5,0.5,length=1000)
plot(θs, pdf.(muse_result.dist, θs), label=@sprintf("MUSE (%.1f seconds)", (muse_result.time / Millisecond(1000))))
plot(θs, pdf.(Normal(vi_result.dist.m[1], vi_result.dist.σ[1]), θs), label=@sprintf("MFVI (%.1f seconds)", t_vi))
legend()
xlabel(L"\theta")
ylabel(L"\mathcal{P}(\theta\,|\,x)")
title("2048-dimensional noisy funnel")
gcf() # hide
plot!(θs, pdf.(muse_result.dist, θs), label=@sprintf("MUSE (%.1f seconds)", (muse_result.time / Millisecond(1000))), lw=2)
vline!([0], c=:black, ls=:dash, alpha=0.5, label=nothing)
plot!(xlabel="θ", ylabel="P(θ|x)", title="2048-dimensional noisy funnel")
```

The timing[^1] difference is indicative of the speedups over HMC that are possible. These get even more dramatic as we increase dimensionality, which is why MUSE really excels on high-dimensional problems.
@@ -180,7 +164,7 @@ prob = SimpleMuseProblem(
function logPrior(θ)
-θ^2/(2*3^2)
end;
autodiff = AD.ZygoteBackend()
autodiff = AbstractDifferentiation.ZygoteBackend()
)
nothing # hide
```
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@ MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
MeasureTheory = "eadaa1a4-d27c-401d-8699-e962e1bbc33b"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Soss = "8ce77f84-9b61-11e8-39ff-d17a774bf41c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

0 comments on commit 67dac79

Please sign in to comment.