Skip to content

Commit

Permalink
add constant kernel (#58)
Browse files Browse the repository at this point in the history
* add constant kernel

add a constant kernel as described in the book Applied Stochastic Differential Equations on p.263.
add ConstantantKernel to the tesstset gp/lti_sde.jl

* Update runtests.jl

forgot to revert some temporary changes...

* Update test/gp/lti_sde.jl

Co-authored-by: willtebbutt <[email protected]>

* Update lti_sde.jl

allow type of k.c[1] to be different from T.
k.c[1] enters initial state variance in stationaly_distribution.
drop adjoint of stationary_distribution in favour of an adjount for the constructor of SMatrix

* Update Project.toml

patch bump

* new zygote rule

add new zygote rule for SMatrix{1, 1}

Co-authored-by: Andreas <[email protected]>
Co-authored-by: willtebbutt <[email protected]>
  • Loading branch information
3 people authored Apr 9, 2021
1 parent d56d1cd commit 1271b8c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TemporalGPs"
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
authors = ["willtebbutt <[email protected]>"]
version = "0.5.1"
version = "0.5.2"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand Down
18 changes: 18 additions & 0 deletions src/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,25 @@ Zygote.@adjoint function stationary_distribution(k::Matern52Kernel, storage_type
return stationary_distribution(k, storage_type), Δ->(nothing, nothing)
end

# Constant

function TemporalGPs.to_sde(k::ConstantKernel, ::SArrayStorage{T}) where {T<:Real}
F = SMatrix{1, 1, T}(0)
q = convert(T, 0)
H = SVector{1, T}(1)
return F, q, H
end

function TemporalGPs.stationary_distribution(k::ConstantKernel, ::SArrayStorage{T}) where {T<:Real}
return TemporalGPs.Gaussian(
SVector{1, T}(0),
SMatrix{1, 1, T}( T(only(k.c)) ),
)
end

Zygote.@adjoint function to_sde(k::ConstantKernel, storage_type)
return to_sde(k, storage_type), Δ->(nothing, nothing)
end

# Scaled

Expand Down
6 changes: 3 additions & 3 deletions src/util/zygote_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ end
return SMatrix{D1, D2}(X), SMatrix_pullback
end

@adjoint function SMatrix{1, 1}(a)
SMatrix_pullback::AbstractMatrix) = (first(Δ), )
function Zygote._pullback(::AContext, ::Type{<:SMatrix{1, 1}}, a)
SMatrix_pullback(::Nothing) = nothing
SMatrix_pullback::AbstractMatrix) = (nothing, first(Δ), )
return SMatrix{1, 1}(a), SMatrix_pullback
end

# Implementation of the matrix exponential that assumes one doesn't require access to the
# gradient w.r.t. `A`, only `t`. The former is a bit compute-intensive to get at, while the
# latter is very cheap.
Expand Down
2 changes: 1 addition & 1 deletion test/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ println("lti_sde:")
# (name="static storage Float32", val=SArrayStorage(Float32)),
)

kernels = [Matern12Kernel(), Matern32Kernel(), Matern52Kernel()]
kernels = [Matern12Kernel(), Matern32Kernel(), Matern52Kernel(), ConstantKernel(c=1.5)]

@testset "$kernel, $(storage.name)" for kernel in kernels, storage in storages
F, q, H = TemporalGPs.to_sde(kernel, storage.val)
Expand Down

2 comments on commit 1271b8c

@willtebbutt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/33903

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.2 -m "<description of version>" 1271b8cc83174a972c836c3316cc3ec9fd5eeb11
git push origin v0.5.2

Please sign in to comment.