Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setup formatter #17

Merged
merged 6 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

style = "blue"
align_assignment = true
align_struct_field = true
align_pair_arrow = true
align_matrix = true
align_conditional = true
26 changes: 26 additions & 0 deletions .github/workflows/Format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Format suggestions

on:
pull_request:

concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
format:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: 1
- run: |
julia -e 'using Pkg; Pkg.add("JuliaFormatter")'
julia -e 'using JuliaFormatter; format("."; verbose=true)'
- uses: reviewdog/action-suggester@v1
with:
tool_name: JuliaFormatter
fail_on_error: true
7 changes: 2 additions & 5 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@ makedocs(;
"Univariate Slice Sampling" => "univariate_slice.md",
"Meta Multivariate Samplers" => "meta_multivariate.md",
"Latent Slice Sampling" => "latent_slice.md",
"Gibbsian Polar Slice Sampling" => "gibbs_polar.md"
"Gibbsian Polar Slice Sampling" => "gibbs_polar.md",
],
)

deploydocs(;
repo="github.com/TuringLang/SliceSampling.jl",
push_preview=true
)
deploydocs(; repo="github.com/TuringLang/SliceSampling.jl", push_preview=true)
59 changes: 30 additions & 29 deletions ext/SliceSamplingTuringExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
using Random
using SliceSampling
using Turing
# using Turing: Turing, Experimental
# using Turing: Turing, Experimental
else
using ..LogDensityProblemsAD
using ..Random
Expand All @@ -17,46 +17,47 @@

# Required for using the slice samplers as `externalsampler`s in Turing
# begin
Turing.Inference.getparams(
::Turing.DynamicPPL.Model,
sample::SliceSampling.Transition
) = sample.params
function Turing.Inference.getparams(
::Turing.DynamicPPL.Model, sample::SliceSampling.Transition
)
return sample.params
end
# end

# Required for using the slice samplers as `Experimental.Gibbs` samplers in Turing
# begin
Turing.Inference.getparams(
::Turing.DynamicPPL.Model,
state::SliceSampling.UnivariateSliceState
) = state.transition.params
function Turing.Inference.getparams(
::Turing.DynamicPPL.Model, state::SliceSampling.UnivariateSliceState
)
return state.transition.params
end

Turing.Inference.getparams(
::Turing.DynamicPPL.Model,
state::SliceSampling.GibbsState
) = state.transition.params
function Turing.Inference.getparams(
::Turing.DynamicPPL.Model, state::SliceSampling.GibbsState
)
return state.transition.params
end

Turing.Inference.getparams(
::Turing.DynamicPPL.Model,
state::SliceSampling.HitAndRunState
) = state.transition.params
function Turing.Inference.getparams(

Check warning on line 41 in ext/SliceSamplingTuringExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SliceSamplingTuringExt.jl#L41

Added line #L41 was not covered by tests
::Turing.DynamicPPL.Model, state::SliceSampling.HitAndRunState
)
return state.transition.params

Check warning on line 44 in ext/SliceSamplingTuringExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SliceSamplingTuringExt.jl#L44

Added line #L44 was not covered by tests
end

Turing.Experimental.gibbs_requires_recompute_logprob(
function Turing.Experimental.gibbs_requires_recompute_logprob(
model_dst,
::Turing.DynamicPPL.Sampler{
<: Turing.Inference.ExternalSampler{
<: SliceSampling.AbstractSliceSampling, A, U
}
<:Turing.Inference.ExternalSampler{<:SliceSampling.AbstractSliceSampling,A,U}
},
sampler_src,
state_dst,
state_src
) where {A,U} = false
state_src,
) where {A,U}
return false
end
# end

function SliceSampling.initial_sample(
rng::Random.AbstractRNG,
ℓ ::Turing.LogDensityFunction
)
function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction)
model = ℓ.model
spl = Turing.SampleFromUniform()
vi = Turing.VarInfo(rng, model, spl)
Expand All @@ -67,14 +68,14 @@
if init_attempt_count == 10
@warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `initial_params` keyword"
end

# NOTE: This will sample in the unconstrained space.
vi = last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromUniform()))
θ = vi[spl]

init_attempt_count += 1
end
θ
return θ
end

end
55 changes: 26 additions & 29 deletions src/SliceSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
- `lp::Real`: Log-target density of the samples.
- `info::NamedTuple`: Named tuple containing information about the transition.
"""
struct Transition{P, L <: Real, I <: NamedTuple}
struct Transition{P,L<:Real,I<:NamedTuple}
"current state of the slice sampling chain"
params::P

Expand All @@ -53,47 +53,44 @@
- `model`: The target `LogDensityProblem`.
"""
function initial_sample(::Random.AbstractRNG, ::Any)
error(
return error(

Check warning on line 56 in src/SliceSampling.jl

View check run for this annotation

Codecov / codecov/patch

src/SliceSampling.jl#L56

Added line #L56 was not covered by tests
"`initial_sample` is not implemented but an initialization wasn't provided. ",
"Consider supplying an initialization to `initial_params`."
"Consider supplying an initialization to `initial_params`.",
)
end

# If target is from `LogDensityProblemsAD`, unwrap target before calling `initial_sample`.
# This is necessary since Turing wraps `DynamicPPL.Model`s when passed to an `externalsampler`.
initial_sample(
rng::Random.AbstractRNG,
wrap::LogDensityProblemsAD.ADGradientWrapper
) = initial_sample(rng, parent(wrap))
function initial_sample(
rng::Random.AbstractRNG, wrap::LogDensityProblemsAD.ADGradientWrapper
)
return initial_sample(rng, parent(wrap))
end

function exceeded_max_prop(max_prop::Int)
error("Exceeded maximum number of proposal $(max_prop), ",
"which indicates an acceptance rate less than $(1/max_prop*100)%. ",
"A quick fix is to increase `max_prop`, ",
"but an acceptance rate that is too low often indicates that there is a problem. ",
"Here are some possible causes:\n",
"- The model might be broken or degenerate (most likely cause).\n",
"- The tunable parameters of the sampler are suboptimal.\n",
"- The initialization is pathologic. (try supplying a (different) `initial_params`)\n",
"- There might be a bug in the sampler. (if this is suspected, file an issue to `SliceSampling`)\n"
)
return error(
"Exceeded maximum number of proposal $(max_prop), ",
"which indicates an acceptance rate less than $(1/max_prop*100)%. ",
"A quick fix is to increase `max_prop`, ",
"but an acceptance rate that is too low often indicates that there is a problem. ",
"Here are some possible causes:\n",
"- The model might be broken or degenerate (most likely cause).\n",
"- The tunable parameters of the sampler are suboptimal.\n",
"- The initialization is pathologic. (try supplying a (different) `initial_params`)\n",
"- There might be a bug in the sampler. (if this is suspected, file an issue to `SliceSampling`)\n",
)
end

## Univariate Slice Sampling Algorithms
export Slice, SliceSteppingOut, SliceDoublingOut

abstract type AbstractUnivariateSliceSampling <: AbstractSliceSampling end
abstract type AbstractUnivariateSliceSampling <: AbstractSliceSampling end

accept_slice_proposal(
::AbstractSliceSampling,
::Any,
::Real,
::Real,
::Real,
::Real,
::Real,
::Real,
) = true
function accept_slice_proposal(
::AbstractSliceSampling, ::Any, ::Real, ::Real, ::Real, ::Real, ::Real, ::Real
)
return true
end

function find_interval end

Expand All @@ -103,7 +100,7 @@
include("univariate/doublingout.jl")

## Multivariate slice sampling algorithms
abstract type AbstractMultivariateSliceSampling <: AbstractSliceSampling end
abstract type AbstractMultivariateSliceSampling <: AbstractSliceSampling end

# Meta Multivariate Samplers
export RandPermGibbs, HitAndRun
Expand Down
Loading
Loading