Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for LogDensityFunction #621

Merged
merged 6 commits into from
Jun 25, 2024
Merged

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Jun 17, 2024

Issue

When evaluating a Model, there are two "sources" of contexts provided: 1) explicitly passed to evaluate!! as an argument, and 2) through the context attached to the model itself in model.context.

The latter was introduced because in many scenarios it makes sense to "contextualize" a Model, e.g. attach a ConditionContext to a Model to specify which parameters are considered conditioned. The former is present because back in the day when the samplers were heavily tied to DynamicPPL and we passed a sampler argument in almost every place where we now pass context.

To "bridge" the two approaches, when we call evaluate!!, the process of "resolving" the context that eventually ends up as __context__ in the model itself, occurs here:

DynamicPPL.jl/src/model.jl

Lines 995 to 997 in d384da2

context_new = setleafcontext(
context, setleafcontext(model.context, leafcontext(context))
)

In short, we do:

  1. Replace the leaf context of model.context with the leaf context provided by context argument.
  2. Set the child context of context argument to the context resulting from (1).

This might be a bit strange, but the result is that context takes precedence over model.context, as it's considered to be "more important" due to the "user" explicitly passing it to evaluate!!.

We did this because some samplers were using contexts to specify certain behaviors that had to be respected, e.g. context could be a PriorContext to indicate that the prior should be evaluated while model.context could be a DefaultContext, in which case we wanted the result to be PriorContext.

This also means that LogDensityFunction, effectively a convenient wrapper around evaluate!!, also has two sources of contexts: f.model.context and f.context. By default, i.e. if we call LogDensityFunction(model), we specify these to be the same, i.e. f.context === f.model.context. This is clearly very redundant, since we're just specifiying the same context twice. Moreover, since, as seen above, we effectively concatenate context and model.context, this results in LogDensityFunction evaluating the model with the context "doubled". In most cases, this still results in the intended behavior, but once you start changing certain fields of the LogDensityFunction, e.g. LogDensityFunction(model_new, f.varinfo, f.context), interesting things can happen. For example, in TuringLang/Turing.jl#2231, I ran into an issue where I'd get two ConditionContext conditioning the same variable; one from f.model.context (what I intended) and one from f.context (what I did not intend).

Solution

This PR addresses this issue for LogDensityFunction by simply allowing nothing in f.context, which is resolved to leafcontext(model.context) if not specified. This addresses the issues I've encountered above, but the proper way of fixing this would, IMO, be to either:

  1. Allow nothing to be passed in place of context "everywhere", i.e. we make them all Optional{AbstractContext} = Union{Nothing,AbstractContext} types, and resolve to model.context whenever it's nothing.
  2. Drop the context argument from evaluate!! completely and instead always just "attach" the context to the Model. This seems "nicer" overall, but will require quite a bit of work both here and on the Turing.jl side + it's not quite clear to me that this will indeed quite work (context is used many other places than just in evaluate!!, e.g. unflatten, to allow samplers in SamplingContext to change behaviors further).

Appendum

This entire PR arose from the following scenario:

model = condition(model, x=0) # `model.context` is now a `ConditionContext(x = 0, DefaultContext())`.
f = LogDensityFunction(model)  # Both `f.model.context` and `f.context` are now `ConditionContext(x = 0, DefaultContext())`
f = Accessors.@set f.model = condition(f.model, x=1)  # `f.model.context` is now `ConditionContext(x = 1, ConditionContext(x = 0, DefaultContext()))`
LogDensityProblems.logdensity(f, params) # Here we get `f.context` wrapping `f.model.context`, i.e. `ConditionContext(x = 0, ConditionContext(x = 1, ...))`

A ConditionContext is such that the "outermost" one takes precedence (since this is the one which was applied last), but in the above scenario this is not respected since we end up using the ConditionContext(x = 0, ...) from f.context instead of the outermost one from f.model.context.

@yebai
Copy link
Member

yebai commented Jun 17, 2024

It looks like a good one for @mhauru to review.

@mhauru
Copy link
Member

mhauru commented Jun 17, 2024

I was just today reading through DynamicPPL and noticed that model.context is a thing, and had to take a while to figure out why, given the explicit passing around of a separate context.

Do I understand correctly that functionally the changes here are equivalent to changing

context::AbstractContext=model.context

to

context::AbstractContext=leafcontext(model.context)

and the rest, introducing the nothing and the getcontext function, are an aesthetic preference for not having the LogDensityFunction struct store redundant data?

@yebai
Copy link
Member

yebai commented Jun 17, 2024

I think we should unify these contexts eventually, although not necessarily in this PR.

I lean towards contextualising a model before passing it to a evaluate!! function:

# check for invalid context composition; note that `contextualising!!` could be called more than once
model_with_context  = contextualising!!(model, context) 

res = evaluate!!(rng, model_with_context, ...) # remove context argument here

If a model is conditioned, when we contextualise it again, it can throw an error in cases where context composition is invalid.

This is probably the same as @torfjelde's idea above, removing the context argument from evaluate!! completely but introducing an explicit contextualising!! function.

@torfjelde
Copy link
Member Author

and the rest, introducing the nothing and the getcontext function, are an aesthetic preference for not having the LogDensityFunction struct store redundant data?

It's not so much about "not storing redundant data", but rather about "lazily" resolving the context in the case of f.context === nothing and passing leaftcontext(model.context) is effectively a no-op.

This is probably the same as @torfjelde's idea above, removing the context argument from evaluate!! completely but introducing an explicit contextualising!! function.

We already have this: contextualize(model, context). But otherwise, yes we're on the same page: replace all calls to evaluate!!(model, varinfo, context) with

evaluate!!(contextualize(model, context), varinfo)

But, as I said above, this isn't so easy because we use explicit context-passing quite a few places beyond evaluate!! 😕

@yebai
Copy link
Member

yebai commented Jun 17, 2024

But, as I said above, this isn't so easy because we use explicit context-passing quite a few places beyond evaluate!! 😕

Is there any other difficulty other than finding all the places and then updating them?

@torfjelde
Copy link
Member Author

torfjelde commented Jun 17, 2024

Is there any other difficulty other than finding all the places and then updating them?

It's a question of what you do with methods such as getindex(varinfo, context) (which are there because we need to let samplers pick out slices of varinfo as we haven't yet replaced the Gibbs sampler). Once we have replaced the Gibbs sampler fully, we'll be able to drop a lot of these methods where we use context / sampler and then it'll indeed be "just finding all the places and replacing them"

src/logdensityfunction.jl Outdated Show resolved Hide resolved
src/logdensityfunction.jl Outdated Show resolved Hide resolved
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@torfjelde
Copy link
Member Author

Something weird is happening with Documenter.jl here? Seems like everything is missing some whitespace o.O

@coveralls
Copy link

coveralls commented Jun 18, 2024

Pull Request Test Coverage Report for Build 9560893923

Details

  • 7 of 8 (87.5%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.01%) to 80.344%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/logdensityfunction.jl 7 8 87.5%
Totals Coverage Status
Change from base Build 9495604846: -0.01%
Covered Lines: 2759
Relevant Lines: 3434

💛 - Coveralls

@torfjelde
Copy link
Member Author

This should be ready

@yebai yebai enabled auto-merge June 18, 2024 21:13
@mhauru
Copy link
Member

mhauru commented Jun 25, 2024

The integration test fails because Turing.jl has code that tries to access the LogDensityFunction.context field directly and is not prepared to handle it being nothing. This may be a more prevalent issue, i.e. this PR may be breaking for many dependants.

One solution to this would be to go and fix all the dependants to use getcontext(x) rather than x.context. This would be long-term nicer IMO, but would require making edits to all those packages, and might cause breakage in the meanwhile if dependants don't specify version upper bounds for DynamicPPL. It would also mean that those packages then don't work with older versions of DynamicPPL anymore.

Another option would be to override getproperty for LogDensityFunction, so that x.context would actually return getcontext(x). This is may be a bit overkill in terms of using a somewhat "deep" Julia construct (getproperty) to solve quite a simple problem, but it would be minimally disruptive.

Relevant Julia style guide page: https://docs.julialang.org/en/v1/manual/style-guide/#Prefer-exported-methods-over-direct-field-access

I lean towards the former solution. Other thoughts?

@yebai
Copy link
Member

yebai commented Jun 25, 2024

One solution to this would be to go and fix all the dependants to use getcontext(x) rather than x.context. This would be long-term nicer IMO, but would require making edits to all those packages, and might cause breakage in the meanwhile if dependants don't specify version upper bounds for DynamicPPL.

@mhauru, can you create a PR for Turing that adopts the suggestion you propose above? For packages without DynamicPPL bounds, that's unfortunate, maybe this is the opportunity that such bounds are added. However, I am not aware of any package depending on DynamicPPL without an explicit version bound.

Also, does that mean this PR can be merged as a breaking release?

Project.toml Outdated Show resolved Hide resolved
@mhauru
Copy link
Member

mhauru commented Jun 25, 2024

Yep, we can make this a breaking release and be fine. I'll make the Turing.jl PR tomorrow.

Co-authored-by: Hong Ge <[email protected]>
Copy link
Contributor

Pull Request Test Coverage Report for Build 9666211295

Details

  • 6 of 8 (75.0%) changed or added relevant lines in 1 file are covered.
  • 103 unchanged lines in 10 files lost coverage.
  • Overall coverage decreased (-3.7%) to 76.624%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/logdensityfunction.jl 6 8 75.0%
Files with Coverage Reduction New Missed Lines %
src/logdensityfunction.jl 1 61.9%
src/sampler.jl 1 94.12%
ext/DynamicPPLForwardDiffExt.jl 1 77.78%
src/contexts.jl 2 77.27%
src/threadsafe.jl 4 50.0%
src/abstract_varinfo.jl 5 82.68%
src/context_implementations.jl 8 58.63%
src/model_utils.jl 10 19.64%
src/loglikelihoods.jl 16 54.84%
src/varinfo.jl 55 85.36%
Totals Coverage Status
Change from base Build 9495604846: -3.7%
Covered Lines: 2642
Relevant Lines: 3448

💛 - Coveralls

1 similar comment
@coveralls
Copy link

coveralls commented Jun 25, 2024

Pull Request Test Coverage Report for Build 9666211295

Details

  • 6 of 8 (75.0%) changed or added relevant lines in 1 file are covered.
  • 103 unchanged lines in 10 files lost coverage.
  • Overall coverage decreased (-3.7%) to 76.624%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/logdensityfunction.jl 6 8 75.0%
Files with Coverage Reduction New Missed Lines %
src/logdensityfunction.jl 1 61.9%
src/sampler.jl 1 94.12%
ext/DynamicPPLForwardDiffExt.jl 1 77.78%
src/contexts.jl 2 77.27%
src/threadsafe.jl 4 50.0%
src/abstract_varinfo.jl 5 82.68%
src/context_implementations.jl 8 58.63%
src/model_utils.jl 10 19.64%
src/loglikelihoods.jl 16 54.84%
src/varinfo.jl 55 85.36%
Totals Coverage Status
Change from base Build 9495604846: -3.7%
Covered Lines: 2642
Relevant Lines: 3448

💛 - Coveralls

@coveralls
Copy link

coveralls commented Jun 25, 2024

Pull Request Test Coverage Report for Build 9666211295

Details

  • 6 of 8 (75.0%) changed or added relevant lines in 1 file are covered.
  • 103 unchanged lines in 10 files lost coverage.
  • Overall coverage decreased (-2.8%) to 77.561%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/logdensityfunction.jl 6 8 75.0%
Files with Coverage Reduction New Missed Lines %
src/logdensityfunction.jl 1 61.9%
src/sampler.jl 1 94.12%
ext/DynamicPPLForwardDiffExt.jl 1 77.78%
src/contexts.jl 2 77.27%
src/threadsafe.jl 4 50.0%
src/abstract_varinfo.jl 5 82.68%
src/context_implementations.jl 8 58.63%
src/model_utils.jl 10 19.64%
src/loglikelihoods.jl 16 54.84%
src/varinfo.jl 55 85.36%
Totals Coverage Status
Change from base Build 9495604846: -2.8%
Covered Lines: 2658
Relevant Lines: 3427

💛 - Coveralls

@yebai yebai disabled auto-merge June 25, 2024 16:53
@yebai yebai merged commit 2b97177 into master Jun 25, 2024
10 of 11 checks passed
@yebai yebai deleted the torfjelde/logdensityfunction-fix-v2 branch June 25, 2024 16:53
@torfjelde
Copy link
Member Author

Thanks for getting this through!

Another option would be to override getproperty for LogDensityFunction, so that x.context would actually return getcontext(x). This is may be a bit overkill in terms of using a somewhat "deep" Julia construct (getproperty) to solve quite a simple problem, but it would be minimally disruptive.

Agree that this would have been overkill:)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants