Skip to content

Commit

Permalink
lv fit: add dynamic #knots selection
Browse files Browse the repository at this point in the history
  • Loading branch information
buntec committed Mar 5, 2024
1 parent a7afd2d commit 5c76775
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 10 deletions.
38 changes: 32 additions & 6 deletions derifree/src/main/scala/derifree/fd/LocalVolFitter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import scala.math.exp
import scala.math.log
import scala.math.max
import scala.math.min
import scala.math.round
import scala.math.pow
import scala.math.sqrt

Expand Down Expand Up @@ -104,9 +105,9 @@ sealed trait LocalVolFitter:
object LocalVolFitter:

case class Settings(
nKnots: Int,
minLv: Double,
maxLv: Double,
minLv: Double = 0.01,
maxLv: Double = 5.0,
nKnots: Settings.Knots = Settings.Knots.Fixed(5),
spatialGrid: Settings.SpatialGrid = Settings.SpatialGrid(),
timeGrid: Settings.TimeGrid = Settings.TimeGrid(),
nRannacherSteps: Int = 2
Expand All @@ -132,6 +133,10 @@ object LocalVolFitter:
expansionMaxIters: Int = 10
)

enum Knots:
case Fixed(n: Int)
case Dynamic(min: Int, max: Int, alpha: Double, minStrikes: Int, minSpread: Double)

case class Result(
/** The fitted expiries in year fraction terms. By convention the first element is always
* zero.
Expand Down Expand Up @@ -408,18 +413,39 @@ object LocalVolFitter:

val timegrid = timeGridFactory(expiries.toSet)

val nKnotsByExpiry = settings.nKnots match
case Settings.Knots.Fixed(n) => obsByExpiry.map(_ => n)
case Settings.Knots.Dynamic(minKnots, maxKnots, alpha, minStrikes, minSpread) =>
obsByExpiry.map: (_, obs) =>
val nStrikes = obs.map(_.strike).toSet.size
val nObs = obs.length
val medianSpread = obs.map(_.spread).sorted.apply(nObs / 2)
round(
min(
maxKnots,
max(
minKnots,
alpha * log(1.0 + max(nStrikes - minStrikes, 0)) / max(
medianSpread / minSpread,
1.0
)
)
)
).toInt

val state = expiries
.zip(obsByExpiry)
.zip(nKnotsByExpiry)
.foldLeft(State()):
case (state, (t1, (t2, obs))) =>
case (state, ((t1, (t2, obs)), nKnots)) =>
val timeGridSlice = timegrid.slice(t1, t2).get.yearFractions

val minStrike = obs.map(_.strike).min
val maxStrike = obs.map(_.strike).max

val lvKnots = List
.tabulate(settings.nKnots)(i =>
minStrike * pow(maxStrike / minStrike, i.toDouble / (settings.nKnots - 1))
.tabulate(nKnots)(i =>
minStrike * pow(maxStrike / minStrike, i.toDouble / (nKnots - 1))
)
.toIndexedSeq

Expand Down
4 changes: 2 additions & 2 deletions derifree/src/test/scala/derifree/fd/lvfit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class LVFitSuite extends munit.FunSuite:
PureObservation(1.10, t, vol, 0.01)
)
)
val settings = Settings(3, 0.01, 1.0)
val settings = Settings(0.01, 1.0, Settings.Knots.Fixed(3))
val result = fitter.fitPureObservations(obs, settings).toTry.get

val clue =
Expand Down Expand Up @@ -95,7 +95,7 @@ class LVFitSuite extends munit.FunSuite:
)
)

val settings = Settings(3, 0.01, 1.0)
val settings = Settings(0.01, 1.0, Settings.Knots.Fixed(3))

val result = fitter.fitPureObservations(obs, settings).toTry.get

Expand Down
9 changes: 7 additions & 2 deletions derifree/src/test/scala/derifree/fd/lvfit2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import cats.effect.IO
import cats.syntax.all.*
import derifree.dtos.etd.options.OptionQuote
import derifree.dtos.etd.options.Snapshot
import derifree.fd.LocalVolFitter.Settings
import derifree.testutils.*

import scala.concurrent.duration.*
Expand Down Expand Up @@ -51,7 +52,11 @@ class LVFitSuite2 extends munit.CatsEffectSuite:
)

val lvFitter = LocalVolFitter.apply
val lvSettings = LocalVolFitter.Settings(5, 0.01, 3.0)
val lvSettings = LocalVolFitter.Settings(
0.01,
3.0,
Settings.Knots.Dynamic(3, 12, 5.0, 3, 0.001)
)

val initialBorrow = YieldCurve.zero[YearFraction]
val divs = Nil
Expand Down Expand Up @@ -109,6 +114,6 @@ class LVFitSuite2 extends munit.CatsEffectSuite:
YearFraction.oneDay * 7
)
)
_ <- IO.println(lv)
_ <- IO.println(s"#knots = ${lv.lvKnots.map(_.length).mkString(", ")}")
yield ()
)

0 comments on commit 5c76775

Please sign in to comment.