Skip to content

Commit

Permalink
Merge branch 'master' into py/aqua
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm authored Jan 10, 2025
2 parents 1092558 + e673b69 commit de4c661
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 84 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.33.0"
version = "0.33.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ DynamicPPL.LogDensityFunction
A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref).

```@docs
|(::Model, ::Any)
|(::Model, ::Union{Tuple,NamedTuple,AbstractDict{<:VarName}})
condition
DynamicPPL.conditioned
```
Expand Down Expand Up @@ -403,6 +403,7 @@ LikelihoodContext
PriorContext
MiniBatchContext
PrefixContext
ConditionContext
```

### Samplers
Expand Down
124 changes: 50 additions & 74 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,20 +309,40 @@ function prefix(model::Model, ::Val{x}) where {x}
return contextualize(model, PrefixContext{Symbol(x)}(model.context))
end

struct ConditionContext{Values,Ctx<:AbstractContext} <: AbstractContext
"""
ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext}
Model context that contains values that are to be conditioned on. The values
can either be a NamedTuple mapping symbols to values, such as `(a=1, b=2)`, or
an AbstractDict mapping varnames to values (e.g. `Dict(@varname(a) => 1,
@varname(b) => 2)`). The former is more performant, but the latter must be used
when there are varnames that cannot be represented as symbols, e.g.
`@varname(x[1])`.
"""
struct ConditionContext{
Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext
} <: AbstractContext
values::Values
context::Ctx
end

const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}}
const DictConditionContext = ConditionContext{<:AbstractDict}

ConditionContext(values) = ConditionContext(values, DefaultContext())

# Try to avoid nested `ConditionContext`.
# Use DefaultContext as the default base context
function ConditionContext(values::Union{NamedTuple,AbstractDict})
return ConditionContext(values, DefaultContext())
end
# Optimisation when there are no values to condition on
ConditionContext(::NamedTuple{()}, context::AbstractContext) = context
# Collapse consecutive levels of `ConditionContext`. Note that this overrides
# values inside the child context, thus giving precedence to the outermost
# `ConditionContext`.
function ConditionContext(values::NamedTuple, context::NamedConditionContext)
# Note that this potentially overrides values from `context`, thus giving
# precedence to the outmost `ConditionContext`.
return ConditionContext(merge(context.values, values), childcontext(context))
end
function ConditionContext(values::AbstractDict{<:VarName}, context::DictConditionContext)
return ConditionContext(merge(context.values, values), childcontext(context))
end

Expand Down Expand Up @@ -399,43 +419,6 @@ function getconditioned_nested(::IsParent, context, vn)
end
end

"""
condition([context::AbstractContext,] values::NamedTuple)
condition([context::AbstractContext]; values...)
Return `ConditionContext` with `values` and `context` if `values` is non-empty,
otherwise return `context` which is [`DefaultContext`](@ref) by default.
See also: [`decondition`](@ref)
"""
AbstractPPL.condition(; values...) = condition(NamedTuple(values))
AbstractPPL.condition(values::NamedTuple) = condition(DefaultContext(), values)
function AbstractPPL.condition(value::Pair{<:VarName}, values::Pair{<:VarName}...)
return condition((value, values...))
end
function AbstractPPL.condition(values::NTuple{<:Any,<:Pair{<:VarName}})
return condition(DefaultContext(), values)
end
AbstractPPL.condition(context::AbstractContext, values::NamedTuple{()}) = context
function AbstractPPL.condition(
context::AbstractContext, values::Union{AbstractDict,NamedTuple}
)
return ConditionContext(values, context)
end
function AbstractPPL.condition(context::AbstractContext; values...)
return condition(context, NamedTuple(values))
end
function AbstractPPL.condition(
context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}...
)
return condition(context, (value, values...))
end
function AbstractPPL.condition(
context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}}
)
return condition(context, Dict(values))
end

"""
decondition(context::AbstractContext, syms...)
Expand All @@ -445,41 +428,34 @@ Note that this recursively traverses contexts, deconditioning all along the way.
See also: [`condition`](@ref)
"""
AbstractPPL.decondition(::IsLeaf, context, args...) = context
function AbstractPPL.decondition(::IsParent, context, args...)
return setchildcontext(context, decondition(childcontext(context), args...))
decondition_context(::IsLeaf, context, args...) = context
function decondition_context(::IsParent, context, args...)
return setchildcontext(context, decondition_context(childcontext(context), args...))
end
function AbstractPPL.decondition(context, args...)
return decondition(NodeTrait(context), context, args...)
function decondition_context(context, args...)
return decondition_context(NodeTrait(context), context, args...)
end
function AbstractPPL.decondition(context::ConditionContext)
return decondition(childcontext(context))
end
function AbstractPPL.decondition(context::ConditionContext, sym)
return condition(
decondition(childcontext(context), sym), BangBang.delete!!(context.values, sym)
)
function decondition_context(context::ConditionContext)
return decondition_context(childcontext(context))
end
function AbstractPPL.decondition(context::ConditionContext, sym, syms...)
return decondition(
condition(
decondition(childcontext(context), syms...),
BangBang.delete!!(context.values, sym),
),
syms...,
)
end

function AbstractPPL.decondition(
context::NamedConditionContext, vn::VarName{sym}
) where {sym}
return condition(
decondition(childcontext(context), vn), BangBang.delete!!(context.values, sym)
)
function decondition_context(context::ConditionContext, sym, syms...)
new_values = deepcopy(context.values)
for s in (sym, syms...)
new_values = BangBang.delete!!(new_values, s)
end
return if length(new_values) == 0
# No more values left, can unwrap
decondition_context(childcontext(context), syms...)
else
ConditionContext(
new_values, decondition_context(childcontext(context), sym, syms...)
)
end
end
function AbstractPPL.decondition(context::ConditionContext, vn::VarName)
return condition(
decondition(childcontext(context), vn), BangBang.delete!!(context.values, vn)
function decondition_context(context::NamedConditionContext, vn::VarName{sym}) where {sym}
return ConditionContext(
BangBang.delete!!(context.values, sym),
decondition_context(childcontext(context), vn),
)
end

Expand Down
38 changes: 30 additions & 8 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ Return a `Model` which now treats variables on the right-hand side as observatio
See [`condition`](@ref) for more information and examples.
"""
Base.:|(model::Model, values) = condition(model, values)
Base.:|(model::Model, values::Union{Pair,Tuple,NamedTuple,AbstractDict{<:VarName}}) =
condition(model, values)

"""
condition(model::Model; values...)
Expand Down Expand Up @@ -264,11 +265,32 @@ julia> conditioned_model_dict()
1.0
```
"""
AbstractPPL.condition(model::Model; values...) = condition(model, NamedTuple(values))
function AbstractPPL.condition(model::Model, value, values...)
return contextualize(model, condition(model.context, value, values...))
function AbstractPPL.condition(model::Model, values...)
# Positional arguments - need to handle cases carefully
return contextualize(
model, ConditionContext(_make_conditioning_values(values...), model.context)
)
end
function AbstractPPL.condition(model::Model; values...)
# Keyword arguments -- just convert to a NamedTuple
return contextualize(model, ConditionContext(NamedTuple(values), model.context))
end

"""
_make_conditioning_values(vals...)
Convert different types of input to either a `NamedTuple` or `AbstractDict` of
conditioning values, suitable for storage in a `ConditionContext`.
This handles all the cases where `vals` is either already a NamedTuple or
AbstractDict (e.g. `model | (x=1, y=2)`), as well as if they are splatted (e.g.
`condition(model, x=1, y=2)`).
"""
_make_conditioning_values(values::Union{NamedTuple,AbstractDict}) = values
_make_conditioning_values(values::Tuple{Pair{<:VarName}}) = Dict(values)
_make_conditioning_values(v::Pair{<:Symbol}, vs::Pair{<:Symbol}...) = NamedTuple(v, vs...)
_make_conditioning_values(v::Pair{<:VarName}, vs::Pair{<:VarName}...) = Dict(v, vs...)

"""
decondition(model::Model)
decondition(model::Model, variables...)
Expand Down Expand Up @@ -379,7 +401,7 @@ true
```
"""
function AbstractPPL.decondition(model::Model, syms...)
return contextualize(model, decondition(model.context, syms...))
return contextualize(model, decondition_context(model.context, syms...))
end

"""
Expand Down Expand Up @@ -413,7 +435,7 @@ julia> # Returns all the variables we have conditioned on + their values.
(x = 100.0, m = 1.0)
julia> # Nested ones also work (note that `PrefixContext` does nothing to the result).
cm = condition(contextualize(m, PrefixContext{:a}(condition(m=1.0))), x=100.0);
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0);
julia> conditioned(cm)
(x = 100.0, m = 1.0)
Expand All @@ -425,15 +447,15 @@ julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed
a.m
julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation.
cm = condition(contextualize(m, PrefixContext{:a}(condition(var"a.m"=1.0))), x=100.0);
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((var"a.m"=1.0,)))), x=100.0);
julia> conditioned(cm).x
100.0
julia> conditioned(cm).var"a.m"
1.0
julia> keys(VarInfo(cm)) # <= no variables are sampled
julia> keys(VarInfo(cm)) # No variables are sampled
VarName[]
```
"""
Expand Down
83 changes: 83 additions & 0 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using DynamicPPL:
PointwiseLogdensityContext,
contextual_isassumption,
ConditionContext,
decondition_context,
hasconditioned,
getconditioned,
hasconditioned_nested,
Expand Down Expand Up @@ -196,6 +197,88 @@ end
@test EnzymeCore.EnzymeRules.inactive_type(typeof(context))
end

@testset "ConditionContext" begin
@testset "Nesting" begin
@testset "NamedTuple" begin
n1 = (x=1, y=2)
n2 = (x=3,)
# Values from outer context should override inner one
ctx1 = ConditionContext(n1, ConditionContext(n2))
@test ctx1.values == (x=1, y=2)
# Check that the two ConditionContexts are collapsed
@test childcontext(ctx1) isa DefaultContext
# Then test the nesting the other way round
ctx2 = ConditionContext(n2, ConditionContext(n1))
@test ctx2.values == (x=3, y=2)
@test childcontext(ctx2) isa DefaultContext
end

@testset "Dict" begin
# Same tests as NamedTuple above
d1 = Dict(@varname(x) => 1, @varname(y) => 2)
d2 = Dict(@varname(x) => 3)
ctx1 = ConditionContext(d1, ConditionContext(d2))
@test ctx1.values == Dict(@varname(x) => 1, @varname(y) => 2)
@test childcontext(ctx1) isa DefaultContext
ctx2 = ConditionContext(d2, ConditionContext(d1))
@test ctx2.values == Dict(@varname(x) => 3, @varname(y) => 2)
@test childcontext(ctx2) isa DefaultContext
end
end

@testset "decondition_context" begin
@testset "NamedTuple" begin
ctx = ConditionContext((x=1, y=2, z=3))
# Decondition all variables
@test decondition_context(ctx) isa DefaultContext
# Decondition only some variables
dctx = decondition_context(ctx, :x)
@test dctx isa ConditionContext
@test dctx.values == (y=2, z=3)
dctx = decondition_context(ctx, :y, :z)
@test dctx isa ConditionContext
@test dctx.values == (x=1,)
# Decondition all variables manually
@test decondition_context(ctx, :x, :y, :z) isa DefaultContext
end

@testset "Dict" begin
ctx = ConditionContext(
Dict(@varname(x) => 1, @varname(y) => 2, @varname(z) => 3)
)
# Decondition all variables
@test decondition_context(ctx) isa DefaultContext
# Decondition only some variables
dctx = decondition_context(ctx, @varname(x))
@test dctx isa ConditionContext
@test dctx.values == Dict(@varname(y) => 2, @varname(z) => 3)
dctx = decondition_context(ctx, @varname(y), @varname(z))
@test dctx isa ConditionContext
@test dctx.values == Dict(@varname(x) => 1)
# Decondition all variables manually
@test decondition_context(ctx, @varname(x), @varname(y), @varname(z)) isa
DefaultContext
end

@testset "Nesting" begin
ctx = ConditionContext(
(x=1, y=2), ConditionContext(Dict(@varname(a) => 3, @varname(b) => 4))
)
# Decondition an outer variable
dctx = decondition_context(ctx, :x)
@test dctx.values == (y=2,)
@test childcontext(dctx).values == Dict(@varname(a) => 3, @varname(b) => 4)
# Decondition an inner variable
dctx = decondition_context(ctx, @varname(a))
@test dctx.values == (x=1, y=2)
@test childcontext(dctx).values == Dict(@varname(b) => 4)
# Try deconditioning everything
dctx = decondition_context(ctx)
@test dctx isa DefaultContext
end
end
end

@testset "FixedContext" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
retval = model()
Expand Down
Loading

0 comments on commit de4c661

Please sign in to comment.