From 4dc2a7237954c2185fffe823a15e69ea300d57bb Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 16 Jan 2025 15:24:57 +0000 Subject: [PATCH 01/29] Remove selector stuff from varinfo tests --- test/varinfo.jl | 259 ++++++++---------------------------------------- 1 file changed, 44 insertions(+), 215 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index 9a55cffb9..c6fa78658 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,8 +1,3 @@ -# Dummy algorithm for testing -# Invoke with: DynamicPPL.Sampler(MyAlg{(:x, :y)}(), ...) -struct MyAlg{space} end -DynamicPPL.getspace(::DynamicPPL.Sampler{MyAlg{space}}) where {space} = space - function check_varinfo_keys(varinfo, vns) if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, @@ -19,16 +14,13 @@ function check_varinfo_keys(varinfo, vns) end end -function randr( - vi::DynamicPPL.VarInfo, - vn::VarName, - dist::Distribution, - spl::DynamicPPL.Sampler, - count::Bool=false, -) +""" +Return the value of `vn` in `vi`. If one doesn't exist, sample and set it. +""" +function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) if !haskey(vi, vn) r = rand(dist) - push!!(vi, vn, r, dist, spl) + push!!(vi, vn, r, dist) r elseif DynamicPPL.is_flagged(vi, vn, "del") DynamicPPL.unset_flag!(vi, vn, "del") @@ -37,8 +29,6 @@ function randr( DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) r else - count && checkindex(vn, vi, spl) - DynamicPPL.updategid!(vi, vn, spl) vi[vn] end end @@ -66,7 +56,6 @@ end tind = fmeta.idcs[vn] @test meta.dists[ind] == fmeta.dists[tind] @test meta.orders[ind] == fmeta.orders[tind] - @test meta.gids[ind] == fmeta.gids[tind] for flag in keys(meta.flags) @test meta.flags[flag][ind] == fmeta.flags[flag][tind] end @@ -89,22 +78,6 @@ end vn2 = @varname x[1][2] @test vn2 == vn1 @test hash(vn2) == hash(vn1) - @test inspace(vn1, (:x,)) - - # Tests for `inspace` - space = (:x, :y, @varname(z[1]), @varname(M[1:10, :])) - - @test inspace(@varname(x), space) - @test inspace(@varname(y), space) - @test inspace(@varname(x[1]), space) - @test inspace(@varname(z[1][1]), space) - @test inspace(@varname(z[1][:]), space) - @test inspace(@varname(z[1][2:3:10]), space) - @test inspace(@varname(M[[2, 3], 1]), space) - @test_throws ErrorException inspace(@varname(M[:, 1:4]), space) - @test inspace(@varname(M[1, [2, 4, 6]]), space) - @test !inspace(@varname(z[2]), space) - @test !inspace(@varname(z), space) function test_base!!(vi_original) vi = empty!!(vi_original) @@ -114,38 +87,31 @@ end vn = @varname x dist = Normal(0, 1) r = rand(dist) - gid = DynamicPPL.Selector() @test isempty(vi) @test ~haskey(vi, vn) @test !(vn in keys(vi)) - vi = push!!(vi, vn, r, dist, gid) + vi = push!!(vi, vn, r, dist) @test ~isempty(vi) @test haskey(vi, vn) @test vn in keys(vi) @test length(vi[vn]) == 1 - @test length(vi[SampleFromPrior()]) == 1 - @test vi[vn] == r - @test vi[SampleFromPrior()][1] == r vi = DynamicPPL.setindex!!(vi, 2 * r, vn) @test vi[vn] == 2 * r - @test vi[SampleFromPrior()][1] == 2 * r - vi = DynamicPPL.setindex!!(vi, [3 * r], SampleFromPrior()) - @test vi[vn] == 3 * r - @test vi[SampleFromPrior()][1] == 3 * r # TODO(mhauru) Implement these functions for other VarInfo types too. if vi isa DynamicPPL.VectorVarInfo delete!(vi, vn) @test isempty(vi) - vi = push!!(vi, vn, r, dist, gid) + vi = push!!(vi, vn, r, dist) end vi = empty!!(vi) @test isempty(vi) - return push!!(vi, vn, r, dist, gid) + vi = push!!(vi, vn, r, dist) + @test ~isempty(vi) end vi = VarInfo() @@ -182,9 +148,8 @@ end vn_x = @varname x dist = Normal(0, 1) r = rand(dist) - gid = Selector() - push!!(vi, vn_x, r, dist, gid) + push!!(vi, vn_x, r, dist) # del is set by default @test !is_flagged(vi, vn_x, "del") @@ -204,35 +169,13 @@ end vn_x = @varname x vn_y = @varname y untyped_vi = VarInfo() - untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector()) + untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1)) typed_vi = TypedVarInfo(untyped_vi) - typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector()) + typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1)) @test typed_vi[vn_x] == 1.0 @test typed_vi[vn_y] == 2.0 end - @testset "setgid!" begin - vi = VarInfo(DynamicPPL.Metadata()) - meta = vi.metadata - vn = @varname x - dist = Normal(0, 1) - r = rand(dist) - gid1 = Selector() - gid2 = Selector(2, :HMC) - - push!!(vi, vn, r, dist, gid1) - @test meta.gids[meta.idcs[vn]] == Set([gid1]) - setgid!(vi, gid2, vn) - @test meta.gids[meta.idcs[vn]] == Set([gid1, gid2]) - - vi = empty!!(TypedVarInfo(vi)) - meta = vi.metadata - push!!(vi, vn, r, dist, gid1) - @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1]) - setgid!(vi, gid2, vn) - @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2]) - end - @testset "setval! & setval_and_resample!" begin @model function testmodel(x) n = length(x) @@ -397,10 +340,9 @@ end """ function test_setval!(model, chain; sample_idx=1, chain_idx=1) var_info = VarInfo(model) - spl = SampleFromPrior() - θ_old = var_info[spl] + θ_old = var_info[:] DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) - θ_new = var_info[spl] + θ_new = var_info[:] @test θ_old != θ_new vals = DynamicPPL.values_as(var_info, OrderedDict) iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) @@ -448,9 +390,9 @@ end # Check that linking and invlinking set the `trans` flag accordingly v = copy(meta.vals) - link!!(vi, model) + vi = link!!(vi, model) @test all(x -> istrans(vi, x), meta.vns) - invlink!!(vi, model) + vi = invlink!!(vi, model) @test all(x -> !istrans(vi, x), meta.vns) @test meta.vals ≈ v atol = 1e-10 @@ -461,21 +403,20 @@ end @test all(x -> !istrans(vi, x), meta.m.vns) v_s = copy(meta.s.vals) v_m = copy(meta.m.vals) - link!!(vi, model) + vi = link!!(vi, model) @test all(x -> istrans(vi, x), meta.s.vns) @test all(x -> istrans(vi, x), meta.m.vns) - invlink!!(vi, model) + vi = invlink!!(vi, model) @test all(x -> !istrans(vi, x), meta.s.vns) @test all(x -> !istrans(vi, x), meta.m.vns) @test meta.s.vals ≈ v_s atol = 1e-10 @test meta.m.vals ≈ v_m atol = 1e-10 # Transform only one variable (`s`) but not the others (`m`) - spl = DynamicPPL.Sampler(MyAlg{(:s,)}(), model) - link!!(vi, spl, model) + vi = link!!(vi, @varname(s), model) @test all(x -> istrans(vi, x), meta.s.vns) @test all(x -> !istrans(vi, x), meta.m.vns) - invlink!!(vi, spl, model) + vi = invlink!!(vi, @varname(s), model) @test all(x -> !istrans(vi, x), meta.s.vns) @test all(x -> !istrans(vi, x), meta.m.vns) @test meta.s.vals ≈ v_s atol = 1e-10 @@ -856,62 +797,6 @@ end @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] @test DynamicPPL.istrans(varinfo_merged, @varname(x)) end - - # The below used to error, testing to avoid regression. - @testset "merge gids" begin - gidset_left = Set([Selector(1)]) - vi_left = VarInfo() - vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left) - gidset_right = Set([Selector(2)]) - vi_right = VarInfo() - vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right) - varinfo_merged = merge(vi_left, vi_right) - @test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left - @test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right - end - end - - @testset "VarInfo with selectors" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - varinfo = VarInfo( - model, - DynamicPPL.SampleFromPrior(), - DynamicPPL.DefaultContext(), - DynamicPPL.Metadata(), - ) - selector = DynamicPPL.Selector() - spl = Sampler(MyAlg{(:s,)}(), model, selector) - - vns = DynamicPPL.TestUtils.varnames(model) - vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns) - vns_m = filter(vn -> DynamicPPL.getsym(vn) === :m, vns) - for vn in vns_s - DynamicPPL.updategid!(varinfo, vn, spl) - end - - # Should only get the variables subsumed by `@varname(s)`. - @test varinfo[spl] == - mapreduce(Base.Fix1(DynamicPPL.getindex_internal, varinfo), vcat, vns_s) - - # `link` - varinfo_linked = DynamicPPL.link(varinfo, spl, model) - # `s` variables should be linked - @test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s) - # `m` variables should NOT be linked - @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) - # And `varinfo` should be unchanged - @test all(!Base.Fix1(DynamicPPL.istrans, varinfo), vns) - - # `invlink` - varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, spl, model) - # `s` variables should no longer be linked - @test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_s) - # `m` variables should still not be linked - @test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_m) - # And `varinfo_linked` should be unchanged - @test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s) - @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) - end end @testset "sampling from linked varinfo" begin @@ -1014,25 +899,22 @@ end vi = DynamicPPL.VarInfo() dists = [Categorical([0.7, 0.3]), Normal()] - spl1 = DynamicPPL.Sampler(MyAlg{()}(), empty_model()) - spl2 = DynamicPPL.Sampler(MyAlg{()}(), empty_model()) - # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) + randr(vi, vn_z1, dists[1]) + randr(vi, vn_a1, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_b, dists[2], spl2) - randr(vi, vn_z2, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) + randr(vi, vn_b, dists[2]) + randr(vi, vn_z2, dists[1]) + randr(vi, vn_a2, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) + randr(vi, vn_z3, dists[1]) @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] @test DynamicPPL.get_num_produce(vi) == 3 DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) + DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @test DynamicPPL.is_flagged(vi, vn_z2, "del") @@ -1040,13 +922,13 @@ end @test DynamicPPL.is_flagged(vi, vn_z3, "del") DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) + randr(vi, vn_z1, dists[1]) + randr(vi, vn_a1, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z2, dists[1], spl1) + randr(vi, vn_z2, dists[1]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) + randr(vi, vn_z3, dists[1]) + randr(vi, vn_a2, dists[2]) @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] @test DynamicPPL.get_num_produce(vi) == 3 @@ -1054,21 +936,21 @@ end # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) + randr(vi, vn_z1, dists[1]) + randr(vi, vn_a1, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_b, dists[2], spl2) - randr(vi, vn_z2, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) + randr(vi, vn_b, dists[2]) + randr(vi, vn_z2, dists[1]) + randr(vi, vn_a2, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) + randr(vi, vn_z3, dists[1]) @test vi.metadata.z.orders == [1, 2, 3] @test vi.metadata.a.orders == [1, 2] @test vi.metadata.b.orders == [2] @test DynamicPPL.get_num_produce(vi) == 3 DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) + DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @test DynamicPPL.is_flagged(vi, vn_z2, "del") @@ -1076,69 +958,16 @@ end @test DynamicPPL.is_flagged(vi, vn_z3, "del") DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) + randr(vi, vn_z1, dists[1]) + randr(vi, vn_a1, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z2, dists[1], spl1) + randr(vi, vn_z2, dists[1]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) + randr(vi, vn_z3, dists[1]) + randr(vi, vn_a2, dists[2]) @test vi.metadata.z.orders == [1, 2, 3] @test vi.metadata.a.orders == [1, 3] @test vi.metadata.b.orders == [2] @test DynamicPPL.get_num_produce(vi) == 3 end - - @testset "varinfo ranges" begin - @model empty_model() = x = 1 - dists = [Normal(0, 1), MvNormal(zeros(2), I), Wishart(7, [1 0.5; 0.5 1])] - - function test_varinfo!(vi) - spl2 = DynamicPPL.Sampler(MyAlg{(:w, :u)}(), empty_model()) - vn_w = @varname w - randr(vi, vn_w, dists[1], spl2, true) - - vn_x = @varname x - vn_y = @varname y - vn_z = @varname z - vns = [vn_x, vn_y, vn_z] - - spl1 = DynamicPPL.Sampler(MyAlg{(:x, :y, :z)}(), empty_model()) - for i in 1:3 - r = randr(vi, vns[i], dists[i], spl1, false) - val = vi[vns[i]] - @test sum(val - r) <= 1e-9 - end - - idcs = DynamicPPL._getidcs(vi, spl1) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 3 - else - @test length(idcs) == 3 - end - @test length(vi[spl1]) == 7 - - idcs = DynamicPPL._getidcs(vi, spl2) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 1 - else - @test length(idcs) == 1 - end - @test length(vi[spl2]) == 1 - - vn_u = @varname u - randr(vi, vn_u, dists[1], spl2, true) - - idcs = DynamicPPL._getidcs(vi, spl2) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 2 - else - @test length(idcs) == 2 - end - @test length(vi[spl2]) == 2 - end - vi = DynamicPPL.VarInfo() - test_varinfo!(vi) - test_varinfo!(empty!!(DynamicPPL.TypedVarInfo(vi))) - end end From 9b492a33b7d6b007b446a4ee2e8e83f7c17485cb Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 16 Jan 2025 15:49:03 +0000 Subject: [PATCH 02/29] Implement link and invlink for varnames rather than samplers --- src/abstract_varinfo.jl | 48 ++++++++++++++--- src/threadsafe.jl | 56 ++++++++++++++----- src/transforming.jl | 22 +++++--- src/varinfo.jl | 116 +++++++++++++++++++++++++++++++++++----- 4 files changed, 202 insertions(+), 40 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 3f513d71d..a3a3c9c78 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -537,8 +537,17 @@ If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variabl """ function settrans!! end +# TODO(mhauru) The fact that we need to to define this type is a sign that the link/invlink +# API is hard to understand. To be fixed by removing samplers from it. +SamplerOrVarName = Union{ + AbstractSampler,VarName,NTuple{N,VarName} where N,AbstractVector{<:VarName} +} + """ link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their linked space, using the transformation `t`, @@ -552,13 +561,19 @@ link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model) function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) return link!!(t, vi, SampleFromPrior(), model) end -function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) +function link!!(vi::AbstractVarInfo, spl_or_vn::SamplerOrVarName, model::Model) # Use `default_transformation` to decide which transformation to use if none is specified. - return link!!(default_transformation(model, vi), vi, spl, model) + return link!!(default_transformation(model, vi), vi, spl_or_vn, model) +end +function link!!(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) + return link!!(t, deepcopy(vi), (vn,), model) end """ link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`. @@ -571,13 +586,19 @@ link(vi::AbstractVarInfo, model::Model) = link(vi, SampleFromPrior(), model) function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) return link(t, deepcopy(vi), SampleFromPrior(), model) end -function link(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) +function link(vi::AbstractVarInfo, spl_or_vn::SamplerOrVarName, model::Model) # Use `default_transformation` to decide which transformation to use if none is specified. - return link(default_transformation(model, vi), deepcopy(vi), spl, model) + return link(default_transformation(model, vi), deepcopy(vi), spl_or_vn, model) +end +function link(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) + return link(t, deepcopy(vi), (vn,), model) end """ invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their constrained space, using the (inverse of) @@ -591,9 +612,14 @@ invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) return invlink!!(t, vi, SampleFromPrior(), model) end -function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) +function invlink!!(vi::AbstractVarInfo, spl_or_vn::SamplerOrVarName, model::Model) # Here we extract the `transformation` from `vi` rather than using the default one. - return invlink!!(transformation(vi), vi, spl, model) + return invlink!!(transformation(vi), vi, spl_or_vn, model) +end +function invlink!!( + t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model +) + return invlink!!(t, vi, (vn,), model) end # Vector-based ones. @@ -629,6 +655,9 @@ end """ invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) + invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) + invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of) @@ -642,8 +671,11 @@ invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), mode function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) return invlink(t, vi, SampleFromPrior(), model) end -function invlink(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) - return invlink(transformation(vi), vi, spl, model) +function invlink(vi::AbstractVarInfo, spl_or_vn::SamplerOrVarName, model::Model) + return invlink(transformation(vi), vi, spl_or_vn, model) +end +function invlink(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) + return invlink(t, vi, (vn,), model) end """ diff --git a/src/threadsafe.jl b/src/threadsafe.jl index cedb0efad..bb60f7bcf 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -81,28 +81,44 @@ haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) +SamplerOrVarNameIterator = Union{ + AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} +} + function link!!( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model + t::AbstractTransformation, + vi::ThreadSafeVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, spl_or_vn, model) end function invlink!!( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model + t::AbstractTransformation, + vi::ThreadSafeVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, spl_or_vn, model) end function link( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model + t::AbstractTransformation, + vi::ThreadSafeVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = link(t, vi.varinfo, spl_or_vn, model) end function invlink( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model + t::AbstractTransformation, + vi::ThreadSafeVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, spl_or_vn, model) end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. @@ -110,13 +126,19 @@ end # consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates # to define `getlogp(vi)`. function link!!( - t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model + t::DynamicTransformation, + vi::ThreadSafeVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end function invlink!!( - ::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model + ::DynamicTransformation, + vi::ThreadSafeVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) return settrans!!( last(evaluate!!(model, vi, DynamicTransformationContext{true}())), @@ -125,15 +147,21 @@ function invlink!!( end function link( - t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model + t::DynamicTransformation, + vi::ThreadSafeVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) - return link!!(t, deepcopy(vi), spl, model) + return link!!(t, deepcopy(vi), spl_or_vn, model) end function invlink( - t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model + t::DynamicTransformation, + vi::ThreadSafeVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) - return invlink!!(t, deepcopy(vi), spl, model) + return invlink!!(t, deepcopy(vi), spl_or_vn, model) end function maybe_invlink_before_eval!!( diff --git a/src/transforming.jl b/src/transforming.jl index 1f6c55e24..d0f3774c5 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -91,14 +91,18 @@ function dot_tilde_assume( return r, lp, vi end +SamplerOrVarNameIterator = Union{ + AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} +} + function link!!( - t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model + t::DynamicTransformation, vi::AbstractVarInfo, ::SamplerOrVarNameIterator, model::Model ) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end function invlink!!( - ::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model + ::DynamicTransformation, vi::AbstractVarInfo, ::SamplerOrVarNameIterator, model::Model ) return settrans!!( last(evaluate!!(model, vi, DynamicTransformationContext{true}())), @@ -107,13 +111,19 @@ function invlink!!( end function link( - t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model + t::DynamicTransformation, + vi::AbstractVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) - return link!!(t, deepcopy(vi), spl, model) + return link!!(t, deepcopy(vi), spl_or_vn, model) end function invlink( - t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model + t::DynamicTransformation, + vi::AbstractVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) - return invlink!!(t, deepcopy(vi), spl, model) + return invlink!!(t, deepcopy(vi), spl_or_vn, model) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 3ebb505e0..d9c1247fc 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1201,29 +1201,40 @@ _isempty(vnv::VarNamedVector) = isempty(vnv) return Expr(:&&, (:(_isempty(metadata.$f)) for f in names)...) end +SamplerOrVarNameIterator = Union{ + AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} +} + # X -> R for all variables associated with given sampler -function link!!(t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) +function link!!( + t::DynamicTransformation, vi::VarInfo, spl_or_vn::SamplerOrVarNameIterator, model::Model +) # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return link(t, vi, spl, model) + has_varnamedvector(vi) && return link(t, vi, spl_or_vn, model) # Call `_link!` instead of `link!` to avoid deprecation warning. - _link!(vi, spl) + _link!(vi, spl_or_vn) return vi end function link!!( t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, - spl::AbstractSampler, + spl_or_vn::SamplerOrVarNameIterator, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl_or_vn, model) end function _link!(vi::UntypedVarInfo, spl::AbstractSampler) + return _link!(vi, _getvns(vi, spl)) +end + +function _link!( + vi::UntypedVarInfo, vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}} +) # TODO: Change to a lazy iterator over `vns` - vns = _getvns(vi, spl) if ~istrans(vi, vns[1]) for vn in vns f = internal_to_linked_internal_transform(vi, vn) @@ -1234,6 +1245,7 @@ function _link!(vi::UntypedVarInfo, spl::AbstractSampler) @warn("[DynamicPPL] attempt to link a linked vi") end end + function _link!(vi::TypedVarInfo, spl::AbstractSampler) return _link!(vi, spl, Val(getspace(spl))) end @@ -1268,26 +1280,70 @@ end return expr end +""" + filter_subsumed(vns1, vns2) + +Return the subset of `vns2` that are subsumed by any variable in `vns1`. +""" +function filter_subsumed(vns1, vns2) + return filter(x -> any(subsumes(y, x) for y in vns1), vns2) +end + +function _link!( + vi::TypedVarInfo, vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}} +) + return _link!(vi.metadata, vi, vns) +end +@generated function _link!( + metadata::NamedTuple{names}, + vi, + vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}}, +) where {names,space} + expr = Expr(:block) + for f in names + push!( + expr.args, + quote + f_vns = vi.metadata.$f.vns + f_vns = filter_subsumed(vns, f_vns) + if !isempty(f_vns) + if !istrans(vi, f_vns[1]) + # Iterate over all `f_vns` and transform + for vn in f_vns + f = internal_to_linked_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) + settrans!!(vi, true, vn) + end + else + @warn("[DynamicPPL] attempt to link a linked vi") + end + end + end, + ) + end + return expr +end + # R -> X for all variables associated with given sampler function invlink!!( - t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model + t::DynamicTransformation, vi::VarInfo, spl_or_vn::SamplerOrVarNameIterator, model::Model ) # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return invlink(t, vi, spl, model) + has_varnamedvector(vi) && return invlink(t, vi, spl_or_vn, model) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, spl) + _invlink!(vi, spl_or_vn) return vi end function invlink!!( ::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, - spl::AbstractSampler, + spl_or_vn::SamplerOrVarNameIterator, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, spl_or_vn, model) end function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, model::Model) @@ -1299,7 +1355,11 @@ function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, mode end function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) - vns = _getvns(vi, spl) + return _invlink!(vi, _getvns(vi, spl)) +end +function _invlink!( + vi::UntypedVarInfo, vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}} +) if istrans(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) @@ -1310,6 +1370,7 @@ function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end + function _invlink!(vi::TypedVarInfo, spl::AbstractSampler) return _invlink!(vi, spl, Val(getspace(spl))) end @@ -1344,6 +1405,37 @@ end return expr end +function _invlink!( + vi::TypedVarInfo, vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}} +) + return _invlink!(vi.metadata, vi, vns) +end +@generated function _invlink!(metadata::NamedTuple{names}, vi, vns) where {names} + expr = Expr(:block) + for f in names + push!( + expr.args, + quote + f_vns = vi.metadata.$f.vns + f_vns = filter_subsumed(vns, f_vns) + if !isempty(f_vns) + if istrans(vi, f_vns[1]) + # Iterate over all `f_vns` and transform + for vn in f_vns + f = linked_internal_to_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) + settrans!!(vi, false, vn) + end + else + @warn("[DynamicPPL] attempt to invlink an invlinked vi") + end + end + end, + ) + end + return expr +end + function _inner_transform!(vi::VarInfo, vn::VarName, f) return _inner_transform!(getmetadata(vi, vn), vi, vn, f) end From b508f08a6faef408d409d52859dd55efb4ce80f2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 16 Jan 2025 15:49:49 +0000 Subject: [PATCH 03/29] Replace set_retained_vns_del_by_spl! with set_retained_vns_del! --- src/varinfo.jl | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index d9c1247fc..8d68f2b86 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -842,6 +842,9 @@ Returns a tuple of the unique symbols of random variables sampled in `vi`. syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols syms(vi::TypedVarInfo) = keys(vi.metadata) +_getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs) +_getidcs(vi::TypedVarInfo) = _getidcs(vi.metadata) + # Get all indices of variables belonging to SampleFromPrior: # if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to # the SampleFromPrior sampler @@ -2109,37 +2112,36 @@ function unset_flag!(vnv::VarNamedVector, ::VarName, flag::String, ignorable::Bo end """ - set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler) + set_retained_vns_del!(vi::VarInfo) Set the `"del"` flag of variables in `vi` with `order > vi.num_produce[]` to `true`. """ -function set_retained_vns_del_by_spl!(vi::UntypedVarInfo, spl::Sampler) - # Get the indices of `vns` that belong to `spl` as a vector - gidcs = _getidcs(vi, spl) +function set_retained_vns_del!(vi::UntypedVarInfo) + idcs = _getidcs(vi) if get_num_produce(vi) == 0 - for i in length(gidcs):-1:1 - vi.metadata.flags["del"][gidcs[i]] = true + for i in length(idcs):-1:1 + vi.metadata.flags["del"][idcs[i]] = true end else for i in 1:length(vi.orders) - if i in gidcs && vi.orders[i] > get_num_produce(vi) + if i in idcs && vi.orders[i] > get_num_produce(vi) vi.metadata.flags["del"][i] = true end end end return nothing end -function set_retained_vns_del_by_spl!(vi::TypedVarInfo, spl::Sampler) +function set_retained_vns_del!(vi::TypedVarInfo) # Get the indices of `vns` that belong to `spl` as a NamedTuple, one entry for each symbol - gidcs = _getidcs(vi, spl) - return _set_retained_vns_del_by_spl!(vi.metadata, gidcs, get_num_produce(vi)) + idcs = _getidcs(vi) + return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi)) end -@generated function _set_retained_vns_del_by_spl!( - metadata, gidcs::NamedTuple{names}, num_produce +@generated function _set_retained_vns_del!( + metadata, idcs::NamedTuple{names}, num_produce ) where {names} expr = Expr(:block) for f in names - f_gidcs = :(gidcs.$f) + f_idcs = :(idcs.$f) f_orders = :(metadata.$f.orders) f_flags = :(metadata.$f.flags) push!( @@ -2147,12 +2149,12 @@ end quote # Set the flag for variables with symbol `f` if num_produce == 0 - for i in length($f_gidcs):-1:1 - $f_flags["del"][$f_gidcs[i]] = true + for i in length($f_idcs):-1:1 + $f_flags["del"][$f_idcs[i]] = true end else for i in 1:length($f_orders) - if i in $f_gidcs && $f_orders[i] > num_produce + if i in $f_idcs && $f_orders[i] > num_produce $f_flags["del"][i] = true end end From b8880d1d2169ca4eaf4612635d4b26cfa0bb08fc Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 16 Jan 2025 17:38:02 +0000 Subject: [PATCH 04/29] Make linking tests more extensive --- test/varinfo.jl | 56 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index c6fa78658..fd1c9a2e9 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -374,13 +374,21 @@ end end @testset "link!! and invlink!!" begin - @model gdemo(x, y) = begin + @model gdemo(a, b, ::Type{T}=Float64) where {T} = begin s ~ InverseGamma(2, 3) m ~ Uniform(0, 2) - x ~ Normal(m, sqrt(s)) - y ~ Normal(m, sqrt(s)) + x = Vector{T}(undef, length(a)) + x .~ Normal(m, sqrt(s)) + y = Vector{T}(undef, length(a)) + for i in eachindex(y) + y[i] ~ Normal(m, sqrt(s)) + end + a .~ Normal(m, sqrt(s)) + for i in eachindex(b) + b[i] ~ Normal(x[i] * y[i], sqrt(s)) + end end - model = gdemo(1.0, 2.0) + model = gdemo([1.0, 1.5], [2.0, 2.5]) # Check that instantiating the model does not perform linking vi = VarInfo() @@ -399,10 +407,13 @@ end # Check that linking and invlinking preserves the values vi = TypedVarInfo(vi) meta = vi.metadata - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) v_s = copy(meta.s.vals) v_m = copy(meta.m.vals) + v_x = copy(meta.x.vals) + v_y = copy(meta.y.vals) + + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) vi = link!!(vi, model) @test all(x -> istrans(vi, x), meta.s.vns) @test all(x -> istrans(vi, x), meta.m.vns) @@ -412,15 +423,30 @@ end @test meta.s.vals ≈ v_s atol = 1e-10 @test meta.m.vals ≈ v_m atol = 1e-10 - # Transform only one variable (`s`) but not the others (`m`) - vi = link!!(vi, @varname(s), model) - @test all(x -> istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) - vi = invlink!!(vi, @varname(s), model) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) - @test meta.s.vals ≈ v_s atol = 1e-10 - @test meta.m.vals ≈ v_m atol = 1e-10 + # Transform only one variable + all_vns = vcat(meta.s.vns, meta.m.vns, meta.x.vns, meta.y.vns) + for vn in [ + @varname(s), + @varname(m), + @varname(x), + @varname(y), + @varname(x[2]), + @varname(y[2]) + ] + target_vns = filter(x -> subsumes(vn, x), all_vns) + other_vns = filter(x -> !subsumes(vn, x), all_vns) + @test !isempty(target_vns) + @test !isempty(other_vns) + vi = link!!(vi, vn, model) + @test all(x -> istrans(vi, x), target_vns) + @test all(x -> !istrans(vi, x), other_vns) + vi = invlink!!(vi, vn, model) + @test all(x -> !istrans(vi, x), all_vns) + @test meta.s.vals ≈ v_s atol = 1e-10 + @test meta.m.vals ≈ v_m atol = 1e-10 + @test meta.x.vals ≈ v_x atol = 1e-10 + @test meta.y.vals ≈ v_y atol = 1e-10 + end end @testset "istrans" begin From 99a8490631b10e9696f501d0555e38770b18128c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 22 Jan 2025 14:51:18 +0000 Subject: [PATCH 05/29] Remove sampler indexing from link methods (but not invlink) --- src/abstract_varinfo.jl | 48 ++++++---- src/simple_varinfo.jl | 4 +- src/threadsafe.jl | 27 ++---- src/transforming.jl | 10 +- src/varinfo.jl | 198 +++++++++++++++++++++++++--------------- 5 files changed, 170 insertions(+), 117 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index a3a3c9c78..26238c12e 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -548,7 +548,6 @@ SamplerOrVarName = Union{ link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their linked space, using the transformation `t`, mutating `vi` if possible. @@ -557,16 +556,25 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`invlink!!`](@ref). """ -link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model) -function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return link!!(t, vi, SampleFromPrior(), model) +# Use `default_transformation` to decide which transformation to use if none is specified. +function link!!(vi::AbstractVarInfo, model::Model) + return link!!(default_transformation(model, vi), vi, model) +end +function link!!(vi::AbstractVarInfo, vns, model::Model) + return link!!(default_transformation(model, vi), vi, vns, model) end -function link!!(vi::AbstractVarInfo, spl_or_vn::SamplerOrVarName, model::Model) - # Use `default_transformation` to decide which transformation to use if none is specified. - return link!!(default_transformation(model, vi), vi, spl_or_vn, model) +# If no variable names are provided, link all variables. +function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) + vns = collect(keys(vi)) + # In case e.g. vns = Any[]. + if !(eltype(vns) <: VarName) + vns = collect(VarName, vns) + end + return link!!(t, vi, vns, model) end +# Wrap a single VarName in a singleton tuple. function link!!(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) - return link!!(t, deepcopy(vi), (vn,), model) + return link!!(t, vi, (vn,), model) end """ @@ -574,7 +582,6 @@ end link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`. @@ -582,16 +589,25 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`invlink`](@ref). """ -link(vi::AbstractVarInfo, model::Model) = link(vi, SampleFromPrior(), model) -function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return link(t, deepcopy(vi), SampleFromPrior(), model) +# Use `default_transformation` to decide which transformation to use if none is specified. +function link(vi::AbstractVarInfo, model::Model) + return link(default_transformation(model, vi), vi, model) +end +function link(vi::AbstractVarInfo, vns, model::Model) + return link(default_transformation(model, vi), vi, vns, model) end -function link(vi::AbstractVarInfo, spl_or_vn::SamplerOrVarName, model::Model) - # Use `default_transformation` to decide which transformation to use if none is specified. - return link(default_transformation(model, vi), deepcopy(vi), spl_or_vn, model) +# If no variable names are provided, link all variables. +function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) + vns = collect(keys(vi)) + # In case e.g. vns = Any[]. + if !(eltype(vns) <: VarName) + vns = collect(VarName, vns) + end + return link(t, vi, vns, model) end +# Wrap a single VarName in a singleton tuple. function link(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) - return link(t, deepcopy(vi), (vn,), model) + return link(t, vi, (vn,), model) end """ diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b6a84238e..6bb723b29 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -680,7 +680,7 @@ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarIn function link!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - spl::AbstractSampler, + ::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}}, model::Model, ) # TODO: Make sure that `spl` is respected. @@ -695,7 +695,7 @@ end function invlink!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - spl::AbstractSampler, + ::AbstractSampler, model::Model, ) # TODO: Make sure that `spl` is respected. diff --git a/src/threadsafe.jl b/src/threadsafe.jl index bb60f7bcf..c5a77c3ef 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -84,14 +84,12 @@ islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl SamplerOrVarNameIterator = Union{ AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} } +VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName},NamedTuple} function link!!( - t::AbstractTransformation, - vi::ThreadSafeVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model ) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, spl_or_vn, model) + return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, vns, model) end function invlink!!( @@ -104,12 +102,9 @@ function invlink!!( end function link( - t::AbstractTransformation, - vi::ThreadSafeVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model ) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, spl_or_vn, model) + return Accessors.@set vi.varinfo = link(t, vi.varinfo, vns, model) end function invlink( @@ -126,10 +121,7 @@ end # consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates # to define `getlogp(vi)`. function link!!( - t::DynamicTransformation, - vi::ThreadSafeVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::DynamicTransformation, vi::ThreadSafeVarInfo, ::VarNameCollection, model::Model ) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end @@ -147,12 +139,9 @@ function invlink!!( end function link( - t::DynamicTransformation, - vi::ThreadSafeVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::DynamicTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model ) - return link!!(t, deepcopy(vi), spl_or_vn, model) + return link!!(t, deepcopy(vi), vns, model) end function invlink( diff --git a/src/transforming.jl b/src/transforming.jl index d0f3774c5..6acaf787c 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -94,9 +94,10 @@ end SamplerOrVarNameIterator = Union{ AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} } +VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName},NamedTuple} function link!!( - t::DynamicTransformation, vi::AbstractVarInfo, ::SamplerOrVarNameIterator, model::Model + t::DynamicTransformation, vi::AbstractVarInfo, ::VarNameCollection, model::Model ) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end @@ -111,12 +112,9 @@ function invlink!!( end function link( - t::DynamicTransformation, - vi::AbstractVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::DynamicTransformation, vi::AbstractVarInfo, vns::VarNameCollection, model::Model ) - return link!!(t, deepcopy(vi), spl_or_vn, model) + return link!!(t, deepcopy(vi), vns, model) end function invlink( diff --git a/src/varinfo.jl b/src/varinfo.jl index 8d68f2b86..4c4125ad8 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1207,35 +1207,37 @@ end SamplerOrVarNameIterator = Union{ AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} } +VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName},NamedTuple} + +# Specialise link!! without varnames provided for TypedVarInfo. The usual version gets +# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, that helps +# keep the downstread calls to link!! type stable. +function link!!(t::AbstractTransformation, vi::TypedVarInfo, model::Model) + return link!!(t, vi, all_varnames_namedtuple(vi), model) +end # X -> R for all variables associated with given sampler -function link!!( - t::DynamicTransformation, vi::VarInfo, spl_or_vn::SamplerOrVarNameIterator, model::Model -) +function link!!(t::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return link(t, vi, spl_or_vn, model) + has_varnamedvector(vi) && return link(t, vi, vns, model) # Call `_link!` instead of `link!` to avoid deprecation warning. - _link!(vi, spl_or_vn) + _link!(vi, vns) return vi end function link!!( t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, - spl_or_vn::SamplerOrVarNameIterator, + vns::VarNameCollection, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl_or_vn, model) -end - -function _link!(vi::UntypedVarInfo, spl::AbstractSampler) - return _link!(vi, _getvns(vi, spl)) + return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end function _link!( - vi::UntypedVarInfo, vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}} + vi::UntypedVarInfo, vns::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} ) # TODO: Change to a lazy iterator over `vns` if ~istrans(vi, vns[1]) @@ -1249,24 +1251,30 @@ function _link!( end end -function _link!(vi::TypedVarInfo, spl::AbstractSampler) - return _link!(vi, spl, Val(getspace(spl))) +""" + filter_subsumed(vns1, vns2) + +Return the subset of `vns2` that are subsumed by any variable in `vns1`. +""" +function filter_subsumed(vns1, vns2) + return filter(x -> any(subsumes(y, x) for y in vns1), vns2) end -function _link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) - vns = _getvns(vi, spl) - return _link!(vi.metadata, vi, vns, spaceval) + +function _link!(vi::TypedVarInfo, vns::VarNameCollection) + return _link!(vi.metadata, vi, vns) end @generated function _link!( - metadata::NamedTuple{names}, vi, vns, ::Val{space} -) where {names,space} + ::NamedTuple{names}, vi, vns::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} +) where {names} expr = Expr(:block) for f in names - if inspace(f, space) || length(space) == 0 - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - if ~istrans(vi, f_vns[1]) + push!( + expr.args, + quote + f_vns = vi.metadata.$f.vns + f_vns = filter_subsumed(vns, f_vns) + if !isempty(f_vns) + if !istrans(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns f = internal_to_linked_internal_transform(vi, vn) @@ -1276,39 +1284,26 @@ end else @warn("[DynamicPPL] attempt to link a linked vi") end - end, - ) - end + end + end, + ) end return expr end -""" - filter_subsumed(vns1, vns2) - -Return the subset of `vns2` that are subsumed by any variable in `vns1`. -""" -function filter_subsumed(vns1, vns2) - return filter(x -> any(subsumes(y, x) for y in vns1), vns2) -end - -function _link!( - vi::TypedVarInfo, vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}} -) - return _link!(vi.metadata, vi, vns) -end @generated function _link!( - metadata::NamedTuple{names}, - vi, - vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}}, -) where {names,space} + ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} +) where {metadata_names,vns_names} expr = Expr(:block) - for f in names + for f in metadata_names + if !(f in vns_names) + continue + end push!( expr.args, quote f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(vns, f_vns) + f_vns = filter_subsumed(vns.$f, f_vns) if !isempty(f_vns) if !istrans(vi, f_vns[1]) # Iterate over all `f_vns` and transform @@ -1361,7 +1356,7 @@ function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) return _invlink!(vi, _getvns(vi, spl)) end function _invlink!( - vi::UntypedVarInfo, vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}} + vi::UntypedVarInfo, vns::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} ) if istrans(vi, vns[1]) for vn in vns @@ -1382,7 +1377,7 @@ function _invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) return _invlink!(vi.metadata, vi, vns, spaceval) end @generated function _invlink!( - metadata::NamedTuple{names}, vi, vns, ::Val{space} + ::NamedTuple{names}, vi, vns, ::Val{space} ) where {names,space} expr = Expr(:block) for f in names @@ -1408,12 +1403,10 @@ end return expr end -function _invlink!( - vi::TypedVarInfo, vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}} -) +function _invlink!(vi::TypedVarInfo, vns::VarNameCollection) return _invlink!(vi.metadata, vi, vns) end -@generated function _invlink!(metadata::NamedTuple{names}, vi, vns) where {names} +@generated function _invlink!(::NamedTuple{names}, vi, vns) where {names} expr = Expr(:block) for f in names push!( @@ -1466,59 +1459,116 @@ function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) return map(Returns(nothing), varinfo.metadata) end -function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) - return _link(model, varinfo, spl) +function link( + ::DynamicTransformation, varinfo::VarInfo, vns::VarNameCollection, model::Model +) + return _link(model, varinfo, vns) end function link( ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, - spl::AbstractSampler, + vns::VarNameCollection, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, spl, model) + return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model) end function _link( - model::Model, varinfo::Union{UntypedVarInfo,VectorVarInfo}, spl::AbstractSampler + model::Model, varinfo::Union{UntypedVarInfo,VectorVarInfo}, vns::VarNameCollection ) varinfo = deepcopy(varinfo) return VarInfo( - _link_metadata!!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), + _link_metadata!!(model, varinfo, varinfo.metadata, vns), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) end -function _link(model::Model, varinfo::TypedVarInfo, spl::AbstractSampler) +""" + unique_syms(vns::T) where {T<:NTuple{N,VarName}} + +Return the unique symbols of the variables in `vns`. +""" +@generated function unique_syms(vns::T) where {T<:NTuple{N,VarName}} where {N} + retval = Expr(:tuple) + syms = [first(vn.parameters) for vn in T.parameters] + for sym in unique(syms) + push!(retval.args, QuoteNode(sym)) + end + return retval +end + +""" + varname_namedtuple(vns::NTuple{N,VarName}) where {N} + varname_namedtuple(vns::AbstractVector{<:VarName}) + +Return a `NamedTuple` of the variables in `vns` grouped by symbol. +""" +function varname_namedtuple(vns::NTuple{N,VarName} where {N}) + syms = unique_syms(vns) + elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) + return NamedTuple{syms}(elements) +end + +# This method is type unstable, but that can't be helped: The problem is inherently type +# unstable if there are VarNames with multiple symbols in a Vector. +function varname_namedtuple(vns::AbstractVector{<:VarName}) + syms = tuple(unique(map(getsym, vns))...) + elements = tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...) + return NamedTuple{syms}(elements) +end + +# A simpler, type stable implementation when all the VarNames in a Vector have the same +# symbol. +function varname_namedtuple(vns::AbstractVector{<:VarName{T}}) where {T} + return NamedTuple{(T,)}((vns,)) +end + +varname_namedtuple(vns::NamedTuple) = vns + +""" + all_varnames_namedtuple(vi::AbstractVarInfo) + +Return a `NamedTuple` of the variables in `vi` grouped by symbol. +""" +all_varnames_namedtuple(vi::TypedVarInfo) = all_varnames_namedtuple(vi.metadata) + +@generated function all_varnames_namedtuple(md::NamedTuple{names}) where {names} + expr = Expr(:tuple) + for f in names + push!(expr.args, :($f = keys(md.$f))) + end + return expr +end + +function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameCollection) varinfo = deepcopy(varinfo) - md = _link_metadata_namedtuple!( - model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) - ) + vns_namedtuple = varname_namedtuple(vns) + md = _link_metadata!(model, varinfo, varinfo.metadata, vns_namedtuple) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -@generated function _link_metadata_namedtuple!( +@generated function _link_metadata!( model::Model, varinfo::VarInfo, - metadata::NamedTuple{names}, - vns::NamedTuple, - ::Val{space}, -) where {names,space} + metadata::NamedTuple{metadata_names}, + vns::NamedTuple{vns_names}, +) where {metadata_names,vns_names} vals = Expr(:tuple) - for f in names - if inspace(f, space) || length(space) == 0 + for f in metadata_names + if f in vns_names push!(vals.args, :(_link_metadata!!(model, varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end end - return :(NamedTuple{$names}($vals)) + return :(NamedTuple{$metadata_names}($vals)) end -function _link_metadata!!(model::Model, varinfo::VarInfo, metadata::Metadata, target_vns) +function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. @@ -1691,7 +1741,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ end function _invlink_metadata!!( - model::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns + ::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns ) vns = target_vns === nothing ? keys(metadata) : target_vns for vn in vns From 4a79b1f66e267a8e7a4951bd81599d979e66b899 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 22 Jan 2025 17:17:07 +0000 Subject: [PATCH 06/29] Remove indexing by samplers from invlink --- src/abstract_varinfo.jl | 49 ++++++++++------- src/simple_varinfo.jl | 2 +- src/threadsafe.jl | 26 +++------ src/transforming.jl | 12 +--- src/varinfo.jl | 118 ++++++++++++++++------------------------ 5 files changed, 87 insertions(+), 120 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 26238c12e..f28755c9f 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -537,12 +537,6 @@ If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variabl """ function settrans!! end -# TODO(mhauru) The fact that we need to to define this type is a sign that the link/invlink -# API is hard to understand. To be fixed by removing samplers from it. -SamplerOrVarName = Union{ - AbstractSampler,VarName,NTuple{N,VarName} where N,AbstractVector{<:VarName} -} - """ link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) @@ -615,7 +609,6 @@ end invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their constrained space, using the (inverse of) transformation `t`, mutating `vi` if possible. @@ -624,14 +617,23 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`link!!`](@ref). """ -invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model) -function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return invlink!!(t, vi, SampleFromPrior(), model) +# Use `default_transformation` to decide which transformation to use if none is specified. +function invlink!!(vi::AbstractVarInfo, model::Model) + return invlink!!(default_transformation(model, vi), vi, model) end -function invlink!!(vi::AbstractVarInfo, spl_or_vn::SamplerOrVarName, model::Model) - # Here we extract the `transformation` from `vi` rather than using the default one. - return invlink!!(transformation(vi), vi, spl_or_vn, model) +function invlink!!(vi::AbstractVarInfo, vns, model::Model) + return invlink!!(default_transformation(model, vi), vi, vns, model) end +# If no variable names are provided, invlink!! all variables. +function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) + vns = collect(keys(vi)) + # In case e.g. vns = Any[]. + if !(eltype(vns) <: VarName) + vns = collect(VarName, vns) + end + return invlink!!(t, vi, vns, model) +end +# Wrap a single VarName in a singleton tuple. function invlink!!( t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model ) @@ -674,7 +676,6 @@ end invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) - invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of) transformation `t`. @@ -683,13 +684,23 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`link`](@ref). """ -invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model) -function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return invlink(t, vi, SampleFromPrior(), model) +# Use `default_transformation` to decide which transformation to use if none is specified. +function invlink(vi::AbstractVarInfo, model::Model) + return invlink(default_transformation(model, vi), vi, model) +end +function invlink(vi::AbstractVarInfo, vns, model::Model) + return invlink(default_transformation(model, vi), vi, vns, model) end -function invlink(vi::AbstractVarInfo, spl_or_vn::SamplerOrVarName, model::Model) - return invlink(transformation(vi), vi, spl_or_vn, model) +# If no variable names are provided, invlink all variables. +function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) + vns = collect(keys(vi)) + # In case e.g. vns = Any[]. + if !(eltype(vns) <: VarName) + vns = collect(VarName, vns) + end + return invlink(t, vi, vns, model) end +# Wrap a single VarName in a singleton tuple. function invlink(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) return invlink(t, vi, (vn,), model) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 6bb723b29..b4e836371 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -695,7 +695,7 @@ end function invlink!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - ::AbstractSampler, + ::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}}, model::Model, ) # TODO: Make sure that `spl` is respected. diff --git a/src/threadsafe.jl b/src/threadsafe.jl index c5a77c3ef..c75ec2291 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -93,12 +93,9 @@ function link!!( end function invlink!!( - t::AbstractTransformation, - vi::ThreadSafeVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model ) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, spl_or_vn, model) + return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, vns, model) end function link( @@ -108,12 +105,9 @@ function link( end function invlink( - t::AbstractTransformation, - vi::ThreadSafeVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model ) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, spl_or_vn, model) + return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, vns, model) end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. @@ -127,10 +121,7 @@ function link!!( end function invlink!!( - ::DynamicTransformation, - vi::ThreadSafeVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + ::DynamicTransformation, vi::ThreadSafeVarInfo, ::VarNameCollection, model::Model ) return settrans!!( last(evaluate!!(model, vi, DynamicTransformationContext{true}())), @@ -145,12 +136,9 @@ function link( end function invlink( - t::DynamicTransformation, - vi::ThreadSafeVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::DynamicTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model ) - return invlink!!(t, deepcopy(vi), spl_or_vn, model) + return invlink!!(t, deepcopy(vi), vns, model) end function maybe_invlink_before_eval!!( diff --git a/src/transforming.jl b/src/transforming.jl index 6acaf787c..46b42d8ed 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -91,9 +91,6 @@ function dot_tilde_assume( return r, lp, vi end -SamplerOrVarNameIterator = Union{ - AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} -} VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName},NamedTuple} function link!!( @@ -103,7 +100,7 @@ function link!!( end function invlink!!( - ::DynamicTransformation, vi::AbstractVarInfo, ::SamplerOrVarNameIterator, model::Model + ::DynamicTransformation, vi::AbstractVarInfo, ::VarNameCollection, model::Model ) return settrans!!( last(evaluate!!(model, vi, DynamicTransformationContext{true}())), @@ -118,10 +115,7 @@ function link( end function invlink( - t::DynamicTransformation, - vi::AbstractVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::DynamicTransformation, vi::AbstractVarInfo, vns::VarNameCollection, model::Model ) - return invlink!!(t, deepcopy(vi), spl_or_vn, model) + return invlink!!(t, deepcopy(vi), vns, model) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 4c4125ad8..c05a42aba 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1322,26 +1322,33 @@ end return expr end +# Specialise invlink!! without varnames provided for TypedVarInfo. The usual version gets +# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, that helps +# keep the downstread calls to link!! type stable. +function invlink!!(t::AbstractTransformation, vi::TypedVarInfo, model::Model) + return invlink!!(t, vi, all_varnames_namedtuple(vi), model) +end + # R -> X for all variables associated with given sampler function invlink!!( - t::DynamicTransformation, vi::VarInfo, spl_or_vn::SamplerOrVarNameIterator, model::Model + t::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model ) # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return invlink(t, vi, spl_or_vn, model) + has_varnamedvector(vi) && return invlink(t, vi, vns, model) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, spl_or_vn) + _invlink!(vi, vns) return vi end function invlink!!( ::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, - spl_or_vn::SamplerOrVarNameIterator, + vns::VarNameCollection, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, spl_or_vn, model) + return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model) end function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, model::Model) @@ -1352,9 +1359,6 @@ function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, mode return maybe_invlink_before_eval!!(t, vi, context, model) end -function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) - return _invlink!(vi, _getvns(vi, spl)) -end function _invlink!( vi::UntypedVarInfo, vns::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} ) @@ -1369,62 +1373,33 @@ function _invlink!( end end -function _invlink!(vi::TypedVarInfo, spl::AbstractSampler) - return _invlink!(vi, spl, Val(getspace(spl))) -end -function _invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) - vns = _getvns(vi, spl) - return _invlink!(vi.metadata, vi, vns, spaceval) +function _invlink!(vi::TypedVarInfo, vns::VarNameCollection) + vns_namedtuple = varname_namedtuple(vns) + return _invlink!(vi.metadata, vi, vns_namedtuple) end @generated function _invlink!( - ::NamedTuple{names}, vi, vns, ::Val{space} -) where {names,space} + ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} +) where {metadata_names,vns_names} expr = Expr(:block) - for f in names - if inspace(f, space) || length(space) == 0 - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - if istrans(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) - end - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") - end - end, - ) + for f in metadata_names + if !(f in vns_names) + continue end - end - return expr -end -function _invlink!(vi::TypedVarInfo, vns::VarNameCollection) - return _invlink!(vi.metadata, vi, vns) -end -@generated function _invlink!(::NamedTuple{names}, vi, vns) where {names} - expr = Expr(:block) - for f in names push!( expr.args, quote f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(vns, f_vns) - if !isempty(f_vns) - if istrans(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) - end - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") + f_vns = filter_subsumed(vns.$f, f_vns) + if istrans(vi, f_vns[1]) + # Iterate over all `f_vns` and transform + for vn in f_vns + f = linked_internal_to_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) + settrans!!(vi, false, vn) end + else + @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end, ) @@ -1641,56 +1616,55 @@ function _link_metadata!!( end function invlink( - ::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model + ::DynamicTransformation, varinfo::VarInfo, vns::VarNameCollection, model::Model ) - return _invlink(model, varinfo, spl) + return _invlink(model, varinfo, vns) end function invlink( ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, - spl::AbstractSampler, + vns::VarNameCollection, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, spl, model) + return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, vns, model) end -function _invlink(model::Model, varinfo::VarInfo, spl::AbstractSampler) +function _invlink(model::Model, varinfo::VarInfo, vns::VarNameCollection) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), + _invlink_metadata!!(model, varinfo, varinfo.metadata, vns), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) end -function _invlink(model::Model, varinfo::TypedVarInfo, spl::AbstractSampler) +function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameCollection) varinfo = deepcopy(varinfo) - md = _invlink_metadata_namedtuple!( - model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) - ) + vns_namedtuple = varname_namedtuple(vns) + md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns_namedtuple) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -@generated function _invlink_metadata_namedtuple!( +@generated function _invlink_metadata!( model::Model, varinfo::VarInfo, - metadata::NamedTuple{names}, - vns::NamedTuple, - ::Val{space}, -) where {names,space} + metadata::NamedTuple{metadata_names}, + vns::NamedTuple{vns_names}, +) where {metadata_names,vns_names} vals = Expr(:tuple) - for f in names - if inspace(f, space) || length(space) == 0 + for f in metadata_names + if (f in vns_names) push!(vals.args, :(_invlink_metadata!!(model, varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end end - return :(NamedTuple{$names}($vals)) + return :(NamedTuple{$metadata_names}($vals)) end + function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns From 090608bc66e4c5d3317f44ccfb4cad582e903a43 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 22 Jan 2025 18:05:38 +0000 Subject: [PATCH 07/29] Work towards removing sampler indexing with StaticTransformation --- src/abstract_varinfo.jl | 41 +++++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index f28755c9f..6b7e412a8 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -644,30 +644,46 @@ end function link!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, - spl::AbstractSampler, - model::Model, + vns::Union{NTuple{N,VarName} where {N},AbstractVector{<:VarName}}, + ::Model, ) + # TODO(mhauru) The behavior of this before the removal of indexing with samplers was a + # bit mixed. For TypedVarInfo you could transform only a subset of the variables, but + # for UntypedVarInfo and SimpleVarInfo it was silently assumed that all variables were + # being set. Unsure if we should support this or not, but at least it now errors + # loudly. + all_vns = Set(keys(vi)) + if Set(vns) != all_vns + msg = "StaticTransforming only a subset of variables is not supported." + throw(ArgumentError(msg)) + end b = inverse(t.bijector) - x = vi[spl] + x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(unflatten(vi, spl, y), lp_new) + vi_new = setlogp!!(unflatten(vi, y), lp_new) return settrans!!(vi_new, t) end function invlink!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, - spl::AbstractSampler, - model::Model, + vns::Union{NTuple{N,VarName} where {N},AbstractVector{<:VarName}}, + ::Model, ) + # TODO(mhauru) See comment in link!! above. + all_vns = Set(keys(vi)) + if Set(vns) != all_vns + msg = "StaticTransforming only a subset of variables is not supported." + throw(ArgumentError(msg)) + end b = t.bijector - y = vi[spl] + y = vi[:] x, logjac = with_logabsdet_jacobian(b, y) lp_new = getlogp(vi) + logjac - vi_new = setlogp!!(unflatten(vi, spl, x), lp_new) + vi_new = setlogp!!(unflatten(vi, x), lp_new) return settrans!!(vi_new, NoTransformation()) end @@ -774,9 +790,14 @@ function maybe_invlink_before_eval!!( return vi end function maybe_invlink_before_eval!!( - t::StaticTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model + t::StaticTransformation, vi::AbstractVarInfo, ::AbstractContext, model::Model ) - return invlink!!(t, vi, _default_sampler(context), model) + # TODO(mhauru) Why does this function need the context argument? + vns = collect(keys(vi)) + if !(eltype(vns) <: VarName) + vns = collect(VarName, vns) + end + return invlink!!(t, vi, vns, model) end function _default_sampler(context::AbstractContext) From 474985376a2c6788b2f36cadabc0e9b4db4cbebb Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 12:09:48 +0000 Subject: [PATCH 08/29] Fix invlink/link for TypedVarInfo and StaticTransformation --- src/varinfo.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 527ac2dc1..e29152b8c 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1160,8 +1160,8 @@ VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName},Na # Specialise link!! without varnames provided for TypedVarInfo. The usual version gets # the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, that helps -# keep the downstread calls to link!! type stable. -function link!!(t::AbstractTransformation, vi::TypedVarInfo, model::Model) +# keep the downstream calls to link!! type stable. +function link!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model) return link!!(t, vi, all_varnames_namedtuple(vi), model) end @@ -1273,8 +1273,8 @@ end # Specialise invlink!! without varnames provided for TypedVarInfo. The usual version gets # the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, that helps -# keep the downstread calls to link!! type stable. -function invlink!!(t::AbstractTransformation, vi::TypedVarInfo, model::Model) +# keep the downstream calls to link!! type stable. +function invlink!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model) return invlink!!(t, vi, all_varnames_namedtuple(vi), model) end From e960679a1d7e97dec87dc67fafdfb8ddde242a7f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 12:09:59 +0000 Subject: [PATCH 09/29] Fix a test in models.jl --- test/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/model.jl b/test/model.jl index 45c770cc4..a9d0b160f 100644 --- a/test/model.jl +++ b/test/model.jl @@ -226,7 +226,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() model = DynamicPPL.TestUtils.demo_dynamic_constraint() spl = SampleFromPrior() vi = VarInfo(model, spl, DefaultContext(), DynamicPPL.Metadata()) - link!!(vi, spl, model) + vi = link!!(vi, model) for i in 1:10 # Sample with large variations. From d507a535521bd43c2c032982ab3049275c3ff119 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 12:37:55 +0000 Subject: [PATCH 10/29] Move some functions to utils.jl, add tests and docstrings --- src/utils.jl | 61 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/varinfo.jl | 42 ---------------------------------- test/utils.jl | 29 ++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 42 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 5fedd3039..b64ae46cc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1268,3 +1268,64 @@ _merge(left::NamedTuple, right::NamedTuple) = merge(left, right) _merge(left::AbstractDict, right::AbstractDict) = merge(left, right) _merge(left::AbstractDict, right::NamedTuple{()}) = left _merge(left::NamedTuple{()}, right::AbstractDict) = right + +""" + unique_syms(vns::T) where {T<:NTuple{N,VarName}} + +Return the unique symbols of the variables in `vns`. + +Note that `unique_syms` is only defined for `Tuple`s of `VarName`s. For a `Vector` you can +just use `Base.unique`. The point of `unique_syms` is that it supports constant propagating +the result, which is possible with a `Tuple` but `Base.unique` won't allow it. +""" +@generated function unique_syms(::T) where {T<:NTuple{N,VarName}} where {N} + retval = Expr(:tuple) + syms = [first(vn.parameters) for vn in T.parameters] + for sym in unique(syms) + push!(retval.args, QuoteNode(sym)) + end + return retval +end + +""" + varname_namedtuple(vns::NTuple{N,VarName}) where {N} + varname_namedtuple(vns::AbstractVector{<:VarName}) + varname_namedtuple(vns::NamedTuple) + +Return a `NamedTuple` of the variables in `vns` grouped by symbol. + +`varname_namedtuple` is type table for inputs that are `Tuple`s, and for vectors when all +`VarName`s in the vector have the same symbol. For a `NamedTuple` it's a no-op. + +Example: +```julia +julia> vns_tuple = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2])) +(x, y[1], x.a, z[15], y[2]) + +julia> vns_nt = (; x=[@varname(x), @varname(x.a)], y=[@varname(y[1]), @varname(y[2])], z=[@varname(z[15])]) +(x = VarName{:x}[x, x.a], y = VarName{:y, IndexLens{Tuple{Int64}}}[y[1], y[2]], z = VarName{:z, IndexLens{Tuple{Int64}}}[z[15]]) + +julia> varname_namedtuple(vns_tuple) == vns_nt +``` +""" +function varname_namedtuple(vns::NTuple{N,VarName} where {N}) + syms = unique_syms(vns) + elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) + return NamedTuple{syms}(elements) +end + +# This method is type unstable, but that can't be helped: The problem is inherently type +# unstable if there are VarNames with multiple symbols in a Vector. +function varname_namedtuple(vns::AbstractVector{<:VarName}) + syms = tuple(unique(map(getsym, vns))...) + elements = tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...) + return NamedTuple{syms}(elements) +end + +# A simpler, type stable implementation when all the VarNames in a Vector have the same +# symbol. +function varname_namedtuple(vns::AbstractVector{<:VarName{T}}) where {T} + return NamedTuple{(T,)}((vns,)) +end + +varname_namedtuple(vns::NamedTuple) = vns diff --git a/src/varinfo.jl b/src/varinfo.jl index e29152b8c..74a6e3b8d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1411,48 +1411,6 @@ function _link( ) end -""" - unique_syms(vns::T) where {T<:NTuple{N,VarName}} - -Return the unique symbols of the variables in `vns`. -""" -@generated function unique_syms(vns::T) where {T<:NTuple{N,VarName}} where {N} - retval = Expr(:tuple) - syms = [first(vn.parameters) for vn in T.parameters] - for sym in unique(syms) - push!(retval.args, QuoteNode(sym)) - end - return retval -end - -""" - varname_namedtuple(vns::NTuple{N,VarName}) where {N} - varname_namedtuple(vns::AbstractVector{<:VarName}) - -Return a `NamedTuple` of the variables in `vns` grouped by symbol. -""" -function varname_namedtuple(vns::NTuple{N,VarName} where {N}) - syms = unique_syms(vns) - elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) - return NamedTuple{syms}(elements) -end - -# This method is type unstable, but that can't be helped: The problem is inherently type -# unstable if there are VarNames with multiple symbols in a Vector. -function varname_namedtuple(vns::AbstractVector{<:VarName}) - syms = tuple(unique(map(getsym, vns))...) - elements = tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...) - return NamedTuple{syms}(elements) -end - -# A simpler, type stable implementation when all the VarNames in a Vector have the same -# symbol. -function varname_namedtuple(vns::AbstractVector{<:VarName{T}}) where {T} - return NamedTuple{(T,)}((vns,)) -end - -varname_namedtuple(vns::NamedTuple) = vns - """ all_varnames_namedtuple(vi::AbstractVarInfo) diff --git a/test/utils.jl b/test/utils.jl index 3f435dca4..af7b3ee4d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -48,4 +48,33 @@ x = rand(dist) @test DynamicPPL.tovec(x) == vec(x.UL) end + + @testset "unique_syms" begin + vns = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2])) + @inferred DynamicPPL.unique_syms(vns) + @inferred DynamicPPL.unique_syms(()) + @test DynamicPPL.unique_syms(vns) == (:x, :y, :z) + @test DynamicPPL.unique_syms(()) == () + end + + @testset "varname_namedtuple" begin + vns_tuple = ( + @varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2]) + ) + vns_vec = collect(vns_tuple) + vns_nt = (; + x=[@varname(x), @varname(x.a)], + y=[@varname(y[1]), @varname(y[2])], + z=[@varname(z[15])], + ) + vns_vec_single_symbol = [@varname(x.a), @varname(x.b), @varname(x[1])] + @inferred DynamicPPL.varname_namedtuple(vns_tuple) + @inferred DynamicPPL.varname_namedtuple(vns_nt) + @inferred DynamicPPL.varname_namedtuple(vns_vec_single_symbol) + @test DynamicPPL.varname_namedtuple(vns_tuple) == vns_nt + @test DynamicPPL.varname_namedtuple(vns_vec) == vns_nt + @test DynamicPPL.varname_namedtuple(vns_nt) == vns_nt + @test DynamicPPL.varname_namedtuple(vns_vec_single_symbol) == + (; x=vns_vec_single_symbol) + end end From 41150b5e5fda23a19fdd2ecff1fb0a6847936256 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 12:54:42 +0000 Subject: [PATCH 11/29] Fix a docstring typo --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 74a6e3b8d..9b104a9a6 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1412,7 +1412,7 @@ function _link( end """ - all_varnames_namedtuple(vi::AbstractVarInfo) + all_varnames_namedtuple(vi::TypedVarInfo) Return a `NamedTuple` of the variables in `vi` grouped by symbol. """ From 45d1f137dd76f2030594cad97263435a71fad346 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 13:28:48 +0000 Subject: [PATCH 12/29] Various simplification to link/invlink --- src/abstract_varinfo.jl | 20 ++++----- src/simple_varinfo.jl | 4 +- src/threadsafe.jl | 5 --- src/utils.jl | 3 ++ src/varinfo.jl | 92 +++++++++++++++++------------------------ 5 files changed, 53 insertions(+), 71 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 6b7e412a8..b4aed0458 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -540,7 +540,7 @@ function settrans!! end """ link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N,VarName}, model::Model) link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) Transform the variables in `vi` to their linked space, using the transformation `t`, @@ -561,6 +561,7 @@ end function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) vns = collect(keys(vi)) # In case e.g. vns = Any[]. + # TODO(mhauru) Could we rather fix akeys` so that it would always return VarName[]? if !(eltype(vns) <: VarName) vns = collect(VarName, vns) end @@ -574,7 +575,7 @@ end """ link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N,VarName}, model::Model) link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`. @@ -594,6 +595,7 @@ end function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) vns = collect(keys(vi)) # In case e.g. vns = Any[]. + # TODO(mhauru) Could we rather fix akeys` so that it would always return VarName[]? if !(eltype(vns) <: VarName) vns = collect(VarName, vns) end @@ -607,7 +609,7 @@ end """ invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N,VarName}, model::Model) invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) Transform the variables in `vi` to their constrained space, using the (inverse of) @@ -644,7 +646,7 @@ end function link!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, - vns::Union{NTuple{N,VarName} where {N},AbstractVector{<:VarName}}, + vns::VarNameCollection, ::Model, ) # TODO(mhauru) The behavior of this before the removal of indexing with samplers was a @@ -669,7 +671,7 @@ end function invlink!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, - vns::Union{NTuple{N,VarName} where {N},AbstractVector{<:VarName}}, + vns::VarNameCollection, ::Model, ) # TODO(mhauru) See comment in link!! above. @@ -690,7 +692,7 @@ end """ invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) - invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) + invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N,VarName}, model::Model) invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of) @@ -793,11 +795,7 @@ function maybe_invlink_before_eval!!( t::StaticTransformation, vi::AbstractVarInfo, ::AbstractContext, model::Model ) # TODO(mhauru) Why does this function need the context argument? - vns = collect(keys(vi)) - if !(eltype(vns) <: VarName) - vns = collect(VarName, vns) - end - return invlink!!(t, vi, vns, model) + return invlink!!(t, vi, model) end function _default_sampler(context::AbstractContext) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b4e836371..f60c0b0fb 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -680,7 +680,7 @@ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarIn function link!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - ::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}}, + ::VarNameCollection, model::Model, ) # TODO: Make sure that `spl` is respected. @@ -695,7 +695,7 @@ end function invlink!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - ::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}}, + ::VarNameCollection, model::Model, ) # TODO: Make sure that `spl` is respected. diff --git a/src/threadsafe.jl b/src/threadsafe.jl index c75ec2291..25aa0d654 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -81,11 +81,6 @@ haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) -SamplerOrVarNameIterator = Union{ - AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} -} -VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName},NamedTuple} - function link!!( t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model ) diff --git a/src/utils.jl b/src/utils.jl index b64ae46cc..854ead3fd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,6 +2,9 @@ struct NoDefault end const NO_DEFAULT = NoDefault() +# A short-hand for a type commonly used in type signatures for VarInfo methods. +VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} + """ @addlogprob!(ex) diff --git a/src/varinfo.jl b/src/varinfo.jl index 9b104a9a6..97fb733e2 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1153,16 +1153,12 @@ _isempty(vnv::VarNamedVector) = isempty(vnv) return Expr(:&&, (:(_isempty(metadata.$f)) for f in names)...) end -SamplerOrVarNameIterator = Union{ - AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} -} -VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName},NamedTuple} - -# Specialise link!! without varnames provided for TypedVarInfo. The usual version gets -# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, that helps -# keep the downstream calls to link!! type stable. -function link!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model) - return link!!(t, vi, all_varnames_namedtuple(vi), model) +# Specialise link!! without varnames provided for TypedVarInfo. The generic version gets +# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which +# helps keep the downstream call to _link! type stable. +function link!!(::DynamicTransformation, vi::TypedVarInfo, ::Model) + _link!(vi, all_varnames_namedtuple(vi)) + return vi end # X -> R for all variables associated with given sampler @@ -1185,9 +1181,7 @@ function link!!( return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end -function _link!( - vi::UntypedVarInfo, vns::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} -) +function _link!(vi::UntypedVarInfo, vns::VarNameCollection) # TODO: Change to a lazy iterator over `vns` if ~istrans(vi, vns[1]) for vn in vns @@ -1209,35 +1203,8 @@ function filter_subsumed(vns1, vns2) return filter(x -> any(subsumes(y, x) for y in vns1), vns2) end -function _link!(vi::TypedVarInfo, vns::VarNameCollection) - return _link!(vi.metadata, vi, vns) -end -@generated function _link!( - ::NamedTuple{names}, vi, vns::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} -) where {names} - expr = Expr(:block) - for f in names - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(vns, f_vns) - if !isempty(f_vns) - if !istrans(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - f = internal_to_linked_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, true, vn) - end - else - @warn("[DynamicPPL] attempt to link a linked vi") - end - end - end, - ) - end - return expr +function _link!(vi::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple}) + return _link!(vi.metadata, vi, varname_namedtuple(vns)) end @generated function _link!( @@ -1271,11 +1238,12 @@ end return expr end -# Specialise invlink!! without varnames provided for TypedVarInfo. The usual version gets -# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, that helps -# keep the downstream calls to link!! type stable. -function invlink!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model) - return invlink!!(t, vi, all_varnames_namedtuple(vi), model) +# Specialise invlink!! without varnames provided for TypedVarInfo. The generic version gets +# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which +# helps keep the downstream calls to link!! type stable. +function invlink!!(::DynamicTransformation, vi::TypedVarInfo, ::Model) + _invlink!(vi, all_varnames_namedtuple(vi)) + return vi end # R -> X for all variables associated with given sampler @@ -1308,9 +1276,7 @@ function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, mode return maybe_invlink_before_eval!!(t, vi, context, model) end -function _invlink!( - vi::UntypedVarInfo, vns::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} -) +function _invlink!(vi::UntypedVarInfo, vns::VarNameCollection) if istrans(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) @@ -1322,7 +1288,7 @@ function _invlink!( end end -function _invlink!(vi::TypedVarInfo, vns::VarNameCollection) +function _invlink!(vi::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple}) vns_namedtuple = varname_namedtuple(vns) return _invlink!(vi.metadata, vi, vns_namedtuple) end @@ -1400,6 +1366,13 @@ function link( return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model) end +# Specialise link without varnames provided for TypedVarInfo. The generic version gets +# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which +# helps keep the downstream calls to link!! type stable. +function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) + return _link(model, vi, all_varnames_namedtuple(vi)) +end + function _link( model::Model, varinfo::Union{UntypedVarInfo,VectorVarInfo}, vns::VarNameCollection ) @@ -1426,7 +1399,9 @@ all_varnames_namedtuple(vi::TypedVarInfo) = all_varnames_namedtuple(vi.metadata) return expr end -function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameCollection) +function _link( + model::Model, varinfo::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple} +) varinfo = deepcopy(varinfo) vns_namedtuple = varname_namedtuple(vns) md = _link_metadata!(model, varinfo, varinfo.metadata, vns_namedtuple) @@ -1450,6 +1425,7 @@ end return :(NamedTuple{$metadata_names}($vals)) end + function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns @@ -1527,6 +1503,7 @@ function invlink( ) return _invlink(model, varinfo, vns) end + function invlink( ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, @@ -1538,6 +1515,13 @@ function invlink( return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, vns, model) end +# Specialise invlink without varnames provided for TypedVarInfo. The generic version gets +# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which +# helps keep the downstream calls to link!! type stable. +function invlink(::DynamicTransformation, vi::TypedVarInfo, model::Model) + return _invlink(model, vi, all_varnames_namedtuple(vi)) +end + function _invlink(model::Model, varinfo::VarInfo, vns::VarNameCollection) varinfo = deepcopy(varinfo) return VarInfo( @@ -1547,7 +1531,9 @@ function _invlink(model::Model, varinfo::VarInfo, vns::VarNameCollection) ) end -function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameCollection) +function _invlink( + model::Model, varinfo::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple} +) varinfo = deepcopy(varinfo) vns_namedtuple = varname_namedtuple(vns) md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns_namedtuple) From 98915c2d5751f45287cfaa0bc1620f3098f6bf78 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 13:35:00 +0000 Subject: [PATCH 13/29] Improve a docstring --- src/utils.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 854ead3fd..16aa38e4a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1277,9 +1277,10 @@ _merge(left::NamedTuple{()}, right::AbstractDict) = right Return the unique symbols of the variables in `vns`. -Note that `unique_syms` is only defined for `Tuple`s of `VarName`s. For a `Vector` you can -just use `Base.unique`. The point of `unique_syms` is that it supports constant propagating -the result, which is possible with a `Tuple` but `Base.unique` won't allow it. +Note that `unique_syms` is only defined for `Tuple`s of `VarName`s and, unlike +`Base.unique`, returns a `Tuple`. For an `AbstractVector{<:VarName}` you can use +`Base.unique`. The point of `unique_syms` is that it supports constant propagating +the result, which is possible only when the argument and the return value are `Tuple`s. """ @generated function unique_syms(::T) where {T<:NTuple{N,VarName}} where {N} retval = Expr(:tuple) From f05068daba935ec974fbbe2b1418940b95b0ca20 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 13:48:24 +0000 Subject: [PATCH 14/29] Style improvements --- src/varinfo.jl | 53 +++++++++++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 97fb733e2..3f9d817b7 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -900,6 +900,21 @@ end return :($(exprs...),) end +""" + all_varnames_namedtuple(vi::TypedVarInfo) + +Return a `NamedTuple` of the variables in `vi` grouped by symbol. +""" +all_varnames_namedtuple(vi::TypedVarInfo) = all_varnames_namedtuple(vi.metadata) + +@generated function all_varnames_namedtuple(md::NamedTuple{names}) where {names} + expr = Expr(:tuple) + for f in names + push!(expr.args, :($f = keys(md.$f))) + end + return expr +end + # Get the index (in vals) ranges of all the vns of variables belonging to spl @inline function _getranges(vi::VarInfo, spl::Sampler) ## Uncomment the spl.info stuff when it is concretely typed, not Dict{Symbol, Any} @@ -1194,17 +1209,17 @@ function _link!(vi::UntypedVarInfo, vns::VarNameCollection) end end -""" - filter_subsumed(vns1, vns2) +function _link!(vi::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple}) + return _link!(vi.metadata, vi, varname_namedtuple(vns)) +end -Return the subset of `vns2` that are subsumed by any variable in `vns1`. """ -function filter_subsumed(vns1, vns2) - return filter(x -> any(subsumes(y, x) for y in vns1), vns2) -end + filter_subsumed(filter_vns, filtered_vns) -function _link!(vi::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple}) - return _link!(vi.metadata, vi, varname_namedtuple(vns)) +Return the subset of `filtered_vns` that are subsumed by any variable in `filter_vns`. +""" +function filter_subsumed(filter_vns, filtered_vns) + return filter(x -> any(subsumes(y, x) for y in filter_vns), filtered_vns) end @generated function _link!( @@ -1240,7 +1255,7 @@ end # Specialise invlink!! without varnames provided for TypedVarInfo. The generic version gets # the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream calls to link!! type stable. +# helps keep the downstream call to _invlink! type stable. function invlink!!(::DynamicTransformation, vi::TypedVarInfo, ::Model) _invlink!(vi, all_varnames_namedtuple(vi)) return vi @@ -1292,6 +1307,7 @@ function _invlink!(vi::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple}) vns_namedtuple = varname_namedtuple(vns) return _invlink!(vi.metadata, vi, vns_namedtuple) end + @generated function _invlink!( ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} ) where {metadata_names,vns_names} @@ -1368,7 +1384,7 @@ end # Specialise link without varnames provided for TypedVarInfo. The generic version gets # the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream calls to link!! type stable. +# helps keep the downstream call to _link type stable. function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) return _link(model, vi, all_varnames_namedtuple(vi)) end @@ -1384,21 +1400,6 @@ function _link( ) end -""" - all_varnames_namedtuple(vi::TypedVarInfo) - -Return a `NamedTuple` of the variables in `vi` grouped by symbol. -""" -all_varnames_namedtuple(vi::TypedVarInfo) = all_varnames_namedtuple(vi.metadata) - -@generated function all_varnames_namedtuple(md::NamedTuple{names}) where {names} - expr = Expr(:tuple) - for f in names - push!(expr.args, :($f = keys(md.$f))) - end - return expr -end - function _link( model::Model, varinfo::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple} ) @@ -1517,7 +1518,7 @@ end # Specialise invlink without varnames provided for TypedVarInfo. The generic version gets # the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream calls to link!! type stable. +# helps keep the downstream call to _invlink type stable. function invlink(::DynamicTransformation, vi::TypedVarInfo, model::Model) return _invlink(model, vi, all_varnames_namedtuple(vi)) end From bc4c42093dafe01c0d7ed3984232e471a1bdcb65 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 14:52:15 +0000 Subject: [PATCH 15/29] Fix broken link/invlink dispatch cascade for VectorVarInfo --- src/varinfo.jl | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 3f9d817b7..f03926051 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1170,14 +1170,18 @@ end # Specialise link!! without varnames provided for TypedVarInfo. The generic version gets # the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream call to _link! type stable. -function link!!(::DynamicTransformation, vi::TypedVarInfo, ::Model) - _link!(vi, all_varnames_namedtuple(vi)) - return vi +# helps keep the downstream call to link!! type stable. +function link!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model) + return link!!(t, vi, all_varnames_namedtuple(vi), model) end # X -> R for all variables associated with given sampler -function link!!(t::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model) +function link!!( + t::DynamicTransformation, + vi::VarInfo, + vns::Union{VarNameCollection,NamedTuple}, + model::Model, +) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return link(t, vi, vns, model) # Call `_link!` instead of `link!` to avoid deprecation warning. @@ -1255,15 +1259,17 @@ end # Specialise invlink!! without varnames provided for TypedVarInfo. The generic version gets # the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream call to _invlink! type stable. -function invlink!!(::DynamicTransformation, vi::TypedVarInfo, ::Model) - _invlink!(vi, all_varnames_namedtuple(vi)) - return vi +# helps keep the downstream call to invlink!! type stable. +function invlink!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model) + return invlink!!(t, vi, all_varnames_namedtuple(vi), model) end # R -> X for all variables associated with given sampler function invlink!!( - t::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model + t::DynamicTransformation, + vi::VarInfo, + vns::Union{VarNameCollection,NamedTuple}, + model::Model, ) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return invlink(t, vi, vns, model) From 71980baf556c86a2a335a8376b075e726de30f78 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 15:36:05 +0000 Subject: [PATCH 16/29] Fix some more broken dispatch cascades --- src/varinfo.jl | 51 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index f03926051..8b835014d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1183,7 +1183,7 @@ function link!!( model::Model, ) # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return link(t, vi, vns, model) + has_varnamedvector(vi) && return _link(model, vi, vns) # Call `_link!` instead of `link!` to avoid deprecation warning. _link!(vi, vns) return vi @@ -1213,8 +1213,14 @@ function _link!(vi::UntypedVarInfo, vns::VarNameCollection) end end -function _link!(vi::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple}) - return _link!(vi.metadata, vi, varname_namedtuple(vns)) +# If we try to _link! a TypedVarInfo with a Tuple or Vector of VarNames, first convert +# it to a NamedTuple that matches the structure of the TypedVarInfo. +function _link!(vi::TypedVarInfo, vns::VarNameCollection) + return _link!(vi, varname_namedtuple(vns)) +end + +function _link!(vi::TypedVarInfo, vns::NamedTuple) + return _link!(vi.metadata, vi, vns) end """ @@ -1272,7 +1278,7 @@ function invlink!!( model::Model, ) # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return invlink(t, vi, vns, model) + has_varnamedvector(vi) && return _invlink(model, vi, vns) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. _invlink!(vi, vns) return vi @@ -1309,9 +1315,14 @@ function _invlink!(vi::UntypedVarInfo, vns::VarNameCollection) end end -function _invlink!(vi::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple}) - vns_namedtuple = varname_namedtuple(vns) - return _invlink!(vi.metadata, vi, vns_namedtuple) +# If we try to _invlink! a TypedVarInfo with a Tuple or Vector of VarNames, first convert +# it to a NamedTuple that matches the structure of the TypedVarInfo. +function _invlink!(vi::TypedVarInfo, vns::VarNameCollection) + return _invlink!(vi.metadata, vi, varname_namedtuple(vns)) +end + +function _invlink!(vi::TypedVarInfo, vns::NamedTuple) + return _invlink!(vi.metadata, vi, vns) end @generated function _invlink!( @@ -1406,12 +1417,15 @@ function _link( ) end -function _link( - model::Model, varinfo::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple} -) +# If we try to _invlink! a TypedVarInfo with a Tuple or Vector of VarNames, first convert +# it to a NamedTuple that matches the structure of the TypedVarInfo. +function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameCollection) + return _link(model, varinfo, varname_namedtuple(vns)) +end + +function _link(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - vns_namedtuple = varname_namedtuple(vns) - md = _link_metadata!(model, varinfo, varinfo.metadata, vns_namedtuple) + md = _link_metadata!(model, varinfo, varinfo.metadata, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end @@ -1538,12 +1552,15 @@ function _invlink(model::Model, varinfo::VarInfo, vns::VarNameCollection) ) end -function _invlink( - model::Model, varinfo::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple} -) +# If we try to _invlink a TypedVarInfo with a Tuple or Vector of VarNames, first convert +# it to a NamedTuple that matches the structure of the TypedVarInfo. +function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameCollection) + return _invlink(model, varinfo, varname_namedtuple(vns)) +end + +function _invlink(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - vns_namedtuple = varname_namedtuple(vns) - md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns_namedtuple) + md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end From 45562a9cacca75439cb422b34e8bc7f011d02090 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 24 Jan 2025 14:34:26 +0000 Subject: [PATCH 17/29] Apply suggestions from code review Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> --- src/abstract_varinfo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index b4aed0458..a215bbd14 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -561,7 +561,7 @@ end function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) vns = collect(keys(vi)) # In case e.g. vns = Any[]. - # TODO(mhauru) Could we rather fix akeys` so that it would always return VarName[]? + # TODO(mhauru) Could we rather fix `keys` so that it would always return VarName[]? if !(eltype(vns) <: VarName) vns = collect(VarName, vns) end @@ -656,7 +656,7 @@ function link!!( # loudly. all_vns = Set(keys(vi)) if Set(vns) != all_vns - msg = "StaticTransforming only a subset of variables is not supported." + msg = "Statically transforming only a subset of variables is not supported." throw(ArgumentError(msg)) end b = inverse(t.bijector) From db5b8357316f6004e2e10a67ee11d2638e4cdfec Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 24 Jan 2025 14:36:30 +0000 Subject: [PATCH 18/29] Remove comments that messed with docstrings --- src/abstract_varinfo.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index a215bbd14..891218fb6 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -550,7 +550,6 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`invlink!!`](@ref). """ -# Use `default_transformation` to decide which transformation to use if none is specified. function link!!(vi::AbstractVarInfo, model::Model) return link!!(default_transformation(model, vi), vi, model) end @@ -584,7 +583,6 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`invlink`](@ref). """ -# Use `default_transformation` to decide which transformation to use if none is specified. function link(vi::AbstractVarInfo, model::Model) return link(default_transformation(model, vi), vi, model) end @@ -619,7 +617,6 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`link!!`](@ref). """ -# Use `default_transformation` to decide which transformation to use if none is specified. function invlink!!(vi::AbstractVarInfo, model::Model) return invlink!!(default_transformation(model, vi), vi, model) end @@ -702,7 +699,6 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`link`](@ref). """ -# Use `default_transformation` to decide which transformation to use if none is specified. function invlink(vi::AbstractVarInfo, model::Model) return invlink(default_transformation(model, vi), vi, model) end From f99effe14ed5189e8b552984ff29b6cf5e56c6b6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Jan 2025 12:32:51 +0000 Subject: [PATCH 19/29] Apply suggestions from code review Co-authored-by: Penelope Yong --- src/abstract_varinfo.jl | 4 ++-- src/utils.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 891218fb6..c8a2ff17b 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -593,7 +593,7 @@ end function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) vns = collect(keys(vi)) # In case e.g. vns = Any[]. - # TODO(mhauru) Could we rather fix akeys` so that it would always return VarName[]? + # TODO(mhauru) Could we rather fix `keys` so that it would always return VarName[]? if !(eltype(vns) <: VarName) vns = collect(VarName, vns) end @@ -674,7 +674,7 @@ function invlink!!( # TODO(mhauru) See comment in link!! above. all_vns = Set(keys(vi)) if Set(vns) != all_vns - msg = "StaticTransforming only a subset of variables is not supported." + msg = "Statically transforming only a subset of variables is not supported." throw(ArgumentError(msg)) end b = t.bijector diff --git a/src/utils.jl b/src/utils.jl index 16aa38e4a..307bf1f85 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1298,7 +1298,7 @@ end Return a `NamedTuple` of the variables in `vns` grouped by symbol. -`varname_namedtuple` is type table for inputs that are `Tuple`s, and for vectors when all +`varname_namedtuple` is type stable for inputs that are `Tuple`s, and for vectors when all `VarName`s in the vector have the same symbol. For a `NamedTuple` it's a no-op. Example: From 56194cd000636bdde0710011f250795430971667 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Jan 2025 13:09:01 +0000 Subject: [PATCH 20/29] Fix issues surfaced in code review --- docs/src/api.md | 2 +- src/DynamicPPL.jl | 2 +- src/abstract_varinfo.jl | 2 ++ src/threadsafe.jl | 4 ++-- src/transforming.jl | 2 -- src/utils.jl | 2 -- src/varinfo.jl | 1 - test/varinfo.jl | 11 +++++++++++ 8 files changed, 17 insertions(+), 9 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 093cb06a6..36dd24250 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -304,7 +304,7 @@ set_num_produce! increment_num_produce! reset_num_produce! setorder! -set_retained_vns_del_by_spl! +set_retained_vns_del! ``` ```@docs diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c1cdbd94e..55e1f7e88 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -59,7 +59,7 @@ export AbstractVarInfo, set_num_produce!, reset_num_produce!, increment_num_produce!, - set_retained_vns_del_by_spl!, + set_retained_vns_del!, is_flagged, set_flag!, unset_flag!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index c8a2ff17b..c59e2990c 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -561,6 +561,7 @@ function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) vns = collect(keys(vi)) # In case e.g. vns = Any[]. # TODO(mhauru) Could we rather fix `keys` so that it would always return VarName[]? + # See https://github.com/TuringLang/DynamicPPL.jl/issues/791. if !(eltype(vns) <: VarName) vns = collect(VarName, vns) end @@ -594,6 +595,7 @@ function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) vns = collect(keys(vi)) # In case e.g. vns = Any[]. # TODO(mhauru) Could we rather fix `keys` so that it would always return VarName[]? + # See https://github.com/TuringLang/DynamicPPL.jl/issues/791. if !(eltype(vns) <: VarName) vns = collect(VarName, vns) end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 25aa0d654..fae0c1613 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -182,8 +182,8 @@ function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) return vector_getranges(vi.varinfo, vns) end -function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) - return set_retained_vns_del_by_spl!(vi.varinfo, spl) +function set_retained_vns_del!(vi::ThreadSafeVarInfo, spl::Sampler) + return set_retained_vns_del!(vi.varinfo, spl) end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) diff --git a/src/transforming.jl b/src/transforming.jl index 46b42d8ed..f3f4fbba0 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -91,8 +91,6 @@ function dot_tilde_assume( return r, lp, vi end -VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName},NamedTuple} - function link!!( t::DynamicTransformation, vi::AbstractVarInfo, ::VarNameCollection, model::Model ) diff --git a/src/utils.jl b/src/utils.jl index 307bf1f85..0bf9d6d3d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1331,5 +1331,3 @@ end function varname_namedtuple(vns::AbstractVector{<:VarName{T}}) where {T} return NamedTuple{(T,)}((vns,)) end - -varname_namedtuple(vns::NamedTuple) = vns diff --git a/src/varinfo.jl b/src/varinfo.jl index 8b835014d..c49a6ffc3 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -2073,7 +2073,6 @@ function set_retained_vns_del!(vi::UntypedVarInfo) return nothing end function set_retained_vns_del!(vi::TypedVarInfo) - # Get the indices of `vns` that belong to `spl` as a NamedTuple, one entry for each symbol idcs = _getidcs(vi) return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi)) end diff --git a/test/varinfo.jl b/test/varinfo.jl index fd1c9a2e9..99d319425 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -825,6 +825,17 @@ end end end + # The below used to error, testing to avoid regression. + @testset "merge different dimensions" begin + vn = @varname(x) + vi_single = VarInfo() + vi_single = push!!(vi_single, vn, 1.0, Normal()) + vi_double = VarInfo() + vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0)) + @test merge(vi_single, vi_double)[vn] == [0.5, 0.6] + @test merge(vi_double, vi_single)[vn] == 1.0 + end + @testset "sampling from linked varinfo" begin # `~` @model function demo(n=1) From c187c49152a619ec663961c90102509fbd8482ae Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Jan 2025 17:16:11 +0000 Subject: [PATCH 21/29] Simplify link/invlink arguments --- src/abstract_varinfo.jl | 146 ++++++++++------------------------------ src/simple_varinfo.jl | 6 +- src/threadsafe.jl | 44 ++++-------- src/transforming.jl | 20 ++---- src/utils.jl | 2 +- src/varinfo.jl | 132 ++++++++++++++++++++++-------------- test/varinfo.jl | 4 +- 7 files changed, 145 insertions(+), 209 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index c59e2990c..c7afc67a5 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -537,127 +537,77 @@ If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variabl """ function settrans!! end +# For link!!, invlink!!, link, and invlink, we deliberately do not provide a fallback +# method for the case when no `vns` is provided, that would get all the keys from the +# `VarInfo`. Hence each subtype of `AbstractVarInfo` needs to implement separately the case +# where `vns` is provided and the one where it is not. This is because having separate +# implementations is typically much more performant, and because not all AbstractVarInfo +# types support partial linking. + """ link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N,VarName}, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model) + +Transform variables in `vi` to their linked space, mutating `vi` if possible. -Transform the variables in `vi` to their linked space, using the transformation `t`, -mutating `vi` if possible. +Either transform all variables, or only ones specified in `vns`. -If `t` is not provided, `default_transformation(model, vi)` will be used. +Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided. See also: [`default_transformation`](@ref), [`invlink!!`](@ref). """ function link!!(vi::AbstractVarInfo, model::Model) return link!!(default_transformation(model, vi), vi, model) end -function link!!(vi::AbstractVarInfo, vns, model::Model) +function link!!(vi::AbstractVarInfo, vns::VarNameCollection, model::Model) return link!!(default_transformation(model, vi), vi, vns, model) end -# If no variable names are provided, link all variables. -function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - vns = collect(keys(vi)) - # In case e.g. vns = Any[]. - # TODO(mhauru) Could we rather fix `keys` so that it would always return VarName[]? - # See https://github.com/TuringLang/DynamicPPL.jl/issues/791. - if !(eltype(vns) <: VarName) - vns = collect(VarName, vns) - end - return link!!(t, vi, vns, model) -end -# Wrap a single VarName in a singleton tuple. -function link!!(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) - return link!!(t, vi, (vn,), model) -end """ link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N,VarName}, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model) + +Transform variables in `vi` to their linked space without mutating `vi`. -Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`. +Either transform all variables, or only ones specified in `vns`. -If `t` is not provided, `default_transformation(model, vi)` will be used. +Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided. See also: [`default_transformation`](@ref), [`invlink`](@ref). """ function link(vi::AbstractVarInfo, model::Model) return link(default_transformation(model, vi), vi, model) end -function link(vi::AbstractVarInfo, vns, model::Model) +function link(vi::AbstractVarInfo, vns::VarNameCollection, model::Model) return link(default_transformation(model, vi), vi, vns, model) end -# If no variable names are provided, link all variables. -function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - vns = collect(keys(vi)) - # In case e.g. vns = Any[]. - # TODO(mhauru) Could we rather fix `keys` so that it would always return VarName[]? - # See https://github.com/TuringLang/DynamicPPL.jl/issues/791. - if !(eltype(vns) <: VarName) - vns = collect(VarName, vns) - end - return link(t, vi, vns, model) -end -# Wrap a single VarName in a singleton tuple. -function link(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) - return link(t, vi, (vn,), model) -end """ invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N,VarName}, model::Model) - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model) + +Transform variables in `vi` to their constrained space, mutating `vi` if possible. -Transform the variables in `vi` to their constrained space, using the (inverse of) -transformation `t`, mutating `vi` if possible. +Either transform all variables, or only ones specified in `vns`. -If `t` is not provided, `default_transformation(model, vi)` will be used. +Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is +not provided. See also: [`default_transformation`](@ref), [`link!!`](@ref). """ function invlink!!(vi::AbstractVarInfo, model::Model) return invlink!!(default_transformation(model, vi), vi, model) end -function invlink!!(vi::AbstractVarInfo, vns, model::Model) +function invlink!!(vi::AbstractVarInfo, vns::VarNameCollection, model::Model) return invlink!!(default_transformation(model, vi), vi, vns, model) end -# If no variable names are provided, invlink!! all variables. -function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - vns = collect(keys(vi)) - # In case e.g. vns = Any[]. - if !(eltype(vns) <: VarName) - vns = collect(VarName, vns) - end - return invlink!!(t, vi, vns, model) -end -# Wrap a single VarName in a singleton tuple. -function invlink!!( - t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model -) - return invlink!!(t, vi, (vn,), model) -end # Vector-based ones. function link!!( - t::StaticTransformation{<:Bijectors.Transform}, - vi::AbstractVarInfo, - vns::VarNameCollection, - ::Model, + t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model ) - # TODO(mhauru) The behavior of this before the removal of indexing with samplers was a - # bit mixed. For TypedVarInfo you could transform only a subset of the variables, but - # for UntypedVarInfo and SimpleVarInfo it was silently assumed that all variables were - # being set. Unsure if we should support this or not, but at least it now errors - # loudly. - all_vns = Set(keys(vi)) - if Set(vns) != all_vns - msg = "Statically transforming only a subset of variables is not supported." - throw(ArgumentError(msg)) - end b = inverse(t.bijector) x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) @@ -668,17 +618,8 @@ function link!!( end function invlink!!( - t::StaticTransformation{<:Bijectors.Transform}, - vi::AbstractVarInfo, - vns::VarNameCollection, - ::Model, + t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model ) - # TODO(mhauru) See comment in link!! above. - all_vns = Set(keys(vi)) - if Set(vns) != all_vns - msg = "Statically transforming only a subset of variables is not supported." - throw(ArgumentError(msg)) - end b = t.bijector y = vi[:] x, logjac = with_logabsdet_jacobian(b, y) @@ -690,36 +631,23 @@ end """ invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) - invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N,VarName}, model::Model) - invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) + invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model) + +Transform variables in `vi` to their constrained space without mutating `vi`. -Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of) -transformation `t`. +Either transform all variables, or only ones specified in `vns`. -If `t` is not provided, `default_transformation(model, vi)` will be used. +Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is +not provided. See also: [`default_transformation`](@ref), [`link`](@ref). """ function invlink(vi::AbstractVarInfo, model::Model) return invlink(default_transformation(model, vi), vi, model) end -function invlink(vi::AbstractVarInfo, vns, model::Model) +function invlink(vi::AbstractVarInfo, vns::VarNameCollection, model::Model) return invlink(default_transformation(model, vi), vi, vns, model) end -# If no variable names are provided, invlink all variables. -function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - vns = collect(keys(vi)) - # In case e.g. vns = Any[]. - if !(eltype(vns) <: VarName) - vns = collect(VarName, vns) - end - return invlink(t, vi, vns, model) -end -# Wrap a single VarName in a singleton tuple. -function invlink(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) - return invlink(t, vi, (vn,), model) -end """ maybe_invlink_before_eval!!([t::Transformation,] vi, context, model) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index f60c0b0fb..57b167077 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -680,8 +680,7 @@ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarIn function link!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - ::VarNameCollection, - model::Model, + ::Model, ) # TODO: Make sure that `spl` is respected. b = inverse(t.bijector) @@ -695,8 +694,7 @@ end function invlink!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - ::VarNameCollection, - model::Model, + ::Model, ) # TODO: Make sure that `spl` is respected. b = t.bijector diff --git a/src/threadsafe.jl b/src/threadsafe.jl index fae0c1613..bf4817fbd 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -81,59 +81,43 @@ haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) -function link!!( - t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model -) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, vns, model) +function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...) end -function invlink!!( - t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model -) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, vns, model) +function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...) end -function link( - t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model -) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, vns, model) +function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = link(t, vi.varinfo, args...) end -function invlink( - t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model -) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, vns, model) +function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, args...) end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. # NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure # consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates # to define `getlogp(vi)`. -function link!!( - t::DynamicTransformation, vi::ThreadSafeVarInfo, ::VarNameCollection, model::Model -) +function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end -function invlink!!( - ::DynamicTransformation, vi::ThreadSafeVarInfo, ::VarNameCollection, model::Model -) +function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) return settrans!!( last(evaluate!!(model, vi, DynamicTransformationContext{true}())), NoTransformation(), ) end -function link( - t::DynamicTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model -) - return link!!(t, deepcopy(vi), vns, model) +function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) + return link!!(t, deepcopy(vi), model) end -function invlink( - t::DynamicTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model -) - return invlink!!(t, deepcopy(vi), vns, model) +function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) + return invlink!!(t, deepcopy(vi), model) end function maybe_invlink_before_eval!!( diff --git a/src/transforming.jl b/src/transforming.jl index f3f4fbba0..1a26d212f 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -91,29 +91,21 @@ function dot_tilde_assume( return r, lp, vi end -function link!!( - t::DynamicTransformation, vi::AbstractVarInfo, ::VarNameCollection, model::Model -) +function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end -function invlink!!( - ::DynamicTransformation, vi::AbstractVarInfo, ::VarNameCollection, model::Model -) +function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) return settrans!!( last(evaluate!!(model, vi, DynamicTransformationContext{true}())), NoTransformation(), ) end -function link( - t::DynamicTransformation, vi::AbstractVarInfo, vns::VarNameCollection, model::Model -) - return link!!(t, deepcopy(vi), vns, model) +function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + return link!!(t, deepcopy(vi), model) end -function invlink( - t::DynamicTransformation, vi::AbstractVarInfo, vns::VarNameCollection, model::Model -) - return invlink!!(t, deepcopy(vi), vns, model) +function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + return invlink!!(t, deepcopy(vi), model) end diff --git a/src/utils.jl b/src/utils.jl index 0bf9d6d3d..265fa773b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -3,7 +3,7 @@ struct NoDefault end const NO_DEFAULT = NoDefault() # A short-hand for a type commonly used in type signatures for VarInfo methods. -VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} +VarNameCollection = NTuple{N,VarName} where {N} """ @addlogprob!(ex) diff --git a/src/varinfo.jl b/src/varinfo.jl index c49a6ffc3..cdf67b019 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1168,20 +1168,30 @@ _isempty(vnv::VarNamedVector) = isempty(vnv) return Expr(:&&, (:(_isempty(metadata.$f)) for f in names)...) end -# Specialise link!! without varnames provided for TypedVarInfo. The generic version gets -# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream call to link!! type stable. -function link!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model) - return link!!(t, vi, all_varnames_namedtuple(vi), model) +function link!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) + vns = all_varnames_namedtuple(vi) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return _link(model, vi, vns) + _link!(vi, vns) + return vi +end + +function link!!(::DynamicTransformation, vi::VarInfo, model::Model) + vns = keys(vi) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return _link(model, vi, vns) + _link!(vi, vns) + return vi +end + +function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) end # X -> R for all variables associated with given sampler -function link!!( - t::DynamicTransformation, - vi::VarInfo, - vns::Union{VarNameCollection,NamedTuple}, - model::Model, -) +function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) # Call `_link!` instead of `link!` to avoid deprecation warning. @@ -1195,12 +1205,12 @@ function link!!( vns::VarNameCollection, model::Model, ) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end -function _link!(vi::UntypedVarInfo, vns::VarNameCollection) +function _link!(vi::UntypedVarInfo, vns) # TODO: Change to a lazy iterator over `vns` if ~istrans(vi, vns[1]) for vn in vns @@ -1213,7 +1223,7 @@ function _link!(vi::UntypedVarInfo, vns::VarNameCollection) end end -# If we try to _link! a TypedVarInfo with a Tuple or Vector of VarNames, first convert +# If we try to _link! a TypedVarInfo with a Tuple of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. function _link!(vi::TypedVarInfo, vns::VarNameCollection) return _link!(vi, varname_namedtuple(vns)) @@ -1263,19 +1273,32 @@ end return expr end -# Specialise invlink!! without varnames provided for TypedVarInfo. The generic version gets -# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream call to invlink!! type stable. -function invlink!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model) - return invlink!!(t, vi, all_varnames_namedtuple(vi), model) +function invlink!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) + vns = all_varnames_namedtuple(vi) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return _invlink(model, vi, vns) + # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. + _invlink!(vi, vns) + return vi +end + +function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) + vns = keys(vi) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return _invlink(model, vi, vns) + _invlink!(vi, vns) + return vi +end + +function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) end # R -> X for all variables associated with given sampler function invlink!!( - t::DynamicTransformation, - vi::VarInfo, - vns::Union{VarNameCollection,NamedTuple}, - model::Model, + ::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model ) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) @@ -1303,7 +1326,7 @@ function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, mode return maybe_invlink_before_eval!!(t, vi, context, model) end -function _invlink!(vi::UntypedVarInfo, vns::VarNameCollection) +function _invlink!(vi::UntypedVarInfo, vns) if istrans(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) @@ -1315,7 +1338,7 @@ function _invlink!(vi::UntypedVarInfo, vns::VarNameCollection) end end -# If we try to _invlink! a TypedVarInfo with a Tuple or Vector of VarNames, first convert +# If we try to _invlink! a TypedVarInfo with a Tuple of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. function _invlink!(vi::TypedVarInfo, vns::VarNameCollection) return _invlink!(vi.metadata, vi, varname_namedtuple(vns)) @@ -1382,6 +1405,20 @@ function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) return map(Returns(nothing), varinfo.metadata) end +function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) + return _link(model, vi, all_varnames_namedtuple(vi)) +end + +function link(::DynamicTransformation, varinfo::VarInfo, model::Model) + return _link(model, varinfo, keys(varinfo)) +end + +function link(::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, model) +end + function link( ::DynamicTransformation, varinfo::VarInfo, vns::VarNameCollection, model::Model ) @@ -1399,22 +1436,10 @@ function link( return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model) end -# Specialise link without varnames provided for TypedVarInfo. The generic version gets -# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream call to _link type stable. -function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) - return _link(model, vi, all_varnames_namedtuple(vi)) -end - -function _link( - model::Model, varinfo::Union{UntypedVarInfo,VectorVarInfo}, vns::VarNameCollection -) +function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - return VarInfo( - _link_metadata!!(model, varinfo, varinfo.metadata, vns), - Base.Ref(getlogp(varinfo)), - Ref(get_num_produce(varinfo)), - ) + md = _link_metadata!!(model, varinfo, varinfo.metadata, vns) + return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end # If we try to _invlink! a TypedVarInfo with a Tuple or Vector of VarNames, first convert @@ -1519,6 +1544,22 @@ function _link_metadata!!( return metadata end +function invlink(::DynamicTransformation, vi::TypedVarInfo, model::Model) + return _invlink(model, vi, all_varnames_namedtuple(vi)) +end + +function invlink(::DynamicTransformation, vi::VarInfo, model::Model) + return _invlink(model, vi, keys(vi)) +end + +function invlink( + ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model +) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, model) +end + function invlink( ::DynamicTransformation, varinfo::VarInfo, vns::VarNameCollection, model::Model ) @@ -1536,14 +1577,7 @@ function invlink( return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, vns, model) end -# Specialise invlink without varnames provided for TypedVarInfo. The generic version gets -# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream call to _invlink type stable. -function invlink(::DynamicTransformation, vi::TypedVarInfo, model::Model) - return _invlink(model, vi, all_varnames_namedtuple(vi)) -end - -function _invlink(model::Model, varinfo::VarInfo, vns::VarNameCollection) +function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) return VarInfo( _invlink_metadata!!(model, varinfo, varinfo.metadata, vns), diff --git a/test/varinfo.jl b/test/varinfo.jl index 99d319425..d689a1bf4 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -437,10 +437,10 @@ end other_vns = filter(x -> !subsumes(vn, x), all_vns) @test !isempty(target_vns) @test !isempty(other_vns) - vi = link!!(vi, vn, model) + vi = link!!(vi, (vn,), model) @test all(x -> istrans(vi, x), target_vns) @test all(x -> !istrans(vi, x), other_vns) - vi = invlink!!(vi, vn, model) + vi = invlink!!(vi, (vn,), model) @test all(x -> !istrans(vi, x), all_vns) @test meta.s.vals ≈ v_s atol = 1e-10 @test meta.m.vals ≈ v_m atol = 1e-10 From 86b25c5ae3ab2d7e4f15816b76d64c9debad9c7f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Jan 2025 17:16:27 +0000 Subject: [PATCH 22/29] Fix a bug in unflatten VarNamedVector --- src/varnamedvector.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index b324e9134..14ef6ce6a 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1066,7 +1066,12 @@ function unflatten(vnv::VarNamedVector, vals::AbstractVector) new_ranges = deepcopy(vnv.ranges) recontiguify_ranges!(new_ranges) return VarNamedVector( - vnv.varname_to_index, vnv.varnames, new_ranges, vals, vnv.transforms + vnv.varname_to_index, + vnv.varnames, + new_ranges, + vals, + vnv.transforms, + vnv.is_unconstrained, ) end From 2a6c1bcef4d14c38bb1ce5c07c868e325ba92aae Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Jan 2025 17:20:54 +0000 Subject: [PATCH 23/29] Rename VarNameCollection -> VarNameTuple --- src/abstract_varinfo.jl | 8 ++++---- src/utils.jl | 2 +- src/varinfo.jl | 30 ++++++++++++------------------ 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index c7afc67a5..087affd90 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -560,7 +560,7 @@ See also: [`default_transformation`](@ref), [`invlink!!`](@ref). function link!!(vi::AbstractVarInfo, model::Model) return link!!(default_transformation(model, vi), vi, model) end -function link!!(vi::AbstractVarInfo, vns::VarNameCollection, model::Model) +function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link!!(default_transformation(model, vi), vi, vns, model) end @@ -580,7 +580,7 @@ See also: [`default_transformation`](@ref), [`invlink`](@ref). function link(vi::AbstractVarInfo, model::Model) return link(default_transformation(model, vi), vi, model) end -function link(vi::AbstractVarInfo, vns::VarNameCollection, model::Model) +function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link(default_transformation(model, vi), vi, vns, model) end @@ -600,7 +600,7 @@ See also: [`default_transformation`](@ref), [`link!!`](@ref). function invlink!!(vi::AbstractVarInfo, model::Model) return invlink!!(default_transformation(model, vi), vi, model) end -function invlink!!(vi::AbstractVarInfo, vns::VarNameCollection, model::Model) +function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink!!(default_transformation(model, vi), vi, vns, model) end @@ -645,7 +645,7 @@ See also: [`default_transformation`](@ref), [`link`](@ref). function invlink(vi::AbstractVarInfo, model::Model) return invlink(default_transformation(model, vi), vi, model) end -function invlink(vi::AbstractVarInfo, vns::VarNameCollection, model::Model) +function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink(default_transformation(model, vi), vi, vns, model) end diff --git a/src/utils.jl b/src/utils.jl index 265fa773b..de7fe5925 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -3,7 +3,7 @@ struct NoDefault end const NO_DEFAULT = NoDefault() # A short-hand for a type commonly used in type signatures for VarInfo methods. -VarNameCollection = NTuple{N,VarName} where {N} +VarNameTuple = NTuple{N,VarName} where {N} """ @addlogprob!(ex) diff --git a/src/varinfo.jl b/src/varinfo.jl index cdf67b019..887f132c6 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1191,7 +1191,7 @@ function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, mode end # X -> R for all variables associated with given sampler -function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model) +function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) # Call `_link!` instead of `link!` to avoid deprecation warning. @@ -1202,7 +1202,7 @@ end function link!!( t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameCollection, + vns::VarNameTuple, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, @@ -1225,7 +1225,7 @@ end # If we try to _link! a TypedVarInfo with a Tuple of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. -function _link!(vi::TypedVarInfo, vns::VarNameCollection) +function _link!(vi::TypedVarInfo, vns::VarNameTuple) return _link!(vi, varname_namedtuple(vns)) end @@ -1297,9 +1297,7 @@ function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, m end # R -> X for all variables associated with given sampler -function invlink!!( - ::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model -) +function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. @@ -1310,7 +1308,7 @@ end function invlink!!( ::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameCollection, + vns::VarNameTuple, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so @@ -1340,7 +1338,7 @@ end # If we try to _invlink! a TypedVarInfo with a Tuple of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. -function _invlink!(vi::TypedVarInfo, vns::VarNameCollection) +function _invlink!(vi::TypedVarInfo, vns::VarNameTuple) return _invlink!(vi.metadata, vi, varname_namedtuple(vns)) end @@ -1419,16 +1417,14 @@ function link(::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, mo return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, model) end -function link( - ::DynamicTransformation, varinfo::VarInfo, vns::VarNameCollection, model::Model -) +function link(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) return _link(model, varinfo, vns) end function link( ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameCollection, + vns::VarNameTuple, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so @@ -1444,7 +1440,7 @@ end # If we try to _invlink! a TypedVarInfo with a Tuple or Vector of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. -function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameCollection) +function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) return _link(model, varinfo, varname_namedtuple(vns)) end @@ -1560,16 +1556,14 @@ function invlink( return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, model) end -function invlink( - ::DynamicTransformation, varinfo::VarInfo, vns::VarNameCollection, model::Model -) +function invlink(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) return _invlink(model, varinfo, vns) end function invlink( ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameCollection, + vns::VarNameTuple, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so @@ -1588,7 +1582,7 @@ end # If we try to _invlink a TypedVarInfo with a Tuple or Vector of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. -function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameCollection) +function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) return _invlink(model, varinfo, varname_namedtuple(vns)) end From 853f47e683428d02d1d87641a5bf2687d611f94c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Jan 2025 17:35:13 +0000 Subject: [PATCH 24/29] Remove test of a removed varname_namedtuple method --- test/utils.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index af7b3ee4d..cdb2af4f7 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -69,11 +69,9 @@ ) vns_vec_single_symbol = [@varname(x.a), @varname(x.b), @varname(x[1])] @inferred DynamicPPL.varname_namedtuple(vns_tuple) - @inferred DynamicPPL.varname_namedtuple(vns_nt) @inferred DynamicPPL.varname_namedtuple(vns_vec_single_symbol) @test DynamicPPL.varname_namedtuple(vns_tuple) == vns_nt @test DynamicPPL.varname_namedtuple(vns_vec) == vns_nt - @test DynamicPPL.varname_namedtuple(vns_nt) == vns_nt @test DynamicPPL.varname_namedtuple(vns_vec_single_symbol) == (; x=vns_vec_single_symbol) end From ed803281da7c4c851be2b86f5a1708ac183a714b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 29 Jan 2025 16:50:00 +0000 Subject: [PATCH 25/29] Apply suggestions from code review Co-authored-by: Penelope Yong --- src/abstract_varinfo.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 087affd90..26785a387 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -546,8 +546,7 @@ function settrans!! end """ link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) Transform variables in `vi` to their linked space, mutating `vi` if possible. @@ -566,8 +565,7 @@ end """ link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) Transform variables in `vi` to their linked space without mutating `vi`. @@ -586,7 +584,7 @@ end """ invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) Transform variables in `vi` to their constrained space, mutating `vi` if possible. @@ -631,7 +629,7 @@ end """ invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model) + invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) Transform variables in `vi` to their constrained space without mutating `vi`. From d996d0cb45e4d959ca4ac17a329f7bd11de28954 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 29 Jan 2025 17:15:55 +0000 Subject: [PATCH 26/29] Respond to review feedback --- src/utils.jl | 34 +++++++++------------------------- src/varinfo.jl | 23 ++++++++++++----------- test/utils.jl | 10 +++------- 3 files changed, 24 insertions(+), 43 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index de7fe5925..2539b7179 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1278,11 +1278,11 @@ _merge(left::NamedTuple{()}, right::AbstractDict) = right Return the unique symbols of the variables in `vns`. Note that `unique_syms` is only defined for `Tuple`s of `VarName`s and, unlike -`Base.unique`, returns a `Tuple`. For an `AbstractVector{<:VarName}` you can use -`Base.unique`. The point of `unique_syms` is that it supports constant propagating -the result, which is possible only when the argument and the return value are `Tuple`s. +`Base.unique`, returns a `Tuple`. The point of `unique_syms` is that it supports constant +propagating the result, which is possible only when the argument and the return value are +`Tuple`s. """ -@generated function unique_syms(::T) where {T<:NTuple{N,VarName}} where {N} +@generated function unique_syms(::T) where {T<:VarNameTuple} retval = Expr(:tuple) syms = [first(vn.parameters) for vn in T.parameters] for sym in unique(syms) @@ -1292,14 +1292,12 @@ the result, which is possible only when the argument and the return value are `T end """ - varname_namedtuple(vns::NTuple{N,VarName}) where {N} - varname_namedtuple(vns::AbstractVector{<:VarName}) - varname_namedtuple(vns::NamedTuple) + group_varnames_by_symbol(vns::NTuple{N,VarName}) where {N} Return a `NamedTuple` of the variables in `vns` grouped by symbol. -`varname_namedtuple` is type stable for inputs that are `Tuple`s, and for vectors when all -`VarName`s in the vector have the same symbol. For a `NamedTuple` it's a no-op. +Note that `group_varnames_by_symbol` only accepts a `Tuple` of `VarName`s. This allows it to +be type stable. Example: ```julia @@ -1309,25 +1307,11 @@ julia> vns_tuple = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), julia> vns_nt = (; x=[@varname(x), @varname(x.a)], y=[@varname(y[1]), @varname(y[2])], z=[@varname(z[15])]) (x = VarName{:x}[x, x.a], y = VarName{:y, IndexLens{Tuple{Int64}}}[y[1], y[2]], z = VarName{:z, IndexLens{Tuple{Int64}}}[z[15]]) -julia> varname_namedtuple(vns_tuple) == vns_nt +julia> group_varnames_by_symbol(vns_tuple) == vns_nt ``` """ -function varname_namedtuple(vns::NTuple{N,VarName} where {N}) +function group_varnames_by_symbol(vns::VarNameTuple) syms = unique_syms(vns) elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) return NamedTuple{syms}(elements) end - -# This method is type unstable, but that can't be helped: The problem is inherently type -# unstable if there are VarNames with multiple symbols in a Vector. -function varname_namedtuple(vns::AbstractVector{<:VarName}) - syms = tuple(unique(map(getsym, vns))...) - elements = tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...) - return NamedTuple{syms}(elements) -end - -# A simpler, type stable implementation when all the VarNames in a Vector have the same -# symbol. -function varname_namedtuple(vns::AbstractVector{<:VarName{T}}) where {T} - return NamedTuple{(T,)}((vns,)) -end diff --git a/src/varinfo.jl b/src/varinfo.jl index 887f132c6..8836962e5 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -901,13 +901,14 @@ end end """ - all_varnames_namedtuple(vi::TypedVarInfo) + all_varnames_grouped_by_symbol(vi::TypedVarInfo) Return a `NamedTuple` of the variables in `vi` grouped by symbol. """ -all_varnames_namedtuple(vi::TypedVarInfo) = all_varnames_namedtuple(vi.metadata) +all_varnames_grouped_by_symbol(vi::TypedVarInfo) = + all_varnames_grouped_by_symbol(vi.metadata) -@generated function all_varnames_namedtuple(md::NamedTuple{names}) where {names} +@generated function all_varnames_grouped_by_symbol(md::NamedTuple{names}) where {names} expr = Expr(:tuple) for f in names push!(expr.args, :($f = keys(md.$f))) @@ -1169,7 +1170,7 @@ _isempty(vnv::VarNamedVector) = isempty(vnv) end function link!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) - vns = all_varnames_namedtuple(vi) + vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) _link!(vi, vns) @@ -1226,7 +1227,7 @@ end # If we try to _link! a TypedVarInfo with a Tuple of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. function _link!(vi::TypedVarInfo, vns::VarNameTuple) - return _link!(vi, varname_namedtuple(vns)) + return _link!(vi, group_varnames_by_symbol(vns)) end function _link!(vi::TypedVarInfo, vns::NamedTuple) @@ -1274,7 +1275,7 @@ end end function invlink!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) - vns = all_varnames_namedtuple(vi) + vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. @@ -1339,7 +1340,7 @@ end # If we try to _invlink! a TypedVarInfo with a Tuple of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. function _invlink!(vi::TypedVarInfo, vns::VarNameTuple) - return _invlink!(vi.metadata, vi, varname_namedtuple(vns)) + return _invlink!(vi.metadata, vi, group_varnames_by_symbol(vns)) end function _invlink!(vi::TypedVarInfo, vns::NamedTuple) @@ -1404,7 +1405,7 @@ function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) end function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) - return _link(model, vi, all_varnames_namedtuple(vi)) + return _link(model, vi, all_varnames_grouped_by_symbol(vi)) end function link(::DynamicTransformation, varinfo::VarInfo, model::Model) @@ -1441,7 +1442,7 @@ end # If we try to _invlink! a TypedVarInfo with a Tuple or Vector of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) - return _link(model, varinfo, varname_namedtuple(vns)) + return _link(model, varinfo, group_varnames_by_symbol(vns)) end function _link(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) @@ -1541,7 +1542,7 @@ function _link_metadata!!( end function invlink(::DynamicTransformation, vi::TypedVarInfo, model::Model) - return _invlink(model, vi, all_varnames_namedtuple(vi)) + return _invlink(model, vi, all_varnames_grouped_by_symbol(vi)) end function invlink(::DynamicTransformation, vi::VarInfo, model::Model) @@ -1583,7 +1584,7 @@ end # If we try to _invlink a TypedVarInfo with a Tuple or Vector of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) - return _invlink(model, varinfo, varname_namedtuple(vns)) + return _invlink(model, varinfo, group_varnames_by_symbol(vns)) end function _invlink(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) diff --git a/test/utils.jl b/test/utils.jl index cdb2af4f7..d683f132d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -57,7 +57,7 @@ @test DynamicPPL.unique_syms(()) == () end - @testset "varname_namedtuple" begin + @testset "group_varnames_by_symbol" begin vns_tuple = ( @varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2]) ) @@ -68,11 +68,7 @@ z=[@varname(z[15])], ) vns_vec_single_symbol = [@varname(x.a), @varname(x.b), @varname(x[1])] - @inferred DynamicPPL.varname_namedtuple(vns_tuple) - @inferred DynamicPPL.varname_namedtuple(vns_vec_single_symbol) - @test DynamicPPL.varname_namedtuple(vns_tuple) == vns_nt - @test DynamicPPL.varname_namedtuple(vns_vec) == vns_nt - @test DynamicPPL.varname_namedtuple(vns_vec_single_symbol) == - (; x=vns_vec_single_symbol) + @inferred DynamicPPL.group_varnames_by_symbol(vns_tuple) + @test DynamicPPL.group_varnames_by_symbol(vns_tuple) == vns_nt end end From 20831485880a7503792f78794b3d4f9751f42c84 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 29 Jan 2025 17:19:05 +0000 Subject: [PATCH 27/29] Remove _default_sampler and a dead argument of maybe_invlink_before_eval --- src/abstract_varinfo.jl | 28 ++++++++-------------------- src/model.jl | 8 ++++---- src/threadsafe.jl | 13 +++++-------- src/varinfo.jl | 4 ++-- test/simple_varinfo.jl | 4 +--- 5 files changed, 20 insertions(+), 37 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 26785a387..26c4268d8 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -648,7 +648,7 @@ function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) end """ - maybe_invlink_before_eval!!([t::Transformation,] vi, context, model) + maybe_invlink_before_eval!!([t::Transformation,] vi, model) Return a possibly invlinked version of `vi`. @@ -699,37 +699,25 @@ julia> # Now performs a single `invlink!!` before model evaluation. -1001.4189385332047 ``` """ -function maybe_invlink_before_eval!!( - vi::AbstractVarInfo, context::AbstractContext, model::Model -) - return maybe_invlink_before_eval!!(transformation(vi), vi, context, model) +function maybe_invlink_before_eval!!(vi::AbstractVarInfo, model::Model) + return maybe_invlink_before_eval!!(transformation(vi), vi, model) end -function maybe_invlink_before_eval!!( - ::NoTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model -) +function maybe_invlink_before_eval!!(::NoTransformation, vi::AbstractVarInfo, model::Model) return vi end function maybe_invlink_before_eval!!( - ::DynamicTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model + ::DynamicTransformation, vi::AbstractVarInfo, model::Model ) - # `DynamicTransformation` is meant to _not_ do the transformation statically, hence we do nothing. + # `DynamicTransformation` is meant to _not_ do the transformation statically, hence we + # do nothing. return vi end function maybe_invlink_before_eval!!( - t::StaticTransformation, vi::AbstractVarInfo, ::AbstractContext, model::Model + t::StaticTransformation, vi::AbstractVarInfo, model::Model ) - # TODO(mhauru) Why does this function need the context argument? return invlink!!(t, vi, model) end -function _default_sampler(context::AbstractContext) - return _default_sampler(NodeTrait(_default_sampler, context), context) -end -_default_sampler(::IsLeaf, context::AbstractContext) = SampleFromPrior() -function _default_sampler(::IsParent, context::AbstractContext) - return _default_sampler(childcontext(context)) -end - # Utilities """ unflatten(vi::AbstractVarInfo[, context::AbstractContext], x::AbstractVector) diff --git a/src/model.jl b/src/model.jl index 6fb0b40b0..462db7397 100644 --- a/src/model.jl +++ b/src/model.jl @@ -971,7 +971,7 @@ Return the arguments and keyword arguments to be passed to the evaluator of the # lazy `invlink`-ing of the parameters. This can be useful for # speeding up computation. See docs for `maybe_invlink_before_eval!!` # for more information. - maybe_invlink_before_eval!!(varinfo, context_new, model), + maybe_invlink_before_eval!!(varinfo, model), context_new, $(unwrap_args...), ) @@ -1169,10 +1169,10 @@ end """ predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) -Generate samples from the posterior predictive distribution by evaluating `model` at each set -of parameter values provided in `chain`. The number of posterior predictive samples matches +Generate samples from the posterior predictive distribution by evaluating `model` at each set +of parameter values provided in `chain`. The number of posterior predictive samples matches the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values -and the predicted values. +and the predicted values. """ function predict( rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo} diff --git a/src/threadsafe.jl b/src/threadsafe.jl index bf4817fbd..b4403c46f 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -120,15 +120,12 @@ function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) return invlink!!(t, deepcopy(vi), model) end -function maybe_invlink_before_eval!!( - vi::ThreadSafeVarInfo, context::AbstractContext, model::Model -) +function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model) # Defer to the wrapped `AbstractVarInfo` object. - # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the `getlogp(vi.varinfo)` - # hence the log-absdet-jacobian term will correctly be included in the `getlogp(vi)`. - return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!( - vi.varinfo, context, model - ) + # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the + # `getlogp(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in + # the `getlogp(vi)`. + return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(vi.varinfo, model) end # `getindex` diff --git a/src/varinfo.jl b/src/varinfo.jl index 8836962e5..9516745f2 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1317,12 +1317,12 @@ function invlink!!( return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model) end -function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, model::Model) +function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) # Because `VarInfo` does not contain any information about what the transformation # other than whether or not it has actually been transformed, the best we can do # is just assume that `default_transformation` is the correct one if `istrans(vi)`. t = istrans(vi) ? default_transformation(model, vi) : NoTransformation() - return maybe_invlink_before_eval!!(t, vi, context, model) + return maybe_invlink_before_eval!!(t, vi, model) end function _invlink!(vi::UntypedVarInfo, vns) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 4343563eb..137c791c2 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -275,9 +275,7 @@ # Make sure `maybe_invlink_before_eval!!` results in `invlink!!`. @test !DynamicPPL.istrans( - DynamicPPL.maybe_invlink_before_eval!!( - deepcopy(vi), SamplingContext(), model - ), + DynamicPPL.maybe_invlink_before_eval!!(deepcopy(vi), model) ) # Resulting varinfo should no longer be transformed. From 39fa6476ad2a3c76b65efecf0c20d04d2cfd4401 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 29 Jan 2025 17:22:33 +0000 Subject: [PATCH 28/29] Fix a typo in a comment --- src/varinfo.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 9516745f2..09f5960c1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1224,8 +1224,8 @@ function _link!(vi::UntypedVarInfo, vns) end end -# If we try to _link! a TypedVarInfo with a Tuple of VarNames, first convert -# it to a NamedTuple that matches the structure of the TypedVarInfo. +# If we try to _link! a TypedVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the TypedVarInfo. function _link!(vi::TypedVarInfo, vns::VarNameTuple) return _link!(vi, group_varnames_by_symbol(vns)) end @@ -1337,8 +1337,8 @@ function _invlink!(vi::UntypedVarInfo, vns) end end -# If we try to _invlink! a TypedVarInfo with a Tuple of VarNames, first convert -# it to a NamedTuple that matches the structure of the TypedVarInfo. +# If we try to _invlink! a TypedVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the TypedVarInfo. function _invlink!(vi::TypedVarInfo, vns::VarNameTuple) return _invlink!(vi.metadata, vi, group_varnames_by_symbol(vns)) end @@ -1428,8 +1428,8 @@ function link( vns::VarNameTuple, model::Model, ) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model) end @@ -1439,8 +1439,8 @@ function _link(model::Model, varinfo::VarInfo, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -# If we try to _invlink! a TypedVarInfo with a Tuple or Vector of VarNames, first convert -# it to a NamedTuple that matches the structure of the TypedVarInfo. +# If we try to _link a TypedVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the TypedVarInfo. function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) return _link(model, varinfo, group_varnames_by_symbol(vns)) end @@ -1581,8 +1581,8 @@ function _invlink(model::Model, varinfo::VarInfo, vns) ) end -# If we try to _invlink a TypedVarInfo with a Tuple or Vector of VarNames, first convert -# it to a NamedTuple that matches the structure of the TypedVarInfo. +# If we try to _invlink a TypedVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the TypedVarInfo. function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) return _invlink(model, varinfo, group_varnames_by_symbol(vns)) end From 2c73de570b8a9f6d8d9dcf9f8a5f5454b3f12e73 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Jan 2025 11:58:39 +0000 Subject: [PATCH 29/29] Add HISTORY entry, fix one set_retained_vns_del! method --- HISTORY.md | 9 +++++++++ src/threadsafe.jl | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index f77d3fa74..eea7435c9 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,15 @@ **Breaking** +### Remove indexing by samplers + +This release removes the feature of `VarInfo` where it kept track of which variable was associated with which sampler. This means removing all user-facing methods where `VarInfo`s where being indexed with samplers. In particular, + + - `link` and `invlink`, and their `!!` versions, no longer accept a sampler as an argument to specify which variables to (inv)link. The `link(varinfo, model)` methods remain in place, and as a new addition one can give a `Tuple` of `VarName`s to (inv)link only select variables, as in `link(varinfo, varname_tuple, model)`. + - `set_retained_vns_del_by_spl!` has been replaced by `set_retained_vns_del!` which applies to all variables. + +### Reverse prefixing order + - For submodels constructed using `to_submodel`, the order in which nested prefixes are applied has been changed. Previously, the order was that outer prefixes were applied first, then inner ones. This version reverses that. diff --git a/src/threadsafe.jl b/src/threadsafe.jl index b4403c46f..69be5dcb1 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -163,8 +163,8 @@ function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) return vector_getranges(vi.varinfo, vns) end -function set_retained_vns_del!(vi::ThreadSafeVarInfo, spl::Sampler) - return set_retained_vns_del!(vi.varinfo, spl) +function set_retained_vns_del!(vi::ThreadSafeVarInfo) + return set_retained_vns_del!(vi.varinfo) end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo)