Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pathwise Sampling (Take 2) #112

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
6 changes: 6 additions & 0 deletions docs/src/api/pathwisesampling.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Pathwise Posterior Sampling

```@autodocs
Modules = [ApproximateGPs.PathwiseSamplingModule]
Private = false
```
8 changes: 8 additions & 0 deletions examples/d-pathwise-sampling/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
ApproximateGPs = "298c2ebc-0411-48ad-af38-99e88101b606"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomFourierFeatures = "46bd3b77-1c99-43cd-804b-1e14830792b5"
54 changes: 54 additions & 0 deletions examples/d-pathwise-sampling/script.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# ## Setup

using RandomFourierFeatures
using ApproximateGPs
using AbstractGPs
using Distributions
using LinearAlgebra
using Random

rng = MersenneTwister(1234)

# Define a GP and generate some data

k = 3 * (SqExponentialKernel() ∘ ScaleTransform(10))
gp = GP(k)

input_dims = 1

x = ColVecs(rand(input_dims, 20))
fx = gp(x, 0.01)
y = rand(fx)

z = x[1:8]
fz = gp(z)

# Any of the following will work:

# q = ApproximateGPs._optimal_variational_posterior(NonCentered(), fz, fx, y)
# ap = posterior(SparseVariationalApproximation(NonCentered(), fz, q))

# q = ApproximateGPs._optimal_variational_posterior(Centered(), fz, fx, y)
# ap = posterior(SparseVariationalApproximation(NonCentered(), fz, q))

ap = posterior(VFE(fz), fx, y)

x_test = ColVecs(sort(rand(input_dims, 500); dims=2))

num_features = 1000
rff_wsa = build_rff_weight_space_approx(rng, input_dims, num_features)

function_samples = ApproximateGPs.pathwise_sample(rng, ap, rff_wsa, 100)

y_samples = reduce(hcat, map((f) -> f(x_test), function_samples)) # size(y_samples): (length(x_plot), n_samples)

# Plot sampled functions against the exact posterior

using Plots

x_plot = x_test.X'

plot(x_plot, y_samples; label="", color=:red, linealpha=0.2)
plot!(vec(x_plot), ap; color=:blue, label="True posterior")
scatter!(x.X', y; label="data")
vline!(vec(z.X'); label="inducing points", color=:orange)
3 changes: 3 additions & 0 deletions src/ApproximateGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ include("LaplaceApproximationModule.jl")
@reexport using .LaplaceApproximationModule:
build_laplace_objective, build_laplace_objective!

include("PathwiseSamplingModule.jl")
@reexport using .PathwiseSamplingModule: pathwise_sample

include("deprecations.jl")

end
80 changes: 80 additions & 0 deletions src/PathwiseSamplingModule.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
module PathwiseSamplingModule

export pathwise_sample

using Random

using ..ApproximateGPs: _chol_cov
using ..SparseVariationalApproximationModule:
SparseVariationalApproximation, Centered, NonCentered, _get_q_u

using AbstractGPs:
ApproxPosteriorGP, VFE, inducing_points, Xt_invA_X, Xt_A_X, inducing_points
using PDMats: chol_lower
using Distributions

struct PosteriorSample{Tapprox<:ApproxPosteriorGP,Tprior,Tv}
approx_post::Tapprox # The approximate posterior GP from which this sample is taken
prior_sample::Tprior # A function sampled from the prior of `approx_post`
v::Tv # The term needed to compute the pathwise update to the prior sample
end
function (s::PosteriorSample)(x::AbstractVector)
return s.prior_sample(x) + cov(s.approx_post, x, inducing_points(s.approx_post)) * s.v
end

@doc raw"""
pathwise_sample(rng::Random.AbstractRNG, f::ApproxPosteriorGP, weight_space_approx[, num_samples::Integer])

Efficiently samples a function from a sparse approximate posterior GP `f`.
Returns a function which can be evaluated at any input locations `X`.
`weight_space_approx` must be a function which takes a prior `AbstractGP` as an
argument and returns a `BayesianLinearRegressors.BasisFunctionRegressor`,
representing a weight space approximation to the prior of `f`. An example of
such a function can be constructed with
`RandomFourierFeatures.build_rff_weight_space_approx`.

If `num_samples` is supplied as an argument, returns a Vector of function
samples.

Details of the method can be found in [1].

[1] - Wilson, James, et al. "Efficiently sampling functions from Gaussian
process posteriors." International Conference on Machine Learning. PMLR, 2020.
"""
function pathwise_sample(rng::Random.AbstractRNG, f::ApproxPosteriorGP, weight_space_approx)
prior_approx = weight_space_approx(f.prior)
prior_sample = rand(rng, prior_approx)

z = inducing_points(f)
q_u = _get_q_u(f)

u = rand(rng, q_u)
v = cov(f, z) \ (u - prior_sample(z))

return PosteriorSample(f, prior_sample, v)
end
pathwise_sample(f::ApproxPosteriorGP, wsa) = pathwise_sample(Random.GLOBAL_RNG, f, wsa)

function pathwise_sample(
rng::Random.AbstractRNG, f::ApproxPosteriorGP, weight_space_approx, num_samples::Integer
)
prior_approx = weight_space_approx(f.prior)
prior_samples = rand(rng, prior_approx, num_samples)

z = inducing_points(f)
q_u = _get_q_u(f)

us = rand(rng, q_u, num_samples)

vs = cov(f, z) \ (us - reduce(hcat, map((s) -> s(z), prior_samples)))

posterior_samples = [
PosteriorSample(f, s, v) for (s, v) in zip(prior_samples, eachcol(vs))
]
return posterior_samples
end
function pathwise_sample(f::ApproxPosteriorGP, wsa, num_samples::Integer)
return pathwise_sample(Random.GLOBAL_RNG, f, wsa, num_samples)
end

end
59 changes: 57 additions & 2 deletions src/SparseVariationalApproximationModule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ using AbstractGPs:
FiniteGP,
LatentFiniteGP,
ApproxPosteriorGP,
VFE,
posterior,
marginals,
At_A,
diag_At_A
diag_At_A,
Xt_A_X,
Xt_invA_X,
inducing_points
using GPLikelihoods: GaussianLikelihood

export DefaultQuadrature, Analytic, GaussHermite, MonteCarlo
Expand Down Expand Up @@ -263,7 +267,9 @@ end
# Misc utility.
#

inducing_points(f::ApproxPosteriorGP{<:SparseVariationalApproximation}) = f.approx.fz.x
function AbstractGPs.inducing_points(f::ApproxPosteriorGP{<:SparseVariationalApproximation})
return f.approx.fz.x
end

#
# elbo
Expand Down Expand Up @@ -372,4 +378,53 @@ function _prior_kl(sva::SparseVariationalApproximation{NonCentered})
return (trace_term + m_ε'm_ε - length(m_ε) - logdet(C_ε)) / 2
end

# Methods to get the explicit variational distribution over inducing points q(u)
function _get_q_u(f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}})
# u = Lε + μ where LLᵀ = cov(fz) and μ = mean(fz)
# q(ε) = N(m, S)
# => q(u) = N(Lm + μ, LSLᵀ)
L, μ = chol_lower(_chol_cov(f.approx.fz)), mean(f.approx.fz)
m, S = mean(f.approx.q), _chol_cov(f.approx.q)
return MvNormal(L * m + μ, Xt_A_X(S, L'))
end
_get_q_u(f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}) = f.approx.q

function _get_q_u(f::ApproxPosteriorGP{<:AbstractGPs.VFE})
# q(u) = N(m, S)
# q(f_k) = N(μ_k, Σ_k) (the predictive distribution at test inputs k)
# μ_k = mean(k) + K_kz * K_zz⁻¹ * m
# where: K_kz = cov(f.prior, k, z)
# implemented as: μ_k = mean(k) + K_kz * α
# => m = K_zz * α
# Σ_k = K_kk - (K_kz * K_zz⁻¹ * K_zk) + (K_kz * K_zz⁻¹ * S * K_zz⁻¹ * K_zk)
# interested in the last term to get S
# implemented as: Aᵀ * Λ_ε⁻¹ * A
# where: A = U⁻ᵀ * K_zk
# UᵀU = K_zz
# so, Λ_ε⁻¹ = U⁻ᵀ * S * U
# => S = Uᵀ * Λ_ε⁻¹ * U
# see https://krasserm.github.io/2020/12/12/gaussian-processes-sparse/ eqns (8) & (9)
U = f.data.U
m = U'U * f.data.α
S = Xt_invA_X(f.data.Λ_ε, U)
return MvNormal(m, S)
end

# Get the optimal closed form solution for the centered variational posterior q(u)
function _optimal_variational_posterior(::Centered, fz, fx, y)
fz.f.mean isa AbstractGPs.ZeroMean ||
error("The exact posterior requires a GP with ZeroMean.")
post = posterior(VFE(fz), fx, y)
return _get_q_u(post)
end

# Get the optimal closed form solution for the non-centered variational posterior q(ε)
function _optimal_variational_posterior(::NonCentered, fz, fx, y)
fz.f.mean isa AbstractGPs.ZeroMean ||
error("The exact posterior requires a GP with ZeroMean.")
q_u = _optimal_variational_posterior(Centered(), fz, fx, y)
Cuu = cholesky(Symmetric(cov(fz)))
return MvNormal(Cuu.L \ (mean(q_u) - mean(fz)), Symmetric((Cuu.L \ cov(q_u)) / Cuu.U))
end

end
62 changes: 62 additions & 0 deletions test/PathwiseSamplingModule.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
@testset "pathwise_sampling" begin
rng = MersenneTwister(1453)

kernel = 0.1 * (SqExponentialKernel() ∘ ScaleTransform(0.2))
Σy = 1e-6
f = GP(kernel)

input_dims = 2
X = ColVecs(rand(rng, input_dims, 8))

fx = f(X, Σy)
y = rand(fx)

Z = X[1:4]
fz = f(Z)

X_test = ColVecs(rand(rng, input_dims, 3))

num_features = 10000
rff_wsa = build_rff_weight_space_approx(rng, input_dims, num_features)

num_samples = 1000

function test_single_sample_stats(ap, num_samples)
return test_stats(ap, [pathwise_sample(rng, ap, rff_wsa) for _ in 1:num_samples])
end

function test_multi_sample_stats(ap, num_samples)
return test_stats(ap, pathwise_sample(rng, ap, rff_wsa, num_samples))
end

function test_stats(ap, function_samples)
y_samples = reduce(hcat, map((f) -> f(X_test), function_samples))
m_empirical = mean(y_samples; dims=2)
Σ_empirical =
(y_samples .- m_empirical) * (y_samples .- m_empirical)' ./ num_samples

@test mean(ap(X_test)) ≈ m_empirical atol = 1e-3 rtol = 1e-3
@test cov(ap(X_test)) ≈ Σ_empirical atol = 1e-3 rtol = 1e-3
end

@testset "Centered SVA" begin
q = _optimal_variational_posterior(Centered(), fz, fx, y)
ap = posterior(SparseVariationalApproximation(Centered(), fz, q))

test_single_sample_stats(ap, num_samples)
test_multi_sample_stats(ap, num_samples)
end
@testset "NonCentered SVA" begin
q = _optimal_variational_posterior(NonCentered(), fz, fx, y)
ap = posterior(SparseVariationalApproximation(NonCentered(), fz, q))

test_single_sample_stats(ap, num_samples)
test_multi_sample_stats(ap, num_samples)
end
@testset "VFE" begin
ap = posterior(AbstractGPs.VFE(fz), fx, y)

test_single_sample_stats(ap, num_samples)
test_multi_sample_stats(ap, num_samples)
end
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomFourierFeatures = "46bd3b77-1c99-43cd-804b-1e14830792b5"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down
18 changes: 7 additions & 11 deletions test/SparseVariationalApproximationModule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
fz = f(z, 1e-6)

# Construct approximate posterior.
q_Centered = optimal_variational_posterior(fz, fx, y)
q_Centered = _optimal_variational_posterior(Centered(), fz, fx, y)
approx_Centered = SparseVariationalApproximation(Centered(), fz, q_Centered)
f_approx_post_Centered = posterior(approx_Centered)

Expand All @@ -32,17 +32,13 @@
end

@testset "NonCentered" begin

# Construct optimal approximate posterior.
q = optimal_variational_posterior(fz, fx, y)
Cuu = cholesky(Symmetric(cov(fz)))
q_ε = MvNormal(
Cuu.L \ (mean(q) - mean(fz)), Symmetric((Cuu.L \ cov(q)) / Cuu.U)
)
q_ε = _optimal_variational_posterior(NonCentered(), fz, fx, y)

@testset "Check that q_ε has been properly constructed" begin
@test mean(q) ≈ mean(fz) + Cuu.L * mean(q_ε)
@test cov(q) ≈ Cuu.L * cov(q_ε) * Cuu.U
Cuu = cholesky(Symmetric(cov(fz)))
@test mean(q_Centered) ≈ mean(fz) + Cuu.L * mean(q_ε)
@test cov(q_Centered) ≈ Cuu.L * cov(q_ε) * Cuu.U
end

# Construct equivalent approximate posteriors.
Expand Down Expand Up @@ -79,7 +75,7 @@
f = GP(kernel)
fx = f(x, 0.1)
fz = f(z)
q_ex = optimal_variational_posterior(fz, fx, y)
q_ex = _optimal_variational_posterior(Centered(), fz, fx, y)

sva = SparseVariationalApproximation(fz, q_ex)
@test elbo(sva, fx, y) isa Real
Expand Down Expand Up @@ -115,7 +111,7 @@
f = GP(kernel)
fx = f(x, lik_noise)
fz = f(z)
q_ex = optimal_variational_posterior(fz, fx, y)
q_ex = _optimal_variational_posterior(Centered(), fz, fx, y)

gpr_post = posterior(fx, y) # Exact GP regression
vfe_post = posterior(VFE(fz), fx, y) # Titsias posterior
Expand Down
Loading