Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extension of #269: Use \circ and compose and deprecate transform #276

Merged
merged 20 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.9.0"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -20,6 +21,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
ChainRulesCore = "0.9"
Compat = "3.7"
CompositionsBase = "0.1"
Distances = "0.10"
Functors = "0.1"
Requires = "1.0.1"
Expand Down
6 changes: 3 additions & 3 deletions docs/create_kernel_plots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ n_grid = 101
fill(x₀, n_grid, 1)
xrange = reshape(collect(range(-3, 3; length=n_grid)), :, 1)

k = transform(SqExponentialKernel(), 1.0)
k = SqExponentialKernel() ∘ ScaleTransform(1.0)
K1 = kernelmatrix(k, xrange; obsdim=1)
p = heatmap(
K1;
Expand All @@ -35,7 +35,7 @@ p = heatmap(
)
savefig(joinpath(@__DIR__, "src", "assets", "heatmap_matern.png"))

k = transform(PolynomialKernel(; c=0.0, d=2.0), LinearTransform(randn(3, 1)))
k = PolynomialKernel(; c=0.0, d=2.0)LinearTransform(randn(3, 1))
K3 = kernelmatrix(k, xrange; obsdim=1)
p = heatmap(
K3;
Expand All @@ -47,7 +47,7 @@ p = heatmap(
savefig(joinpath(@__DIR__, "src", "assets", "heatmap_poly.png"))

k =
0.5 * SqExponentialKernel() * transform(LinearKernel(), 0.5) +
0.5 * SqExponentialKernel() * (LinearKernel() ∘ ScaleTransform(0.5)) +
0.4 * (@kernel Matern32Kernel() FunctionTransform(x -> sin.(x)))
K4 = kernelmatrix(k, xrange; obsdim=1)
p = heatmap(
Expand Down
3 changes: 0 additions & 3 deletions docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,6 @@ of kernels together.

```@docs
TransformedKernel
transform(::Kernel, ::Transform)
transform(::Kernel, ::Real)
transform(::Kernel, ::AbstractVector)
ScaledKernel
devmotion marked this conversation as resolved.
Show resolved Hide resolved
KernelSum
KernelProduct
Expand Down
4 changes: 2 additions & 2 deletions docs/src/transform.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ LowRankTransform(rand(10, 5)) ∘ ScaleTransform(2.0)
A transformation `t` can be applied to a single input `x` with `t(x)` and to multiple inputs
`xs` with `map(t, xs)`.

Kernels can be coupled with input transformations with
[`transform`](@ref). It falls back to creating a [`TransformedKernel`](@ref) but allows more
Kernels can be coupled with input transformations with `∘` or its alias `compose`. It falls
devmotion marked this conversation as resolved.
Show resolved Hide resolved
back to creating a [`TransformedKernel`](@ref) but allows more
optimized implementations for specific kernels and transformations.

## List of Input Transforms
Expand Down
5 changes: 4 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ export MOInput
export IndependentMOKernel, LatentFactorMOKernel

# Reexports
export tensor, ⊗
export tensor, ⊗, compose

using Compat
using ChainRulesCore: ChainRulesCore, Composite, Zero, One, DoesNotExist, NO_FIELDS
using ChainRulesCore: @thunk, InplaceableThunk
using CompositionsBase
using Requires
using Distances, LinearAlgebra
using Functors
Expand Down Expand Up @@ -106,6 +107,8 @@ include("zygoterules.jl")

include("test_utils.jl")

include("deprecations.jl")

function __init__()
@require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin
include(joinpath("matrix", "kernelkroneckermat.jl"))
Expand Down
42 changes: 15 additions & 27 deletions src/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ k(x, x'; l, p) = \\exp\\bigg(- \\cos\\bigg(\\pi\\sum_{i=1}^d \\frac{x_i - x'_i}{
"""
struct GaborKernel{K<:Kernel} <: Kernel
kernel::K

function GaborKernel(; ell=nothing, p=nothing)
k = _gabor(; ell=ell, p=p)
ell_transform = _lengthscale_transform(ell)
p_transform = _lengthscale_transform(p)
k = (SqExponentialKernel() ∘ ell_transform) * (CosineKernel() ∘ p_transform)
return new{typeof(k)}(k)
end
end
Expand All @@ -24,38 +27,23 @@ end

(κ::GaborKernel)(x, y) = κ.kernel(x, y)

function _gabor(; ell=nothing, p=nothing)
if ell === nothing
if p === nothing
return SqExponentialKernel() * CosineKernel()
else
return SqExponentialKernel() * transform(CosineKernel(), 1 ./ p)
end
elseif p === nothing
return transform(SqExponentialKernel(), 1 ./ ell) * CosineKernel()
else
return transform(SqExponentialKernel(), 1 ./ ell) *
transform(CosineKernel(), 1 ./ p)
end
end
_lengthscale_transform(::Nothing) = IdentityTransform()
theogf marked this conversation as resolved.
Show resolved Hide resolved
_lengthscale_transform(x::Real) = ScaleTransform(inv(x))
_lengthscale_transform(x::AbstractVector) = ARDTransform(map(inv, x))

_lengthscale(::IdentityTransform) = 1
_lengthscale(t::ScaleTransform) = inv(first(t.s))
_lengthscale(t::ARDTransform) = map(inv, t.v)

function Base.getproperty(k::GaborKernel, v::Symbol)
if v == :kernel
return getfield(k, v)
elseif v == :ell
kernel1 = k.kernel.kernels[1]
if kernel1 isa TransformedKernel
return 1 ./ kernel1.transform.s[1]
else
return 1.0
end
ell_transform = k.kernel.kernels[1].transform
return _lengthscale(ell_transform)
elseif v == :p
kernel2 = k.kernel.kernels[2]
if kernel2 isa TransformedKernel
return 1 ./ kernel2.transform.s[1]
else
return 1.0
end
p_transform = k.kernel.kernels[2].transform
return _lengthscale(p_transform)
else
error("Invalid Property")
end
Expand Down
4 changes: 4 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
@deprecate transform(k::Kernel, t::Transform) k ∘ t
@deprecate transform(k::TransformedKernel, t::Transform) k.kernel ∘ t ∘ k.transform
@deprecate transform(k::Kernel, ρ::Real) k ∘ ScaleTransform(ρ)
@deprecate transform(k::Kernel, ρ::AbstractVector) k ∘ ARDTransform(ρ)
39 changes: 8 additions & 31 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

Kernel derived from `k` for which inputs are transformed via a [`Transform`](@ref) `t`.

It is preferred to create kernels with input transformations with [`transform`](@ref)
instead of `TransformedKernel` directly since [`transform`](@ref) allows optimized
It is preferred to create kernels with input transformations with `∘` or its alias
`compose` instead of `TransformedKernel` directly since this allows optimized
implementations for specific kernels and transformations.

# Definition
Expand Down Expand Up @@ -41,31 +41,8 @@ function _scale(t::ScaleTransform, metric::Union{SqEuclidean,DotProduct}, x, y)
end
_scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y))

"""
transform(k::Kernel, t::Transform)

Create a [`TransformedKernel`](@ref) for kernel `k` and transform `t`.
"""
transform(k::Kernel, t::Transform) = TransformedKernel(k, t)
function transform(k::TransformedKernel, t::Transform)
return TransformedKernel(k.kernel, t ∘ k.transform)
end

"""
transform(k::Kernel, ρ::Real)

Create a [`TransformedKernel`](@ref) for kernel `k` and inverse lengthscale `ρ`.
"""
transform(k::Kernel, ρ::Real) = transform(k, ScaleTransform(ρ))

"""
transform(k::Kernel, ρ::AbstractVector)

Create a [`TransformedKernel`](@ref) for kernel `k` and inverse lengthscales `ρ`.
"""
transform(k::Kernel, ρ::AbstractVector) = transform(k, ARDTransform(ρ))

kernel(κ) = κ.kernel
theogf marked this conversation as resolved.
Show resolved Hide resolved
Base.:∘(k::Kernel, t::Transform) = TransformedKernel(k, t)
Base.:∘(k::TransformedKernel, t::Transform) = TransformedKernel(k.kernel, k.transform ∘ t)

Base.show(io::IO, κ::TransformedKernel) = printshifted(io, κ, 0)

Expand All @@ -87,13 +64,13 @@ function kernelmatrix_diag!(
end

function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector)
return kernelmatrix!(K, kernel(κ), _map(κ.transform, x))
return kernelmatrix!(K, κ.kernel, _map(κ.transform, x))
end

function kernelmatrix!(
K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
)
return kernelmatrix!(K, kernel(κ), _map(κ.transform, x), _map(κ.transform, y))
return kernelmatrix!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
end

function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector)
Expand All @@ -105,9 +82,9 @@ function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector, y::Abstract
end

function kernelmatrix(κ::TransformedKernel, x::AbstractVector)
return kernelmatrix(kernel(κ), _map(κ.transform, x))
return kernelmatrix(κ.kernel, _map(κ.transform, x))
end

function kernelmatrix(κ::TransformedKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix(kernel(κ), _map(κ.transform, x), _map(κ.transform, y))
return kernelmatrix(κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
end
4 changes: 2 additions & 2 deletions test/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
k_manual = exp(-sqeuclidean(v1, v2) / (2 * k.ell^2)) * cospi(euclidean(v1, v2) / k.p)
@test k(v1, v2) ≈ k_manual atol = 1e-5

lhs_manual = transform(SqExponentialKernel(), 1 / k.ell)(v1, v2)
rhs_manual = transform(CosineKernel(), 1 / k.p)(v1, v2)
lhs_manual = (SqExponentialKernel() ∘ ScaleTransform(1 / k.ell))(v1, v2)
rhs_manual = (CosineKernel() ∘ ScaleTransform(1 / k.p))(v1, v2)
@test k(v1, v2) ≈ lhs_manual * rhs_manual atol = 1e-5

k = GaborKernel()
Expand Down
20 changes: 20 additions & 0 deletions test/deprecations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
@testset "deprecations.jl" begin
p = rand()
v = rand(3)
M = rand(3, 3)
v1 = rand(3)
v2 = rand(3)
kernel = SqExponentialKernel()

k1 = @test_deprecated transform(kernel, LinearTransform(M))
@test k1(v1, v2) == (kernel ∘ LinearTransform(M))(v1, v2)

k2 = @test_deprecated transform(kernel ∘ ScaleTransform(p), ARDTransform(v))
@test k2(v1, v2) == (kernel ∘ ARDTransform(v) ∘ ScaleTransform(p))(v1, v2)

k3 = @test_deprecated transform(kernel, p)
@test k3(v1, v2) == (kernel ∘ ScaleTransform(p))(v1, v2)

k4 = @test_deprecated transform(kernel, v)
@test k4(v1, v2) == (kernel ∘ ARDTransform(v))(v1, v2)
end
29 changes: 14 additions & 15 deletions test/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,26 @@
v2 = rand(rng, 3)

s = rand(rng)
s2 = rand(rng)
v = rand(rng, 3)
P = rand(rng, 3, 2)
k = SqExponentialKernel()
kt = TransformedKernel(k, ScaleTransform(s))
ktard = TransformedKernel(k, ARDTransform(v))
@test kt(v1, v2) == transform(k, ScaleTransform(s))(v1, v2)
@test kt(v1, v2) == transform(k, s)(v1, v2)
@test kt(v1, v2) == (k ∘ ScaleTransform(s))(v1, v2)
@test kt(v1, v2) ≈ k(s * v1, s * v2) atol = 1e-5
@test ktard(v1, v2) ≈ transform(k, ARDTransform(v))(v1, v2) atol = 1e-5
@test ktard(v1, v2) == transform(k, v)(v1, v2)
@test ktard(v1, v2) == (k ∘ ARDTransform(v))(v1, v2)
@test ktard(v1, v2) == k(v .* v1, v .* v2)
@test transform(kt, s2)(v1, v2) ≈ kt(s2 * v1, s2 * v2)
@test KernelFunctions.kernel(kt) == k
@test (k ∘ (LinearTransform(P') ∘ ScaleTransform(s)))(v1, v2) ==
((k ∘ LinearTransform(P')) ∘ ScaleTransform(s))(v1, v2)

@test repr(kt) == repr(k) * "\n\t- " * repr(ScaleTransform(s))

TestUtils.test_interface(k, Float64)
test_ADs(x -> transform(SqExponentialKernel(), x[1]), rand(1))# ADs = [:ForwardDiff, :ReverseDiff])
test_ADs(x -> SqExponentialKernel() ∘ ScaleTransform(x[1]), rand(1))

# Test implicit gradients
@testset "Implicit gradients" begin
k = transform(SqExponentialKernel(), 2.0)
k = SqExponentialKernel() ∘ ScaleTransform(2.0)
ps = Flux.params(k)
X = rand(10, 1)
x = vec(X)
Expand All @@ -46,12 +46,11 @@
@test g1[first(ps)] ≈ g3[first(ps)]
end

P = rand(3, 2)
c = Chain(Dense(3, 2))

test_params(transform(k, s), (k, [s]))
test_params(transform(k, v), (k, v))
test_params(transform(k, LinearTransform(P)), (k, P))
test_params(transform(k, LinearTransform(P) ∘ ScaleTransform(s)), (k, [s], P))
test_params(transform(k, FunctionTransform(c)), (k, c))
test_params((k ∘ ScaleTransform(s)), (k, [s]))
test_params((k ∘ ARDTransform(v)), (k, v))
test_params((k ∘ LinearTransform(P)), (k, P))
test_params((k ∘ (LinearTransform(P) ∘ ScaleTransform(s))), (k, [s], P))
test_params((k ∘ FunctionTransform(c)), (k, c))
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ include("test_utils.jl")
include("chainrules.jl")
include("zygoterules.jl")

include("deprecations.jl")

@testset "doctests" begin
DocMeta.setdocmeta!(
KernelFunctions,
Expand Down
2 changes: 1 addition & 1 deletion test/transform/ardtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@
@test_throws DimensionMismatch map(t, ColVecs(randn(rng, D + 1, 3)))

@test repr(t) == "ARD Transform (dims: $D)"
test_ADs(x -> transform(SEKernel(), exp.(x)), randn(rng, 3))
test_ADs(x -> SEKernel() ∘ ARDTransform(exp.(x)), randn(rng, 3))
end
2 changes: 1 addition & 1 deletion test/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# Verify printing works as expected.
@test repr(tp ∘ tf) == "Chain of 2 transforms:\n\t - $(tf) |> $(tp)"
test_ADs(
x -> transform(SEKernel(), ScaleTransform(exp(x[1])) ∘ ARDTransform(exp.(x[2:4]))),
x -> SEKernel() ∘ (ScaleTransform(exp(x[1])) ∘ ARDTransform(exp.(x[2:4]))),
randn(rng, 4),
)
end
2 changes: 1 addition & 1 deletion test/transform/functiontransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@

@test repr(FunctionTransform(sin)) == "Function Transform: $(sin)"
f(a, x) = sin.(a .* x)
test_ADs(x -> transform(SEKernel(), FunctionTransform(y -> f(x, y))), randn(rng, 3))
test_ADs(x -> SEKernel()FunctionTransform(y -> f(x, y)), randn(rng, 3))
end
2 changes: 1 addition & 1 deletion test/transform/lineartransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@
@test_throws DimensionMismatch map(t, ColVecs(randn(rng, Din + 1, Dout)))

@test repr(t) == "Linear transform (size(A) = ($Dout, $Din))"
test_ADs(x -> transform(SEKernel(), LinearTransform(x)), randn(rng, 3, 3))
test_ADs(x -> SEKernel()LinearTransform(x), randn(rng, 3, 3))
end
4 changes: 2 additions & 2 deletions test/transform/periodic_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
x = collect(range(0.0, 3.0 / f; length=1_000))

# Construct in the usual way.
k_eq_periodic = transform(PeriodicKernel(; r=[sqrt(0.25)]), f)
k_eq_periodic = PeriodicKernel(; r=[sqrt(0.25)]) ∘ ScaleTransform(f)

# Construct using the peridic transform.
k_eq_transform = transform(SqExponentialKernel(), PeriodicTransform(f))
k_eq_transform = SqExponentialKernel()PeriodicTransform(f)

@test kernelmatrix(k_eq_periodic, x) ≈ kernelmatrix(k_eq_transform, x)
# TODO - add interface_tests once #159 is merged.
Expand Down
2 changes: 1 addition & 1 deletion test/transform/scaletransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
@test t.s == [s2]
@test isequal(ScaleTransform(s), ScaleTransform(s))
@test repr(t) == "Scale Transform (s = $(s2))"
test_ADs(x -> transform(SEKernel(), exp(x[1])), randn(rng, 1))
test_ADs(x -> SEKernel() ∘ ScaleTransform(exp(x[1])), randn(rng, 1))
end
10 changes: 5 additions & 5 deletions test/transform/selecttransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
@test repr(t) == "Select Transform (dims: $(select2))"
@test repr(ts) == "Select Transform (dims: $(select_symbols2))"

test_ADs(() -> transform(SEKernel(), SelectTransform([1, 2])))
test_ADs(() -> SEKernel()SelectTransform([1, 2]))

X = randn(rng, (4, 3))
A = AxisArray(X; row=[:a, :b, :c, :d], col=[:x, :y, :z])
Expand All @@ -53,10 +53,10 @@
Z = randn(rng, (2, 3))
C = AxisArray(Z; row=[:e, :f], col=[:x, :y, :z])

tx_row = transform(SEKernel(), SelectTransform([1, 2, 4]))
ta_row = transform(SEKernel(), SelectTransform([:a, :b, :d]))
tx_col = transform(SEKernel(), SelectTransform([1, 3]))
ta_col = transform(SEKernel(), SelectTransform([:x, :z]))
tx_row = SEKernel()SelectTransform([1, 2, 4])
ta_row = SEKernel()SelectTransform([:a, :b, :d])
tx_col = SEKernel()SelectTransform([1, 3])
ta_col = SEKernel()SelectTransform([:x, :z])

@test kernelmatrix(tx_row, X; obsdim=2) ≈ kernelmatrix(ta_row, A; obsdim=2)
@test kernelmatrix(tx_col, X; obsdim=1) ≈ kernelmatrix(ta_col, A; obsdim=1)
Expand Down
Loading