Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move predict from Turing #716

Merged
merged 19 commits into from
Dec 20, 2024
Merged

Move predict from Turing #716

merged 19 commits into from
Dec 20, 2024

Conversation

sunxd3
Copy link
Member

@sunxd3 sunxd3 commented Nov 12, 2024

This PR migrates the predict function from Turing.jl to DynamicPPL while maintaining its existing interface and core implementation. Since predict returns a MCMCChains.Chain, the implementation is placed in MCMCChainsExt, similar to generated_quantities.

The purpose of the PR is not to add a "proper" predict implementation for DynamicPPL just yet, but as a first step towards that. Some improvements we should make in the future:

  1. merge Add StatsBase.predict to the interface AbstractPPL.jl#81 and export predict through StatsBase.predict
  2. add support for more input types: NamedTuple, OrderedDict, VarInfo, etc.
  3. values_as_in_model is probably wrong (ref Better support for := Turing.jl#2409)

@sunxd3 sunxd3 marked this pull request as draft November 13, 2024 08:52
sunxd3 and others added 2 commits November 13, 2024 09:21
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@sunxd3
Copy link
Member Author

sunxd3 commented Nov 14, 2024

Some tests still fail: the mean of the predictions looks correct, but it seems the variance is high. Not certain where goes wrong, so need further investigation.

The reason is some tests implicitly rely on the variance of the posterior samples. Discarding some initial samples fixes this. Turing do this by default, but via LogDensityFunction we need do the discarding explicitly.

sunxd3 and others added 4 commits November 18, 2024 11:06
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@coveralls
Copy link

coveralls commented Nov 18, 2024

Pull Request Test Coverage Report for Build 12435577043

Details

  • 44 of 45 (97.78%) changed or added relevant lines in 2 files are covered.
  • 26 unchanged lines in 2 files lost coverage.
  • Overall coverage increased (+0.1%) to 86.137%

Changes Missing Coverage Covered Lines Changed/Added Lines %
ext/DynamicPPLMCMCChainsExt.jl 37 38 97.37%
Files with Coverage Reduction New Missed Lines %
src/model.jl 5 77.39%
src/threadsafe.jl 21 46.61%
Totals Coverage Status
Change from base Build 12384706841: 0.1%
Covered Lines: 3722
Relevant Lines: 4321

💛 - Coveralls

Copy link

codecov bot commented Nov 18, 2024

Codecov Report

Attention: Patch coverage is 97.77778% with 1 line in your changes missing coverage. Please review.

Project coverage is 86.05%. Comparing base (d0cfaaf) to head (da7fa1c).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
ext/DynamicPPLMCMCChainsExt.jl 97.36% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #716      +/-   ##
==========================================
+ Coverage   85.93%   86.05%   +0.12%     
==========================================
  Files          36       36              
  Lines        4280     4325      +45     
==========================================
+ Hits         3678     3722      +44     
- Misses        602      603       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@sunxd3
Copy link
Member Author

sunxd3 commented Nov 18, 2024

We had a fast discussion on this today at the meeting. Tor raised that we should probably implement predict that take generic Vector as the second argument (instead of just Chain), this is because predict works with sample, and sample can produce non-Chain type returns.

Also although we don't use fix for this PR yet, it is worthwhile to have some nice and better thought-out implementations.

@torfjelde
Copy link
Member

Vector as the second argument

Specifically, I was thinking Vector{<:VarInfo}:) But otherwise, this sounds very good 👍

sunxd3 and others added 3 commits November 21, 2024 12:42
src/model.jl Outdated
varinfos::AbstractArray{<:AbstractVarInfo};
include_all=false,
)
predictive_samples = Array{PredictiveSample}(undef, size(varinfos))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need the PredictiveSample here?

My original suggestion was just to use Vector{<:OrderedDict} for the return-value (an abstractly typed PredictiveSample doesn't really offer anything beyond this, does it?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't think too deep about this. A new type certainly is easier to dispatch on, but may not be necessary. Let me look into it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we don't need to dispatch on this, do we?

Also, maybe it makes more sense to follow the convetion of return the same type as the input type, i.e. in this case we should return a AbstractArray{<:AbstractVarInfo} and in the Chains case we return Chains

@torfjelde
Copy link
Member

Otherwise stuff is starting to look nice though:)

src/model.jl Outdated
varinfos::AbstractArray{<:AbstractVarInfo};
include_all=false,
)
predictive_samples = similar(varinfos, OrderedDict{Symbol,Any})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a resaon why you're using Symbol instead of VarName here? Seems better to use VarName, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is confusing. The OrderedDict here is actually

OrderedDict{Symbol, Any}(
    values => ..., # a vector of Tuples (varname, value)
    logp =>
)

using NamedTuple now, and use better field names

Copy link
Member

@torfjelde torfjelde Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if your keeping other information than just the realizations (which, tbh, is IMO all we need here), why aren't we just returning the varinfos themselves (I suggested this is in the other comment here)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah gotcha 👍 , I totally misunderstood: I was reading the AbstractVector part but somehow ignored {<:AbstractVarInfo} part.

I can get behind the idea of using a vector of VarInfo for predict and return a vector of VarInfos. But I think the interface need to be spec-ed more. For instance, ideally we want be more clear on questions like: in the returned VarInfos, should the VarName be varname leaves or as appeared in the model; should the values in the returned VarInfo in transformed or constrained space; how exactly should model and input VarInfos conform to each other.

I am a bit short for time now, so after some thoughts, I think it's probably a good idea now to just keep all the logic in MCMCChainsExt and maintain exactly the same interface Turing.jl has now. Then in the future, we can work on to improve predict interface.

Copy link
Member

@torfjelde torfjelde Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hould the VarName be varname leaves or as appeared in the model

As appeared in the model:)

should the values in the returned VarInfo in transformed or constrained spac

Constrained space.

how exactly should model and input VarInfos conform to each other.

Confused; what do you mean?

Overall, I'm still a bit confused by this discussion: Turing.jl's predict literally does: iterate over chain, create varinfo, evaluate model on varinfo, and extract variables from varinfo.

So, why do we not just do

# In DynamicPPL.jl proper:
function predict(rng::Random.AbstractRNG, model::Model, chain::AbstractVector{<:AbstractVarInfo})
    varinfo = DynamicPPL.VarInfo(model)
    return map(chain) do varinfo_params
        DynamicPPL.setval_and_resample!(varinfo, varinfo_params)
        model(rng, varinfo)
        return deepcopy(varinfo)
    end
end

which is effectively what Turing.jl's predict does before converting into a Chains?

EDIT: This is ignoring the values_as_in_model which apparently is used in Turing.jl's predict, though, as mentioned in the other comment, it's very unclear if that's what we want here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sorry it was a bit confusing.

I am thinking that it'll be more intuitive for predict to hold that

predicted_vis = predict(rng, model, varinfos)

then

_varinfos = predict(rng, model, predicted_vis)

returns varinfos that looks like varinfos.

But if the values are in constrained space, can this break?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also does the above code return varinfos with values in constainted space?

@sunxd3 sunxd3 marked this pull request as ready for review December 1, 2024 12:19
Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few comments:+1:

The current impl has different behavior from Turing.predict in a few different ways, so we should address these issues before merging.

@@ -2,6 +2,7 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this doesn't quite seem worth it to test predict, no? What's the reasoning here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add anything or change the implementation in this PR.

Agree AHMC is heavy dep, but tests like https://github.com/TuringLang/DynamicPPL.jl/blob/fd1277b7201477448d3257cab65557b850bcf5b4/test/ext/DynamicPPLMCMCChainsExt.jl#L48C1-L55C45
rely on quality of samples

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but should just replace them with samples from the prior or something. This is just checking that the statistics are correct; it doesn't matter if these statistics are from the prior or posterior 🤷

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, would it be really bad to make AdvancedHMC be a test dependency of DynamicPPL? (again, I don't like this either, but it's not too bad, I would be for adding an issue for removing this dependency later than tempering more with this PR anymore)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't look at this PR properly until Wednesday, but in https://github.com/TuringLang/DynamicPPL.jl/pull/733/files#diff-3981168ff1709b3f48c35e40f491c26d9b91fc29373e512f1272f3b928cea6c0 I wrote a function that generates a chain by sampling from the prior. (It's called make_chain_from_prior if the link doesn't bring you to the right place)
Feel free to take it if you think it's useful :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sunxd3, @penelopeysm the posterior of Bayesian linear regression can be obtained in closed form (i.e. it is a Gaussian, see here). I suggest that

  1. add this BLR model to DynamicPPL test models
  2. implement its analytical posterior
  3. sample from the analytical posterior directly and drop the AHMC deps.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though the closed-form posterior is a good idea, there's really no need to run this test on posterior samples:) These were just some stats that were picked to have something to compare to; prior chain is the way to go I think 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prior chain make sense: should we generate samples from prior, take out samples of a particular variable, and try to predict it?

src/model.jl Outdated
Comment on lines 1216 to 1220
function predict(model::Model, chain; include_all=false)
# this is only defined in `ext/DynamicPPLMCMCChainsExt.jl`
# TODO: add other methods for different type of `chain` arguments: e.g., `VarInfo`, `NamedTuple`, and `OrderedDict`
return predict(Random.default_rng(), model, chain; include_all)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, we should definitively inform the user of this, no? Otherwise they'll just be like "oh why is this not defined?"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to export predict right now, so predict is only available through Turing.jl, give or take.

would function not defined be meaningful enough if user give other types of input?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If Turing exports it, it's better for DynamicPPL to export it, too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I was proposing delaying this until a good predict spec is reached

m_lin_reg = linear_reg(xs_train, ys_train)
chain_lin_reg = sample(
DynamicPPL.LogDensityFunction(m_lin_reg),
AdvancedHMC.NUTS(0.65),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really doesn't seem necessary to use NUTS here. Just construct a Chains by hand or something, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same reason as above: some tests relies on the quality of the samples


# Examples
```jldoctest
julia> using DynamicPPL, AbstractMCMC, AdvancedHMC, ForwardDiff;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here: no need to use AdvancedHMC (or any of the other packages), just construct the Chains by hand.
This also doesn't actually show that you need to import MCMCChains for this to work, which might be a good idea

)
model(rng, varinfo, DynamicPPL.SampleFromPrior())

vals = DynamicPPL.values_as_in_model(model, varinfo)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually changing the behavior from Turing.jl's implementation. This will result in also including variables used in := statements, which is not currently done.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooooh nice catch; thanks! Hmm, uncertain if this is desired behavior though 😕

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw your issue on :=, totally understand the concern here. But if we are not exporting predict, we can change this in near future, also we might want to use fix in the future, so the behavior will be right then.

We would need to make a minor release of Turing if we change this now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But isn't this the purpose of this PR? To move the predict from Turing.jl to DynamicPPL.jl?

also we might want to use fix in the future

Whether we're using fix or not is just an internal impl detail, and is not relevant for its usage, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But isn't this the purpose of this PR? To move the predict from Turing.jl to DynamicPPL.jl?

Ideally, I would want this PR to do a proper implementation of predict in DynamicPPL. But now, I am okay with the PR being only a first step towards that.

Whether we're using fix or not is just an internal impl detail, and is not relevant for its usage, right?

what I was trying to say is that, with fix it should have the right behavior (with regards to :=). Of course not the only way to reach the desired behavior.

Copy link
Member

@yebai yebai Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improving it in a separate PR sounds good, but please create an issue to track @torfjelde's comment.

src/model.jl Outdated
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
predictive distribution.
"""
function predict(model::Model, chain; include_all=false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Turing.jl we're currently overloading StatsBase.predict, so we should probably do the same here, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree with this, but probably not time yet. Definitely after TuringLang/AbstractPPL.jl#81 is merged 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But is this PR then held up until that PR is merged then?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, that PR doesn't really matter; overloading StatsBase.predict here and now just means that we'll immediately be compliant with the AbstractPPL.jl interface when that PR merges?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grey area: for me it is okay, because this PR is just about introduce a Turing-faced predict, not a user faced one yet. At the moment predict is not a public API yet

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If nothing significant is missing in TuringLang/AbstractPPL.jl#81, let's merge it and overload AbstractPPL.predict here.

@sunxd3 sunxd3 requested a review from penelopeysm December 2, 2024 11:32
@yebai
Copy link
Member

yebai commented Dec 16, 2024

@sunxd3, let's get this merged in the next few days.

@sunxd3
Copy link
Member Author

sunxd3 commented Dec 16, 2024

will do, on top of my priority list

@sunxd3 sunxd3 mentioned this pull request Dec 20, 2024
@sunxd3
Copy link
Member Author

sunxd3 commented Dec 20, 2024

Some regression test for TuringLang/Turing.jl#1352 are removed, as far as I can tell, it should be covered by tests of values_as_in_model.

@sunxd3
Copy link
Member Author

sunxd3 commented Dec 20, 2024

The tests are failing because fix, condition are exported by AbstractPPL, while DynamicPPL currently doesn't actually import these from AbstractPPL.

edit: this is inaccurate, condition is actually imported, but fix and unfix are not -- I think this is because they were first introduced into DynamicPPL, then AbstractPPL?

@sunxd3
Copy link
Member Author

sunxd3 commented Dec 20, 2024

@yebai @torfjelde @penelopeysm I think this should be ready for another look

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are failing because fix, condition are exported by AbstractPPL, while DynamicPPL currently doesn't actually import these from AbstractPPL.

Let's fix these in this PR if possible.

@sunxd3
Copy link
Member Author

sunxd3 commented Dec 20, 2024

I think the tests are run, but the codecov thinks the code in MCMCChainsExt is not covered. Is this known issue, or I did something wrong?

@yebai yebai added this pull request to the merge queue Dec 20, 2024
Merged via the queue into master with commit 6657441 Dec 20, 2024
19 of 20 checks passed
@yebai yebai deleted the sunxd/move_predict branch December 20, 2024 18:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants