diff --git a/Project.toml b/Project.toml index 4faf2449..77472fd3 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/docs/src/api/pathwisesampling.md b/docs/src/api/pathwisesampling.md new file mode 100644 index 00000000..93363656 --- /dev/null +++ b/docs/src/api/pathwisesampling.md @@ -0,0 +1,6 @@ +# Pathwise Posterior Sampling + +```@autodocs +Modules = [ApproximateGPs.PathwiseSamplingModule] +Private = false +``` diff --git a/examples/d-pathwise-sampling/Project.toml b/examples/d-pathwise-sampling/Project.toml new file mode 100644 index 00000000..97a47f1c --- /dev/null +++ b/examples/d-pathwise-sampling/Project.toml @@ -0,0 +1,8 @@ +[deps] +AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" +ApproximateGPs = "298c2ebc-0411-48ad-af38-99e88101b606" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RandomFourierFeatures = "46bd3b77-1c99-43cd-804b-1e14830792b5" diff --git a/examples/d-pathwise-sampling/script.jl b/examples/d-pathwise-sampling/script.jl new file mode 100644 index 00000000..bf050494 --- /dev/null +++ b/examples/d-pathwise-sampling/script.jl @@ -0,0 +1,54 @@ +# ## Setup + +using RandomFourierFeatures +using ApproximateGPs +using AbstractGPs +using Distributions +using LinearAlgebra +using Random + +rng = MersenneTwister(1234) + +# Define a GP and generate some data + +k = 3 * (SqExponentialKernel() ∘ ScaleTransform(10)) +gp = GP(k) + +input_dims = 1 + +x = ColVecs(rand(input_dims, 20)) +fx = gp(x, 0.01) +y = rand(fx) + +z = x[1:8] +fz = gp(z) + +# Any of the following will work: + +# q = ApproximateGPs._optimal_variational_posterior(NonCentered(), fz, fx, y) +# ap = posterior(SparseVariationalApproximation(NonCentered(), fz, q)) + +# q = ApproximateGPs._optimal_variational_posterior(Centered(), fz, fx, y) +# ap = posterior(SparseVariationalApproximation(NonCentered(), fz, q)) + +ap = posterior(VFE(fz), fx, y) + +x_test = ColVecs(sort(rand(input_dims, 500); dims=2)) + +num_features = 1000 +rff_wsa = build_rff_weight_space_approx(rng, input_dims, num_features) + +function_samples = ApproximateGPs.pathwise_sample(rng, ap, rff_wsa, 100) + +y_samples = reduce(hcat, map((f) -> f(x_test), function_samples)) # size(y_samples): (length(x_plot), n_samples) + +# Plot sampled functions against the exact posterior + +using Plots + +x_plot = x_test.X' + +plot(x_plot, y_samples; label="", color=:red, linealpha=0.2) +plot!(vec(x_plot), ap; color=:blue, label="True posterior") +scatter!(x.X', y; label="data") +vline!(vec(z.X'); label="inducing points", color=:orange) diff --git a/src/ApproximateGPs.jl b/src/ApproximateGPs.jl index 8797857f..3df44631 100644 --- a/src/ApproximateGPs.jl +++ b/src/ApproximateGPs.jl @@ -21,6 +21,9 @@ include("LaplaceApproximationModule.jl") @reexport using .LaplaceApproximationModule: build_laplace_objective, build_laplace_objective! +include("PathwiseSamplingModule.jl") +@reexport using .PathwiseSamplingModule: pathwise_sample + include("deprecations.jl") end diff --git a/src/PathwiseSamplingModule.jl b/src/PathwiseSamplingModule.jl new file mode 100644 index 00000000..35562bf9 --- /dev/null +++ b/src/PathwiseSamplingModule.jl @@ -0,0 +1,80 @@ +module PathwiseSamplingModule + +export pathwise_sample + +using Random + +using ..ApproximateGPs: _chol_cov +using ..SparseVariationalApproximationModule: + SparseVariationalApproximation, Centered, NonCentered, _get_q_u + +using AbstractGPs: + ApproxPosteriorGP, VFE, inducing_points, Xt_invA_X, Xt_A_X, inducing_points +using PDMats: chol_lower +using Distributions + +struct PosteriorSample{Tapprox<:ApproxPosteriorGP,Tprior,Tv} + approx_post::Tapprox # The approximate posterior GP from which this sample is taken + prior_sample::Tprior # A function sampled from the prior of `approx_post` + v::Tv # The term needed to compute the pathwise update to the prior sample +end +function (s::PosteriorSample)(x::AbstractVector) + return s.prior_sample(x) + cov(s.approx_post, x, inducing_points(s.approx_post)) * s.v +end + +@doc raw""" + pathwise_sample(rng::Random.AbstractRNG, f::ApproxPosteriorGP, weight_space_approx[, num_samples::Integer]) + +Efficiently samples a function from a sparse approximate posterior GP `f`. +Returns a function which can be evaluated at any input locations `X`. +`weight_space_approx` must be a function which takes a prior `AbstractGP` as an +argument and returns a `BayesianLinearRegressors.BasisFunctionRegressor`, +representing a weight space approximation to the prior of `f`. An example of +such a function can be constructed with +`RandomFourierFeatures.build_rff_weight_space_approx`. + +If `num_samples` is supplied as an argument, returns a Vector of function +samples. + +Details of the method can be found in [1]. + +[1] - Wilson, James, et al. "Efficiently sampling functions from Gaussian +process posteriors." International Conference on Machine Learning. PMLR, 2020. +""" +function pathwise_sample(rng::Random.AbstractRNG, f::ApproxPosteriorGP, weight_space_approx) + prior_approx = weight_space_approx(f.prior) + prior_sample = rand(rng, prior_approx) + + z = inducing_points(f) + q_u = _get_q_u(f) + + u = rand(rng, q_u) + v = cov(f, z) \ (u - prior_sample(z)) + + return PosteriorSample(f, prior_sample, v) +end +pathwise_sample(f::ApproxPosteriorGP, wsa) = pathwise_sample(Random.GLOBAL_RNG, f, wsa) + +function pathwise_sample( + rng::Random.AbstractRNG, f::ApproxPosteriorGP, weight_space_approx, num_samples::Integer +) + prior_approx = weight_space_approx(f.prior) + prior_samples = rand(rng, prior_approx, num_samples) + + z = inducing_points(f) + q_u = _get_q_u(f) + + us = rand(rng, q_u, num_samples) + + vs = cov(f, z) \ (us - reduce(hcat, map((s) -> s(z), prior_samples))) + + posterior_samples = [ + PosteriorSample(f, s, v) for (s, v) in zip(prior_samples, eachcol(vs)) + ] + return posterior_samples +end +function pathwise_sample(f::ApproxPosteriorGP, wsa, num_samples::Integer) + return pathwise_sample(Random.GLOBAL_RNG, f, wsa, num_samples) +end + +end diff --git a/src/SparseVariationalApproximationModule.jl b/src/SparseVariationalApproximationModule.jl index 8e965476..613211df 100644 --- a/src/SparseVariationalApproximationModule.jl +++ b/src/SparseVariationalApproximationModule.jl @@ -19,10 +19,14 @@ using AbstractGPs: FiniteGP, LatentFiniteGP, ApproxPosteriorGP, + VFE, posterior, marginals, At_A, - diag_At_A + diag_At_A, + Xt_A_X, + Xt_invA_X, + inducing_points using GPLikelihoods: GaussianLikelihood export DefaultQuadrature, Analytic, GaussHermite, MonteCarlo @@ -263,7 +267,9 @@ end # Misc utility. # -inducing_points(f::ApproxPosteriorGP{<:SparseVariationalApproximation}) = f.approx.fz.x +function AbstractGPs.inducing_points(f::ApproxPosteriorGP{<:SparseVariationalApproximation}) + return f.approx.fz.x +end # # elbo @@ -372,4 +378,53 @@ function _prior_kl(sva::SparseVariationalApproximation{NonCentered}) return (trace_term + m_ε'm_ε - length(m_ε) - logdet(C_ε)) / 2 end +# Methods to get the explicit variational distribution over inducing points q(u) +function _get_q_u(f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}) + # u = Lε + μ where LLᵀ = cov(fz) and μ = mean(fz) + # q(ε) = N(m, S) + # => q(u) = N(Lm + μ, LSLᵀ) + L, μ = chol_lower(_chol_cov(f.approx.fz)), mean(f.approx.fz) + m, S = mean(f.approx.q), _chol_cov(f.approx.q) + return MvNormal(L * m + μ, Xt_A_X(S, L')) +end +_get_q_u(f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}) = f.approx.q + +function _get_q_u(f::ApproxPosteriorGP{<:AbstractGPs.VFE}) + # q(u) = N(m, S) + # q(f_k) = N(μ_k, Σ_k) (the predictive distribution at test inputs k) + # μ_k = mean(k) + K_kz * K_zz⁻¹ * m + # where: K_kz = cov(f.prior, k, z) + # implemented as: μ_k = mean(k) + K_kz * α + # => m = K_zz * α + # Σ_k = K_kk - (K_kz * K_zz⁻¹ * K_zk) + (K_kz * K_zz⁻¹ * S * K_zz⁻¹ * K_zk) + # interested in the last term to get S + # implemented as: Aᵀ * Λ_ε⁻¹ * A + # where: A = U⁻ᵀ * K_zk + # UᵀU = K_zz + # so, Λ_ε⁻¹ = U⁻ᵀ * S * U + # => S = Uᵀ * Λ_ε⁻¹ * U + # see https://krasserm.github.io/2020/12/12/gaussian-processes-sparse/ eqns (8) & (9) + U = f.data.U + m = U'U * f.data.α + S = Xt_invA_X(f.data.Λ_ε, U) + return MvNormal(m, S) +end + +# Get the optimal closed form solution for the centered variational posterior q(u) +function _optimal_variational_posterior(::Centered, fz, fx, y) + fz.f.mean isa AbstractGPs.ZeroMean || + error("The exact posterior requires a GP with ZeroMean.") + post = posterior(VFE(fz), fx, y) + return _get_q_u(post) +end + +# Get the optimal closed form solution for the non-centered variational posterior q(ε) +function _optimal_variational_posterior(::NonCentered, fz, fx, y) + fz.f.mean isa AbstractGPs.ZeroMean || + error("The exact posterior requires a GP with ZeroMean.") + q_u = _optimal_variational_posterior(Centered(), fz, fx, y) + Cuu = cholesky(Symmetric(cov(fz))) + return MvNormal(Cuu.L \ (mean(q_u) - mean(fz)), Symmetric((Cuu.L \ cov(q_u)) / Cuu.U)) +end + end diff --git a/test/PathwiseSamplingModule.jl b/test/PathwiseSamplingModule.jl new file mode 100644 index 00000000..c751efb9 --- /dev/null +++ b/test/PathwiseSamplingModule.jl @@ -0,0 +1,62 @@ +@testset "pathwise_sampling" begin + rng = MersenneTwister(1453) + + kernel = 0.1 * (SqExponentialKernel() ∘ ScaleTransform(0.2)) + Σy = 1e-6 + f = GP(kernel) + + input_dims = 2 + X = ColVecs(rand(rng, input_dims, 8)) + + fx = f(X, Σy) + y = rand(fx) + + Z = X[1:4] + fz = f(Z) + + X_test = ColVecs(rand(rng, input_dims, 3)) + + num_features = 10000 + rff_wsa = build_rff_weight_space_approx(rng, input_dims, num_features) + + num_samples = 1000 + + function test_single_sample_stats(ap, num_samples) + return test_stats(ap, [pathwise_sample(rng, ap, rff_wsa) for _ in 1:num_samples]) + end + + function test_multi_sample_stats(ap, num_samples) + return test_stats(ap, pathwise_sample(rng, ap, rff_wsa, num_samples)) + end + + function test_stats(ap, function_samples) + y_samples = reduce(hcat, map((f) -> f(X_test), function_samples)) + m_empirical = mean(y_samples; dims=2) + Σ_empirical = + (y_samples .- m_empirical) * (y_samples .- m_empirical)' ./ num_samples + + @test mean(ap(X_test)) ≈ m_empirical atol = 1e-3 rtol = 1e-3 + @test cov(ap(X_test)) ≈ Σ_empirical atol = 1e-3 rtol = 1e-3 + end + + @testset "Centered SVA" begin + q = _optimal_variational_posterior(Centered(), fz, fx, y) + ap = posterior(SparseVariationalApproximation(Centered(), fz, q)) + + test_single_sample_stats(ap, num_samples) + test_multi_sample_stats(ap, num_samples) + end + @testset "NonCentered SVA" begin + q = _optimal_variational_posterior(NonCentered(), fz, fx, y) + ap = posterior(SparseVariationalApproximation(NonCentered(), fz, q)) + + test_single_sample_stats(ap, num_samples) + test_multi_sample_stats(ap, num_samples) + end + @testset "VFE" begin + ap = posterior(AbstractGPs.VFE(fz), fx, y) + + test_single_sample_stats(ap, num_samples) + test_multi_sample_stats(ap, num_samples) + end +end diff --git a/test/Project.toml b/test/Project.toml index 39e43130..4db4ed17 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,6 +12,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Optim = "429524aa-4258-5aef-a3af-852621145aeb" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RandomFourierFeatures = "46bd3b77-1c99-43cd-804b-1e14830792b5" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/SparseVariationalApproximationModule.jl b/test/SparseVariationalApproximationModule.jl index 104949f2..f890dcf3 100644 --- a/test/SparseVariationalApproximationModule.jl +++ b/test/SparseVariationalApproximationModule.jl @@ -19,7 +19,7 @@ fz = f(z, 1e-6) # Construct approximate posterior. - q_Centered = optimal_variational_posterior(fz, fx, y) + q_Centered = _optimal_variational_posterior(Centered(), fz, fx, y) approx_Centered = SparseVariationalApproximation(Centered(), fz, q_Centered) f_approx_post_Centered = posterior(approx_Centered) @@ -32,17 +32,13 @@ end @testset "NonCentered" begin - # Construct optimal approximate posterior. - q = optimal_variational_posterior(fz, fx, y) - Cuu = cholesky(Symmetric(cov(fz))) - q_ε = MvNormal( - Cuu.L \ (mean(q) - mean(fz)), Symmetric((Cuu.L \ cov(q)) / Cuu.U) - ) + q_ε = _optimal_variational_posterior(NonCentered(), fz, fx, y) @testset "Check that q_ε has been properly constructed" begin - @test mean(q) ≈ mean(fz) + Cuu.L * mean(q_ε) - @test cov(q) ≈ Cuu.L * cov(q_ε) * Cuu.U + Cuu = cholesky(Symmetric(cov(fz))) + @test mean(q_Centered) ≈ mean(fz) + Cuu.L * mean(q_ε) + @test cov(q_Centered) ≈ Cuu.L * cov(q_ε) * Cuu.U end # Construct equivalent approximate posteriors. @@ -79,7 +75,7 @@ f = GP(kernel) fx = f(x, 0.1) fz = f(z) - q_ex = optimal_variational_posterior(fz, fx, y) + q_ex = _optimal_variational_posterior(Centered(), fz, fx, y) sva = SparseVariationalApproximation(fz, q_ex) @test elbo(sva, fx, y) isa Real @@ -115,7 +111,7 @@ f = GP(kernel) fx = f(x, lik_noise) fz = f(z) - q_ex = optimal_variational_posterior(fz, fx, y) + q_ex = _optimal_variational_posterior(Centered(), fz, fx, y) gpr_post = posterior(fx, y) # Exact GP regression vfe_post = posterior(VFE(fz), fx, y) # Titsias posterior diff --git a/test/runtests.jl b/test/runtests.jl index fb7f5ae8..ecdb7ce1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using Random using Test -using ApproximateGPs +using RandomFourierFeatures using Flux using IterTools using AbstractGPs @@ -14,7 +14,11 @@ using Zygote using ChainRulesCore using ChainRulesTestUtils using FiniteDifferences -using ApproximateGPs: SparseVariationalApproximationModule, LaplaceApproximationModule +using ApproximateGPs +using ApproximateGPs: + SparseVariationalApproximationModule, + SparseVariationalApproximationModule._optimal_variational_posterior, + LaplaceApproximationModule # Writing tests: # 1. The file structure of the test should match precisely the file structure of src. @@ -56,7 +60,11 @@ include("test_utils.jl") include("SparseVariationalApproximationModule.jl") println(" ") - @info "Ran sva tests" + @info "Ran sparse variational tests" + + include("PathwiseSamplingModule.jl") + println(" ") + @info "Ran pathwise sampling tests" include("LaplaceApproximationModule.jl") println(" ") diff --git a/test/test_utils.jl b/test/test_utils.jl index 58bd666c..dae04a51 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,17 +1,2 @@ # Create a default kernel from two parameters k[1] and k[2] make_kernel(k) = softplus(k[1]) * (SqExponentialKernel() ∘ ScaleTransform(softplus(k[2]))) - -# Computes the optimal closed form solution for the variational posterior -# q(u) (e.g. # https://krasserm.github.io/2020/12/12/gaussian-processes-sparse/ -# equations (11) & (12)). Assumes a ZeroMean function. -function optimal_variational_posterior(fu, fx, y) - fu.f.mean isa AbstractGPs.ZeroMean || - error("The exact posterior requires a GP with ZeroMean.") - σ² = fx.Σy[1] - Kuf = cov(fu, fx) - Kuu = Symmetric(cov(fu)) - Σ = (Symmetric(cov(fu) + (1 / σ²) * Kuf * Kuf')) - m = ((1 / σ²) * Kuu * (Σ \ Kuf)) * y - S = Symmetric(Kuu * (Σ \ Kuu)) - return MvNormal(m, S) -end