Skip to content

Commit

Permalink
Merge branch 'master' into tgf/softmax_docs
Browse files Browse the repository at this point in the history
  • Loading branch information
theogf authored Feb 14, 2023
2 parents 749c400 + 4b33f7a commit d09552c
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 49 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GPLikelihoods"
uuid = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
authors = ["JuliaGaussianProcesses Team"]
version = "0.4.3"
version = "0.4.5"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -20,7 +20,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ChainRulesCore = "1.7"
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
FastGaussQuadrature = "0.4"
Functors = "0.1, 0.2"
Functors = "0.1, 0.2, 0.3"
InverseFunctions = "0.1.2"
IrrationalConstants = "0.1"
SpecialFunctions = "1, 2"
Expand Down
26 changes: 7 additions & 19 deletions src/expectations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,25 +87,13 @@ function expected_loglikelihood(
# Compute the expectation via Gauss-Hermite quadrature
# using a reparameterisation by change of variable
# (see e.g. en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)
return sum(Broadcast.instantiate(
Broadcast.broadcasted(y, q_f) do yᵢ, q_fᵢ # Loop over every pair
# of marginal distribution q(fᵢ) and observation yᵢ
expected_loglikelihood(gh, lik, q_fᵢ, yᵢ)
end,
))
end

# Compute the expected_loglikelihood for one observation and a marginal distributions
function expected_loglikelihood(gh::GaussHermiteExpectation, lik, q_f::Normal, y)
μ = mean(q_f)
σ̃ = sqrt2 * std(q_f)
return invsqrtπ * sum(Broadcast.instantiate(
Broadcast.broadcasted(gh.xs, gh.ws) do x, w # Loop over every
# pair of Gauss-Hermite point x with weight w
f = σ̃ * x + μ
loglikelihood(lik(f), y) * w
end,
))
# PR #90 introduces eager instead of lazy broadcast over observations
# and Gauss-Hermit points and weights in order to make the function
# type stable. Compared to other type stable implementations, e.g.
# using a custom two-argument pairwise sum, this is faster to
# differentiate using Zygote.
A = loglikelihood.(lik.(sqrt2 .* std.(q_f) .* gh.xs' .+ mean.(q_f)), y) .* gh.ws'
return invsqrtπ * sum(A)
end

function expected_loglikelihood(
Expand Down
8 changes: 5 additions & 3 deletions src/likelihoods/negativebinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,11 @@ struct NBParamII{T} <: NBParamMean
end

function (l::NegativeBinomialLikelihood{<:NBParamII})(f::Real)
μ = l.invlink(f)
ev = l.params.α * μ
return NegativeBinomial(_nb_mean_excessvar_to_r_p(μ, ev)...)
# Simplify parameter conversions and avoid splatting
α = l.params.α
r = inv(α)
p = inv(one(α) + α * l.invlink(f))
return NegativeBinomial(r, p)
end

"""
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
Functors = "0.1, 0.2"
Functors = "0.1, 0.2, 0.3"
StatsFuns = "0.9, 1"
Zygote = "0.6"
julia = "1.3"
71 changes: 47 additions & 24 deletions test/expectations.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
@testset "expectations" begin
# Test that the various methods of computing expectations return the same
# result.
rng = MersenneTwister(123456)
q_f = Normal.(zeros(10), ones(10))

likelihoods_to_test = [
BernoulliLikelihood(),
ExponentialLikelihood(),
GammaLikelihood(),
PoissonLikelihood(),
GaussianLikelihood(),
NegativeBinomialLikelihood(NBParamSuccess(1.0)),
NegativeBinomialLikelihood(NBParamFailure(1.0)),
NegativeBinomialLikelihood(NBParamI(1.0)),
NegativeBinomialLikelihood(NBParamII(1.0)),
PoissonLikelihood(),
]

@testset "testing all analytic implementations" begin
Expand All @@ -30,30 +33,50 @@
end
end

@testset "$(nameof(typeof(lik)))" for lik in likelihoods_to_test
methods = [
GaussHermiteExpectation(100),
MonteCarloExpectation(1e7),
GPLikelihoods.DefaultExpectationMethod(),
]
def = GPLikelihoods.default_expectation_method(lik)
if def isa GPLikelihoods.AnalyticExpectation
push!(methods, def)
end
y = rand.(rng, lik.(zeros(10)))
@testset "testing consistency of different expectation methods" begin
@testset "$(nameof(typeof(lik)))" for lik in likelihoods_to_test
# Test that the various methods of computing expectations return the same
# result.
methods = [
GaussHermiteExpectation(100),
MonteCarloExpectation(1e7),
GPLikelihoods.DefaultExpectationMethod(),
]
def = GPLikelihoods.default_expectation_method(lik)
if def isa GPLikelihoods.AnalyticExpectation
push!(methods, def)
end
y = rand.(rng, lik.(zeros(10)))

results = map(m -> GPLikelihoods.expected_loglikelihood(m, lik, q_f, y), methods)
@test all(x -> isapprox(x, results[end]; atol=1e-6, rtol=1e-3), results)
results = map(
m -> GPLikelihoods.expected_loglikelihood(m, lik, q_f, y), methods
)
@test all(x -> isapprox(x, results[end]; atol=1e-6, rtol=1e-3), results)
end
end

@test GPLikelihoods.expected_loglikelihood(
MonteCarloExpectation(1), GaussianLikelihood(), q_f, zeros(10)
) isa Real
@test GPLikelihoods.expected_loglikelihood(
GaussHermiteExpectation(1), GaussianLikelihood(), q_f, zeros(10)
) isa Real
@test GPLikelihoods.default_expectation_method-> Normal(0, θ)) isa
GaussHermiteExpectation
@testset "testing return types and type stability" begin
@test GPLikelihoods.expected_loglikelihood(
MonteCarloExpectation(1), GaussianLikelihood(), q_f, zeros(10)
) isa Real
@test GPLikelihoods.expected_loglikelihood(
GaussHermiteExpectation(1), GaussianLikelihood(), q_f, zeros(10)
) isa Real
@test GPLikelihoods.default_expectation_method-> Normal(0, θ)) isa
GaussHermiteExpectation

@testset "$(nameof(typeof(lik)))" for lik in likelihoods_to_test
# Test that `expectec_loglikelihood` is type-stable
y = rand.(rng, lik.(zeros(10)))
for method in [
MonteCarloExpectation(100),
GaussHermiteExpectation(100),
GPLikelihoods.DefaultExpectationMethod(),
]
@test (@inferred expected_loglikelihood(method, lik, q_f, y)) isa Real
end
end
end

# see https://github.com/JuliaGaussianProcesses/ApproximateGPs.jl/issues/82
@testset "testing Zygote compatibility with GaussHermiteExpectation" begin
Expand Down

0 comments on commit d09552c

Please sign in to comment.