Skip to content

Commit

Permalink
Fix merge_metadata for cases where the dimension of the variable ch…
Browse files Browse the repository at this point in the history
…anges (#781)

* Add test merging VarInfos with different dimensions for a variable

* Fix merge_metadata for differing dimensions

* Bump patch version to 0.34.1.

* Fix test

* Fix test more

* Pin KernelAbstractions to v0.9.31

* Make KernelAbstractions version bound an upper bound

Co-authored-by: Penelope Yong <[email protected]>

* Fix syntax

---------

Co-authored-by: Penelope Yong <[email protected]>
  • Loading branch information
mhauru and penelopeysm authored Jan 22, 2025
1 parent 938a69d commit 727da63
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 66 deletions.
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.34.0"
version = "0.34.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -15,6 +15,9 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
# TODO(penelopeysm,mhauru) KernelAbstractions is only a dependency so that we can pin its version, see
# https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Expand Down Expand Up @@ -55,6 +58,9 @@ Compat = "4"
ConstructionBase = "1.5.4"
Distributions = "0.25"
DocStringExtensions = "0.9"
# TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
# for why KernelAbstractions is pinned like this.
KernelAbstractions = "< 0.9.32"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
JET = "0.9"
Expand Down
79 changes: 14 additions & 65 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,73 +521,22 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
offset = 0

for (idx, vn) in enumerate(vns_both)
# `idcs`
idcs[vn] = idx
# `vns`
push!(vns, vn)
if vn in vns_left && vn in vns_right
# `vals`: only valid if they're the length.
vals_left = getindex_internal(metadata_left, vn)
vals_right = getindex_internal(metadata_right, vn)
@assert length(vals_left) == length(vals_right)
append!(vals, vals_right)
# `ranges`
r = (offset + 1):(offset + length(vals_left))
push!(ranges, r)
offset = r[end]
# `dists`: only valid if they're the same.
dist_right = getdist(metadata_right, vn)
# Give precedence to `metadata_right`.
push!(dists, dist_right)
gid = metadata_right.gids[getidx(metadata_right, vn)]
push!(gids, gid)
# `orders`: giving precedence to `metadata_right`
push!(orders, getorder(metadata_right, vn))
# `flags`
for k in keys(flags)
# Using `metadata_right`; should we?
push!(flags[k], is_flagged(metadata_right, vn, k))
end
elseif vn in vns_left
# Just extract the metadata from `metadata_left`.
# `vals`
vals_left = getindex_internal(metadata_left, vn)
append!(vals, vals_left)
# `ranges`
r = (offset + 1):(offset + length(vals_left))
push!(ranges, r)
offset = r[end]
# `dists`
dist_left = getdist(metadata_left, vn)
push!(dists, dist_left)
gid = metadata_left.gids[getidx(metadata_left, vn)]
push!(gids, gid)
# `orders`
push!(orders, getorder(metadata_left, vn))
# `flags`
for k in keys(flags)
push!(flags[k], is_flagged(metadata_left, vn, k))
end
else
# Just extract the metadata from `metadata_right`.
# `vals`
vals_right = getindex_internal(metadata_right, vn)
append!(vals, vals_right)
# `ranges`
r = (offset + 1):(offset + length(vals_right))
push!(ranges, r)
offset = r[end]
# `dists`
dist_right = getdist(metadata_right, vn)
push!(dists, dist_right)
gid = metadata_right.gids[getidx(metadata_right, vn)]
push!(gids, gid)
# `orders`
push!(orders, getorder(metadata_right, vn))
# `flags`
for k in keys(flags)
push!(flags[k], is_flagged(metadata_right, vn, k))
end
metadata_for_vn = vn in vns_right ? metadata_right : metadata_left

val = getindex_internal(metadata_for_vn, vn)
append!(vals, val)
r = (offset + 1):(offset + length(val))
push!(ranges, r)
offset = r[end]
dist = getdist(metadata_for_vn, vn)
push!(dists, dist)
gid = metadata_for_vn.gids[getidx(metadata_for_vn, vn)]
push!(gids, gid)
push!(orders, getorder(metadata_for_vn, vn))
for k in keys(flags)
push!(flags[k], is_flagged(metadata_for_vn, vn, k))
end
end

Expand Down
11 changes: 11 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,17 @@ end
@test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left
@test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right
end

# The below used to error, testing to avoid regression.
@testset "merge different dimensions" begin
vn = @varname(x)
vi_single = VarInfo()
vi_single = push!!(vi_single, vn, 1.0, Normal())
vi_double = VarInfo()
vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0))
@test merge(vi_single, vi_double)[vn] == [0.5, 0.6]
@test merge(vi_double, vi_single)[vn] == 1.0
end
end

@testset "VarInfo with selectors" begin
Expand Down

2 comments on commit 727da63

@mhauru
Copy link
Member Author

@mhauru mhauru commented on 727da63 Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:
Fix an issue that prevented merging two VarInfos if they had different dimensions for a variable.

Upper bound the compat version of KernelAbstractions to work around an issue in determining the right VarInfo type to use.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/123483

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.34.1 -m "<description of version>" 727da635d290c22bc978dd09febe229bb8e7c906
git push origin v0.34.1

Please sign in to comment.