From 7258994968a416b0e74686486d182b1f7f874269 Mon Sep 17 00:00:00 2001 From: Steve Buzzard Date: Thu, 16 May 2019 18:54:08 -0400 Subject: [PATCH] =?UTF-8?q?Added=20further=20hackery=20to=20dtrace's=20ina?= =?UTF-8?q?dequate=20stack=20safety=20mechanism=20s=E2=80=A6=20(#79)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added further hackery to dtrace's inadequate stack safety mechanism since manufactored applicative type of parallel isn't a monad and thus can't directly use our original hack. Fixes immediate stack overflow issue with parTraverse. Opened an issue to fix properly but that'll take longer/need some more thought. Fixes issue #77 * bump up the soe test iterations from 100000 to 500000 --- .../cedi/dtrace/TypeclassLawTests.scala | 23 +++++++++++ .../com/ccadllc/cedi/dtrace/TraceT.scala | 40 +++++++++++++------ .../com/ccadllc/cedi/dtrace/dtrace.scala | 12 ++++++ 3 files changed, 63 insertions(+), 12 deletions(-) diff --git a/core/jvm/src/test/scala/com/ccadllc/cedi/dtrace/TypeclassLawTests.scala b/core/jvm/src/test/scala/com/ccadllc/cedi/dtrace/TypeclassLawTests.scala index 1c728ba..f6f121c 100644 --- a/core/jvm/src/test/scala/com/ccadllc/cedi/dtrace/TypeclassLawTests.scala +++ b/core/jvm/src/test/scala/com/ccadllc/cedi/dtrace/TypeclassLawTests.scala @@ -35,6 +35,8 @@ import java.nio.charset.Charset import org.scalactic.source import org.scalacheck._ import org.scalatest.prop.Checkers +import org.scalatest.OptionValues._ +import org.scalatest.TryValues._ import org.scalatest.{ FunSuite, Matchers, Tag } import org.typelevel.discipline.Laws @@ -177,6 +179,27 @@ class TypeclassLawTests extends FunSuite with Matchers with Checkers with Discip () } + testAsync("parMap2 should be stack safe") { testC => + implicit val cs = testC.contextShift[IO] + val count = 500000 + val tasks = (0 until count).map(_ => TraceIO(1)) + val sum = tasks.foldLeft(TraceIO(0))((acc, t) => (acc, t).parMapN(_ + _)) + val f = sum.trace(tc).unsafeToFuture() + testC.tick() + f.value shouldBe Some(Success(count)) + () + } + + testAsync("parTraverse should be stack safe") { testC => + implicit val cs = testC.contextShift[IO] + val count = 500000 + val numbers = (0 until count).toVector + val f = numbers.parTraverse(i => TraceIO.pure(i + 1)).trace(tc).unsafeToFuture() + testC.tick() + f.value.value.success.value.sum shouldBe numbers.map(_ + 1).sum + () + } + private def checkAllAsync(name: String, f: TestContext => Laws#RuleSet): Unit = { val testC = TestContext() val ruleSet = f(testC) diff --git a/core/shared/src/main/scala/com/ccadllc/cedi/dtrace/TraceT.scala b/core/shared/src/main/scala/com/ccadllc/cedi/dtrace/TraceT.scala index b5051f0..cb1deb0 100644 --- a/core/shared/src/main/scala/com/ccadllc/cedi/dtrace/TraceT.scala +++ b/core/shared/src/main/scala/com/ccadllc/cedi/dtrace/TraceT.scala @@ -728,13 +728,21 @@ private[dtrace] sealed trait TraceTParallelInstance extends TraceTNonEmptyParall */ protected class ParallelTraceT[M[_], F[_]](implicit P: Parallel[M, F], M: Monad[M], F: Applicative[F]) extends NonEmptyParallelTraceT[M, F] with Parallel[TraceT[M, ?], TraceT[F, ?]] { override def applicative: Applicative[TraceT[F, ?]] = new Applicative[TraceT[F, ?]] { - override def map[A, B](ta: TraceT[F, A])(f: A => B): TraceT[F, B] = TraceT.suspendEffect { tc => - P.applicative.map(ta.toEffect(tc))(f) - } + + override def ap[A, B](tab: TraceT[F, A => B])(ta: TraceT[F, A]): TraceT[F, B] = + map2(tab, ta)(_(_)) + + override def map[A, B](ta: TraceT[F, A])(f: A => B): TraceT[F, B] = + parAppSuspendEffect(tc => P.applicative.map(ta.toEffect(tc))(f)) + + override def map2[A, B, Z](ta: TraceT[F, A], tb: TraceT[F, B])(f: (A, B) => Z): TraceT[F, Z] = + parAppSuspendEffect(tc => P.applicative.map2(ta.toEffect(tc), tb.toEffect(tc))(f)) + + override def product[A, B](ta: TraceT[F, A], tab: TraceT[F, B]): TraceT[F, (A, B)] = + map2(ta, tab)((_, _)) + override def pure[A](a: A): TraceT[F, A] = TraceT.toTraceT(P.applicative.pure(a)) - override def ap[A, B](tab: TraceT[F, A => B])(ta: TraceT[F, A]): TraceT[F, B] = TraceT.suspendEffect { tc => - P.applicative.ap(tab.toEffect(tc))(ta.toEffect(tc)) - } + override def toString: String = "ParApplicative[TraceT[F, ?]]" } override def monad: Monad[TraceT[M, ?]] = new Monad[TraceT[M, ?]] { @@ -770,12 +778,18 @@ private[dtrace] sealed trait TraceTNonEmptyParallelInstance { */ protected class NonEmptyParallelTraceT[M[_], F[_]](implicit P: NonEmptyParallel[M, F], M: FlatMap[M], F: Apply[F]) extends NonEmptyParallel[TraceT[M, ?], TraceT[F, ?]] { def apply: Apply[TraceT[F, ?]] = new Apply[TraceT[F, ?]] { - override def ap[A, B](tab: TraceT[F, A => B])(ta: TraceT[F, A]): TraceT[F, B] = TraceT.suspendEffect { tc => - P.apply.ap(tab.toEffect(tc))(ta.toEffect(tc)) - } - override def map[A, B](ta: TraceT[F, A])(f: A => B): TraceT[F, B] = TraceT.suspendEffect { tc => - P.apply.map(ta.toEffect(tc))(f) - } + override def ap[A, B](tab: TraceT[F, A => B])(ta: TraceT[F, A]): TraceT[F, B] = + map2(tab, ta)(_(_)) + + override def map[A, B](ta: TraceT[F, A])(f: A => B): TraceT[F, B] = + parAppSuspendEffect(tc => P.apply.map(ta.toEffect(tc))(f)) + + override def map2[A, B, Z](ta: TraceT[F, A], tb: TraceT[F, B])(f: (A, B) => Z): TraceT[F, Z] = + parAppSuspendEffect(tc => P.apply.map2(ta.toEffect(tc), tb.toEffect(tc))(f)) + + override def product[A, B](ta: TraceT[F, A], tab: TraceT[F, B]): TraceT[F, (A, B)] = + map2(ta, tab)((_, _)) + override def toString: String = "ParApply[TraceT[F, ?]]" } def flatMap: FlatMap[TraceT[M, ?]] = new FlatMap[TraceT[M, ?]] { @@ -803,5 +817,7 @@ private[dtrace] sealed trait TraceTNonEmptyParallelInstance { P.parallel(tma.toEffect(translate(tc, P.sequential))) } } + protected def parAppSuspendEffect[A](action: TraceContext[F] => F[A]): TraceT[F, A] = + parallel(TraceT.suspendEffect[M, A](tcm => P.sequential(action(translate(tcm, P.parallel))))) } } diff --git a/core/shared/src/main/scala/com/ccadllc/cedi/dtrace/dtrace.scala b/core/shared/src/main/scala/com/ccadllc/cedi/dtrace/dtrace.scala index 4df0ed2..517a04a 100644 --- a/core/shared/src/main/scala/com/ccadllc/cedi/dtrace/dtrace.scala +++ b/core/shared/src/main/scala/com/ccadllc/cedi/dtrace/dtrace.scala @@ -53,6 +53,18 @@ package object dtrace { IO.Par.unwrap(tiop.toEffect(translate(tc, P.parallel))) } } + /** + * Creates a simple, noncancelable `TraceIO[A]` instance that + * executes an asynchronous process on evaluation. + * + * The given function is being injected with a side-effectful + * callback for signaling the final result of an asynchronous + * process. + * + * @param k is a function that should be called with a + * callback for signaling the result once it is ready + */ + def async[A](cb: (Either[Throwable, A] => Unit) => Unit): TraceIO[A] = toTraceIO(IO.async(cb)) /** * Ask for the current `TraceContext[IO]` in a `TraceIO`.