From 1271b8cc83174a972c836c3316cc3ec9fd5eeb11 Mon Sep 17 00:00:00 2001 From: andreaskoher Date: Fri, 9 Apr 2021 09:59:09 +0200 Subject: [PATCH] add constant kernel (#58) * add constant kernel add a constant kernel as described in the book Applied Stochastic Differential Equations on p.263. add ConstantantKernel to the tesstset gp/lti_sde.jl * Update runtests.jl forgot to revert some temporary changes... * Update test/gp/lti_sde.jl Co-authored-by: willtebbutt * Update lti_sde.jl allow type of k.c[1] to be different from T. k.c[1] enters initial state variance in stationaly_distribution. drop adjoint of stationary_distribution in favour of an adjount for the constructor of SMatrix * Update Project.toml patch bump * new zygote rule add new zygote rule for SMatrix{1, 1} Co-authored-by: Andreas Co-authored-by: willtebbutt --- Project.toml | 2 +- src/gp/lti_sde.jl | 18 ++++++++++++++++++ src/util/zygote_rules.jl | 6 +++--- test/gp/lti_sde.jl | 2 +- 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 5a5429aa..341f3699 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TemporalGPs" uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" authors = ["willtebbutt "] -version = "0.5.1" +version = "0.5.2" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" diff --git a/src/gp/lti_sde.jl b/src/gp/lti_sde.jl index 2d950c75..c26ea0e4 100644 --- a/src/gp/lti_sde.jl +++ b/src/gp/lti_sde.jl @@ -246,7 +246,25 @@ Zygote.@adjoint function stationary_distribution(k::Matern52Kernel, storage_type return stationary_distribution(k, storage_type), Δ->(nothing, nothing) end +# Constant +function TemporalGPs.to_sde(k::ConstantKernel, ::SArrayStorage{T}) where {T<:Real} + F = SMatrix{1, 1, T}(0) + q = convert(T, 0) + H = SVector{1, T}(1) + return F, q, H +end + +function TemporalGPs.stationary_distribution(k::ConstantKernel, ::SArrayStorage{T}) where {T<:Real} + return TemporalGPs.Gaussian( + SVector{1, T}(0), + SMatrix{1, 1, T}( T(only(k.c)) ), + ) +end + +Zygote.@adjoint function to_sde(k::ConstantKernel, storage_type) + return to_sde(k, storage_type), Δ->(nothing, nothing) +end # Scaled diff --git a/src/util/zygote_rules.jl b/src/util/zygote_rules.jl index a0a7b9a6..0ad8797b 100644 --- a/src/util/zygote_rules.jl +++ b/src/util/zygote_rules.jl @@ -38,11 +38,11 @@ end return SMatrix{D1, D2}(X), SMatrix_pullback end -@adjoint function SMatrix{1, 1}(a) - SMatrix_pullback(Δ::AbstractMatrix) = (first(Δ), ) +function Zygote._pullback(::AContext, ::Type{<:SMatrix{1, 1}}, a) + SMatrix_pullback(::Nothing) = nothing + SMatrix_pullback(Δ::AbstractMatrix) = (nothing, first(Δ), ) return SMatrix{1, 1}(a), SMatrix_pullback end - # Implementation of the matrix exponential that assumes one doesn't require access to the # gradient w.r.t. `A`, only `t`. The former is a bit compute-intensive to get at, while the # latter is very cheap. diff --git a/test/gp/lti_sde.jl b/test/gp/lti_sde.jl index 44479e0e..3e22dd67 100644 --- a/test/gp/lti_sde.jl +++ b/test/gp/lti_sde.jl @@ -26,7 +26,7 @@ println("lti_sde:") # (name="static storage Float32", val=SArrayStorage(Float32)), ) - kernels = [Matern12Kernel(), Matern32Kernel(), Matern52Kernel()] + kernels = [Matern12Kernel(), Matern32Kernel(), Matern52Kernel(), ConstantKernel(c=1.5)] @testset "$kernel, $(storage.name)" for kernel in kernels, storage in storages F, q, H = TemporalGPs.to_sde(kernel, storage.val)