Skip to content

Commit

Permalink
Merge branch 'master' into patch-3
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai authored Dec 17, 2024
2 parents 0811bef + 8972b98 commit 169dda2
Show file tree
Hide file tree
Showing 31 changed files with 861 additions and 1,098 deletions.
12 changes: 11 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@ permissions:
actions: write
contents: read

# Cancel existing tests on the same PR if a new commit is added to a pull request
concurrency:
group: ${{ github.workflow }}-${{ github.ref || github.run_id }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
test:
runs-on: ${{ matrix.runner.os }}
strategy:
fail-fast: false

matrix:
runner:
# Current stable version
Expand Down Expand Up @@ -58,6 +65,9 @@ jobs:
os: macos-latest
arch: aarch64
num_threads: 2
test_group:
- Group1
- Group2

steps:
- uses: actions/checkout@v4
Expand All @@ -73,7 +83,7 @@ jobs:

- uses: julia-actions/julia-runtest@v1
env:
GROUP: All
GROUP: ${{ matrix.test_group }}
JULIA_NUM_THREADS: ${{ matrix.runner.num_threads }}

- uses: julia-actions/julia-processcoverage@v1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "docs", "test", "test/turing"])'
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "docs", "test"])'
2 changes: 0 additions & 2 deletions .github/workflows/JuliaPre.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,3 @@ jobs:
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
GROUP: DynamicPPL
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.31.4"
version = "0.32.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -29,6 +29,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
Expand All @@ -37,6 +38,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMooncakeExt = ["Mooncake"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
Expand All @@ -55,6 +57,7 @@ Distributions = "0.25"
DocStringExtensions = "0.9"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
JET = "0.9"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -18,6 +19,7 @@ Documenter = "1"
DocumenterMermaid = "0.1"
FillArrays = "0.13, 1"
ForwardDiff = "0.10"
JET = "0.9"
LogDensityProblems = "2"
MCMCChains = "5, 6"
StableRNGs = "1"
29 changes: 23 additions & 6 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,20 +265,24 @@ AbstractVarInfo

But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary.

#### `VarInfo`
For constructing the "default" typed and untyped varinfo types used in DynamicPPL (see [the section on varinfo design](@ref "Design of `VarInfo`") for more on this), we have the following two methods:

```@docs
VarInfo
TypedVarInfo
DynamicPPL.untyped_varinfo
DynamicPPL.typed_varinfo
```

One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form.
#### `VarInfo`

```@docs
link!
invlink!
VarInfo
TypedVarInfo
```

One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [transformation page](internals/transformations.md).
The [Transformations section below](#Transformations) describes the methods used for this.
In the specific case of `VarInfo`, it keeps track of whether samples have been transformed by setting flags on them, using the following functions.

```@docs
set_flag!
unset_flag!
Expand Down Expand Up @@ -425,6 +429,19 @@ DynamicPPL.loadstate
DynamicPPL.initialsampler
```

Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.

```@docs
DynamicPPL.default_varinfo
```

There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.typed_varinfo`](@ref) or [`DynamicPPL.untyped_varinfo`](@ref), depending on which supports the model:

```@docs
DynamicPPL.Experimental.determine_suitable_varinfo
DynamicPPL.Experimental.is_suitable_varinfo
```

### [Model-Internal Functions](@id model_internal)

```@docs
Expand Down
4 changes: 1 addition & 3 deletions docs/src/internals/varinfo.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ For example, with the model above we have

```@example varinfo-design
# Type-unstable `VarInfo`
varinfo_untyped = DynamicPPL.untyped_varinfo(
demo(), SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata()
)
varinfo_untyped = DynamicPPL.untyped_varinfo(demo())
typeof(varinfo_untyped.metadata)
```

Expand Down
53 changes: 53 additions & 0 deletions ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
module DynamicPPLJETExt

using DynamicPPL: DynamicPPL
using JET: JET

function DynamicPPL.Experimental.is_suitable_varinfo(
model::DynamicPPL.Model,
context::DynamicPPL.AbstractContext,
varinfo::DynamicPPL.AbstractVarInfo;
only_ddpl::Bool=true,
)
# Let's make sure that both evaluation and sampling doesn't result in type errors.
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, varinfo, context
)
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
# This way we don't just fall back to untyped if the user's code is the issue.
result = if only_ddpl
JET.report_call(f, argtypes; target_modules=(JET.AnyFrameModule(DynamicPPL),))
else
JET.report_call(f, argtypes)
end
return length(JET.get_reports(result)) == 0, result
end

function DynamicPPL.Experimental._determine_varinfo_jet(
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true
)
# First we try with the typed varinfo.
varinfo = DynamicPPL.typed_varinfo(model, context)

# Let's make sure that both evaluation and sampling doesn't result in type errors.
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
model, context, varinfo; only_ddpl
)

if !issuccess
# Useful information for debugging.
@debug "Evaluaton with typed varinfo failed with the following issues:"
@debug result
end

# If we didn't fail anywhere, we return the type stable one.
return if issuccess
varinfo
else
# Warn the user that we can't use the type stable one.
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
DynamicPPL.untyped_varinfo(model, context)
end
end

end
41 changes: 22 additions & 19 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,32 +199,35 @@ include("values_as_in_model.jl")
include("debug_utils.jl")
using .DebugUtils

include("experimental.jl")
include("deprecated.jl")

if !isdefined(Base, :get_extension)
using Requires
end

@static if !isdefined(Base, :get_extension)
# Better error message if users forget to load JET
if isdefined(Base.Experimental, :register_error_hint)
function __init__()
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include(
"../ext/DynamicPPLChainRulesCoreExt.jl"
)
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
"../ext/DynamicPPLEnzymeCoreExt.jl"
)
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(
"../ext/DynamicPPLForwardDiffExt.jl"
)
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
"../ext/DynamicPPLMCMCChainsExt.jl"
)
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
"../ext/DynamicPPLReverseDiffExt.jl"
)
@require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include(
"../ext/DynamicPPLZygoteRulesExt.jl"
)
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
requires_jet =
exc.f === DynamicPPL.Experimental._determine_varinfo_jet &&
length(argtypes) >= 2 &&
argtypes[1] <: Model &&
argtypes[2] <: AbstractContext
requires_jet |=
exc.f === DynamicPPL.Experimental.is_suitable_varinfo &&
length(argtypes) >= 3 &&
argtypes[1] <: Model &&
argtypes[2] <: AbstractContext &&
argtypes[3] <: AbstractVarInfo
if requires_jet
print(
io,
"\n$(exc.f) requires JET.jl to be loaded. Please run `using JET` before calling $(exc.f).",
)
end
end
end
end

Expand Down
104 changes: 104 additions & 0 deletions src/experimental.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
module Experimental

using DynamicPPL: DynamicPPL

# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency.
"""
is_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...)
Check if the `model` supports evaluation using the provided `context` and `varinfo`.
!!! warning
Loading JET.jl is required before calling this function.
# Arguments
- `model`: The model to verify the support for.
- `context`: The context to use for the model evaluation.
- `varinfo`: The varinfo to verify the support for.
# Keyword Arguments
- `only_ddpl`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`.
# Returns
- `issuccess`: `true` if the model supports the varinfo, otherwise `false`.
- `report`: The result of `report_call` from JET.jl.
"""
function is_suitable_varinfo end

# Internal hook for JET.jl to overload.
function _determine_varinfo_jet end

"""
determine_suitable_varinfo(model[, context]; only_ddpl::Bool=true)
Return a suitable varinfo for the given `model`.
See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref).
!!! warning
For full functionality, this requires JET.jl to be loaded.
If JET.jl is not loaded, this function will assume the model is compatible with typed varinfo.
# Arguments
- `model`: The model for which to determine the varinfo.
- `context`: The context to use for the model evaluation. Default: `SamplingContext()`.
# Keyword Arguments
- `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl.
# Examples
```jldoctest
julia> using DynamicPPL.Experimental: determine_suitable_varinfo
julia> using JET: JET # needs to be loaded for full functionality
julia> @model function model_with_random_support()
x ~ Bernoulli()
if x
y ~ Normal()
else
z ~ Normal()
end
end
model_with_random_support (generic function with 2 methods)
julia> model = model_with_random_support();
julia> # Typed varinfo cannot handle this random support model properly
# as using a single execution of the model will not see all random variables.
# Hence, this this model requires untyped varinfo.
vi = determine_suitable_varinfo(model);
┌ Warning: Model seems incompatible with typed varinfo. Falling back to untyped varinfo.
└ @ DynamicPPLJETExt ~/.julia/dev/DynamicPPL.jl/ext/DynamicPPLJETExt.jl:48
julia> vi isa typeof(DynamicPPL.untyped_varinfo(model))
true
julia> # In contrast, a simple model with no random support can be handled by typed varinfo.
@model model_with_static_support() = x ~ Normal()
model_with_static_support (generic function with 2 methods)
julia> vi = determine_suitable_varinfo(model_with_static_support());
julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support()))
true
```
"""
function determine_suitable_varinfo(
model::DynamicPPL.Model,
context::DynamicPPL.AbstractContext=DynamicPPL.SamplingContext();
only_ddpl::Bool=true,
)
# If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that.
return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing
_determine_varinfo_jet(model, context; only_ddpl)
else
# Warn the user.
@warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo."
# Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat).
DynamicPPL.typed_varinfo(model, context)
end
end

end
Loading

0 comments on commit 169dda2

Please sign in to comment.