diff --git a/Project.toml b/Project.toml index 43e7eb9f0..ce1d6569b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.8.18" +version = "0.8.19" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -13,6 +13,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" @@ -24,5 +25,6 @@ Requires = "1.0.1" SpecialFunctions = "0.8, 0.9, 0.10, 1" StatsBase = "0.32, 0.33" StatsFuns = "0.8, 0.9" +TensorCore = "0.1" ZygoteRules = "0.2" julia = "1.3" diff --git a/docs/src/kernels.md b/docs/src/kernels.md index 8e2adc52d..ee769eb3c 100644 --- a/docs/src/kernels.md +++ b/docs/src/kernels.md @@ -124,7 +124,7 @@ transform(::Kernel, ::AbstractVector) ScaledKernel KernelSum KernelProduct -TensorProduct +KernelTensorProduct ``` ## Multi-output Kernels diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 8a938e103..98fc0cb0e 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -29,9 +29,8 @@ export LinearKernel, PolynomialKernel export RationalQuadraticKernel, GammaRationalQuadraticKernel export GaborKernel, PiecewisePolynomialKernel export PeriodicKernel, NeuralNetworkKernel -export KernelSum, KernelProduct +export KernelSum, KernelProduct, KernelTensorProduct export TransformedKernel, ScaledKernel -export TensorProduct export Transform, SelectTransform, @@ -52,6 +51,9 @@ export ColVecs, RowVecs export MOInput export IndependentMOKernel, LatentFactorMOKernel +# Reexports +export tensor, ⊗ + using Compat using Requires using Distances, LinearAlgebra @@ -61,6 +63,7 @@ using ZygoteRules: @adjoint, pullback using StatsFuns: logtwo using InteractiveUtils: subtypes using StatsBase +using TensorCore abstract type Kernel end abstract type SimpleKernel <: Kernel end @@ -100,7 +103,8 @@ include(joinpath("kernels", "scaledkernel.jl")) include(joinpath("matrix", "kernelmatrix.jl")) include(joinpath("kernels", "kernelsum.jl")) include(joinpath("kernels", "kernelproduct.jl")) -include(joinpath("kernels", "tensorproduct.jl")) +include(joinpath("kernels", "kerneltensorproduct.jl")) +include(joinpath("kernels", "overloads.jl")) include(joinpath("approximations", "nystrom.jl")) include("generic.jl") diff --git a/src/basekernels/sm.jl b/src/basekernels/sm.jl index b7d51f41a..833945295 100644 --- a/src/basekernels/sm.jl +++ b/src/basekernels/sm.jl @@ -92,7 +92,7 @@ function spectral_mixture_product_kernel( if !(size(αs) == size(γs) == size(ωs)) throw(DimensionMismatch("The dimensions of αs, γs, ans ωs do not match")) end - return TensorProduct( + return KernelTensorProduct( spectral_mixture_kernel(h, α, reshape(γ, 1, :), reshape(ω, 1, :)) for (α, γ, ω) in zip(eachrow(αs), eachrow(γs), eachrow(ωs)) ) diff --git a/src/deprecated.jl b/src/deprecated.jl index 5c3ca3131..eeae1eb4c 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -7,3 +7,8 @@ @deprecate PiecewisePolynomialKernel{V}(A::AbstractMatrix{<:Real}) where {V} transform( PiecewisePolynomialKernel{V}(size(A, 1)), LinearTransform(cholesky(A).U) ) + +@deprecate TensorProduct(kernels) KernelTensorProduct(kernels) +@deprecate TensorProduct(kernel::Kernel, kernels::Kernel...) KernelTensorProduct( + kernel, kernels... +) diff --git a/src/kernels/kernelproduct.jl b/src/kernels/kernelproduct.jl index 8201815e0..7f2800077 100644 --- a/src/kernels/kernelproduct.jl +++ b/src/kernels/kernelproduct.jl @@ -41,31 +41,6 @@ end @functor KernelProduct -Base.:*(k1::Kernel, k2::Kernel) = KernelProduct(k1, k2) - -function Base.:*( - k1::KernelProduct{<:AbstractVector{<:Kernel}}, - k2::KernelProduct{<:AbstractVector{<:Kernel}}, -) - return KernelProduct(vcat(k1.kernels, k2.kernels)) -end - -function Base.:*(k1::KernelProduct, k2::KernelProduct) - return KernelProduct(k1.kernels..., k2.kernels...) -end - -function Base.:*(k::Kernel, ks::KernelProduct{<:AbstractVector{<:Kernel}}) - return KernelProduct(vcat(k, ks.kernels)) -end - -Base.:*(k::Kernel, kp::KernelProduct) = KernelProduct(k, kp.kernels...) - -function Base.:*(ks::KernelProduct{<:AbstractVector{<:Kernel}}, k::Kernel) - return KernelProduct(vcat(ks.kernels, k)) -end - -Base.:*(kp::KernelProduct, k::Kernel) = KernelProduct(kp.kernels..., k) - Base.length(k::KernelProduct) = length(k.kernels) (κ::KernelProduct)(x, y) = prod(k(x, y) for k in κ.kernels) diff --git a/src/kernels/kernelsum.jl b/src/kernels/kernelsum.jl index 176e870ea..017642d95 100644 --- a/src/kernels/kernelsum.jl +++ b/src/kernels/kernelsum.jl @@ -41,28 +41,6 @@ end @functor KernelSum -Base.:+(k1::Kernel, k2::Kernel) = KernelSum(k1, k2) - -function Base.:+( - k1::KernelSum{<:AbstractVector{<:Kernel}}, k2::KernelSum{<:AbstractVector{<:Kernel}} -) - return KernelSum(vcat(k1.kernels, k2.kernels)) -end - -Base.:+(k1::KernelSum, k2::KernelSum) = KernelSum(k1.kernels..., k2.kernels...) - -function Base.:+(k::Kernel, ks::KernelSum{<:AbstractVector{<:Kernel}}) - return KernelSum(vcat(k, ks.kernels)) -end - -Base.:+(k::Kernel, ks::KernelSum) = KernelSum(k, ks.kernels...) - -function Base.:+(ks::KernelSum{<:AbstractVector{<:Kernel}}, k::Kernel) - return KernelSum(vcat(ks.kernels, k)) -end - -Base.:+(ks::KernelSum, k::Kernel) = KernelSum(ks.kernels..., k) - Base.length(k::KernelSum) = length(k.kernels) (κ::KernelSum)(x, y) = sum(k(x, y) for k in κ.kernels) diff --git a/src/kernels/kerneltensorproduct.jl b/src/kernels/kerneltensorproduct.jl new file mode 100644 index 000000000..16bed14f3 --- /dev/null +++ b/src/kernels/kerneltensorproduct.jl @@ -0,0 +1,144 @@ +""" + KernelTensorProduct + +Tensor product of kernels. + +# Definition + +For inputs ``x = (x_1, \\ldots, x_n)`` and ``x' = (x'_1, \\ldots, x'_n)``, the tensor +product of kernels ``k_1, \\ldots, k_n`` is defined as +```math +k(x, x'; k_1, \\ldots, k_n) = \\Big(\\bigotimes_{i=1}^n k_i\\Big)(x, x') = \\prod_{i=1}^n k_i(x_i, x'_i). +``` + +# Construction + +The simplest way to specify a `KernelTensorProduct` is to use the overloaded `tensor` +operator or its alias `⊗` (can be typed by `\\otimes`). +```jldoctest tensorproduct +julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5, 2); + +julia> kernelmatrix(k1 ⊗ k2, RowVecs(X)) == kernelmatrix(k1, X[:, 1]) .* kernelmatrix(k2, X[:, 2]) +true +``` + +You can also specify a `KernelTensorProduct` by providing kernels as individual arguments +or as an iterable data structure such as a `Tuple` or a `Vector`. Using a tuple or +individual arguments guarantees that `KernelTensorProduct` is concretely typed but might +lead to large compilation times if the number of kernels is large. +```jldoctest tensorproduct +julia> KernelTensorProduct(k1, k2) == k1 ⊗ k2 +true + +julia> KernelTensorProduct((k1, k2)) == k1 ⊗ k2 +true + +julia> KernelTensorProduct([k1, k2]) == k1 ⊗ k2 +true +``` +""" +struct KernelTensorProduct{K} <: Kernel + kernels::K +end + +function KernelTensorProduct(kernel::Kernel, kernels::Kernel...) + return KernelTensorProduct((kernel, kernels...)) +end + +@functor KernelTensorProduct + +Base.length(kernel::KernelTensorProduct) = length(kernel.kernels) + +function (kernel::KernelTensorProduct)(x, y) + if !(length(x) == length(y) == length(kernel)) + throw(DimensionMismatch("number of kernels and number of features +are not consistent")) + end + return prod(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y)) +end + +function validate_domain(k::KernelTensorProduct, x::AbstractVector) + return dim(x) == length(k) || + error("number of kernels and groups of features are not consistent") +end + +# Utility for slicing up inputs. +slices(x::AbstractVector{<:Real}) = (x,) +slices(x::ColVecs) = eachrow(x.X) +slices(x::RowVecs) = eachcol(x.X) + +function kernelmatrix!(K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVector) + validate_inplace_dims(K, x) + validate_domain(k, x) + + kernels_and_inputs = zip(k.kernels, slices(x)) + kernelmatrix!(K, first(kernels_and_inputs)...) + for (k, xi) in Iterators.drop(kernels_and_inputs, 1) + K .*= kernelmatrix(k, xi) + end + + return K +end + +function kernelmatrix!( + K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVector, y::AbstractVector +) + validate_inplace_dims(K, x, y) + validate_domain(k, x) + + kernels_and_inputs = zip(k.kernels, slices(x), slices(y)) + kernelmatrix!(K, first(kernels_and_inputs)...) + for (k, xi, yi) in Iterators.drop(kernels_and_inputs, 1) + K .*= kernelmatrix(k, xi, yi) + end + + return K +end + +function kerneldiagmatrix!(K::AbstractVector, k::KernelTensorProduct, x::AbstractVector) + validate_inplace_dims(K, x) + validate_domain(k, x) + + kernels_and_inputs = zip(k.kernels, slices(x)) + kerneldiagmatrix!(K, first(kernels_and_inputs)...) + for (k, xi) in Iterators.drop(kernels_and_inputs, 1) + K .*= kerneldiagmatrix(k, xi) + end + + return K +end + +function kernelmatrix(k::KernelTensorProduct, x::AbstractVector) + validate_domain(k, x) + return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x)) +end + +function kernelmatrix(k::KernelTensorProduct, x::AbstractVector, y::AbstractVector) + validate_domain(k, x) + return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x), slices(y)) +end + +function kerneldiagmatrix(k::KernelTensorProduct, x::AbstractVector) + validate_domain(k, x) + return mapreduce(kerneldiagmatrix, hadamard, k.kernels, slices(x)) +end + +Base.show(io::IO, kernel::KernelTensorProduct) = printshifted(io, kernel, 0) + +function Base.:(==)(x::KernelTensorProduct, y::KernelTensorProduct) + return ( + length(x.kernels) == length(y.kernels) && + all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels)) + ) +end + +function printshifted(io::IO, kernel::KernelTensorProduct, shift::Int) + print(io, "Tensor product of ", length(kernel), " kernels:") + for k in kernel.kernels + print(io, "\n") + for _ in 1:(shift + 1) + print(io, "\t") + end + printshifted(io, k, shift + 2) + end +end diff --git a/src/kernels/overloads.jl b/src/kernels/overloads.jl new file mode 100644 index 000000000..b3ba76c7f --- /dev/null +++ b/src/kernels/overloads.jl @@ -0,0 +1,22 @@ +for (M, op, T) in ( + (:Base, :+, :KernelSum), + (:Base, :*, :KernelProduct), + (:TensorCore, :tensor, :KernelTensorProduct), +) + @eval begin + $M.$op(k1::Kernel, k2::Kernel) = $T(k1, k2) + + $M.$op(k1::$T, k2::$T) = $T(k1.kernels..., k2.kernels...) + function $M.$op( + k1::$T{<:AbstractVector{<:Kernel}}, k2::$T{<:AbstractVector{<:Kernel}} + ) + return $T(vcat(k1.kernels, k2.kernels)) + end + + $M.$op(k::Kernel, ks::$T) = $T(k, ks.kernels...) + $M.$op(k::Kernel, ks::$T{<:AbstractVector{<:Kernel}}) = $T(vcat(k, ks.kernels)) + + $M.$op(ks::$T, k::Kernel) = $T(ks.kernels..., k) + $M.$op(ks::$T{<:AbstractVector{<:Kernel}}, k::Kernel) = $T(vcat(ks.kernels, k)) + end +end diff --git a/src/kernels/tensorproduct.jl b/src/kernels/tensorproduct.jl deleted file mode 100644 index c1b18212b..000000000 --- a/src/kernels/tensorproduct.jl +++ /dev/null @@ -1,120 +0,0 @@ -""" - TensorProduct(kernels...) - -Create a tensor product kernel from kernels ``k_1, \\ldots, k_n``, i.e., -a kernel ``k`` that is given by -```math -k(x, y) = \\prod_{i=1}^n k_i(x_i, y_i). -``` - -The `kernels` can be specified as individual arguments, a tuple, or an iterable data -structure such as an array. Using a tuple or individual arguments guarantees that -`TensorProduct` is concretely typed but might lead to large compilation times if the -number of kernels is large. -""" -struct TensorProduct{K} <: Kernel - kernels::K -end - -function TensorProduct(kernel::Kernel, kernels::Kernel...) - return TensorProduct((kernel, kernels...)) -end - -@functor TensorProduct - -Base.length(kernel::TensorProduct) = length(kernel.kernels) - -function (kernel::TensorProduct)(x, y) - if !(length(x) == length(y) == length(kernel)) - throw(DimensionMismatch("number of kernels and number of features -are not consistent")) - end - return prod(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y)) -end - -function validate_domain(k::TensorProduct, x::AbstractVector) - return dim(x) == length(k) || - error("number of kernels and groups of features are not consistent") -end - -# Utility for slicing up inputs. -slices(x::AbstractVector{<:Real}) = (x,) -slices(x::ColVecs) = eachrow(x.X) -slices(x::RowVecs) = eachcol(x.X) - -function kernelmatrix!(K::AbstractMatrix, k::TensorProduct, x::AbstractVector) - validate_inplace_dims(K, x) - validate_domain(k, x) - - kernels_and_inputs = zip(k.kernels, slices(x)) - kernelmatrix!(K, first(kernels_and_inputs)...) - for (k, xi) in Iterators.drop(kernels_and_inputs, 1) - K .*= kernelmatrix(k, xi) - end - - return K -end - -function kernelmatrix!( - K::AbstractMatrix, k::TensorProduct, x::AbstractVector, y::AbstractVector -) - validate_inplace_dims(K, x, y) - validate_domain(k, x) - - kernels_and_inputs = zip(k.kernels, slices(x), slices(y)) - kernelmatrix!(K, first(kernels_and_inputs)...) - for (k, xi, yi) in Iterators.drop(kernels_and_inputs, 1) - K .*= kernelmatrix(k, xi, yi) - end - - return K -end - -function kerneldiagmatrix!(K::AbstractVector, k::TensorProduct, x::AbstractVector) - validate_inplace_dims(K, x) - validate_domain(k, x) - - kernels_and_inputs = zip(k.kernels, slices(x)) - kerneldiagmatrix!(K, first(kernels_and_inputs)...) - for (k, xi) in Iterators.drop(kernels_and_inputs, 1) - K .*= kerneldiagmatrix(k, xi) - end - - return K -end - -function kernelmatrix(k::TensorProduct, x::AbstractVector) - validate_domain(k, x) - return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x)) -end - -function kernelmatrix(k::TensorProduct, x::AbstractVector, y::AbstractVector) - validate_domain(k, x) - return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x), slices(y)) -end - -function kerneldiagmatrix(k::TensorProduct, x::AbstractVector) - validate_domain(k, x) - return mapreduce(kerneldiagmatrix, hadamard, k.kernels, slices(x)) -end - -Base.show(io::IO, kernel::TensorProduct) = printshifted(io, kernel, 0) - -function Base.:(==)(x::TensorProduct, y::TensorProduct) - return ( - length(x.kernels) == length(y.kernels) && - all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels)) - ) -end - -function printshifted(io::IO, kernel::TensorProduct, shift::Int) - print(io, "Tensor product of ", length(kernel), " kernels:") - for k in kernel.kernels - print(io, "\n") - for _ in 1:(shift + 1) - print(io, "\t") - end - print(io, "- ") - printshifted(io, k, shift + 2) - end -end diff --git a/src/matrix/kernelkroneckermat.jl b/src/matrix/kernelkroneckermat.jl index 384bd1de7..26f7d65a9 100644 --- a/src/matrix/kernelkroneckermat.jl +++ b/src/matrix/kernelkroneckermat.jl @@ -1,11 +1,14 @@ -using .Kronecker +# Since Kronecker does not implement `TensorCore.:⊗` but instead exports its own function +# `Kronecker.:⊗`, only the module is imported and Kronecker.:⊗ and Kronecker.kronecker are +# called explicitly. +using .Kronecker: Kronecker export kernelkronmat function kernelkronmat(κ::Kernel, X::AbstractVector, dims::Int) @assert iskroncompatible(κ) "The chosen kernel is not compatible for kroenecker matrices (see [`iskroncompatible`](@ref))" k = kernelmatrix(κ, X) - return kronecker(k, dims) + return Kronecker.kronecker(k, dims) end function kernelkronmat( @@ -13,7 +16,7 @@ function kernelkronmat( ) @assert iskroncompatible(κ) "The chosen kernel is not compatible for Kronecker matrices" Ks = kernelmatrix.(κ, X) - return K = reduce(⊗, Ks) + return K = reduce(Kronecker.:⊗, Ks) end """ diff --git a/src/utils.jl b/src/utils.jl index 487405613..7e28c82ef 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,3 @@ -hadamard(x, y) = x .* y - # Macro for checking arguments macro check_args(K, param, cond, desc=string(cond)) quote diff --git a/test/kernels/kernelproduct.jl b/test/kernels/kernelproduct.jl index d35ed3f56..670a5bdf0 100644 --- a/test/kernels/kernelproduct.jl +++ b/test/kernels/kernelproduct.jl @@ -1,43 +1,18 @@ @testset "kernelproduct" begin - rng = MersenneTwister(123456) - x = rand(rng) * 2 - v1 = rand(rng, 3) - v2 = rand(rng, 3) - k1 = LinearKernel() k2 = SqExponentialKernel() - k3 = RationalQuadraticKernel() - X = rand(rng, 2, 2) - k = KernelProduct(k1, k2) - ks1 = 2.0 * k1 - ks2 = 0.5 * k2 + @test k == KernelProduct([k1, k2]) == KernelProduct((k1, k2)) @test length(k) == 2 @test string(k) == ( "Product of 2 kernels:\n\tLinear Kernel (c = 0.0)\n\tSquared " * "Exponential Kernel" ) - @test k(v1, v2) == (k1 * k2)(v1, v2) - @test (k * k3)(v1, v2) ≈ (k3 * k)(v1, v2) - @test (k1 * k2)(v1, v2) == KernelProduct(k1, k2)(v1, v2) - @test (k * ks1)(v1, v2) ≈ (ks1 * k)(v1, v2) - @test (k * k)(v1, v2) == KernelProduct([k1, k2, k1, k2])(v1, v2) - @test KernelProduct([k1, k2]) == KernelProduct((k1, k2)) == k1 * k2 - - @test (KernelProduct([k1, k2]) * KernelProduct([k2, k1])).kernels == [k1, k2, k2, k1] - @test (KernelProduct([k1, k2]) * k3).kernels == [k1, k2, k3] - @test (k3 * KernelProduct([k1, k2])).kernels == [k3, k1, k2] - - @test (KernelProduct((k1, k2)) * KernelProduct((k2, k1))).kernels == (k1, k2, k2, k1) - @test (KernelProduct((k1, k2)) * k3).kernels == (k1, k2, k3) - @test (k3 * KernelProduct((k1, k2))).kernels == (k3, k1, k2) # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs( - x -> SqExponentialKernel() * LinearKernel(; c=x[1]), - rand(1); - ADs=[:ForwardDiff, :ReverseDiff, :Zygote], + x -> KernelProduct(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))), rand(1) ) test_params(k1 * k2, (k1, k2)) diff --git a/test/kernels/kernelsum.jl b/test/kernels/kernelsum.jl index a76f8270b..da69fdcc7 100644 --- a/test/kernels/kernelsum.jl +++ b/test/kernels/kernelsum.jl @@ -1,43 +1,16 @@ @testset "kernelsum" begin - rng = MersenneTwister(123456) - x = rand(rng) * 2 - v1 = rand(rng, 3) - v2 = rand(rng, 3) - k1 = LinearKernel() k2 = SqExponentialKernel() - k3 = RationalQuadraticKernel() - X = rand(rng, 2, 2) - k = KernelSum(k1, k2) - ks1 = 2.0 * k1 - ks2 = 0.5 * k2 + @test k == KernelSum([k1, k2]) == KernelSum((k1, k2)) @test length(k) == 2 @test string(k) == ( "Sum of 2 kernels:\n\tLinear Kernel (c = 0.0)\n\tSquared " * "Exponential Kernel" ) - @test k(v1, v2) == (k1 + k2)(v1, v2) - @test (k + k3)(v1, v2) ≈ (k3 + k)(v1, v2) - @test (k1 + k2)(v1, v2) == KernelSum(k1, k2)(v1, v2) - @test (k + ks1)(v1, v2) ≈ (ks1 + k)(v1, v2) - @test (k + k)(v1, v2) == KernelSum([k1, k2, k1, k2])(v1, v2) - @test KernelSum([k1, k2]) == KernelSum((k1, k2)) == k1 + k2 - - @test (KernelSum([k1, k2]) + KernelSum([k2, k1])).kernels == [k1, k2, k2, k1] - @test (KernelSum([k1, k2]) + k3).kernels == [k1, k2, k3] - @test (k3 + KernelSum([k1, k2])).kernels == [k3, k1, k2] - - @test (KernelSum((k1, k2)) + KernelSum((k2, k1))).kernels == (k1, k2, k2, k1) - @test (KernelSum((k1, k2)) + k3).kernels == (k1, k2, k3) - @test (k3 + KernelSum((k1, k2))).kernels == (k3, k1, k2) # Standardised tests. TestUtils.test_interface(k, Float64) - test_ADs( - x -> KernelSum(SqExponentialKernel(), LinearKernel(; c=x[1])), - rand(1); - ADs=[:ForwardDiff, :ReverseDiff, :Zygote], - ) + test_ADs(x -> KernelSum(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))), rand(1)) test_params(k1 + k2, (k1, k2)) end diff --git a/test/kernels/tensorproduct.jl b/test/kernels/kerneltensorproduct.jl similarity index 55% rename from test/kernels/tensorproduct.jl rename to test/kernels/kerneltensorproduct.jl index d762112c0..e01fa3871 100644 --- a/test/kernels/tensorproduct.jl +++ b/test/kernels/kerneltensorproduct.jl @@ -1,4 +1,4 @@ -@testset "tensorproduct" begin +@testset "kerneltensorproduct" begin rng = MersenneTwister(123456) u1 = rand(rng, 10) u2 = rand(rng, 10) @@ -8,12 +8,15 @@ # kernels k1 = SqExponentialKernel() k2 = ExponentialKernel() - kernel1 = TensorProduct(k1, k2) - kernel2 = TensorProduct([k1, k2]) + kernel1 = KernelTensorProduct(k1, k2) + kernel2 = KernelTensorProduct([k1, k2]) @test kernel1 == kernel2 - @test kernel1.kernels === (k1, k2) === TensorProduct((k1, k2)).kernels + @test kernel1.kernels === (k1, k2) === KernelTensorProduct((k1, k2)).kernels @test length(kernel1) == length(kernel2) == 2 + @test string(kernel1) == ( + "Tensor product of 2 kernels:\n\tSquared Exponential Kernel\n\tExponential Kernel" + ) @test_throws DimensionMismatch kernel1(rand(3), rand(3)) @testset "val" begin @@ -24,14 +27,23 @@ end end + # Deprecations + @test (@test_deprecated TensorProduct(k1, k2)) == k1 ⊗ k2 + @test (@test_deprecated TensorProduct((k1, k2))) == k1 ⊗ k2 + @test (@test_deprecated TensorProduct([k1, k2])) == k1 ⊗ k2 + # Standardised tests. TestUtils.test_interface(kernel1, ColVecs{Float64}) TestUtils.test_interface(kernel1, RowVecs{Float64}) - test_ADs(() -> TensorProduct(SqExponentialKernel(), LinearKernel()); dims=[2, 2]) # ADs = [:ForwardDiff, :ReverseDiff]) - test_params(TensorProduct(k1, k2), (k1, k2)) + test_ADs( + x -> KernelTensorProduct(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))), + rand(1); + dims=[2, 2], + ) + test_params(KernelTensorProduct(k1, k2), (k1, k2)) @testset "single kernel" begin - kernel = TensorProduct(k1) + kernel = KernelTensorProduct(k1) @test length(kernel) == 1 @testset "eval" begin diff --git a/test/kernels/overloads.jl b/test/kernels/overloads.jl new file mode 100644 index 000000000..eb79d41f8 --- /dev/null +++ b/test/kernels/overloads.jl @@ -0,0 +1,36 @@ +@testset "overloads" begin + rng = MersenneTwister(123456) + + k1 = LinearKernel() + k2 = SqExponentialKernel() + k3 = RationalQuadraticKernel() + + for (op, T) in ((+, KernelSum), (*, KernelProduct), (⊗, KernelTensorProduct)) + if T === KernelTensorProduct + v2_1 = rand(rng, 2) + v2_2 = rand(rng, 2) + v3_1 = rand(rng, 3) + v3_2 = rand(rng, 3) + v4_1 = rand(rng, 4) + v4_2 = rand(rng, 4) + else + v2_1 = v3_1 = v4_1 = rand(rng, 3) + v2_2 = v3_2 = v4_2 = rand(rng, 3) + end + k = T(k1, k2) + + @test op(k1, k2)(v2_1, v2_2) == k(v2_1, v2_2) + @test op(k, k3)(v3_1, v3_2) == T((k1, k2, k3))(v3_1, v3_2) + @test op(k3, k)(v3_1, v3_2) == T((k3, k1, k2))(v3_1, v3_2) + @test op(k, k)(v4_1, v4_2) == T((k1, k2, k1, k2))(v4_1, v4_2) + @test op(k1, k2) == T([k1, k2]) == T((k1, k2)) + + @test op(T([k1, k2]), T([k2, k1])).kernels == [k1, k2, k2, k1] + @test op(T([k1, k2]), k3).kernels == [k1, k2, k3] + @test op(k3, T([k1, k2])).kernels == [k3, k1, k2] + + @test op(T((k1, k2)), T((k2, k1))).kernels == (k1, k2, k2, k1) + @test op(T((k1, k2)), k3).kernels == (k1, k2, k3) + @test op(k3, T((k1, k2))).kernels == (k3, k1, k2) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 733a4c4c9..7ad679905 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,7 @@ using KernelFunctions using AxisArrays using Distances using Documenter -using Kronecker +using Kronecker: Kronecker using LinearAlgebra using PDMats using Random @@ -120,8 +120,9 @@ include("test_utils.jl") @testset "kernels" begin include(joinpath("kernels", "kernelproduct.jl")) include(joinpath("kernels", "kernelsum.jl")) + include(joinpath("kernels", "kerneltensorproduct.jl")) + include(joinpath("kernels", "overloads.jl")) include(joinpath("kernels", "scaledkernel.jl")) - include(joinpath("kernels", "tensorproduct.jl")) include(joinpath("kernels", "transformedkernel.jl")) end @info "Ran tests on Kernel"