From 727da635d290c22bc978dd09febe229bb8e7c906 Mon Sep 17 00:00:00 2001 From: Markus Hauru <markus@mhauru.org> Date: Wed, 22 Jan 2025 11:58:49 +0000 Subject: [PATCH] Fix `merge_metadata` for cases where the dimension of the variable changes (#781) * Add test merging VarInfos with different dimensions for a variable * Fix merge_metadata for differing dimensions * Bump patch version to 0.34.1. * Fix test * Fix test more * Pin KernelAbstractions to v0.9.31 * Make KernelAbstractions version bound an upper bound Co-authored-by: Penelope Yong <penelopeysm@gmail.com> * Fix syntax --------- Co-authored-by: Penelope Yong <penelopeysm@gmail.com> --- Project.toml | 8 ++++- src/varinfo.jl | 79 +++++++++---------------------------------------- test/varinfo.jl | 11 +++++++ 3 files changed, 32 insertions(+), 66 deletions(-) diff --git a/Project.toml b/Project.toml index fb9a1c55f..bd553c0cc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.34.0" +version = "0.34.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -15,6 +15,9 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +# TODO(penelopeysm,mhauru) KernelAbstractions is only a dependency so that we can pin its version, see +# https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767 +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" @@ -55,6 +58,9 @@ Compat = "4" ConstructionBase = "1.5.4" Distributions = "0.25" DocStringExtensions = "0.9" +# TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767 +# for why KernelAbstractions is pinned like this. +KernelAbstractions = "< 0.9.32" EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10" JET = "0.9" diff --git a/src/varinfo.jl b/src/varinfo.jl index 3ebb505e0..3f36cc293 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -521,73 +521,22 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) offset = 0 for (idx, vn) in enumerate(vns_both) - # `idcs` idcs[vn] = idx - # `vns` push!(vns, vn) - if vn in vns_left && vn in vns_right - # `vals`: only valid if they're the length. - vals_left = getindex_internal(metadata_left, vn) - vals_right = getindex_internal(metadata_right, vn) - @assert length(vals_left) == length(vals_right) - append!(vals, vals_right) - # `ranges` - r = (offset + 1):(offset + length(vals_left)) - push!(ranges, r) - offset = r[end] - # `dists`: only valid if they're the same. - dist_right = getdist(metadata_right, vn) - # Give precedence to `metadata_right`. - push!(dists, dist_right) - gid = metadata_right.gids[getidx(metadata_right, vn)] - push!(gids, gid) - # `orders`: giving precedence to `metadata_right` - push!(orders, getorder(metadata_right, vn)) - # `flags` - for k in keys(flags) - # Using `metadata_right`; should we? - push!(flags[k], is_flagged(metadata_right, vn, k)) - end - elseif vn in vns_left - # Just extract the metadata from `metadata_left`. - # `vals` - vals_left = getindex_internal(metadata_left, vn) - append!(vals, vals_left) - # `ranges` - r = (offset + 1):(offset + length(vals_left)) - push!(ranges, r) - offset = r[end] - # `dists` - dist_left = getdist(metadata_left, vn) - push!(dists, dist_left) - gid = metadata_left.gids[getidx(metadata_left, vn)] - push!(gids, gid) - # `orders` - push!(orders, getorder(metadata_left, vn)) - # `flags` - for k in keys(flags) - push!(flags[k], is_flagged(metadata_left, vn, k)) - end - else - # Just extract the metadata from `metadata_right`. - # `vals` - vals_right = getindex_internal(metadata_right, vn) - append!(vals, vals_right) - # `ranges` - r = (offset + 1):(offset + length(vals_right)) - push!(ranges, r) - offset = r[end] - # `dists` - dist_right = getdist(metadata_right, vn) - push!(dists, dist_right) - gid = metadata_right.gids[getidx(metadata_right, vn)] - push!(gids, gid) - # `orders` - push!(orders, getorder(metadata_right, vn)) - # `flags` - for k in keys(flags) - push!(flags[k], is_flagged(metadata_right, vn, k)) - end + metadata_for_vn = vn in vns_right ? metadata_right : metadata_left + + val = getindex_internal(metadata_for_vn, vn) + append!(vals, val) + r = (offset + 1):(offset + length(val)) + push!(ranges, r) + offset = r[end] + dist = getdist(metadata_for_vn, vn) + push!(dists, dist) + gid = metadata_for_vn.gids[getidx(metadata_for_vn, vn)] + push!(gids, gid) + push!(orders, getorder(metadata_for_vn, vn)) + for k in keys(flags) + push!(flags[k], is_flagged(metadata_for_vn, vn, k)) end end diff --git a/test/varinfo.jl b/test/varinfo.jl index 9a55cffb9..fce87b2f3 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -869,6 +869,17 @@ end @test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left @test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right end + + # The below used to error, testing to avoid regression. + @testset "merge different dimensions" begin + vn = @varname(x) + vi_single = VarInfo() + vi_single = push!!(vi_single, vn, 1.0, Normal()) + vi_double = VarInfo() + vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0)) + @test merge(vi_single, vi_double)[vn] == [0.5, 0.6] + @test merge(vi_double, vi_single)[vn] == 1.0 + end end @testset "VarInfo with selectors" begin