From bd4baf154e29168d26ec5b04b1b29a30db255f09 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 14 Oct 2024 13:40:41 +0100 Subject: [PATCH] Fix treatment of gid in merge(::Metadata) --- src/varinfo.jl | 8 +++++++- test/varinfo.jl | 13 +++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 8727796bc..4b229d828 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -490,7 +490,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) ranges = Vector{UnitRange{Int}}() vals = T[] dists = D[] - gids = metadata_right.gids # NOTE: giving precedence to `metadata_right` + gids = Set{Selector}[] orders = Int[] flags = Dict{String,BitVector}() # Initialize the `flags`. @@ -520,6 +520,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) dist_right = getdist(metadata_right, vn) # Give precedence to `metadata_right`. push!(dists, dist_right) + gid = metadata_right.gids[getidx(metadata_right, vn)] + push!(gids, gid) # `orders`: giving precedence to `metadata_right` push!(orders, getorder(metadata_right, vn)) # `flags` @@ -539,6 +541,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) # `dists` dist_left = getdist(metadata_left, vn) push!(dists, dist_left) + gid = metadata_left.gids[getidx(metadata_left, vn)] + push!(gids, gid) # `orders` push!(orders, getorder(metadata_left, vn)) # `flags` @@ -557,6 +561,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) # `dists` dist_right = getdist(metadata_right, vn) push!(dists, dist_right) + gid = metadata_right.gids[getidx(metadata_right, vn)] + push!(gids, gid) # `orders` push!(orders, getorder(metadata_right, vn)) # `flags` diff --git a/test/varinfo.jl b/test/varinfo.jl index 65f849dda..88439425a 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -694,6 +694,19 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] @test DynamicPPL.istrans(varinfo_merged, @varname(x)) end + + # The below used to error, testing to avoid regression. + @testset "merge gids" begin + gidset_left = Set([Selector(1)]) + vi_left = VarInfo() + vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left) + gidset_right = Set([Selector(2)]) + vi_right = VarInfo() + vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right) + varinfo_merged = merge(vi_left, vi_right) + @test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left + @test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right + end end @testset "VarInfo with selectors" begin