Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extension of #269: Use \circ and compose and deprecate transform #276

Merged
merged 20 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.9.1"
version = "0.9.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -20,6 +21,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
ChainRulesCore = "0.9"
Compat = "3.7"
CompositionsBase = "0.1"
Distances = "0.10"
Functors = "0.1"
Requires = "1.0.1"
Expand Down
6 changes: 3 additions & 3 deletions docs/create_kernel_plots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ n_grid = 101
fill(x₀, n_grid, 1)
xrange = reshape(collect(range(-3, 3; length=n_grid)), :, 1)

k = transform(SqExponentialKernel(), 1.0)
k = SqExponentialKernel() ScaleTransform(1.0)
K1 = kernelmatrix(k, xrange; obsdim=1)
p = heatmap(
K1;
Expand All @@ -35,7 +35,7 @@ p = heatmap(
)
savefig(joinpath(@__DIR__, "src", "assets", "heatmap_matern.png"))

k = transform(PolynomialKernel(; c=0.0, d=2.0), LinearTransform(randn(3, 1)))
k = PolynomialKernel(; c=0.0, d=2.0) LinearTransform(randn(3, 1))
K3 = kernelmatrix(k, xrange; obsdim=1)
p = heatmap(
K3;
Expand All @@ -47,7 +47,7 @@ p = heatmap(
savefig(joinpath(@__DIR__, "src", "assets", "heatmap_poly.png"))

k =
0.5 * SqExponentialKernel() * transform(LinearKernel(), 0.5) +
0.5 * SqExponentialKernel() * (LinearKernel() ScaleTransform(0.5)) +
0.4 * (@kernel Matern32Kernel() FunctionTransform(x -> sin.(x)))
K4 = kernelmatrix(k, xrange; obsdim=1)
p = heatmap(
Expand Down
4 changes: 1 addition & 3 deletions docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ of kernels together.

```@docs
TransformedKernel
transform(::Kernel, ::Transform)
transform(::Kernel, ::Real)
transform(::Kernel, ::AbstractVector)
∘(::Kernel, ::Transform)
ScaledKernel
devmotion marked this conversation as resolved.
Show resolved Hide resolved
KernelSum
KernelProduct
Expand Down
4 changes: 2 additions & 2 deletions docs/src/transform.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ LowRankTransform(rand(10, 5)) ∘ ScaleTransform(2.0)
A transformation `t` can be applied to a single input `x` with `t(x)` and to multiple inputs
`xs` with `map(t, xs)`.

Kernels can be coupled with input transformations with
[`transform`](@ref). It falls back to creating a [`TransformedKernel`](@ref) but allows more
Kernels can be coupled with input transformations with [``](@ref) or its alias `compose`. It falls
back to creating a [`TransformedKernel`](@ref) but allows more
optimized implementations for specific kernels and transformations.

## List of Input Transforms
Expand Down
5 changes: 4 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ export MOInput
export IndependentMOKernel, LatentFactorMOKernel

# Reexports
export tensor,
export tensor, , compose

using Compat
using ChainRulesCore: ChainRulesCore, Composite, Zero, One, DoesNotExist, NO_FIELDS
using ChainRulesCore: @thunk, InplaceableThunk
using CompositionsBase
using Requires
using Distances, LinearAlgebra
using Functors
Expand Down Expand Up @@ -106,6 +107,8 @@ include("zygoterules.jl")

include("test_utils.jl")

include("deprecations.jl")

function __init__()
@require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin
include(joinpath("matrix", "kernelkroneckermat.jl"))
Expand Down
42 changes: 15 additions & 27 deletions src/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ k(x, x'; l, p) = \\exp\\bigg(- \\cos\\bigg(\\pi\\sum_{i=1}^d \\frac{x_i - x'_i}{
"""
struct GaborKernel{K<:Kernel} <: Kernel
kernel::K

function GaborKernel(; ell=nothing, p=nothing)
k = _gabor(; ell=ell, p=p)
ell_transform = _lengthscale_transform(ell)
p_transform = _lengthscale_transform(p)
k = (SqExponentialKernel() ∘ ell_transform) * (CosineKernel() ∘ p_transform)
return new{typeof(k)}(k)
end
end
Expand All @@ -24,38 +27,23 @@ end

(κ::GaborKernel)(x, y) = κ.kernel(x, y)

function _gabor(; ell=nothing, p=nothing)
if ell === nothing
if p === nothing
return SqExponentialKernel() * CosineKernel()
else
return SqExponentialKernel() * transform(CosineKernel(), 1 ./ p)
end
elseif p === nothing
return transform(SqExponentialKernel(), 1 ./ ell) * CosineKernel()
else
return transform(SqExponentialKernel(), 1 ./ ell) *
transform(CosineKernel(), 1 ./ p)
end
end
_lengthscale_transform(::Nothing) = IdentityTransform()
theogf marked this conversation as resolved.
Show resolved Hide resolved
_lengthscale_transform(x::Real) = ScaleTransform(inv(x))
_lengthscale_transform(x::AbstractVector) = ARDTransform(map(inv, x))

_lengthscale(::IdentityTransform) = 1
_lengthscale(t::ScaleTransform) = inv(first(t.s))
_lengthscale(t::ARDTransform) = map(inv, t.v)

function Base.getproperty(k::GaborKernel, v::Symbol)
if v == :kernel
return getfield(k, v)
elseif v == :ell
kernel1 = k.kernel.kernels[1]
if kernel1 isa TransformedKernel
return 1 ./ kernel1.transform.s[1]
else
return 1.0
end
ell_transform = k.kernel.kernels[1].transform
return _lengthscale(ell_transform)
elseif v == :p
kernel2 = k.kernel.kernels[2]
if kernel2 isa TransformedKernel
return 1 ./ kernel2.transform.s[1]
else
return 1.0
end
p_transform = k.kernel.kernels[2].transform
return _lengthscale(p_transform)
else
error("Invalid Property")
end
Expand Down
4 changes: 4 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
@deprecate transform(k::Kernel, t::Transform) k t
@deprecate transform(k::TransformedKernel, t::Transform) k.kernel t k.transform
@deprecate transform(k::Kernel, ρ::Real) k ScaleTransform(ρ)
@deprecate transform(k::Kernel, ρ::AbstractVector) k ARDTransform(ρ)
65 changes: 33 additions & 32 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,11 @@
Kernel derived from `k` for which inputs are transformed via a [`Transform`](@ref) `t`.
It is preferred to create kernels with input transformations with [`transform`](@ref)
instead of `TransformedKernel` directly since [`transform`](@ref) allows optimized
implementations for specific kernels and transformations.
The preferred way to create kernels with input transformations is to use the composition
operator [`∘`](@ref) or its alias `compose` instead of `TransformedKernel` directly since
this allows optimized implementations for specific kernels and transformations.
# Definition
For inputs ``x, x'``, the transformed kernel ``\\widetilde{k}`` derived from kernel ``k`` by
input transformation ``t`` is defined as
```math
\\widetilde{k}(x, x'; k, t) = k\\big(t(x), t(x')\\big).
```
See also: [`∘`](@ref)
"""
struct TransformedKernel{Tk<:Kernel,Tr<:Transform} <: Kernel
kernel::Tk
Expand Down Expand Up @@ -42,30 +36,37 @@ end
_scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y))

"""
transform(k::Kernel, t::Transform)
kernel ∘ transform
∘(kernel, transform)
compose(kernel, transform)
Create a [`TransformedKernel`](@ref) for kernel `k` and transform `t`.
"""
transform(k::Kernel, t::Transform) = TransformedKernel(k, t)
function transform(k::TransformedKernel, t::Transform)
return TransformedKernel(k.kernel, t k.transform)
end
Compose a `kernel` with a transformation `transform` of its inputs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we repeat the mathematical definition from TransformedKernel? Since we want users to avoid using it directly this will remove the need to navigate the docs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like redundancy 😛 The docstring of TransformedKernel is displayed right above this docstring, so users don't have to navigate to another page. If TransformedKernel would not be exported anymore, one could remove its docstring and move the mathematical definition here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant to access the docs from the REPL :) I personally almost never check API pages from the docs but overuse ?!

"""
transform(k::Kernel, ρ::Real)
The prefix forms support chains of multiple transformations:
`∘(kernel, transform1, transform2) = kernel ∘ transform1 ∘ transform2`.
Create a [`TransformedKernel`](@ref) for kernel `k` and inverse lengthscale `ρ`.
"""
transform(k::Kernel, ρ::Real) = transform(k, ScaleTransform(ρ))
# Definition
"""
transform(k::Kernel, ρ::AbstractVector)
For inputs ``x, x'``, the transformed kernel ``\\widetilde{k}`` derived from kernel ``k`` by
input transformation ``t`` is defined as
```math
\\widetilde{k}(x, x'; k, t) = k\\big(t(x), t(x')\\big).
```
Create a [`TransformedKernel`](@ref) for kernel `k` and inverse lengthscales `ρ`.
"""
transform(k::Kernel, ρ::AbstractVector) = transform(k, ARDTransform(ρ))
# Examples
```jldoctest
julia> (SqExponentialKernel() ∘ ScaleTransform(0.5))(0, 2) == exp(-0.5)
true
kernel(κ) = κ.kernel
theogf marked this conversation as resolved.
Show resolved Hide resolved
julia> ∘(ExponentialKernel(), ScaleTransform(2), ScaleTransform(0.5))(1, 2) == exp(-1)
true
```
See also: [`TransformedKernel`](@ref)
"""
Base.:(k::Kernel, t::Transform) = TransformedKernel(k, t)
Base.:(k::TransformedKernel, t::Transform) = TransformedKernel(k.kernel, k.transform t)

Base.show(io::IO, κ::TransformedKernel) = printshifted(io, κ, 0)

Expand All @@ -87,13 +88,13 @@ function kernelmatrix_diag!(
end

function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector)
return kernelmatrix!(K, kernel(κ), _map.transform, x))
return kernelmatrix!(K, κ.kernel, _map.transform, x))
end

function kernelmatrix!(
K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
)
return kernelmatrix!(K, kernel(κ), _map.transform, x), _map.transform, y))
return kernelmatrix!(K, κ.kernel, _map.transform, x), _map.transform, y))
end

function kernelmatrix_diag::TransformedKernel, x::AbstractVector)
Expand All @@ -105,9 +106,9 @@ function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector, y::Abstract
end

function kernelmatrix::TransformedKernel, x::AbstractVector)
return kernelmatrix(kernel(κ), _map.transform, x))
return kernelmatrix(κ.kernel, _map.transform, x))
end

function kernelmatrix::TransformedKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix(kernel(κ), _map.transform, x), _map.transform, y))
return kernelmatrix(κ.kernel, _map.transform, x), _map.transform, y))
end
6 changes: 3 additions & 3 deletions test/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
k_manual = exp(-sqeuclidean(v1, v2) / (2 * k.ell^2)) * cospi(euclidean(v1, v2) / k.p)
@test k(v1, v2) k_manual atol = 1e-5

lhs_manual = transform(SqExponentialKernel(), 1 / k.ell)(v1, v2)
rhs_manual = transform(CosineKernel(), 1 / k.p)(v1, v2)
lhs_manual = (SqExponentialKernel() ScaleTransform(1 / k.ell))(v1, v2)
rhs_manual = (CosineKernel() ScaleTransform(1 / k.p))(v1, v2)
@test k(v1, v2) lhs_manual * rhs_manual atol = 1e-5

k = GaborKernel()
@test k.ell 1.0 atol = 1e-5
@test k.p 1.0 atol = 1e-5
@test repr(k) == "Gabor Kernel (ell = 1.0, p = 1.0)"
@test repr(k) == "Gabor Kernel (ell = 1, p = 1)"

test_interface(k, Vector{Float64})

Expand Down
20 changes: 20 additions & 0 deletions test/deprecations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
@testset "deprecations.jl" begin
p = rand()
v = rand(3)
M = rand(3, 3)
v1 = rand(3)
v2 = rand(3)
kernel = SqExponentialKernel()

k1 = @test_deprecated transform(kernel, LinearTransform(M))
@test k1(v1, v2) == (kernel LinearTransform(M))(v1, v2)

k2 = @test_deprecated transform(kernel ScaleTransform(p), ARDTransform(v))
@test k2(v1, v2) == (kernel ARDTransform(v) ScaleTransform(p))(v1, v2)

k3 = @test_deprecated transform(kernel, p)
@test k3(v1, v2) == (kernel ScaleTransform(p))(v1, v2)

k4 = @test_deprecated transform(kernel, v)
@test k4(v1, v2) == (kernel ARDTransform(v))(v1, v2)
end
35 changes: 19 additions & 16 deletions test/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,27 @@
v2 = rand(rng, 3)

s = rand(rng)
s2 = rand(rng)
v = rand(rng, 3)
P = rand(rng, 3, 2)
k = SqExponentialKernel()
kt = TransformedKernel(k, ScaleTransform(s))
ktard = TransformedKernel(k, ARDTransform(v))
@test kt(v1, v2) == transform(k, ScaleTransform(s))(v1, v2)
@test kt(v1, v2) == transform(k, s)(v1, v2)
@test kt(v1, v2) == (k ScaleTransform(s))(v1, v2)
@test kt(v1, v2) k(s * v1, s * v2) atol = 1e-5
@test ktard(v1, v2) transform(k, ARDTransform(v))(v1, v2) atol = 1e-5
@test ktard(v1, v2) == transform(k, v)(v1, v2)
@test ktard(v1, v2) == (k ARDTransform(v))(v1, v2)
@test ktard(v1, v2) == k(v .* v1, v .* v2)
@test transform(kt, s2)(v1, v2) kt(s2 * v1, s2 * v2)
@test KernelFunctions.kernel(kt) == k
@test (k LinearTransform(P') ScaleTransform(s))(v1, v2) ==
((k LinearTransform(P')) ScaleTransform(s))(v1, v2) ==
(k (LinearTransform(P') ScaleTransform(s)))(v1, v2)

@test repr(kt) == repr(k) * "\n\t- " * repr(ScaleTransform(s))

TestUtils.test_interface(k, Float64)
test_ADs(x -> transform(SqExponentialKernel(), x[1]), rand(1))# ADs = [:ForwardDiff, :ReverseDiff])
test_ADs(x -> SqExponentialKernel() ScaleTransform(x[1]), rand(1))

# Test implicit gradients
@testset "Implicit gradients" begin
k = transform(SqExponentialKernel(), 2.0)
k = SqExponentialKernel() ScaleTransform(2.0)
ps = Flux.params(k)
X = rand(10, 1)
x = vec(X)
Expand All @@ -46,12 +47,14 @@
@test g1[first(ps)] g3[first(ps)]
end

P = rand(3, 2)
c = Chain(Dense(3, 2))
@testset "Parameters" begin
k = ConstantKernel(; c=rand(rng))
c = Chain(Dense(3, 2))

test_params(transform(k, s), (k, [s]))
test_params(transform(k, v), (k, v))
test_params(transform(k, LinearTransform(P)), (k, P))
test_params(transform(k, LinearTransform(P) ScaleTransform(s)), (k, [s], P))
test_params(transform(k, FunctionTransform(c)), (k, c))
test_params(k ScaleTransform(s), (k, [s]))
test_params(k ARDTransform(v), (k, v))
test_params(k LinearTransform(P), (k, P))
test_params(k LinearTransform(P) ScaleTransform(s), (k, [s], P))
test_params(k FunctionTransform(c), (k, c))
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ include("test_utils.jl")
include("chainrules.jl")
include("zygoterules.jl")

include("deprecations.jl")

@testset "doctests" begin
DocMeta.setdocmeta!(
KernelFunctions,
Expand Down
2 changes: 1 addition & 1 deletion test/transform/ardtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@
@test_throws DimensionMismatch map(t, ColVecs(randn(rng, D + 1, 3)))

@test repr(t) == "ARD Transform (dims: $D)"
test_ADs(x -> transform(SEKernel(), exp.(x)), randn(rng, 3))
test_ADs(x -> SEKernel() ARDTransform(exp.(x)), randn(rng, 3))
end
2 changes: 1 addition & 1 deletion test/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# Verify printing works as expected.
@test repr(tp tf) == "Chain of 2 transforms:\n\t - $(tf) |> $(tp)"
test_ADs(
x -> transform(SEKernel(), ScaleTransform(exp(x[1])) ARDTransform(exp.(x[2:4]))),
x -> SEKernel() (ScaleTransform(exp(x[1])) ARDTransform(exp.(x[2:4]))),
randn(rng, 4),
)
end
2 changes: 1 addition & 1 deletion test/transform/functiontransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@

@test repr(FunctionTransform(sin)) == "Function Transform: $(sin)"
f(a, x) = sin.(a .* x)
test_ADs(x -> transform(SEKernel(), FunctionTransform(y -> f(x, y))), randn(rng, 3))
test_ADs(x -> SEKernel() FunctionTransform(y -> f(x, y)), randn(rng, 3))
end
Loading