Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using JET.jl to determine if typed varinfo is okay #728

Merged
merged 65 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
361c45e
fixed calls to `to_linked_internal_transform`
torfjelde Nov 27, 2024
545cfab
fixed incorrect call to `acclogp_assume!!`
torfjelde Nov 28, 2024
abd432f
added `determine_varinfo` and an implementation using JET for this
torfjelde Nov 28, 2024
5cd9009
Merge remote-tracking branch 'origin/torfjelde/minor-bugfixes' into t…
torfjelde Nov 28, 2024
d503c3c
made filtering for errors only in the tilde pipeline optional
torfjelde Nov 28, 2024
acb2cb0
formatting
torfjelde Nov 28, 2024
902641f
fixed incorrect comment
torfjelde Nov 28, 2024
d93006b
added test for the branch we were currently imssing
torfjelde Nov 28, 2024
64ff18a
formatting
torfjelde Nov 28, 2024
90c2df0
Merge branch 'master' into torfjelde/minor-bugfixes
torfjelde Nov 28, 2024
a94dbd5
Merge branch 'torfjelde/minor-bugfixes' into torfjelde/determine-varinfo
torfjelde Nov 28, 2024
67723d6
Merge branch 'master' into torfjelde/determine-varinfo
torfjelde Nov 28, 2024
3d8ad44
renamed `determine_varinfo` to `determine_suitable_varinfo` with
torfjelde Nov 29, 2024
c06b080
removed now-redundant init used with Requires.jl, since this is no
torfjelde Nov 29, 2024
d1a5bab
`determine_suitable_varinfo` now only performs checks using the
torfjelde Nov 29, 2024
5370e55
formatting
torfjelde Nov 29, 2024
dd408ee
updated error hint
torfjelde Nov 29, 2024
c253e9b
added def of `untyped_varinfo` which takes just `model` and `context`
torfjelde Nov 29, 2024
891b46a
fixed incorrect call to `untyped_varinfo` in `_determine_varinfo_jet`
torfjelde Nov 29, 2024
686ed9f
explicitly call `typed_varinfo` when we want such a thing rather than
torfjelde Nov 29, 2024
d7d785a
`typed_varinfo` and `untyped_varinfo` handles wrapping passed context
torfjelde Nov 29, 2024
dda56ec
use `determine_suitable_varinfo` in `LogDensityFunction` when not con…
torfjelde Nov 29, 2024
46ea18c
formatting
torfjelde Nov 29, 2024
c20ede3
formatting
torfjelde Nov 29, 2024
0b3c36e
fixed a bug in `DynamicPPLJETExt.is_tilde_instance`
torfjelde Nov 29, 2024
f76658a
updated docs
torfjelde Nov 29, 2024
690b017
Update docs/src/internals/varinfo.md
torfjelde Nov 29, 2024
97258f3
added back def of `untyped_varinfo` that shouldn't have been removed +
torfjelde Nov 29, 2024
95bb3a9
Merge remote-tracking branch 'origin/torfjelde/determine-varinfo' int…
torfjelde Nov 29, 2024
155ce66
Merge branch 'master' into torfjelde/determine-varinfo
torfjelde Nov 29, 2024
4998d08
minor codestyle improvement
torfjelde Nov 29, 2024
5c27677
temporary hack to debug what's happening
torfjelde Nov 30, 2024
99d4df7
more debugging
torfjelde Nov 30, 2024
3b9a9eb
use the `target_modules` kwarg in `report_call` instead of manually
torfjelde Nov 30, 2024
99fb153
formatting
torfjelde Nov 30, 2024
3588597
more debugging
torfjelde Nov 30, 2024
7a302e5
Merge remote-tracking branch 'origin/torfjelde/determine-varinfo' int…
torfjelde Nov 30, 2024
040cb54
more debugging
torfjelde Nov 30, 2024
889c370
Merge branch 'master' into torfjelde/determine-varinfo
torfjelde Nov 30, 2024
c98fe49
more debugging: try with new bijectors.jl
torfjelde Nov 30, 2024
123b644
formatting
torfjelde Nov 30, 2024
37fabb0
removed the hacky debugging stuff used for the CI
torfjelde Nov 30, 2024
33e5b98
Merge remote-tracking branch 'origin/torfjelde/determine-varinfo' int…
torfjelde Nov 30, 2024
7ddec2c
removed now-redudant filtering methods since we use JET's own filters
torfjelde Nov 30, 2024
b6b4bff
bump Bijectors.jl compat entry to 0.15.1 in test so JET.jl tests pass
torfjelde Nov 30, 2024
e07ecdb
moved the JET.jl-dependent experimental `determine_varinfo` into a
torfjelde Dec 3, 2024
8ba8f82
Merge branch 'master' into torfjelde/determine-varinfo
torfjelde Dec 3, 2024
9ec1556
forgot to add the experimenta.jl file in previous commit
torfjelde Dec 3, 2024
599488b
Merge remote-tracking branch 'origin/torfjelde/determine-varinfo' int…
torfjelde Dec 3, 2024
fa155a4
reverted changes to `default_varinfo` and `LogDensityFunction`
torfjelde Dec 3, 2024
8496968
added a bunch of docs for introduced and existing methods
torfjelde Dec 4, 2024
fd82871
added doctests to `determine_suitable_varinfo`
torfjelde Dec 4, 2024
bb87ba0
added JET.jl as a dep to docs
torfjelde Dec 4, 2024
62c5cd1
fixed referencing in docs
torfjelde Dec 4, 2024
55dc91e
fixed docstring
torfjelde Dec 4, 2024
ae51778
Merge branch 'master' into torfjelde/determine-varinfo
torfjelde Dec 4, 2024
a692ec3
fixed doctest
torfjelde Dec 5, 2024
d5eb404
Merge remote-tracking branch 'origin/torfjelde/determine-varinfo' int…
torfjelde Dec 5, 2024
17b6ec9
Update Project.toml
torfjelde Dec 5, 2024
bfa88b2
applied suggestions from @mhauru
torfjelde Dec 5, 2024
82578cf
fixed doctests
torfjelde Dec 5, 2024
3aad34f
finally fixed doctests
torfjelde Dec 6, 2024
da3eefe
removed unnecessary `typed_varinfo` and `untyped_varinfo` methods
torfjelde Dec 6, 2024
325c5f9
added filter to ignore source of warnings in doctest
torfjelde Dec 6, 2024
4a17e82
Merge branch 'master' into torfjelde/determine-varinfo
Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
Expand All @@ -37,6 +38,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMooncakeExt = ["Mooncake"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
Expand All @@ -55,6 +57,7 @@ Distributions = "0.25"
DocStringExtensions = "0.9"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
JET = "0.9"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -18,6 +19,7 @@ Documenter = "1"
DocumenterMermaid = "0.1"
FillArrays = "0.13, 1"
ForwardDiff = "0.10"
JET = "0.9"
LogDensityProblems = "2"
MCMCChains = "5, 6"
StableRNGs = "1"
20 changes: 20 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,13 @@ AbstractVarInfo

But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary.

For constructing the "default" typed and untyped varinfo types used in DynamicPPL (see [the section on varinfo design](@ref "Design of `VarInfo`") for more on this), we have the following two methods:

```@docs
DynamicPPL.untyped_varinfo
DynamicPPL.typed_varinfo
```

#### `VarInfo`

```@docs
Expand Down Expand Up @@ -425,6 +432,19 @@ DynamicPPL.loadstate
DynamicPPL.initialsampler
```

Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.

```@docs
DynamicPPL.default_varinfo
```

There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.typed_varinfo`](@ref) or [`DynamicPPL.untyped_varinfo`](@ref), depending on which supports the model:

```@docs
DynamicPPL.Experimental.determine_suitable_varinfo
DynamicPPL.Experimental.is_suitable_varinfo
```

### [Model-Internal Functions](@id model_internal)

```@docs
Expand Down
4 changes: 1 addition & 3 deletions docs/src/internals/varinfo.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ For example, with the model above we have

```@example varinfo-design
# Type-unstable `VarInfo`
varinfo_untyped = DynamicPPL.untyped_varinfo(
demo(), SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata()
)
varinfo_untyped = DynamicPPL.untyped_varinfo(demo())
typeof(varinfo_untyped.metadata)
```

Expand Down
53 changes: 53 additions & 0 deletions ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
module DynamicPPLJETExt

using DynamicPPL: DynamicPPL
using JET: JET

function DynamicPPL.Experimental.is_suitable_varinfo(
model::DynamicPPL.Model,
context::DynamicPPL.AbstractContext,
varinfo::DynamicPPL.AbstractVarInfo;
only_ddpl::Bool=true,
)
# Let's make sure that both evaluation and sampling doesn't result in type errors.
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, varinfo, context
)
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
# This way we don't just fall back to untyped if the user's code is the issue.
result = if only_ddpl
JET.report_call(f, argtypes; target_modules=(JET.AnyFrameModule(DynamicPPL),))
else
JET.report_call(f, argtypes)
end
return length(JET.get_reports(result)) == 0, result
end

function DynamicPPL.Experimental._determine_varinfo_jet(
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true
)
# First we try with the typed varinfo.
varinfo = DynamicPPL.typed_varinfo(model, context)

# Let's make sure that both evaluation and sampling doesn't result in type errors.
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
model, context, varinfo; only_ddpl
)

if !issuccess
# Useful information for debugging.
@debug "Evaluaton with typed varinfo failed with the following issues:"
@debug result
end

# If we didn't fail anywhere, we return the type stable one.
return if issuccess
varinfo
else
# Warn the user that we can't use the type stable one.
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
DynamicPPL.untyped_varinfo(model, context)
end
end

end
41 changes: 22 additions & 19 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,32 +199,35 @@
include("debug_utils.jl")
using .DebugUtils

include("experimental.jl")
include("deprecated.jl")

if !isdefined(Base, :get_extension)
using Requires
end

@static if !isdefined(Base, :get_extension)
# Better error message if users forget to load JET
if isdefined(Base.Experimental, :register_error_hint)
function __init__()
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include(
"../ext/DynamicPPLChainRulesCoreExt.jl"
)
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
"../ext/DynamicPPLEnzymeCoreExt.jl"
)
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(
"../ext/DynamicPPLForwardDiffExt.jl"
)
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
"../ext/DynamicPPLMCMCChainsExt.jl"
)
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
"../ext/DynamicPPLReverseDiffExt.jl"
)
@require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include(
"../ext/DynamicPPLZygoteRulesExt.jl"
)
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
requires_jet =
exc.f === DynamicPPL.Experimental._determine_varinfo_jet &&
length(argtypes) >= 2 &&
argtypes[1] <: Model &&
argtypes[2] <: AbstractContext
requires_jet |=
exc.f === DynamicPPL.Experimental.is_suitable_varinfo &&
length(argtypes) >= 3 &&
argtypes[1] <: Model &&
argtypes[2] <: AbstractContext &&
argtypes[3] <: AbstractVarInfo
if requires_jet
print(

Check warning on line 225 in src/DynamicPPL.jl

View check run for this annotation

Codecov / codecov/patch

src/DynamicPPL.jl#L225

Added line #L225 was not covered by tests
io,
"\n$(exc.f) requires JET.jl to be loaded. Please run `using JET` before calling $(exc.f).",
)
end
end
Comment on lines +212 to +230
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could there be some way to test this? I do see that it's tricky. I'm a bit uncomfortable having this in without any testing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah was thinking the same. We could put in a test strictly before loading JET.jl ofc. It's a bit messy, but seems like the best way 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does wrapping the tests in separate modules save us?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah. AFAIK extensions trigger if the package is loaded at any point, e.g. even if a dep loads it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also a thing where it doesn't seem like we can nicely get the resulting error message (the error hint is not in the msg of the error or something). So I think we just leave this for now 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, it does seem nasty to test for. Have you tried locally that it does what you expect?

end
end

Expand Down
104 changes: 104 additions & 0 deletions src/experimental.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
module Experimental

using DynamicPPL: DynamicPPL

torfjelde marked this conversation as resolved.
Show resolved Hide resolved
# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency.
"""
is_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...)

Check if the `model` supports evaluation using the provided `context` and `varinfo`.

!!! warning
Loading JET.jl is required before calling this function.

# Arguments
- `model`: The model to verify the support for.
- `context`: The context to use for the model evaluation.
- `varinfo`: The varinfo to verify the support for.

# Keyword Arguments
- `only_ddpl`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`.
mhauru marked this conversation as resolved.
Show resolved Hide resolved

# Returns
- `issuccess`: `true` if the model supports the varinfo, otherwise `false`.
- `report`: The result of `report_call` from JET.jl.
"""
function is_suitable_varinfo end

# Internal hook for JET.jl to overload.
function _determine_varinfo_jet end

"""
determine_suitable_varinfo(model[, context]; only_ddpl::Bool=true)

Return a suitable varinfo for the given `model`.

See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref).

!!! warning
For full functionality, this requires JET.jl to be loaded.
If JET.jl is not loaded, this function will assume the model is compatible with typed varinfo.

# Arguments
- `model`: The model for which to determine the varinfo.
- `context`: The context to use for the model evaluation. Default: `SamplingContext()`.

# Keyword Arguments
mhauru marked this conversation as resolved.
Show resolved Hide resolved
- `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl.

# Examples

```jldoctest
julia> using DynamicPPL.Experimental: determine_suitable_varinfo

julia> using JET: JET # needs to be loaded for full functionality

julia> @model function model_with_random_support()
x ~ Bernoulli()
if x
y ~ Normal()
else
z ~ Normal()
end
end
model_with_random_support (generic function with 2 methods)

julia> model = model_with_random_support();

julia> # Typed varinfo cannot handle this random support model properly
# as using a single execution of the model will not see all random variables.
# Hence, this this model requires untyped varinfo.
vi = determine_suitable_varinfo(model);
┌ Warning: Model seems incompatible with typed varinfo. Falling back to untyped varinfo.
└ @ DynamicPPLJETExt ~/.julia/dev/DynamicPPL.jl/ext/DynamicPPLJETExt.jl:48

julia> vi isa typeof(DynamicPPL.untyped_varinfo(model))
true

julia> # In contrast, a simple model with no random support can be handled by typed varinfo.
@model model_with_static_support() = x ~ Normal()
model_with_static_support (generic function with 2 methods)

julia> vi = determine_suitable_varinfo(model_with_static_support());

julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support()))
true
```
"""
function determine_suitable_varinfo(
model::DynamicPPL.Model,
context::DynamicPPL.AbstractContext=DynamicPPL.SamplingContext();
only_ddpl::Bool=true,
)
# If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that.
return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing
_determine_varinfo_jet(model, context; only_ddpl)
else
# Warn the user.
@warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo."

Check warning on line 98 in src/experimental.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental.jl#L98

Added line #L98 was not covered by tests
# Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat).
DynamicPPL.typed_varinfo(model, context)

Check warning on line 100 in src/experimental.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental.jl#L100

Added line #L100 was not covered by tests
end
end

end
16 changes: 15 additions & 1 deletion src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ function AbstractMCMC.step(
return vi, nothing
end

"""
default_varinfo(rng, model, sampler[, context])

Return a default varinfo object for the given `model` and `sampler`.

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `model::Model`: Model for which we want to create a varinfo object.
- `sampler::AbstractSampler`: Sampler which will make use of the varinfo object.
- `context::AbstractContext`: Context in which the model is evaluated.

# Returns
- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`.
"""
function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler)
return default_varinfo(rng, model, sampler, DefaultContext())
end
Expand Down Expand Up @@ -126,7 +140,7 @@ By default, `data` is returned.
loadstate(data) = data

"""
default_chaintype(sampler)
default_chain_type(sampler)

Default type of the chain of posterior samples from `sampler`.
"""
Expand Down
34 changes: 20 additions & 14 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,30 +164,36 @@ function has_varnamedvector(vi::VarInfo)
end

"""
untyped_varinfo([rng, ]model[, sampler, context])
untyped_varinfo(model[, context, metadata])

Return an untyped `VarInfo` instance for the model `model`.
Return an untyped varinfo object for the given `model` and `context`.

# Arguments
- `model::Model`: The model for which to create the varinfo object.
- `context::AbstractContext`: The context in which to evaluate the model. Default: `SamplingContext()`.
- `metadata::Union{Metadata,VarNamedVector}`: The metadata to use for the varinfo object.
Default: `Metadata()`.
"""
function untyped_varinfo(
rng::Random.AbstractRNG,
model::Model,
sampler::AbstractSampler=SampleFromPrior(),
context::AbstractContext=DefaultContext(),
context::AbstractContext=SamplingContext(),
metadata::Union{Metadata,VarNamedVector}=Metadata(),
)
varinfo = VarInfo(metadata)
return last(evaluate!!(model, varinfo, SamplingContext(rng, sampler, context)))
end
function untyped_varinfo(
model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}...
)
return untyped_varinfo(Random.default_rng(), model, args...)
return last(
evaluate!!(model, varinfo, hassampler(context) ? context : SamplingContext(context))
)
end

"""
typed_varinfo([rng, ]model[, sampler, context])
typed_varinfo(model[, context, metadata])

Return a typed varinfo object for the given `model`, `sampler` and `context`.

This simply calls [`DynamicPPL.untyped_varinfo`](@ref) and converts the resulting
varinfo object to a typed varinfo object.

Return a typed `VarInfo` instance for the model `model`.
See also: [`DynamicPPL.untyped_varinfo`](@ref)
"""
typed_varinfo(args...) = TypedVarInfo(untyped_varinfo(args...))

Expand All @@ -198,7 +204,7 @@ function VarInfo(
context::AbstractContext=DefaultContext(),
metadata::Union{Metadata,VarNamedVector}=Metadata(),
)
return typed_varinfo(rng, model, sampler, context, metadata)
return typed_varinfo(model, SamplingContext(rng, sampler, context), metadata)
end
VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)

Expand Down
Loading
Loading