Skip to content

Commit

Permalink
import lecture 12
Browse files Browse the repository at this point in the history
  • Loading branch information
smidl committed Dec 13, 2023
1 parent 8d31455 commit a3dc9be
Show file tree
Hide file tree
Showing 18 changed files with 6,269 additions and 0 deletions.
466 changes: 466 additions & 0 deletions docs/src/lecture_12/LV_GaussNum.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
454 changes: 454 additions & 0 deletions docs/src/lecture_12/LV_GaussNum2.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
454 changes: 454 additions & 0 deletions docs/src/lecture_12/LV_Measurements.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
454 changes: 454 additions & 0 deletions docs/src/lecture_12/LV_Measurements2.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
460 changes: 460 additions & 0 deletions docs/src/lecture_12/LV_Quadrics.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2,368 changes: 2,368 additions & 0 deletions docs/src/lecture_12/LV_ensemble.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/src/lecture_12/cubature.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/src/lecture_12/euler.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
134 changes: 134 additions & 0 deletions docs/src/lecture_12/hw.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# [Homework 12 - The Runge-Kutta ODE Solver](@id hw12)

There exist many different ODE solvers. To demonstrate how we can get
significantly better results with a simple update to `Euler`, you will
implement the second order Runge-Kutta method `RK2`:
```math
\begin{align*}
\tilde x_{n+1} &= x_n + hf(x_n, t_n)\\
x_{n+1} &= x_n + \frac{h}{2}(f(x_n,t_n)+f(\tilde x_{n+1},t_{n+1}))
\end{align*}
```
`RK2` is a 2nd order method. It uses not only $f$ (the slope at a given point),
but also $f'$ (the derivative of the slope). With some clever manipulations you
can arrive at the equations above with make use of $f'$ without needing an
explicit expression for it (if you want to know how, see
[here](https://web.mit.edu/10.001/Web/Course_Notes/Differential_Equations_Notes/node5.html)).
Essentially, `RK2` computes an initial guess $\tilde x_{n+1}$ to then average
the slopes at the current point $x_n$ and at the guess $\tilde x_{n+1}$ which
is illustarted below.
![rk2](rk2.png)

The code from the lab that you will need for this homework is given below.
As always, put all your code in a file called `hw.jl`, zip it, and upload it
to BRUTE.
```@example hw
struct ODEProblem{F,T<:Tuple{Number,Number},U<:AbstractVector,P<:AbstractVector}
f::F
tspan::T
u0::U
θ::P
end
abstract type ODESolver end
struct Euler{T} <: ODESolver
dt::T
end
function (solver::Euler)(prob::ODEProblem, u, t)
f, θ, dt = prob.f, prob.θ, solver.dt
(u + dt*f(u,θ), t+dt)
end
function solve(prob::ODEProblem, solver::ODESolver)
t = prob.tspan[1]; u = prob.u0
us = [u]; ts = [t]
while t < prob.tspan[2]
(u,t) = solver(prob, u, t)
push!(us,u)
push!(ts,t)
end
ts, reduce(hcat,us)
end
# Define & Solve ODE
function lotkavolterra(x,θ)
α, β, γ, δ = θ
x₁, x₂ = x
dx₁ = α*x₁ - β*x₁*x₂
dx₂ = δ*x₁*x₂ - γ*x₂
[dx₁, dx₂]
end
```
```@raw html
<div class="admonition is-category-homework">
<header class="admonition-header">Homework (2 points)</header>
<div class="admonition-body">
```
Implement the 2nd order Runge-Kutta solver according to the equations given above
by overloading the call method of a new type `RK2`.
```julia
(solver::RK2)(prob::ODEProblem, u, t)
```
```@raw html
</div></div>
```

```@setup hw
struct RK2{T} <: ODESolver
dt::T
end
function (solver::RK2)(prob::ODEProblem, u, t)
f, θ, dt = prob.f, prob.θ, solver.dt
du = f(u,θ)
uh = u + du*dt
u + dt/2*(du + f(uh,θ)), t+dt
end
```
You should be able to use it exactly like our `Euler` solver before:
```@example hw
using Plots
using JLD2
# Define ODE
function lotkavolterra(x,θ)
α, β, γ, δ = θ
x₁, x₂ = x
dx₁ = α*x₁ - β*x₁*x₂
dx₂ = δ*x₁*x₂ - γ*x₂
[dx₁, dx₂]
end
θ = [0.1,0.2,0.3,0.2]
u0 = [1.0,1.0]
tspan = (0.,100.)
prob = ODEProblem(lotkavolterra,tspan,u0,θ)
# load correct data
true_data = load("lotkadata.jld2")
# create plot
p1 = plot(true_data["t"], true_data["u"][1,:], lw=4, ls=:dash, alpha=0.7,
color=:gray, label="x Truth")
plot!(p1, true_data["t"], true_data["u"][2,:], lw=4, ls=:dash, alpha=0.7,
color=:gray, label="y Truth")
# Euler solve
(t,X) = solve(prob, Euler(0.2))
plot!(p1,t,X[1,:], color=3, lw=3, alpha=0.8, label="x Euler", ls=:dot)
plot!(p1,t,X[2,:], color=4, lw=3, alpha=0.8, label="y Euler", ls=:dot)
# RK2 solve
(t,X) = solve(prob, RK2(0.2))
plot!(p1,t,X[1,:], color=1, lw=3, alpha=0.8, label="x RK2")
plot!(p1,t,X[2,:], color=2, lw=3, alpha=0.8, label="y RK2")
```
49 changes: 49 additions & 0 deletions docs/src/lecture_12/lab-ode.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
struct ODEProblem{F,T<:Tuple{Number,Number},U<:AbstractVector,P<:AbstractVector}
f::F
tspan::T
u0::U
θ::P
end


abstract type ODESolver end

struct Euler{T} <: ODESolver
dt::T
end

function (solver::Euler)(prob::ODEProblem, u, t)
f, θ, dt = prob.f, prob.θ, solver.dt
(u + dt*f(u,θ), t+dt)
end


function solve(prob::ODEProblem, solver::ODESolver)
t = prob.tspan[1]; u = prob.u0
us = [u]; ts = [t]
while t < prob.tspan[2]
(u,t) = solver(prob, u, t)
push!(us,u)
push!(ts,t)
end
ts, reduce(hcat,us)
end


# Define & Solve ODE

function lotkavolterra(x,θ)
α, β, γ, δ = θ
x₁, x₂ = x

dx₁ = α*x₁ - β*x₁*x₂
dx₂ = δ*x₁*x₂ - γ*x₂

[dx₁, dx₂]
end

θ = [0.1,0.2,0.3,0.2]
u0 = [1.0,1.0]
tspan = (0.,100.)
prob = ODEProblem(lotkavolterra,tspan,u0,θ)

182 changes: 182 additions & 0 deletions docs/src/lecture_12/lab.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
using Zygote

struct GaussNum{T<:Real} <: Real
μ::T
σ::T
end
mu(x::GaussNum) = x.μ
sig(x::GaussNum) = x.σ
GaussNum(x,y) = GaussNum(promote(x,y)...)
±(x,y) = GaussNum(x,y)
Base.convert(::Type{T}, x::T) where T<:GaussNum = x
Base.convert(::Type{GaussNum{T}}, x::Number) where T = GaussNum(x,zero(T))
Base.promote_rule(::Type{GaussNum{T}}, ::Type{S}) where {T,S} = GaussNum{T}
Base.promote_rule(::Type{GaussNum{T}}, ::Type{GaussNum{T}}) where T = GaussNum{T}

# convert(GaussNum{Float64}, 1.0) |> display
# promote(GaussNum(1.0,1.0), 2.0) |> display
# error()

#+(x::GaussNum{T},a::T) where T =GaussNum(x.μ+a,x.σ)
#+(a::T,x::GaussNum{T}) where T =GaussNum(x.μ+a,x.σ)
#-(x::GaussNum{T},a::T) where T =GaussNum(x.μ-a,x.σ)
#-(a::T,x::GaussNum{T}) where T =GaussNum(x.μ-a,x.σ)
#*(x::GaussNum{T},a::T) where T =GaussNum(x.μ*a,a*x.σ)
#*(a::T,x::GaussNum{T}) where T =GaussNum(x.μ*a,a*x.σ)


# function Base.:*(x1::GaussNum, x2::GaussNum)
# f(x1,x2) = x1 * x2
# s1 = Zygote.gradient(μ -> f(μ,x2.μ), x1.μ)[1]^2 * x1.σ^2
# s2 = Zygote.gradient(μ -> f(x1.μ,μ), x2.μ)[1]^2 * x2.σ^2
# GaussNum(f(x1.μ,x2.μ), sqrt(s1+s2))
# end

function _uncertain(f, args::GaussNum...)
μs = [x.μ for x in args]
dfs = Zygote.gradient(f,μs...)
σ = map(zip(dfs,args)) do (df,x)
df^2 * x.σ^2
end |> sum |> sqrt
GaussNum(f(μs...), σ)
end

function _uncertain(expr::Expr)
if expr.head == :call
:(_uncertain($(expr.args[1]), $(expr.args[2:end]...)))
else
error("Expression has to be a :call")
end
end

macro uncertain(expr)
_uncertain(expr)
end

getmodule(f) = first(methods(f)).module

function _register(func::Symbol)
mod = getmodule(eval(func))
:($(mod).$(func)(args::GaussNum...) = _uncertain($func, args...))
end

function _register(funcs::Expr)
Expr(:block, map(_register, funcs.args)...)
end

macro register(funcs)
_register(funcs)
end

@register - + *

f(x,y) = x+y*x


# @register *
# @register +
# @register -
# @register f

asdf(x1::GaussNum{T},x2::GaussNum{T}) where T =GaussNum(x1.μ*x2.μ, sqrt((x2.μ*x1.σ).^2 + (x1.μ * x2.σ).^2))
gggg(x1::GaussNum{T},x2::GaussNum{T}) where T =GaussNum(x1.μ+x2.μ, sqrt(x1.σ.^2 + x2.σ.^2))

x1 = GaussNum(rand(),rand())
x2 = GaussNum(rand(),rand())

display(x1*x2)
display(asdf(x1,x2))
display(_uncertain(*,x1,x2))
display(@uncertain x1*x2)

display(x1-x2)
display(x1+x2)
display(f(x1,x2))
#error()


using Plots
using JLD2

abstract type AbstractODEProblem end

struct ODEProblem{F,T,U,P} <: AbstractODEProblem
f::F
tspan::T
u0::U
θ::P
end

abstract type ODESolver end
struct Euler{T} <: ODESolver
dt::T
end
struct RK2{T} <: ODESolver
dt::T
end

function f(x,θ)
α, β, γ, δ = θ
x₁, x₂ = x

dx₁ = α*x₁ - β*x₁*x₂
dx₂ = δ*x₁*x₂ - γ*x₂

[dx₁, dx₂]
end

function solve(prob::AbstractODEProblem, solver::ODESolver)
t = prob.tspan[1]; u = prob.u0
us = [u]; ts = [t]
while t < prob.tspan[2]
(u,t) = solver(prob, u, t)
push!(us,u)
push!(ts,t)
end
ts, reduce(hcat,us)
end

function (solver::Euler)(prob::ODEProblem, u, t)
f, θ, dt = prob.f, prob.θ, solver.dt
(u + dt*f(u,θ), t+dt)
end

function (solver::RK2)(prob::ODEProblem, u, t)
f, θ, dt = prob.f, prob.θ, solver.dt
uh = u + f(u,θ)*dt
u + dt/2*(f(u,θ) + f(uh,θ)), t+dt
end


@recipe function plot(ts::AbstractVector, xs::AbstractVector{<:GaussNum})
# you can set a default value for an attribute with `-->`
# and force an argument with `:=`
μs = [x.μ for x in xs]
σs = [x.σ for x in xs]
@series begin
:seriestype := :path
# ignore series in legend and color cycling
primary := false
linecolor := nothing
fillcolor := :lightgray
fillalpha := 0.5
fillrange := μs .- σs
# ensure no markers are shown for the error band
markershape := :none
# return series data
ts, μs .+ σs
end
ts, μs
end

θ = [0.1,0.2,0.3,0.2]
u0 = [GaussNum(1.0,0.1),GaussNum(1.0,0.1)]
tspan = (0.,100.)
dt = 0.1
prob = ODEProblem(f,tspan,u0,θ)

t,X=solve(prob, RK2(0.2))
p1 = plot(t, X[1,:], label="x", lw=3)
plot!(p1, t, X[2,:], label="y", lw=3)

display(p1)
Loading

0 comments on commit a3dc9be

Please sign in to comment.