Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove selector stuff from VarInfo tests and link/invlink #780

Merged
merged 33 commits into from
Jan 30, 2025
Merged
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4dc2a72
Remove selector stuff from varinfo tests
mhauru Jan 16, 2025
9b492a3
Implement link and invlink for varnames rather than samplers
mhauru Jan 16, 2025
b508f08
Replace set_retained_vns_del_by_spl! with set_retained_vns_del!
mhauru Jan 16, 2025
b8880d1
Make linking tests more extensive
mhauru Jan 16, 2025
99a8490
Remove sampler indexing from link methods (but not invlink)
mhauru Jan 22, 2025
4a79b1f
Remove indexing by samplers from invlink
mhauru Jan 22, 2025
26a1901
Merge remote-tracking branch 'origin/master' into mhauru/remove-selec…
mhauru Jan 22, 2025
090608b
Work towards removing sampler indexing with StaticTransformation
mhauru Jan 22, 2025
4749853
Fix invlink/link for TypedVarInfo and StaticTransformation
mhauru Jan 23, 2025
e960679
Fix a test in models.jl
mhauru Jan 23, 2025
d507a53
Move some functions to utils.jl, add tests and docstrings
mhauru Jan 23, 2025
41150b5
Fix a docstring typo
mhauru Jan 23, 2025
836fb13
Merge branch 'release-0.35' into mhauru/remove-selectors-linking
mhauru Jan 23, 2025
45d1f13
Various simplification to link/invlink
mhauru Jan 23, 2025
98915c2
Improve a docstring
mhauru Jan 23, 2025
f05068d
Style improvements
mhauru Jan 23, 2025
bc4c420
Fix broken link/invlink dispatch cascade for VectorVarInfo
mhauru Jan 23, 2025
71980ba
Fix some more broken dispatch cascades
mhauru Jan 23, 2025
45562a9
Apply suggestions from code review
mhauru Jan 24, 2025
db5b835
Remove comments that messed with docstrings
mhauru Jan 24, 2025
f99effe
Apply suggestions from code review
mhauru Jan 28, 2025
56194cd
Fix issues surfaced in code review
mhauru Jan 28, 2025
c187c49
Simplify link/invlink arguments
mhauru Jan 28, 2025
86b25c5
Fix a bug in unflatten VarNamedVector
mhauru Jan 28, 2025
2a6c1bc
Rename VarNameCollection -> VarNameTuple
mhauru Jan 28, 2025
853f47e
Remove test of a removed varname_namedtuple method
mhauru Jan 28, 2025
ed80328
Apply suggestions from code review
mhauru Jan 29, 2025
d996d0c
Respond to review feedback
mhauru Jan 29, 2025
2083148
Remove _default_sampler and a dead argument of maybe_invlink_before_eval
mhauru Jan 29, 2025
39fa647
Fix a typo in a comment
mhauru Jan 29, 2025
9df364f
Merge remote-tracking branch 'origin/release-0.35' into mhauru/remove…
mhauru Jan 30, 2025
2c73de5
Add HISTORY entry, fix one set_retained_vns_del! method
mhauru Jan 30, 2025
49604e1
Merge remote-tracking branch 'origin/release-0.35' into mhauru/remove…
mhauru Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix broken link/invlink dispatch cascade for VectorVarInfo
  • Loading branch information
mhauru committed Jan 23, 2025
commit bc4c42093dafe01c0d7ed3984232e471a1bdcb65
26 changes: 16 additions & 10 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
@@ -791,8 +791,8 @@
syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols
syms(vi::TypedVarInfo) = keys(vi.metadata)

_getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs)
_getidcs(vi::TypedVarInfo) = _getidcs(vi.metadata)

Check warning on line 795 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L794-L795

Added lines #L794 - L795 were not covered by tests

# Get all indices of variables belonging to SampleFromPrior:
# if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to
@@ -1170,18 +1170,22 @@

# Specialise link!! without varnames provided for TypedVarInfo. The generic version gets
# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which
# helps keep the downstream call to _link! type stable.
function link!!(::DynamicTransformation, vi::TypedVarInfo, ::Model)
_link!(vi, all_varnames_namedtuple(vi))
return vi
# helps keep the downstream call to link!! type stable.
function link!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model)
return link!!(t, vi, all_varnames_namedtuple(vi), model)

Check warning on line 1175 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1174-L1175

Added lines #L1174 - L1175 were not covered by tests
end

# X -> R for all variables associated with given sampler
function link!!(t::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model)
function link!!(

Check warning on line 1179 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1179

Added line #L1179 was not covered by tests
t::DynamicTransformation,
vi::VarInfo,
vns::Union{VarNameCollection,NamedTuple},
model::Model,
)
# If we're working with a `VarNamedVector`, we always use immutable.
has_varnamedvector(vi) && return link(t, vi, vns, model)

Check warning on line 1186 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1186

Added line #L1186 was not covered by tests
# Call `_link!` instead of `link!` to avoid deprecation warning.
_link!(vi, vns)

Check warning on line 1188 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1188

Added line #L1188 was not covered by tests
return vi
end

@@ -1193,10 +1197,10 @@
)
# By default this will simply evaluate the model with `DynamicTransformationContext`, and so
# we need to specialize to avoid this.
return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model)

Check warning on line 1200 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1200

Added line #L1200 was not covered by tests
end

function _link!(vi::UntypedVarInfo, vns::VarNameCollection)

Check warning on line 1203 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1203

Added line #L1203 was not covered by tests
# TODO: Change to a lazy iterator over `vns`
if ~istrans(vi, vns[1])
for vn in vns
@@ -1209,8 +1213,8 @@
end
end

function _link!(vi::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple})
return _link!(vi.metadata, vi, varname_namedtuple(vns))

Check warning on line 1217 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1216-L1217

Added lines #L1216 - L1217 were not covered by tests
end

"""
@@ -1218,25 +1222,25 @@

Return the subset of `filtered_vns` that are subsumed by any variable in `filter_vns`.
"""
function filter_subsumed(filter_vns, filtered_vns)
return filter(x -> any(subsumes(y, x) for y in filter_vns), filtered_vns)

Check warning on line 1226 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1225-L1226

Added lines #L1225 - L1226 were not covered by tests
end

@generated function _link!(
::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names}
) where {metadata_names,vns_names}
expr = Expr(:block)
for f in metadata_names
if !(f in vns_names)
continue

Check warning on line 1235 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1233-L1235

Added lines #L1233 - L1235 were not covered by tests
end
push!(

Check warning on line 1237 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1237

Added line #L1237 was not covered by tests
expr.args,
quote
f_vns = vi.metadata.$f.vns
f_vns = filter_subsumed(vns.$f, f_vns)
if !isempty(f_vns)
if !istrans(vi, f_vns[1])

Check warning on line 1243 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1240-L1243

Added lines #L1240 - L1243 were not covered by tests
# Iterate over all `f_vns` and transform
for vn in f_vns
f = internal_to_linked_internal_transform(vi, vn)
@@ -1255,20 +1259,22 @@

# Specialise invlink!! without varnames provided for TypedVarInfo. The generic version gets
# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which
# helps keep the downstream call to _invlink! type stable.
function invlink!!(::DynamicTransformation, vi::TypedVarInfo, ::Model)
_invlink!(vi, all_varnames_namedtuple(vi))
return vi
# helps keep the downstream call to invlink!! type stable.
function invlink!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model)
return invlink!!(t, vi, all_varnames_namedtuple(vi), model)

Check warning on line 1264 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1263-L1264

Added lines #L1263 - L1264 were not covered by tests
end

# R -> X for all variables associated with given sampler
function invlink!!(
t::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model
t::DynamicTransformation,
vi::VarInfo,
vns::Union{VarNameCollection,NamedTuple},
model::Model,
)
# If we're working with a `VarNamedVector`, we always use immutable.
has_varnamedvector(vi) && return invlink(t, vi, vns, model)

Check warning on line 1275 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1275

Added line #L1275 was not covered by tests
# Call `_invlink!` instead of `invlink!` to avoid deprecation warning.
_invlink!(vi, vns)

Check warning on line 1277 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1277

Added line #L1277 was not covered by tests
return vi
end

@@ -1280,7 +1286,7 @@
)
# By default this will simply evaluate the model with `DynamicTransformationContext`, and so
# we need to specialize to avoid this.
return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model)

Check warning on line 1289 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1289

Added line #L1289 was not covered by tests
end

function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, model::Model)
@@ -1291,7 +1297,7 @@
return maybe_invlink_before_eval!!(t, vi, context, model)
end

function _invlink!(vi::UntypedVarInfo, vns::VarNameCollection)

Check warning on line 1300 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1300

Added line #L1300 was not covered by tests
if istrans(vi, vns[1])
for vn in vns
f = linked_internal_to_internal_transform(vi, vn)
@@ -1303,34 +1309,34 @@
end
end

function _invlink!(vi::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple})
vns_namedtuple = varname_namedtuple(vns)
return _invlink!(vi.metadata, vi, vns_namedtuple)

Check warning on line 1314 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1312-L1314

Added lines #L1312 - L1314 were not covered by tests
end

@generated function _invlink!(
::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names}
) where {metadata_names,vns_names}
expr = Expr(:block)
for f in metadata_names
if !(f in vns_names)
continue

Check warning on line 1323 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1321-L1323

Added lines #L1321 - L1323 were not covered by tests
end

push!(

Check warning on line 1326 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1326

Added line #L1326 was not covered by tests
expr.args,
quote
f_vns = vi.metadata.$f.vns
f_vns = filter_subsumed(vns.$f, f_vns)
if istrans(vi, f_vns[1])

Check warning on line 1331 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1329-L1331

Added lines #L1329 - L1331 were not covered by tests
# Iterate over all `f_vns` and transform
for vn in f_vns
f = linked_internal_to_internal_transform(vi, vn)
_inner_transform!(vi, vn, f)
settrans!!(vi, false, vn)
end

Check warning on line 1337 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1333-L1337

Added lines #L1333 - L1337 were not covered by tests
else
@warn("[DynamicPPL] attempt to invlink an invlinked vi")

Check warning on line 1339 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1339

Added line #L1339 was not covered by tests
end
end,
)
@@ -1365,10 +1371,10 @@
return map(Returns(nothing), varinfo.metadata)
end

function link(

Check warning on line 1374 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1374

Added line #L1374 was not covered by tests
::DynamicTransformation, varinfo::VarInfo, vns::VarNameCollection, model::Model
)
return _link(model, varinfo, vns)

Check warning on line 1377 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1377

Added line #L1377 was not covered by tests
end

function link(
@@ -1379,7 +1385,7 @@
)
# By default this will simply evaluate the model with `DynamicTransformationContext`, and so
# we need to specialize to avoid this.
return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model)

Check warning on line 1388 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1388

Added line #L1388 was not covered by tests
end

# Specialise link without varnames provided for TypedVarInfo. The generic version gets
@@ -1502,7 +1508,7 @@
function invlink(
::DynamicTransformation, varinfo::VarInfo, vns::VarNameCollection, model::Model
)
return _invlink(model, varinfo, vns)

Check warning on line 1511 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1511

Added line #L1511 was not covered by tests
end

function invlink(
@@ -1513,7 +1519,7 @@
)
# By default this will simply evaluate the model with `DynamicTransformationContext`, and so
# we need to specialize to avoid this.
return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, vns, model)

Check warning on line 1522 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1522

Added line #L1522 was not covered by tests
end

# Specialise invlink without varnames provided for TypedVarInfo. The generic version gets
@@ -1523,7 +1529,7 @@
return _invlink(model, vi, all_varnames_namedtuple(vi))
end

function _invlink(model::Model, varinfo::VarInfo, vns::VarNameCollection)

Check warning on line 1532 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1532

Added line #L1532 was not covered by tests
varinfo = deepcopy(varinfo)
return VarInfo(
_invlink_metadata!!(model, varinfo, varinfo.metadata, vns),
@@ -2034,32 +2040,32 @@

Set the `"del"` flag of variables in `vi` with `order > vi.num_produce[]` to `true`.
"""
function set_retained_vns_del!(vi::UntypedVarInfo)
idcs = _getidcs(vi)

Check warning on line 2044 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L2043-L2044

Added lines #L2043 - L2044 were not covered by tests
if get_num_produce(vi) == 0
for i in length(idcs):-1:1
vi.metadata.flags["del"][idcs[i]] = true

Check warning on line 2047 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L2046-L2047

Added lines #L2046 - L2047 were not covered by tests
end
else
for i in 1:length(vi.orders)
if i in idcs && vi.orders[i] > get_num_produce(vi)

Check warning on line 2051 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L2051

Added line #L2051 was not covered by tests
vi.metadata.flags["del"][i] = true
end
end
end
return nothing
end
function set_retained_vns_del!(vi::TypedVarInfo)

Check warning on line 2058 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L2058

Added line #L2058 was not covered by tests
# Get the indices of `vns` that belong to `spl` as a NamedTuple, one entry for each symbol
mhauru marked this conversation as resolved.
Show resolved Hide resolved
idcs = _getidcs(vi)
return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi))

Check warning on line 2061 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L2060-L2061

Added lines #L2060 - L2061 were not covered by tests
end
@generated function _set_retained_vns_del!(

Check warning on line 2063 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L2063

Added line #L2063 was not covered by tests
metadata, idcs::NamedTuple{names}, num_produce
) where {names}
expr = Expr(:block)
for f in names
f_idcs = :(idcs.$f)

Check warning on line 2068 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L2068

Added line #L2068 was not covered by tests
f_orders = :(metadata.$f.orders)
f_flags = :(metadata.$f.flags)
push!(
@@ -2067,12 +2073,12 @@
quote
# Set the flag for variables with symbol `f`
if num_produce == 0
for i in length($f_idcs):-1:1
$f_flags["del"][$f_idcs[i]] = true

Check warning on line 2077 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L2076-L2077

Added lines #L2076 - L2077 were not covered by tests
end
else
for i in 1:length($f_orders)
if i in $f_idcs && $f_orders[i] > num_produce

Check warning on line 2081 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L2081

Added line #L2081 was not covered by tests
$f_flags["del"][i] = true
end
end
Loading