From 727da635d290c22bc978dd09febe229bb8e7c906 Mon Sep 17 00:00:00 2001
From: Markus Hauru <markus@mhauru.org>
Date: Wed, 22 Jan 2025 11:58:49 +0000
Subject: [PATCH] Fix `merge_metadata` for cases where the dimension of the
 variable changes (#781)

* Add test merging VarInfos with different dimensions for a variable

* Fix merge_metadata for differing dimensions

* Bump patch version to 0.34.1.

* Fix test

* Fix test more

* Pin KernelAbstractions to v0.9.31

* Make KernelAbstractions version bound an upper bound

Co-authored-by: Penelope Yong <penelopeysm@gmail.com>

* Fix syntax

---------

Co-authored-by: Penelope Yong <penelopeysm@gmail.com>
---
 Project.toml    |  8 ++++-
 src/varinfo.jl  | 79 +++++++++----------------------------------------
 test/varinfo.jl | 11 +++++++
 3 files changed, 32 insertions(+), 66 deletions(-)

diff --git a/Project.toml b/Project.toml
index fb9a1c55f..bd553c0cc 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
 name = "DynamicPPL"
 uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
-version = "0.34.0"
+version = "0.34.1"
 
 [deps]
 ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -15,6 +15,9 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
 InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
+# TODO(penelopeysm,mhauru) KernelAbstractions is only a dependency so that we can pin its version, see
+# https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
+KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
 LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
@@ -55,6 +58,9 @@ Compat = "4"
 ConstructionBase = "1.5.4"
 Distributions = "0.25"
 DocStringExtensions = "0.9"
+# TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
+# for why KernelAbstractions is pinned like this.
+KernelAbstractions = "< 0.9.32"
 EnzymeCore = "0.6 - 0.8"
 ForwardDiff = "0.10"
 JET = "0.9"
diff --git a/src/varinfo.jl b/src/varinfo.jl
index 3ebb505e0..3f36cc293 100644
--- a/src/varinfo.jl
+++ b/src/varinfo.jl
@@ -521,73 +521,22 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
     offset = 0
 
     for (idx, vn) in enumerate(vns_both)
-        # `idcs`
         idcs[vn] = idx
-        # `vns`
         push!(vns, vn)
-        if vn in vns_left && vn in vns_right
-            # `vals`: only valid if they're the length.
-            vals_left = getindex_internal(metadata_left, vn)
-            vals_right = getindex_internal(metadata_right, vn)
-            @assert length(vals_left) == length(vals_right)
-            append!(vals, vals_right)
-            # `ranges`
-            r = (offset + 1):(offset + length(vals_left))
-            push!(ranges, r)
-            offset = r[end]
-            # `dists`: only valid if they're the same.
-            dist_right = getdist(metadata_right, vn)
-            # Give precedence to `metadata_right`.
-            push!(dists, dist_right)
-            gid = metadata_right.gids[getidx(metadata_right, vn)]
-            push!(gids, gid)
-            # `orders`: giving precedence to `metadata_right`
-            push!(orders, getorder(metadata_right, vn))
-            # `flags`
-            for k in keys(flags)
-                # Using `metadata_right`; should we?
-                push!(flags[k], is_flagged(metadata_right, vn, k))
-            end
-        elseif vn in vns_left
-            # Just extract the metadata from `metadata_left`.
-            # `vals`
-            vals_left = getindex_internal(metadata_left, vn)
-            append!(vals, vals_left)
-            # `ranges`
-            r = (offset + 1):(offset + length(vals_left))
-            push!(ranges, r)
-            offset = r[end]
-            # `dists`
-            dist_left = getdist(metadata_left, vn)
-            push!(dists, dist_left)
-            gid = metadata_left.gids[getidx(metadata_left, vn)]
-            push!(gids, gid)
-            # `orders`
-            push!(orders, getorder(metadata_left, vn))
-            # `flags`
-            for k in keys(flags)
-                push!(flags[k], is_flagged(metadata_left, vn, k))
-            end
-        else
-            # Just extract the metadata from `metadata_right`.
-            # `vals`
-            vals_right = getindex_internal(metadata_right, vn)
-            append!(vals, vals_right)
-            # `ranges`
-            r = (offset + 1):(offset + length(vals_right))
-            push!(ranges, r)
-            offset = r[end]
-            # `dists`
-            dist_right = getdist(metadata_right, vn)
-            push!(dists, dist_right)
-            gid = metadata_right.gids[getidx(metadata_right, vn)]
-            push!(gids, gid)
-            # `orders`
-            push!(orders, getorder(metadata_right, vn))
-            # `flags`
-            for k in keys(flags)
-                push!(flags[k], is_flagged(metadata_right, vn, k))
-            end
+        metadata_for_vn = vn in vns_right ? metadata_right : metadata_left
+
+        val = getindex_internal(metadata_for_vn, vn)
+        append!(vals, val)
+        r = (offset + 1):(offset + length(val))
+        push!(ranges, r)
+        offset = r[end]
+        dist = getdist(metadata_for_vn, vn)
+        push!(dists, dist)
+        gid = metadata_for_vn.gids[getidx(metadata_for_vn, vn)]
+        push!(gids, gid)
+        push!(orders, getorder(metadata_for_vn, vn))
+        for k in keys(flags)
+            push!(flags[k], is_flagged(metadata_for_vn, vn, k))
         end
     end
 
diff --git a/test/varinfo.jl b/test/varinfo.jl
index 9a55cffb9..fce87b2f3 100644
--- a/test/varinfo.jl
+++ b/test/varinfo.jl
@@ -869,6 +869,17 @@ end
             @test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left
             @test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right
         end
+
+        # The below used to error, testing to avoid regression.
+        @testset "merge different dimensions" begin
+            vn = @varname(x)
+            vi_single = VarInfo()
+            vi_single = push!!(vi_single, vn, 1.0, Normal())
+            vi_double = VarInfo()
+            vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0))
+            @test merge(vi_single, vi_double)[vn] == [0.5, 0.6]
+            @test merge(vi_double, vi_single)[vn] == 1.0
+        end
     end
 
     @testset "VarInfo with selectors" begin