-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
18 changed files
with
6,269 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# [Homework 12 - The Runge-Kutta ODE Solver](@id hw12) | ||
|
||
There exist many different ODE solvers. To demonstrate how we can get | ||
significantly better results with a simple update to `Euler`, you will | ||
implement the second order Runge-Kutta method `RK2`: | ||
```math | ||
\begin{align*} | ||
\tilde x_{n+1} &= x_n + hf(x_n, t_n)\\ | ||
x_{n+1} &= x_n + \frac{h}{2}(f(x_n,t_n)+f(\tilde x_{n+1},t_{n+1})) | ||
\end{align*} | ||
``` | ||
`RK2` is a 2nd order method. It uses not only $f$ (the slope at a given point), | ||
but also $f'$ (the derivative of the slope). With some clever manipulations you | ||
can arrive at the equations above with make use of $f'$ without needing an | ||
explicit expression for it (if you want to know how, see | ||
[here](https://web.mit.edu/10.001/Web/Course_Notes/Differential_Equations_Notes/node5.html)). | ||
Essentially, `RK2` computes an initial guess $\tilde x_{n+1}$ to then average | ||
the slopes at the current point $x_n$ and at the guess $\tilde x_{n+1}$ which | ||
is illustarted below. | ||
![rk2](rk2.png) | ||
|
||
The code from the lab that you will need for this homework is given below. | ||
As always, put all your code in a file called `hw.jl`, zip it, and upload it | ||
to BRUTE. | ||
```@example hw | ||
struct ODEProblem{F,T<:Tuple{Number,Number},U<:AbstractVector,P<:AbstractVector} | ||
f::F | ||
tspan::T | ||
u0::U | ||
θ::P | ||
end | ||
abstract type ODESolver end | ||
struct Euler{T} <: ODESolver | ||
dt::T | ||
end | ||
function (solver::Euler)(prob::ODEProblem, u, t) | ||
f, θ, dt = prob.f, prob.θ, solver.dt | ||
(u + dt*f(u,θ), t+dt) | ||
end | ||
function solve(prob::ODEProblem, solver::ODESolver) | ||
t = prob.tspan[1]; u = prob.u0 | ||
us = [u]; ts = [t] | ||
while t < prob.tspan[2] | ||
(u,t) = solver(prob, u, t) | ||
push!(us,u) | ||
push!(ts,t) | ||
end | ||
ts, reduce(hcat,us) | ||
end | ||
# Define & Solve ODE | ||
function lotkavolterra(x,θ) | ||
α, β, γ, δ = θ | ||
x₁, x₂ = x | ||
dx₁ = α*x₁ - β*x₁*x₂ | ||
dx₂ = δ*x₁*x₂ - γ*x₂ | ||
[dx₁, dx₂] | ||
end | ||
``` | ||
```@raw html | ||
<div class="admonition is-category-homework"> | ||
<header class="admonition-header">Homework (2 points)</header> | ||
<div class="admonition-body"> | ||
``` | ||
Implement the 2nd order Runge-Kutta solver according to the equations given above | ||
by overloading the call method of a new type `RK2`. | ||
```julia | ||
(solver::RK2)(prob::ODEProblem, u, t) | ||
``` | ||
```@raw html | ||
</div></div> | ||
``` | ||
|
||
```@setup hw | ||
struct RK2{T} <: ODESolver | ||
dt::T | ||
end | ||
function (solver::RK2)(prob::ODEProblem, u, t) | ||
f, θ, dt = prob.f, prob.θ, solver.dt | ||
du = f(u,θ) | ||
uh = u + du*dt | ||
u + dt/2*(du + f(uh,θ)), t+dt | ||
end | ||
``` | ||
You should be able to use it exactly like our `Euler` solver before: | ||
```@example hw | ||
using Plots | ||
using JLD2 | ||
# Define ODE | ||
function lotkavolterra(x,θ) | ||
α, β, γ, δ = θ | ||
x₁, x₂ = x | ||
dx₁ = α*x₁ - β*x₁*x₂ | ||
dx₂ = δ*x₁*x₂ - γ*x₂ | ||
[dx₁, dx₂] | ||
end | ||
θ = [0.1,0.2,0.3,0.2] | ||
u0 = [1.0,1.0] | ||
tspan = (0.,100.) | ||
prob = ODEProblem(lotkavolterra,tspan,u0,θ) | ||
# load correct data | ||
true_data = load("lotkadata.jld2") | ||
# create plot | ||
p1 = plot(true_data["t"], true_data["u"][1,:], lw=4, ls=:dash, alpha=0.7, | ||
color=:gray, label="x Truth") | ||
plot!(p1, true_data["t"], true_data["u"][2,:], lw=4, ls=:dash, alpha=0.7, | ||
color=:gray, label="y Truth") | ||
# Euler solve | ||
(t,X) = solve(prob, Euler(0.2)) | ||
plot!(p1,t,X[1,:], color=3, lw=3, alpha=0.8, label="x Euler", ls=:dot) | ||
plot!(p1,t,X[2,:], color=4, lw=3, alpha=0.8, label="y Euler", ls=:dot) | ||
# RK2 solve | ||
(t,X) = solve(prob, RK2(0.2)) | ||
plot!(p1,t,X[1,:], color=1, lw=3, alpha=0.8, label="x RK2") | ||
plot!(p1,t,X[2,:], color=2, lw=3, alpha=0.8, label="y RK2") | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
struct ODEProblem{F,T<:Tuple{Number,Number},U<:AbstractVector,P<:AbstractVector} | ||
f::F | ||
tspan::T | ||
u0::U | ||
θ::P | ||
end | ||
|
||
|
||
abstract type ODESolver end | ||
|
||
struct Euler{T} <: ODESolver | ||
dt::T | ||
end | ||
|
||
function (solver::Euler)(prob::ODEProblem, u, t) | ||
f, θ, dt = prob.f, prob.θ, solver.dt | ||
(u + dt*f(u,θ), t+dt) | ||
end | ||
|
||
|
||
function solve(prob::ODEProblem, solver::ODESolver) | ||
t = prob.tspan[1]; u = prob.u0 | ||
us = [u]; ts = [t] | ||
while t < prob.tspan[2] | ||
(u,t) = solver(prob, u, t) | ||
push!(us,u) | ||
push!(ts,t) | ||
end | ||
ts, reduce(hcat,us) | ||
end | ||
|
||
|
||
# Define & Solve ODE | ||
|
||
function lotkavolterra(x,θ) | ||
α, β, γ, δ = θ | ||
x₁, x₂ = x | ||
|
||
dx₁ = α*x₁ - β*x₁*x₂ | ||
dx₂ = δ*x₁*x₂ - γ*x₂ | ||
|
||
[dx₁, dx₂] | ||
end | ||
|
||
θ = [0.1,0.2,0.3,0.2] | ||
u0 = [1.0,1.0] | ||
tspan = (0.,100.) | ||
prob = ODEProblem(lotkavolterra,tspan,u0,θ) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
using Zygote | ||
|
||
struct GaussNum{T<:Real} <: Real | ||
μ::T | ||
σ::T | ||
end | ||
mu(x::GaussNum) = x.μ | ||
sig(x::GaussNum) = x.σ | ||
GaussNum(x,y) = GaussNum(promote(x,y)...) | ||
±(x,y) = GaussNum(x,y) | ||
Base.convert(::Type{T}, x::T) where T<:GaussNum = x | ||
Base.convert(::Type{GaussNum{T}}, x::Number) where T = GaussNum(x,zero(T)) | ||
Base.promote_rule(::Type{GaussNum{T}}, ::Type{S}) where {T,S} = GaussNum{T} | ||
Base.promote_rule(::Type{GaussNum{T}}, ::Type{GaussNum{T}}) where T = GaussNum{T} | ||
|
||
# convert(GaussNum{Float64}, 1.0) |> display | ||
# promote(GaussNum(1.0,1.0), 2.0) |> display | ||
# error() | ||
|
||
#+(x::GaussNum{T},a::T) where T =GaussNum(x.μ+a,x.σ) | ||
#+(a::T,x::GaussNum{T}) where T =GaussNum(x.μ+a,x.σ) | ||
#-(x::GaussNum{T},a::T) where T =GaussNum(x.μ-a,x.σ) | ||
#-(a::T,x::GaussNum{T}) where T =GaussNum(x.μ-a,x.σ) | ||
#*(x::GaussNum{T},a::T) where T =GaussNum(x.μ*a,a*x.σ) | ||
#*(a::T,x::GaussNum{T}) where T =GaussNum(x.μ*a,a*x.σ) | ||
|
||
|
||
# function Base.:*(x1::GaussNum, x2::GaussNum) | ||
# f(x1,x2) = x1 * x2 | ||
# s1 = Zygote.gradient(μ -> f(μ,x2.μ), x1.μ)[1]^2 * x1.σ^2 | ||
# s2 = Zygote.gradient(μ -> f(x1.μ,μ), x2.μ)[1]^2 * x2.σ^2 | ||
# GaussNum(f(x1.μ,x2.μ), sqrt(s1+s2)) | ||
# end | ||
|
||
function _uncertain(f, args::GaussNum...) | ||
μs = [x.μ for x in args] | ||
dfs = Zygote.gradient(f,μs...) | ||
σ = map(zip(dfs,args)) do (df,x) | ||
df^2 * x.σ^2 | ||
end |> sum |> sqrt | ||
GaussNum(f(μs...), σ) | ||
end | ||
|
||
function _uncertain(expr::Expr) | ||
if expr.head == :call | ||
:(_uncertain($(expr.args[1]), $(expr.args[2:end]...))) | ||
else | ||
error("Expression has to be a :call") | ||
end | ||
end | ||
|
||
macro uncertain(expr) | ||
_uncertain(expr) | ||
end | ||
|
||
getmodule(f) = first(methods(f)).module | ||
|
||
function _register(func::Symbol) | ||
mod = getmodule(eval(func)) | ||
:($(mod).$(func)(args::GaussNum...) = _uncertain($func, args...)) | ||
end | ||
|
||
function _register(funcs::Expr) | ||
Expr(:block, map(_register, funcs.args)...) | ||
end | ||
|
||
macro register(funcs) | ||
_register(funcs) | ||
end | ||
|
||
@register - + * | ||
|
||
f(x,y) = x+y*x | ||
|
||
|
||
# @register * | ||
# @register + | ||
# @register - | ||
# @register f | ||
|
||
asdf(x1::GaussNum{T},x2::GaussNum{T}) where T =GaussNum(x1.μ*x2.μ, sqrt((x2.μ*x1.σ).^2 + (x1.μ * x2.σ).^2)) | ||
gggg(x1::GaussNum{T},x2::GaussNum{T}) where T =GaussNum(x1.μ+x2.μ, sqrt(x1.σ.^2 + x2.σ.^2)) | ||
|
||
x1 = GaussNum(rand(),rand()) | ||
x2 = GaussNum(rand(),rand()) | ||
|
||
display(x1*x2) | ||
display(asdf(x1,x2)) | ||
display(_uncertain(*,x1,x2)) | ||
display(@uncertain x1*x2) | ||
|
||
display(x1-x2) | ||
display(x1+x2) | ||
display(f(x1,x2)) | ||
#error() | ||
|
||
|
||
using Plots | ||
using JLD2 | ||
|
||
abstract type AbstractODEProblem end | ||
|
||
struct ODEProblem{F,T,U,P} <: AbstractODEProblem | ||
f::F | ||
tspan::T | ||
u0::U | ||
θ::P | ||
end | ||
|
||
abstract type ODESolver end | ||
struct Euler{T} <: ODESolver | ||
dt::T | ||
end | ||
struct RK2{T} <: ODESolver | ||
dt::T | ||
end | ||
|
||
function f(x,θ) | ||
α, β, γ, δ = θ | ||
x₁, x₂ = x | ||
|
||
dx₁ = α*x₁ - β*x₁*x₂ | ||
dx₂ = δ*x₁*x₂ - γ*x₂ | ||
|
||
[dx₁, dx₂] | ||
end | ||
|
||
function solve(prob::AbstractODEProblem, solver::ODESolver) | ||
t = prob.tspan[1]; u = prob.u0 | ||
us = [u]; ts = [t] | ||
while t < prob.tspan[2] | ||
(u,t) = solver(prob, u, t) | ||
push!(us,u) | ||
push!(ts,t) | ||
end | ||
ts, reduce(hcat,us) | ||
end | ||
|
||
function (solver::Euler)(prob::ODEProblem, u, t) | ||
f, θ, dt = prob.f, prob.θ, solver.dt | ||
(u + dt*f(u,θ), t+dt) | ||
end | ||
|
||
function (solver::RK2)(prob::ODEProblem, u, t) | ||
f, θ, dt = prob.f, prob.θ, solver.dt | ||
uh = u + f(u,θ)*dt | ||
u + dt/2*(f(u,θ) + f(uh,θ)), t+dt | ||
end | ||
|
||
|
||
@recipe function plot(ts::AbstractVector, xs::AbstractVector{<:GaussNum}) | ||
# you can set a default value for an attribute with `-->` | ||
# and force an argument with `:=` | ||
μs = [x.μ for x in xs] | ||
σs = [x.σ for x in xs] | ||
@series begin | ||
:seriestype := :path | ||
# ignore series in legend and color cycling | ||
primary := false | ||
linecolor := nothing | ||
fillcolor := :lightgray | ||
fillalpha := 0.5 | ||
fillrange := μs .- σs | ||
# ensure no markers are shown for the error band | ||
markershape := :none | ||
# return series data | ||
ts, μs .+ σs | ||
end | ||
ts, μs | ||
end | ||
|
||
θ = [0.1,0.2,0.3,0.2] | ||
u0 = [GaussNum(1.0,0.1),GaussNum(1.0,0.1)] | ||
tspan = (0.,100.) | ||
dt = 0.1 | ||
prob = ODEProblem(f,tspan,u0,θ) | ||
|
||
t,X=solve(prob, RK2(0.2)) | ||
p1 = plot(t, X[1,:], label="x", lw=3) | ||
plot!(p1, t, X[2,:], label="y", lw=3) | ||
|
||
display(p1) |
Oops, something went wrong.