From 5c767751754222462078605ebb20b56422fe542f Mon Sep 17 00:00:00 2001 From: Christoph Bunte Date: Tue, 5 Mar 2024 15:55:21 +0100 Subject: [PATCH] lv fit: add dynamic #knots selection --- .../scala/derifree/fd/LocalVolFitter.scala | 38 ++++++++++++++++--- .../src/test/scala/derifree/fd/lvfit.scala | 4 +- .../src/test/scala/derifree/fd/lvfit2.scala | 9 ++++- 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/derifree/src/main/scala/derifree/fd/LocalVolFitter.scala b/derifree/src/main/scala/derifree/fd/LocalVolFitter.scala index 0a58b64..c81abd1 100644 --- a/derifree/src/main/scala/derifree/fd/LocalVolFitter.scala +++ b/derifree/src/main/scala/derifree/fd/LocalVolFitter.scala @@ -32,6 +32,7 @@ import scala.math.exp import scala.math.log import scala.math.max import scala.math.min +import scala.math.round import scala.math.pow import scala.math.sqrt @@ -104,9 +105,9 @@ sealed trait LocalVolFitter: object LocalVolFitter: case class Settings( - nKnots: Int, - minLv: Double, - maxLv: Double, + minLv: Double = 0.01, + maxLv: Double = 5.0, + nKnots: Settings.Knots = Settings.Knots.Fixed(5), spatialGrid: Settings.SpatialGrid = Settings.SpatialGrid(), timeGrid: Settings.TimeGrid = Settings.TimeGrid(), nRannacherSteps: Int = 2 @@ -132,6 +133,10 @@ object LocalVolFitter: expansionMaxIters: Int = 10 ) + enum Knots: + case Fixed(n: Int) + case Dynamic(min: Int, max: Int, alpha: Double, minStrikes: Int, minSpread: Double) + case class Result( /** The fitted expiries in year fraction terms. By convention the first element is always * zero. @@ -408,18 +413,39 @@ object LocalVolFitter: val timegrid = timeGridFactory(expiries.toSet) + val nKnotsByExpiry = settings.nKnots match + case Settings.Knots.Fixed(n) => obsByExpiry.map(_ => n) + case Settings.Knots.Dynamic(minKnots, maxKnots, alpha, minStrikes, minSpread) => + obsByExpiry.map: (_, obs) => + val nStrikes = obs.map(_.strike).toSet.size + val nObs = obs.length + val medianSpread = obs.map(_.spread).sorted.apply(nObs / 2) + round( + min( + maxKnots, + max( + minKnots, + alpha * log(1.0 + max(nStrikes - minStrikes, 0)) / max( + medianSpread / minSpread, + 1.0 + ) + ) + ) + ).toInt + val state = expiries .zip(obsByExpiry) + .zip(nKnotsByExpiry) .foldLeft(State()): - case (state, (t1, (t2, obs))) => + case (state, ((t1, (t2, obs)), nKnots)) => val timeGridSlice = timegrid.slice(t1, t2).get.yearFractions val minStrike = obs.map(_.strike).min val maxStrike = obs.map(_.strike).max val lvKnots = List - .tabulate(settings.nKnots)(i => - minStrike * pow(maxStrike / minStrike, i.toDouble / (settings.nKnots - 1)) + .tabulate(nKnots)(i => + minStrike * pow(maxStrike / minStrike, i.toDouble / (nKnots - 1)) ) .toIndexedSeq diff --git a/derifree/src/test/scala/derifree/fd/lvfit.scala b/derifree/src/test/scala/derifree/fd/lvfit.scala index 8d54680..1963806 100644 --- a/derifree/src/test/scala/derifree/fd/lvfit.scala +++ b/derifree/src/test/scala/derifree/fd/lvfit.scala @@ -56,7 +56,7 @@ class LVFitSuite extends munit.FunSuite: PureObservation(1.10, t, vol, 0.01) ) ) - val settings = Settings(3, 0.01, 1.0) + val settings = Settings(0.01, 1.0, Settings.Knots.Fixed(3)) val result = fitter.fitPureObservations(obs, settings).toTry.get val clue = @@ -95,7 +95,7 @@ class LVFitSuite extends munit.FunSuite: ) ) - val settings = Settings(3, 0.01, 1.0) + val settings = Settings(0.01, 1.0, Settings.Knots.Fixed(3)) val result = fitter.fitPureObservations(obs, settings).toTry.get diff --git a/derifree/src/test/scala/derifree/fd/lvfit2.scala b/derifree/src/test/scala/derifree/fd/lvfit2.scala index d913522..13e3d55 100644 --- a/derifree/src/test/scala/derifree/fd/lvfit2.scala +++ b/derifree/src/test/scala/derifree/fd/lvfit2.scala @@ -21,6 +21,7 @@ import cats.effect.IO import cats.syntax.all.* import derifree.dtos.etd.options.OptionQuote import derifree.dtos.etd.options.Snapshot +import derifree.fd.LocalVolFitter.Settings import derifree.testutils.* import scala.concurrent.duration.* @@ -51,7 +52,11 @@ class LVFitSuite2 extends munit.CatsEffectSuite: ) val lvFitter = LocalVolFitter.apply - val lvSettings = LocalVolFitter.Settings(5, 0.01, 3.0) + val lvSettings = LocalVolFitter.Settings( + 0.01, + 3.0, + Settings.Knots.Dynamic(3, 12, 5.0, 3, 0.001) + ) val initialBorrow = YieldCurve.zero[YearFraction] val divs = Nil @@ -109,6 +114,6 @@ class LVFitSuite2 extends munit.CatsEffectSuite: YearFraction.oneDay * 7 ) ) - _ <- IO.println(lv) + _ <- IO.println(s"#knots = ${lv.lvKnots.map(_.length).mkString(", ")}") yield () )