Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove selector stuff from VarInfo tests and link/invlink #780

Merged
merged 33 commits into from
Jan 30, 2025
Merged
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4dc2a72
Remove selector stuff from varinfo tests
mhauru Jan 16, 2025
9b492a3
Implement link and invlink for varnames rather than samplers
mhauru Jan 16, 2025
b508f08
Replace set_retained_vns_del_by_spl! with set_retained_vns_del!
mhauru Jan 16, 2025
b8880d1
Make linking tests more extensive
mhauru Jan 16, 2025
99a8490
Remove sampler indexing from link methods (but not invlink)
mhauru Jan 22, 2025
4a79b1f
Remove indexing by samplers from invlink
mhauru Jan 22, 2025
26a1901
Merge remote-tracking branch 'origin/master' into mhauru/remove-selec…
mhauru Jan 22, 2025
090608b
Work towards removing sampler indexing with StaticTransformation
mhauru Jan 22, 2025
4749853
Fix invlink/link for TypedVarInfo and StaticTransformation
mhauru Jan 23, 2025
e960679
Fix a test in models.jl
mhauru Jan 23, 2025
d507a53
Move some functions to utils.jl, add tests and docstrings
mhauru Jan 23, 2025
41150b5
Fix a docstring typo
mhauru Jan 23, 2025
836fb13
Merge branch 'release-0.35' into mhauru/remove-selectors-linking
mhauru Jan 23, 2025
45d1f13
Various simplification to link/invlink
mhauru Jan 23, 2025
98915c2
Improve a docstring
mhauru Jan 23, 2025
f05068d
Style improvements
mhauru Jan 23, 2025
bc4c420
Fix broken link/invlink dispatch cascade for VectorVarInfo
mhauru Jan 23, 2025
71980ba
Fix some more broken dispatch cascades
mhauru Jan 23, 2025
45562a9
Apply suggestions from code review
mhauru Jan 24, 2025
db5b835
Remove comments that messed with docstrings
mhauru Jan 24, 2025
f99effe
Apply suggestions from code review
mhauru Jan 28, 2025
56194cd
Fix issues surfaced in code review
mhauru Jan 28, 2025
c187c49
Simplify link/invlink arguments
mhauru Jan 28, 2025
86b25c5
Fix a bug in unflatten VarNamedVector
mhauru Jan 28, 2025
2a6c1bc
Rename VarNameCollection -> VarNameTuple
mhauru Jan 28, 2025
853f47e
Remove test of a removed varname_namedtuple method
mhauru Jan 28, 2025
ed80328
Apply suggestions from code review
mhauru Jan 29, 2025
d996d0c
Respond to review feedback
mhauru Jan 29, 2025
2083148
Remove _default_sampler and a dead argument of maybe_invlink_before_eval
mhauru Jan 29, 2025
39fa647
Fix a typo in a comment
mhauru Jan 29, 2025
9df364f
Merge remote-tracking branch 'origin/release-0.35' into mhauru/remove…
mhauru Jan 30, 2025
2c73de5
Add HISTORY entry, fix one set_retained_vns_del! method
mhauru Jan 30, 2025
49604e1
Merge remote-tracking branch 'origin/release-0.35' into mhauru/remove…
mhauru Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Apply suggestions from code review
Co-authored-by: Penelope Yong <[email protected]>
mhauru and penelopeysm authored Jan 28, 2025
commit f99effe14ed5189e8b552984ff29b6cf5e56c6b6
4 changes: 2 additions & 2 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
@@ -593,15 +593,15 @@
function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
vns = collect(keys(vi))
# In case e.g. vns = Any[].
# TODO(mhauru) Could we rather fix akeys` so that it would always return VarName[]?
# TODO(mhauru) Could we rather fix `keys` so that it would always return VarName[]?
if !(eltype(vns) <: VarName)
vns = collect(VarName, vns)
end
return link(t, vi, vns, model)
end
# Wrap a single VarName in a singleton tuple.
mhauru marked this conversation as resolved.
Show resolved Hide resolved
function link(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model)
return link(t, vi, (vn,), model)

Check warning on line 604 in src/abstract_varinfo.jl

Codecov / codecov/patch

src/abstract_varinfo.jl#L603-L604

Added lines #L603 - L604 were not covered by tests
end

"""
@@ -653,8 +653,8 @@
# loudly.
all_vns = Set(keys(vi))
if Set(vns) != all_vns
msg = "Statically transforming only a subset of variables is not supported."
throw(ArgumentError(msg))

Check warning on line 657 in src/abstract_varinfo.jl

Codecov / codecov/patch

src/abstract_varinfo.jl#L656-L657

Added lines #L656 - L657 were not covered by tests
end
b = inverse(t.bijector)
x = vi[:]
@@ -674,8 +674,8 @@
# TODO(mhauru) See comment in link!! above.
all_vns = Set(keys(vi))
if Set(vns) != all_vns
msg = "StaticTransforming only a subset of variables is not supported."
msg = "Statically transforming only a subset of variables is not supported."
throw(ArgumentError(msg))

Check warning on line 678 in src/abstract_varinfo.jl

Codecov / codecov/patch

src/abstract_varinfo.jl#L677-L678

Added lines #L677 - L678 were not covered by tests
end
b = t.bijector
y = vi[:]
@@ -715,8 +715,8 @@
return invlink(t, vi, vns, model)
end
# Wrap a single VarName in a singleton tuple.
function invlink(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model)
return invlink(t, vi, (vn,), model)

Check warning on line 719 in src/abstract_varinfo.jl

Codecov / codecov/patch

src/abstract_varinfo.jl#L718-L719

Added lines #L718 - L719 were not covered by tests
end

"""
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
@@ -1298,7 +1298,7 @@ end

Return a `NamedTuple` of the variables in `vns` grouped by symbol.

`varname_namedtuple` is type table for inputs that are `Tuple`s, and for vectors when all
`varname_namedtuple` is type stable for inputs that are `Tuple`s, and for vectors when all
`VarName`s in the vector have the same symbol. For a `NamedTuple` it's a no-op.

Example:

Unchanged files with check annotations Beta

function link!!(
t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model
)
return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, vns, model)

Check warning on line 87 in src/threadsafe.jl

Codecov / codecov/patch

src/threadsafe.jl#L87

Added line #L87 was not covered by tests
end
function invlink!!(
t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model
)
return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, vns, model)

Check warning on line 93 in src/threadsafe.jl

Codecov / codecov/patch

src/threadsafe.jl#L93

Added line #L93 was not covered by tests
end
function link(
t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model
)
return Accessors.@set vi.varinfo = link(t, vi.varinfo, vns, model)

Check warning on line 99 in src/threadsafe.jl

Codecov / codecov/patch

src/threadsafe.jl#L99

Added line #L99 was not covered by tests
end
function invlink(
t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model
)
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, vns, model)

Check warning on line 105 in src/threadsafe.jl

Codecov / codecov/patch

src/threadsafe.jl#L105

Added line #L105 was not covered by tests
end
# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity.
)
# By default this will simply evaluate the model with `DynamicTransformationContext`, and so
# we need to specialize to avoid this.
return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model)

Check warning on line 1295 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1295

Added line #L1295 was not covered by tests
end
function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, model::Model)
settrans!!(vi, false, vn)
end
else
@warn("[DynamicPPL] attempt to invlink an invlinked vi")

Check warning on line 1350 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1350

Added line #L1350 was not covered by tests
end
end,
)
end
else
for i in 1:length(vi.orders)
if i in idcs && vi.orders[i] > get_num_produce(vi)

Check warning on line 2068 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L2068

Added line #L2068 was not covered by tests
vi.metadata.flags["del"][i] = true
end
end
end
else
for i in 1:length($f_orders)
if i in $f_idcs && $f_orders[i] > num_produce

Check warning on line 2098 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L2098

Added line #L2098 was not covered by tests
$f_flags["del"][i] = true
end
end