-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
21 changed files
with
551 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ jobs: | |
matrix: | ||
version: | ||
- '1' | ||
- '1.3' | ||
- '1.6' | ||
- 'nightly' | ||
os: | ||
- ubuntu-latest | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,28 @@ | ||
name = "GPLikelihoods" | ||
uuid = "6031954c-0455-49d7-b3b9-3e1c99afaf40" | ||
authors = ["JuliaGaussianProcesses Team"] | ||
version = "0.3.2" | ||
version = "0.4.3" | ||
|
||
[deps] | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" | ||
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" | ||
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" | ||
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" | ||
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
||
[compat] | ||
ChainRulesCore = "1.7" | ||
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25" | ||
FastGaussQuadrature = "0.4" | ||
Functors = "0.1, 0.2" | ||
InverseFunctions = "0.1.2" | ||
StatsFuns = "0.9.13" | ||
julia = "1.3" | ||
IrrationalConstants = "0.1" | ||
SpecialFunctions = "1, 2" | ||
StatsFuns = "0.9.13, 1" | ||
julia = "1.6" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,11 @@ | ||
[deps] | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" | ||
GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40" | ||
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" | ||
|
||
[compat] | ||
Distributions = "0.25" | ||
Documenter = "0.25, 0.26, 0.27" | ||
GPLikelihoods = "0.4" | ||
StatsFuns = "0.9, 1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,15 @@ | ||
using Documenter, GPLikelihoods | ||
|
||
makedocs(; | ||
modules=[GPLikelihoods], | ||
sitename="GPLikelihoods.jl", | ||
format=Documenter.HTML(), | ||
modules=[GPLikelihoods], | ||
pages=["Home" => "index.md", "API" => "api.md"], | ||
repo="https://github.com/JuliaGaussianProcesses/GPLikelihoods.jl/blob/{commit}{path}#L{line}", | ||
sitename="GPLikelihoods.jl", | ||
authors="JuliaGaussianProcesses organization", | ||
assets=String[], | ||
strict=true, | ||
checkdocs=:exports, | ||
#doctestfilters=JuliaGPsDocs.DOCTEST_FILTERS, | ||
) | ||
|
||
deploydocs(; repo="github.com/JuliaGaussianProcesses/GPLikelihoods.jl", push_preview=true) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,68 @@ | ||
```@meta | ||
CurrentModule = GPLikelihoods | ||
``` | ||
```@setup test-repl | ||
using Distributions | ||
using GPLikelihoods | ||
using StatsFuns | ||
``` | ||
|
||
# GPLikelihoods | ||
|
||
[`GPLikelihoods.jl`](https://github.com/JuliaGaussianProcesses/GPLikelihoods.jl) provides a practical interface to connect Gaussian and non-conjugate likelihoods | ||
to Gaussian Processes. | ||
The API is very basic: Every `AbstractLikelihood` object is a [functor](https://docs.julialang.org/en/v1/manual/methods/#Function-like-objects-1) | ||
taking a `Real` or an `AbstractVector` as an input and returns a | ||
taking a `Real` or an `AbstractVector` as an input and returning a | ||
`Distribution` from [`Distributions.jl`](https://github.com/JuliaStats/Distributions.jl). | ||
|
||
### Single-latent vs multi-latent likelihoods | ||
|
||
Most likelihoods, like the [`GaussianLikelihood`](@ref), only require one latent Gaussian process. | ||
Passing a `Real` will therefore return a [`UnivariateDistribution`](https://juliastats.org/Distributions.jl/latest/univariate/), | ||
and passing an `AbstractVector{<:Real}` will return a [multivariate product of distributions](https://juliastats.org/Distributions.jl/latest/multivariate/#Product-distributions). | ||
```@repl | ||
```@repl test-repl | ||
f = 2.0; | ||
GaussianLikelihood()(f) == Normal(2.0) | ||
fs = [2.0, 3.0, 1.5] | ||
GaussianLikelihood()(fs) == Product([Normal(2.0), Normal(3.0), Normal(1.5)]) | ||
GaussianLikelihood()(f) == Normal(2.0, 1e-3) | ||
fs = [2.0, 3.0, 1.5]; | ||
GaussianLikelihood()(fs) isa AbstractMvNormal | ||
``` | ||
|
||
Some likelihoods, like the [`CategoricalLikelihood`](@ref), requires multiple latent Gaussian processes, | ||
Some likelihoods, like the [`CategoricalLikelihood`](@ref), require multiple latent Gaussian processes, | ||
and an `AbstractVector{<:Real}` needs to be passed. | ||
To obtain a product of distributions an `AbstractVector{<:AbstractVector{<:Real}}` has to be passed (we recommend | ||
using [`ColVecs` and `RowVecs` from KernelFunctions.jl](https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/api/#Vector-Valued-Inputs) | ||
if you need to transform an `AbstractMatrix`). | ||
```@repl | ||
```@repl test-repl | ||
fs = [2.0, 3.0, 4.5]; | ||
CategoricalLikelihood()(fs) isa Categorical | ||
Fs = [rand(3) for _ in 1:4] | ||
Fs = [rand(3) for _ in 1:4]; | ||
CategoricalLikelihood()(Fs) isa Product{<:Any,<:Categorical} | ||
``` | ||
|
||
### Constrained parameters | ||
|
||
The domain of some distributions parameters can be different from | ||
``\mathbb{R}``, the real domain. | ||
To solve this problem, we also provide the [`Link`](@ref) type, which can be | ||
passed to the [`Likelihood`](@ref) constructors. | ||
Alternatively, `function`s can also directly be passed and will be wrapped in a `Link`). | ||
The function values `f` of the latent Gaussian process live in the real domain | ||
``\mathbb{R}``. For some likelihoods, the domain of the distribution parameter | ||
`p` that is modulated by the latent Gaussian process is constrained to some | ||
subset of ``\mathbb{R}``, e.g. only positive values or values in an interval. | ||
|
||
To connect these two domains, a transformation from `f` to `p` is required. | ||
For this, we provide the [`Link`](@ref) type, which can be passed to the | ||
likelihood constructors. (Alternatively, `function`s can also directly be | ||
passed and will be wrapped in a `Link`.) | ||
|
||
We typically call this passed transformation the `invlink`. This comes from | ||
the statistics literature, where the "link" is defined as `f = link(p)`, | ||
whereas here we need `p = invlink(f)`. | ||
|
||
For more details about which likelihoods require a [`Link`](@ref) check out their docs. | ||
We typically named this passed link as the `invlink`. | ||
This comes from the statistic literature, where the "link" is defined as `f = link(y)`. | ||
|
||
A classical example is the [`BernoulliLikelihood`](@ref) for classification, with the probability parameter ``p \in \[0, 1\]``. | ||
The default it to use a [`logistic`](https://en.wikipedia.org/wiki/Logistic_function) transformation, but one could also use the inverse of the [`probit`](https://en.wikipedia.org/wiki/Probit) link: | ||
A classical example is the [`BernoulliLikelihood`](@ref) for classification, with the probability parameter ``p \in [0, 1]``. | ||
The default is to use a [`logistic`](https://en.wikipedia.org/wiki/Logistic_function) transformation, but one could also use the inverse of the [`probit`](https://en.wikipedia.org/wiki/Probit) link: | ||
|
||
```@repl | ||
```@repl test-repl | ||
f = 2.0; | ||
BernoulliLikelihood()(f) == Bernoulli(logistic(f)) | ||
BernoulliLikelihood(NormalCDFLink()) == Bernoulli(normalcdf(f)) | ||
BernoulliLikelihood(NormalCDFLink())(f) == Bernoulli(normcdf(f)) | ||
``` | ||
Note that we passed the `inverse` of the `probit` function which is the `normalcdf` function. | ||
Note that we passed the _inverse_ of the `probit` function which is the `normcdf` function. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
using FastGaussQuadrature: gausshermite | ||
using SpecialFunctions: loggamma | ||
using ChainRulesCore: ChainRulesCore | ||
using IrrationalConstants: sqrt2, invsqrtπ | ||
|
||
struct DefaultExpectationMethod end | ||
|
||
struct AnalyticExpectation end | ||
|
||
struct GaussHermiteExpectation | ||
xs::Vector{Float64} | ||
ws::Vector{Float64} | ||
end | ||
GaussHermiteExpectation(n::Integer) = GaussHermiteExpectation(gausshermite(n)...) | ||
|
||
ChainRulesCore.@non_differentiable gausshermite(n) | ||
|
||
struct MonteCarloExpectation | ||
n_samples::Int | ||
end | ||
|
||
default_expectation_method(_) = GaussHermiteExpectation(20) | ||
|
||
""" | ||
expected_loglikelihood( | ||
quadrature, | ||
lik, | ||
q_f::AbstractVector{<:Normal}, | ||
y::AbstractVector, | ||
) | ||
This function computes the expected log likelihood: | ||
```math | ||
∫ q(f) log p(y | f) df | ||
``` | ||
where `p(y | f)` is the process likelihood. This is described by `lik`, which should be a | ||
callable that takes `f` as input and returns a Distribution over `y` that supports | ||
`loglikelihood(lik(f), y)`. | ||
`q(f)` is an approximation to the latent function values `f` given by: | ||
```math | ||
q(f) = ∫ p(f | u) q(u) du | ||
``` | ||
where `q(u)` is the variational distribution over inducing points. | ||
The marginal distributions of `q(f)` are given by `q_f`. | ||
`quadrature` determines which method is used to calculate the expected log | ||
likelihood. | ||
# Extended help | ||
`q(f)` is assumed to be an `MvNormal` distribution and `p(y | f)` is assumed to | ||
have independent marginals such that only the marginals of `q(f)` are required. | ||
""" | ||
expected_loglikelihood(quadrature, lik, q_f, y) | ||
|
||
""" | ||
expected_loglikelihood(::DefaultExpectationMethod, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector) | ||
The expected log likelihood, using the default quadrature method for the given likelihood. | ||
(The default quadrature method is defined by `default_expectation_method(lik)`, and should | ||
be the closed form solution if it exists, but otherwise defaults to Gauss-Hermite | ||
quadrature.) | ||
""" | ||
function expected_loglikelihood( | ||
::DefaultExpectationMethod, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector | ||
) | ||
quadrature = default_expectation_method(lik) | ||
return expected_loglikelihood(quadrature, lik, q_f, y) | ||
end | ||
|
||
function expected_loglikelihood( | ||
mc::MonteCarloExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector | ||
) | ||
# take `n_samples` reparameterised samples | ||
f_μ = mean.(q_f) | ||
fs = f_μ .+ std.(q_f) .* randn(eltype(f_μ), length(q_f), mc.n_samples) | ||
lls = loglikelihood.(lik.(fs), y) | ||
return sum(lls) / mc.n_samples | ||
end | ||
|
||
# Compute the expected_loglikelihood over a collection of observations and marginal distributions | ||
function expected_loglikelihood( | ||
gh::GaussHermiteExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector | ||
) | ||
# Compute the expectation via Gauss-Hermite quadrature | ||
# using a reparameterisation by change of variable | ||
# (see e.g. en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) | ||
return sum(Broadcast.instantiate( | ||
Broadcast.broadcasted(y, q_f) do yᵢ, q_fᵢ # Loop over every pair | ||
# of marginal distribution q(fᵢ) and observation yᵢ | ||
expected_loglikelihood(gh, lik, q_fᵢ, yᵢ) | ||
end, | ||
)) | ||
end | ||
|
||
# Compute the expected_loglikelihood for one observation and a marginal distributions | ||
function expected_loglikelihood(gh::GaussHermiteExpectation, lik, q_f::Normal, y) | ||
μ = mean(q_f) | ||
σ̃ = sqrt2 * std(q_f) | ||
return invsqrtπ * sum(Broadcast.instantiate( | ||
Broadcast.broadcasted(gh.xs, gh.ws) do x, w # Loop over every | ||
# pair of Gauss-Hermite point x with weight w | ||
f = σ̃ * x + μ | ||
loglikelihood(lik(f), y) * w | ||
end, | ||
)) | ||
end | ||
|
||
function expected_loglikelihood( | ||
::AnalyticExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector | ||
) | ||
return error( | ||
"No analytic solution exists for $(typeof(lik)). Use `DefaultExpectationMethod`, `GaussHermiteExpectation` or `MonteCarloExpectation` instead.", | ||
) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.