From ba44c2340cadf18d42eddaacff5d65b15ed87e49 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 27 Nov 2024 01:27:46 +0800 Subject: [PATCH] combining trail to represent delimited continuation --- src/main/scala/wasm/MiniWasmFX.scala | 163 +++++++++++++---------- src/main/scala/wasm/MiniWasmScript.scala | 4 +- src/test/scala/genwasym/TestFx.scala | 2 +- 3 files changed, 93 insertions(+), 76 deletions(-) diff --git a/src/main/scala/wasm/MiniWasmFX.scala b/src/main/scala/wasm/MiniWasmFX.scala index bd3de9c8..8746804d 100644 --- a/src/main/scala/wasm/MiniWasmFX.scala +++ b/src/main/scala/wasm/MiniWasmFX.scala @@ -12,13 +12,28 @@ case class EvaluatorFX(module: ModuleInstance) { import Primtives._ implicit val m: ModuleInstance = module + trait ContTrait[A] { + def apply(stack: Stack, trail1: List[ContTrait[A]], mcont: MCont[A]): A + } + type Stack = List[Value] - type Cont[A] = (Stack, MCont[A]) => A + type Cont[A] = ContTrait[A] type MCont[A] = Stack => A type Handler[A] = Stack => A + def init[Ans](s: Stack, trail1: List[Cont[Ans]], mkont: MCont[Ans]): Ans = { + trail1 match { + case k1 :: trail1 => k1(s, trail1, mkont) + case Nil => mkont(s) + } + } + + def +[Ans](k1: Cont[Ans], k2: Cont[Ans]): Cont[Ans] = { + (s, trail1, mkont) => k1(s, k2 :: trail1, mkont) + } + // Only used for resumable try-catch (need refactoring): - case class TCContV[A](k: (Stack, Cont[A], MCont[A]) => A) extends Value { + case class TCContV[A](k: (Stack, Cont[A], List[Cont[A]], MCont[A]) => A) extends Value { def tipe(implicit m: ModuleInstance): ValueType = ??? } @@ -27,8 +42,9 @@ case class EvaluatorFX(module: ModuleInstance) { stack: List[Value], frame: Frame, kont: Cont[Ans], + trail1: List[Cont[Ans]], mkont: MCont[Ans], - trail: List[Cont[Ans]], + trail2: List[Cont[Ans]], h: Handler[Ans], isTail: Boolean): Ans = { module.funcs(funcIndex) match { @@ -39,24 +55,24 @@ case class EvaluatorFX(module: ModuleInstance) { val newFrame = Frame(ArrayBuffer(frameLocals: _*)) if (isTail) // when tail call, share the continuation for returning with the callee - eval(body, List(), newFrame, kont, mkont, List(kont), h) + eval(body, List(), newFrame, kont, trail1, mkont, List(kont), h) else { - val restK: Cont[Ans] = (retStack, mkont) => - eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, mkont, trail, h) - // We make a new trail by `restK`, since function creates a new block to escape + val restK: Cont[Ans] = (retStack, trail1, mkont) => + eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail1, mkont, trail2, h) + // We make a new trail2 by `restK`, since function creates a new block to escape // (more or less like `return`) - eval(body, List(), newFrame, restK, mkont, List(restK), h) + eval(body, List(), newFrame, restK, trail1, mkont, List(restK), h) } case Import("console", "log", _) => // println(s"[DEBUG] current stack: $stack") val I32V(v) :: newStack = stack println(v) - eval(rest, newStack, frame, kont, mkont, trail, h) + eval(rest, newStack, frame, kont, trail1, mkont, trail2, h) case Import("spectest", "print_i32", _) => // println(s"[DEBUG] current stack: $stack") val I32V(v) :: newStack = stack println(v) - eval(rest, newStack, frame, kont, mkont, trail, h) + eval(rest, newStack, frame, kont, trail1, mkont, trail2, h) case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") case _ => throw new Exception(s"Definition at $funcIndex is not callable") } @@ -66,10 +82,11 @@ case class EvaluatorFX(module: ModuleInstance) { stack: List[Value], frame: Frame, kont: Cont[Ans], + trail1: List[Cont[Ans]], mkont: MCont[Ans], - trail: List[Cont[Ans]], + trail2: List[Cont[Ans]], h: Handler[Ans]): Ans = { - if (insts.isEmpty) return kont(stack, mkont) + if (insts.isEmpty) return kont(stack, trail1, mkont) val inst = insts.head val rest = insts.tail @@ -78,23 +95,23 @@ case class EvaluatorFX(module: ModuleInstance) { // println(s"inst: ${inst} \t | ${frame.locals} | ${stack.reverse}" ) inst match { - case Drop => eval(rest, stack.tail, frame, kont, mkont, trail, h) + case Drop => eval(rest, stack.tail, frame, kont, trail1, mkont, trail2, h) case Select(_) => val I32V(cond) :: v2 :: v1 :: newStack = stack val value = if (cond == 0) v1 else v2 - eval(rest, value :: newStack, frame, kont, mkont, trail, h) + eval(rest, value :: newStack, frame, kont, trail1, mkont, trail2, h) case LocalGet(i) => - eval(rest, frame.locals(i) :: stack, frame, kont, mkont, trail, h) + eval(rest, frame.locals(i) :: stack, frame, kont, trail1, mkont, trail2, h) case LocalSet(i) => val value :: newStack = stack frame.locals(i) = value - eval(rest, newStack, frame, kont, mkont, trail, h) + eval(rest, newStack, frame, kont, trail1, mkont, trail2, h) case LocalTee(i) => val value :: newStack = stack frame.locals(i) = value - eval(rest, stack, frame, kont, mkont, trail, h) + eval(rest, stack, frame, kont, trail1, mkont, trail2, h) case GlobalGet(i) => - eval(rest, module.globals(i).value :: stack, frame, kont, mkont, trail, h) + eval(rest, module.globals(i).value :: stack, frame, kont, trail1, mkont, trail2, h) case GlobalSet(i) => val value :: newStack = stack module.globals(i).ty match { @@ -103,18 +120,18 @@ case class EvaluatorFX(module: ModuleInstance) { case GlobalType(_, true) => throw new Exception("Invalid type") case _ => throw new Exception("Cannot set immutable global") } - eval(rest, newStack, frame, kont, mkont, trail, h) + eval(rest, newStack, frame, kont, trail1, mkont, trail2, h) case MemorySize => - eval(rest, I32V(module.memory.head.size) :: stack, frame, kont, mkont, trail, h) + eval(rest, I32V(module.memory.head.size) :: stack, frame, kont, trail1, mkont, trail2, h) case MemoryGrow => val I32V(delta) :: newStack = stack val mem = module.memory.head val oldSize = mem.size mem.grow(delta) match { case Some(e) => - eval(rest, I32V(-1) :: newStack, frame, kont, mkont, trail, h) + eval(rest, I32V(-1) :: newStack, frame, kont, trail1, mkont, trail2, h) case _ => - eval(rest, I32V(oldSize) :: newStack, frame, kont, mkont, trail, h) + eval(rest, I32V(oldSize) :: newStack, frame, kont, trail1, mkont, trail2, h) } case MemoryFill => val I32V(value) :: I32V(offset) :: I32V(size) :: newStack = stack @@ -122,7 +139,7 @@ case class EvaluatorFX(module: ModuleInstance) { throw new Exception("Out of bounds memory access") // GW: turn this into a `trap`? else { module.memory.head.fill(offset, size, value.toByte) - eval(rest, newStack, frame, kont, mkont, trail, h) + eval(rest, newStack, frame, kont, trail1, mkont, trail2, h) } case MemoryCopy => val I32V(n) :: I32V(src) :: I32V(dest) :: newStack = stack @@ -130,100 +147,101 @@ case class EvaluatorFX(module: ModuleInstance) { throw new Exception("Out of bounds memory access") else { module.memory.head.copy(dest, src, n) - eval(rest, newStack, frame, kont, mkont, trail, h) + eval(rest, newStack, frame, kont, trail1, mkont, trail2, h) } - case Const(n) => eval(rest, n :: stack, frame, kont, mkont, trail, h) + case Const(n) => eval(rest, n :: stack, frame, kont, trail1, mkont, trail2, h) case Binary(op) => val v2 :: v1 :: newStack = stack - eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, mkont, trail, h) + eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, trail1, mkont, trail2, h) case Unary(op) => val v :: newStack = stack - eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, mkont, trail, h) + eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, trail1, mkont, trail2, h) case Compare(op) => val v2 :: v1 :: newStack = stack - eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, mkont, trail, h) + eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, trail1, mkont, trail2, h) case Test(op) => val v :: newStack = stack - eval(rest, evalTestOp(op, v) :: newStack, frame, kont, mkont, trail, h) + eval(rest, evalTestOp(op, v) :: newStack, frame, kont, trail1, mkont, trail2, h) case Store(StoreOp(align, offset, ty, None)) => val I32V(v) :: I32V(addr) :: newStack = stack module.memory(0).storeInt(addr + offset, v) - eval(rest, newStack, frame, kont, mkont, trail, h) + eval(rest, newStack, frame, kont, trail1, mkont, trail2, h) case Load(LoadOp(align, offset, ty, None, None)) => val I32V(addr) :: newStack = stack val value = module.memory(0).loadInt(addr + offset) - eval(rest, I32V(value) :: newStack, frame, kont, mkont, trail, h) + eval(rest, I32V(value) :: newStack, frame, kont, trail1, mkont, trail2, h) case Nop => - eval(rest, stack, frame, kont, mkont, trail, h) + eval(rest, stack, frame, kont, trail1, mkont, trail2, h) case Unreachable => throw Trap() case Block(ty, inner) => val funcTy = getFuncType(ty) val (inputs, restStack) = stack.splitAt(funcTy.inps.size) - val restK: Cont[Ans] = (retStack, mkont1) => { + val restK: Cont[Ans] = (retStack, trail1, mkont1) => { // kont -> mkont -> mkont1 - eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, mkont1, trail, h) + eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail1, mkont1, trail2, h) } - eval(inner, inputs, frame, restK, mkont, restK :: trail, h) + eval(inner, inputs, frame, restK, trail1, mkont, restK :: trail2, h) case Loop(ty, inner) => // We construct two continuations, one for the break (to the begining of the loop), // and one for fall-through to the next instruction following the syntactic structure // of the program. val funcTy = getFuncType(ty) val (inputs, restStack) = stack.splitAt(funcTy.inps.size) - val restK: Cont[Ans] = (retStack, mkont) => - eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, mkont, trail, h) - def loop(retStack: List[Value], mkont: MCont[Ans]): Ans = - eval(inner, retStack.take(funcTy.inps.size), frame, restK, mkont, loop _ :: trail, h) - loop(inputs, mkont) + val restK: Cont[Ans] = (retStack, trail1, mkont) => + eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail1, mkont, trail2, h) + def loop(retStack: List[Value], trail1: List[Cont[Ans]], mkont: MCont[Ans]): Ans = + + eval(inner, retStack.take(funcTy.inps.size), frame, restK, trail1, mkont, (loop _ : Cont[Ans]):: trail2, h) + loop(inputs, trail1, mkont) case If(ty, thn, els) => val funcTy = getFuncType(ty) val I32V(cond) :: newStack = stack val inner = if (cond != 0) thn else els val (inputs, restStack) = newStack.splitAt(funcTy.inps.size) - val restK: Cont[Ans] = (retStack, mkont) => - eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, mkont, trail, h) - eval(inner, inputs, frame, restK, mkont, restK :: trail, h) + val restK: Cont[Ans] = (retStack, trail1, mkont) => + eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail1, mkont, trail2, h) + eval(inner, inputs, frame, restK, trail1, mkont, restK :: trail2, h) case Br(label) => - trail(label)(stack, mkont) // s => ().asInstanceOf[Ans]) //mkont) + trail2(label)(stack, trail1, mkont) // s => ().asInstanceOf[Ans]) //mkont) case BrIf(label) => val I32V(cond) :: newStack = stack - if (cond != 0) trail(label)(newStack, mkont) - else eval(rest, newStack, frame, kont, mkont, trail, h) + if (cond != 0) trail2(label)(newStack, trail1, mkont) + else eval(rest, newStack, frame, kont, trail1, mkont, trail2, h) case BrTable(labels, default) => val I32V(cond) :: newStack = stack val goto = if (cond < labels.length) labels(cond) else default - trail(goto)(newStack, mkont) - case Return => trail.last(stack, mkont) - case Call(f) => evalCall(f, rest, stack, frame, kont, mkont, trail, h, false) - case ReturnCall(f) => evalCall(f, rest, stack, frame, kont, mkont, trail, h, true) + trail2(goto)(newStack, trail1, mkont) + case Return => trail2.last(stack, trail1, mkont) + case Call(f) => evalCall(f, rest, stack, frame, kont, trail1, mkont, trail2, h, false) + case ReturnCall(f) => evalCall(f, rest, stack, frame, kont, trail1, mkont, trail2, h, true) case RefFunc(f) => // TODO: RefFuncV stores an applicable function, instead of a syntactic structure - eval(rest, RefFuncV(f) :: stack, frame, kont, mkont, trail, h) + eval(rest, RefFuncV(f) :: stack, frame, kont, trail1, mkont, trail2, h) case CallRef(ty) => val RefFuncV(f) :: newStack = stack - evalCall(f, rest, newStack, frame, kont, mkont, trail, h, false) + evalCall(f, rest, newStack, frame, kont, trail1, mkont, trail2, h, false) // resumable try-catch exception handling: case TryCatch(es1, es2) => - val join: MCont[Ans] = (newStack) => eval(rest, stack, frame, kont, mkont, trail, h) + // push trail1 to join point + val join: MCont[Ans] = (newStack) => eval(rest, stack, frame, kont, trail1, mkont, trail2, h) // the `restK` for catch block (es2) is the join point // the restK simply applies the meta-continuation, this is the same the [nil] case // where we fall back to join point - val idK: Cont[Ans] = (s, m) => m(s) - val newHandler: Handler[Ans] = (newStack) => eval(es2, newStack, frame, idK, join, trail, h) - eval(es1, List(), frame, idK, join, trail, newHandler) + // val idK: Cont[Ans] = (s, m) => m(s) + val newHandler: Handler[Ans] = (newStack) => eval(es2, newStack, frame, init: Cont[Ans], List(), join, trail2, h) + eval(es1, List(), frame, init: Cont[Ans], List(), join, trail2, newHandler) case Resume0() => val (resume: TCContV[Ans]) :: newStack = stack - val k: Cont[Ans] = (s, m) => eval(rest, newStack /*!*/, frame, kont, m, trail, h) - resume.k(List(), k, mkont) + val k: Cont[Ans] = (s, trail1, m) => eval(rest, newStack /*!*/, frame, kont, trail1, m, trail2, h) + resume.k(List(), k, trail1, mkont) case Throw() => val err :: newStack = stack - // kont composed with k + // kont composed with k1 and trail1 // note that kr doesn't use the stack at all // it only takes the err value - def kr(s: Stack, k1: Cont[Ans], m: MCont[Ans]): Ans = { - val kontK: Cont[Ans] = (s1, m1) => kont(s1, s2 => k1(s2, m1)) - eval(rest, newStack /*!*/, frame, kontK, m /*vs mkont?*/, trail, h) + def kr(s: Stack, k1: Cont[Ans], newTrail1: List[Cont[Ans]], m: MCont[Ans]): Ans = { + eval(rest, newStack /*!*/, frame, kont, trail1 ++ List(k1) ++ newTrail1, m /*vs mkont?*/, trail2, h) } h(List(err, TCContV(kr))) @@ -232,15 +250,14 @@ case class EvaluatorFX(module: ModuleInstance) { val RefFuncV(f) :: newStack = stack // should be similar to the contiuantion thrown by `throw` - // val k: Cont[Ans] = (s, mk) => evalCall(f, List(), s, frame, idK, mk, trail, h, false) + // val k: Cont[Ans] = (s, mk) => evalCall(f, List(), s, frame, idK, mk, trail2, h, false) // TODO: where should kont go? // TODO: this implementation is not right - def kr(s: Stack, k1: Cont[Ans], mk: MCont[Ans]): Ans = { - val kontK: Cont[Ans] = (s1, m1) => kont(s1, s2 => k1(s2, m1)) - evalCall(f, List(), s, frame, kontK, mk, trail, h, false) + def kr(s: Stack, k1: Cont[Ans], trail1: List[Cont[Ans]], mk: MCont[Ans]): Ans = { + evalCall(f, List(), s, frame, k1, trail1, mk, trail2, h, false) } - eval(rest, TCContV(kr) :: newStack, frame, kont, mkont, trail, h) + eval(rest, TCContV(kr) :: newStack, frame, kont, trail1, mkont, trail2, h) // TODO: implement the following case Suspend(tag_id) => { // println(s"${RED}Unimplimented Suspending tag $tag_id") @@ -258,8 +275,8 @@ case class EvaluatorFX(module: ModuleInstance) { val (inputs, restStack) = newStack.splitAt(inps.size) if (handler.length == 0) { - val k: Cont[Ans] = (s, m) => eval(rest, newStack, frame, kont, m, trail, h) - f.k(inputs, k, mkont) + val k: Cont[Ans] = (s, trail1, m) => eval(rest, newStack, frame, kont, trail1, m, trail2, h) + f.k(inputs, k, List(), mkont) } else { // TODO: attempt single tag first throw new Exception("tags not supported") @@ -283,11 +300,11 @@ case class EvaluatorFX(module: ModuleInstance) { case CallRef(ty) => val RefFuncV(f) :: newStack = stack - evalCall(f, rest, newStack, frame, kont, mkont, trail, h, false) + evalCall(f, rest, newStack, frame, kont, trail1, mkont, trail2, h, false) case CallRef(ty) => val RefFuncV(f) :: newStack = stack - evalCall(f, rest, newStack, frame, kont, mkont, trail, h, false) + evalCall(f, rest, newStack, frame, kont, trail1, mkont, trail2, h, false) case _ => println(inst) @@ -323,11 +340,11 @@ case class EvaluatorFX(module: ModuleInstance) { } if (instrs.isEmpty) println("Warning: nothing is executed") val handler0: Handler[Ans] = stack => throw new Exception(s"Uncaught exception: $stack") - eval(instrs, List(), Frame(ArrayBuffer(I32V(0))), halt, mhalt, List(halt), handler0) + eval(instrs, List(), Frame(ArrayBuffer(I32V(0))), halt, List(), mhalt, List(halt), handler0) } def evalTop(m: ModuleInstance): Unit = { - val halt: Cont[Unit] = (stack, m) => m(stack) + val halt: Cont[Unit] = init evalTop(halt, stack => ()) } } diff --git a/src/main/scala/wasm/MiniWasmScript.scala b/src/main/scala/wasm/MiniWasmScript.scala index 01d3a469..45bd3aec 100644 --- a/src/main/scala/wasm/MiniWasmScript.scala +++ b/src/main/scala/wasm/MiniWasmScript.scala @@ -31,11 +31,11 @@ sealed class ScriptRunner { type Cont = evaluator.Cont[evaluator.Stack] type MCont = evaluator.MCont[evaluator.Stack] type Handler = evaluator.Handler[evaluator.Stack] - val k: Cont = (retStack, m) => m(retStack) + val k: Cont = evaluator.init; val mk: MCont = (retStack) => retStack val h0: Handler = stack => throw new Exception(s"Uncaught exception: $stack") // TODO: change this back to Evaluator if we are just testing original stuff - val actual = evaluator.eval(instrs, List(), Frame(ArrayBuffer(args: _*)), k, mk, List(k), h0) + val actual = evaluator.eval(instrs, List(), Frame(ArrayBuffer(args: _*)), k, List(), mk, List(k), h0) assert(actual == expect) } } diff --git a/src/test/scala/genwasym/TestFx.scala b/src/test/scala/genwasym/TestFx.scala index 5800475c..73a9d513 100644 --- a/src/test/scala/genwasym/TestFx.scala +++ b/src/test/scala/genwasym/TestFx.scala @@ -28,7 +28,7 @@ class TestFx extends FunSuite { val evaluator = EvaluatorFX(ModuleInstance(module)) type Cont = evaluator.Cont[Unit] type MCont = evaluator.MCont[Unit] - val haltK: Cont = (stack, m) => m(stack) + val haltK: Cont = evaluator.init; val haltMK: MCont = (stack) => { //println(s"halt cont: $stack") expected match {