Skip to content

Commit

Permalink
Resolves #14 (#16)
Browse files Browse the repository at this point in the history
* Fix bug

* Adds bounds for ChainRulesCore

* Exports ColVecs and RowVecs

* Update bounds

* Export ColVecs / RowVecs

* Bump patch version
  • Loading branch information
willtebbutt authored Jul 20, 2020
1 parent 022067b commit 16b84cc
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 19 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "AbstractGPs"
uuid = "99985d1d-32ba-4be9-9821-2ec096f28918"
authors = ["willtebbutt <[email protected]>"]
version = "0.2.2"
version = "0.2.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
Expand All @@ -14,6 +15,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.9"
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23"
FillArrays = "0.7, 0.8"
KernelFunctions = "0.4"
Expand Down
5 changes: 4 additions & 1 deletion src/AbstractGPs.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module AbstractGPs

using ChainRulesCore
using Distributions
using FillArrays
using LinearAlgebra
Expand All @@ -8,9 +9,11 @@ module AbstractGPs
using Random
using Statistics

using KernelFunctions: ColVecs, RowVecs

export GP, mean, cov, std, cov_diag, mean_and_cov, marginals, rand,
logpdf, elbo, dtc, posterior, approx_posterior, VFE, DTC, AbstractGP, sampleplot,
update_approx_posterior, LatentGP
update_approx_posterior, LatentGP, ColVecs, RowVecs

# Various bits of utility functionality.
include(joinpath("util", "common_covmat_ops.jl"))
Expand Down
7 changes: 6 additions & 1 deletion src/gp/mean_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ Returns `zero(T)` everywhere.
"""
struct ZeroMean{T<:Real} <: MeanFunction end

Base.map(::ZeroMean{T}, x::AbstractVector) where T = zeros(T, length(x))
Base.map(::ZeroMean{T}, x::AbstractVector) where {T} = zeros(T, length(x))

function ChainRulesCore.rrule(::typeof(Base.map), m::ZeroMean, x::AbstractVector)
map_ZeroMean_pullback(Δ) = (NO_FIELDS, NO_FIELDS, Zero())
return map(m, x), map_ZeroMean_pullback
end

ZeroMean() = ZeroMean{Float64}()

Expand Down
4 changes: 3 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Expand All @@ -11,10 +12,11 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ChainRulesCore = "0.9"
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23"
Documenter = "0.24, 0.25"
FiniteDifferences = "0.9.6, 0.10"
KernelFunctions = "0.4"
Plots = "1"
Zygote = "0.4.6, 0.5"
Zygote = "0.5"
julia = "1.3"
35 changes: 24 additions & 11 deletions test/gp/mean_functions.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
@testset "mean_functions" begin
@testset "CustomMean" begin
rng, N, D = MersenneTwister(123456), 11, 2
x = randn(rng, N)
foo_mean = x->sum(abs2, x)
f = CustomMean(foo_mean)

@test map(f, x) == map(foo_mean, x)
# differentiable_mean_function_tests(f, randn(rng, N), x)
end
@testset "ZeroMean" begin
rng, P, Q, D = MersenneTwister(123456), 3, 2, 4
P = 3
Q = 2
D = 4
# X = ColVecs(randn(rng, D, P))
x = randn(rng, P)
x = randn(P)
= randn(P)
f = ZeroMean{Float64}()

for x in [x]
@test map(f, x) == zeros(size(x))
# differentiable_mean_function_tests(f, randn(rng, P), x)
end

# Manually verify the ChainRule. Really, this should employ FiniteDifferences, but
# currently ChainRulesTestUtils isn't up to handling this, so this will have to do
# for now.
y, pb = rrule(map, f, x)
@test y == map(f, x)
Δmap, Δf, Δx = pb(randn(P))
@test iszero(Δmap)
@test iszero(Δf)
@test iszero(Δx)
end
@testset "ConstMean" begin
rng, D, N = MersenneTwister(123456), 5, 3
Expand All @@ -31,4 +35,13 @@
# differentiable_mean_function_tests(m, randn(rng, N), x)
end
end
@testset "CustomMean" begin
rng, N, D = MersenneTwister(123456), 11, 2
x = randn(rng, N)
foo_mean = x->sum(abs2, x)
f = CustomMean(foo_mean)

@test map(f, x) == map(foo_mean, x)
# differentiable_mean_function_tests(f, randn(rng, N), x)
end
end
10 changes: 6 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@ using AbstractGPs: AbstractGP, MeanFunction, FiniteGP, ConstMean, GP, ZeroMean,
ConstMean, CustomMean, Xt_A_X, Xt_A_Y, Xt_invA_Y, Xt_invA_X, diag_At_A, diag_At_B,
diag_Xt_A_X, diag_Xt_A_Y, diag_Xt_invA_X, diag_Xt_invA_Y, Xtinv_A_Xinv, tr_At_A,
mean_and_cov_diag

using Documenter
using ChainRulesCore
using Distributions: MvNormal, PDMat
using FiniteDifferences
using FiniteDifferences: j′vp, to_vec
using KernelFunctions
using KernelFunctions: Kernel, ColVecs, RowVecs
using LinearAlgebra
using LinearAlgebra: AbstractTriangular
using Random
using Plots
using Test
using FiniteDifferences
using FiniteDifferences: j′vp, to_vec
using Random
using Statistics
using Test
using Zygote


Expand Down

2 comments on commit 16b84cc

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/18197

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.3 -m "<description of version>" 16b84cc52a2457f745839f9c1cae3cc1002106a0
git push origin v0.2.3

Please sign in to comment.