Skip to content

Commit

Permalink
combining trail to represent delimited continuation
Browse files Browse the repository at this point in the history
  • Loading branch information
butterunderflow committed Nov 26, 2024
1 parent 10214c0 commit ba44c23
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 76 deletions.
163 changes: 90 additions & 73 deletions src/main/scala/wasm/MiniWasmFX.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,28 @@ case class EvaluatorFX(module: ModuleInstance) {
import Primtives._
implicit val m: ModuleInstance = module

trait ContTrait[A] {
def apply(stack: Stack, trail1: List[ContTrait[A]], mcont: MCont[A]): A
}

type Stack = List[Value]
type Cont[A] = (Stack, MCont[A]) => A
type Cont[A] = ContTrait[A]
type MCont[A] = Stack => A
type Handler[A] = Stack => A

def init[Ans](s: Stack, trail1: List[Cont[Ans]], mkont: MCont[Ans]): Ans = {
trail1 match {
case k1 :: trail1 => k1(s, trail1, mkont)
case Nil => mkont(s)
}
}

def +[Ans](k1: Cont[Ans], k2: Cont[Ans]): Cont[Ans] = {
(s, trail1, mkont) => k1(s, k2 :: trail1, mkont)
}

// Only used for resumable try-catch (need refactoring):
case class TCContV[A](k: (Stack, Cont[A], MCont[A]) => A) extends Value {
case class TCContV[A](k: (Stack, Cont[A], List[Cont[A]], MCont[A]) => A) extends Value {
def tipe(implicit m: ModuleInstance): ValueType = ???
}

Expand All @@ -27,8 +42,9 @@ case class EvaluatorFX(module: ModuleInstance) {
stack: List[Value],
frame: Frame,
kont: Cont[Ans],
trail1: List[Cont[Ans]],
mkont: MCont[Ans],
trail: List[Cont[Ans]],
trail2: List[Cont[Ans]],
h: Handler[Ans],
isTail: Boolean): Ans = {
module.funcs(funcIndex) match {
Expand All @@ -39,24 +55,24 @@ case class EvaluatorFX(module: ModuleInstance) {
val newFrame = Frame(ArrayBuffer(frameLocals: _*))
if (isTail)
// when tail call, share the continuation for returning with the callee
eval(body, List(), newFrame, kont, mkont, List(kont), h)
eval(body, List(), newFrame, kont, trail1, mkont, List(kont), h)
else {
val restK: Cont[Ans] = (retStack, mkont) =>
eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, mkont, trail, h)
// We make a new trail by `restK`, since function creates a new block to escape
val restK: Cont[Ans] = (retStack, trail1, mkont) =>
eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail1, mkont, trail2, h)
// We make a new trail2 by `restK`, since function creates a new block to escape
// (more or less like `return`)
eval(body, List(), newFrame, restK, mkont, List(restK), h)
eval(body, List(), newFrame, restK, trail1, mkont, List(restK), h)
}
case Import("console", "log", _) =>
// println(s"[DEBUG] current stack: $stack")
val I32V(v) :: newStack = stack
println(v)
eval(rest, newStack, frame, kont, mkont, trail, h)
eval(rest, newStack, frame, kont, trail1, mkont, trail2, h)
case Import("spectest", "print_i32", _) =>
// println(s"[DEBUG] current stack: $stack")
val I32V(v) :: newStack = stack
println(v)
eval(rest, newStack, frame, kont, mkont, trail, h)
eval(rest, newStack, frame, kont, trail1, mkont, trail2, h)
case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex")
case _ => throw new Exception(s"Definition at $funcIndex is not callable")
}
Expand All @@ -66,10 +82,11 @@ case class EvaluatorFX(module: ModuleInstance) {
stack: List[Value],
frame: Frame,
kont: Cont[Ans],
trail1: List[Cont[Ans]],
mkont: MCont[Ans],
trail: List[Cont[Ans]],
trail2: List[Cont[Ans]],
h: Handler[Ans]): Ans = {
if (insts.isEmpty) return kont(stack, mkont)
if (insts.isEmpty) return kont(stack, trail1, mkont)

val inst = insts.head
val rest = insts.tail
Expand All @@ -78,23 +95,23 @@ case class EvaluatorFX(module: ModuleInstance) {
// println(s"inst: ${inst} \t | ${frame.locals} | ${stack.reverse}" )

inst match {
case Drop => eval(rest, stack.tail, frame, kont, mkont, trail, h)
case Drop => eval(rest, stack.tail, frame, kont, trail1, mkont, trail2, h)
case Select(_) =>
val I32V(cond) :: v2 :: v1 :: newStack = stack
val value = if (cond == 0) v1 else v2
eval(rest, value :: newStack, frame, kont, mkont, trail, h)
eval(rest, value :: newStack, frame, kont, trail1, mkont, trail2, h)
case LocalGet(i) =>
eval(rest, frame.locals(i) :: stack, frame, kont, mkont, trail, h)
eval(rest, frame.locals(i) :: stack, frame, kont, trail1, mkont, trail2, h)
case LocalSet(i) =>
val value :: newStack = stack
frame.locals(i) = value
eval(rest, newStack, frame, kont, mkont, trail, h)
eval(rest, newStack, frame, kont, trail1, mkont, trail2, h)
case LocalTee(i) =>
val value :: newStack = stack
frame.locals(i) = value
eval(rest, stack, frame, kont, mkont, trail, h)
eval(rest, stack, frame, kont, trail1, mkont, trail2, h)
case GlobalGet(i) =>
eval(rest, module.globals(i).value :: stack, frame, kont, mkont, trail, h)
eval(rest, module.globals(i).value :: stack, frame, kont, trail1, mkont, trail2, h)
case GlobalSet(i) =>
val value :: newStack = stack
module.globals(i).ty match {
Expand All @@ -103,127 +120,128 @@ case class EvaluatorFX(module: ModuleInstance) {
case GlobalType(_, true) => throw new Exception("Invalid type")
case _ => throw new Exception("Cannot set immutable global")
}
eval(rest, newStack, frame, kont, mkont, trail, h)
eval(rest, newStack, frame, kont, trail1, mkont, trail2, h)
case MemorySize =>
eval(rest, I32V(module.memory.head.size) :: stack, frame, kont, mkont, trail, h)
eval(rest, I32V(module.memory.head.size) :: stack, frame, kont, trail1, mkont, trail2, h)
case MemoryGrow =>
val I32V(delta) :: newStack = stack
val mem = module.memory.head
val oldSize = mem.size
mem.grow(delta) match {
case Some(e) =>
eval(rest, I32V(-1) :: newStack, frame, kont, mkont, trail, h)
eval(rest, I32V(-1) :: newStack, frame, kont, trail1, mkont, trail2, h)
case _ =>
eval(rest, I32V(oldSize) :: newStack, frame, kont, mkont, trail, h)
eval(rest, I32V(oldSize) :: newStack, frame, kont, trail1, mkont, trail2, h)
}
case MemoryFill =>
val I32V(value) :: I32V(offset) :: I32V(size) :: newStack = stack
if (memOutOfBound(module, 0, offset, size))
throw new Exception("Out of bounds memory access") // GW: turn this into a `trap`?
else {
module.memory.head.fill(offset, size, value.toByte)
eval(rest, newStack, frame, kont, mkont, trail, h)
eval(rest, newStack, frame, kont, trail1, mkont, trail2, h)
}
case MemoryCopy =>
val I32V(n) :: I32V(src) :: I32V(dest) :: newStack = stack
if (memOutOfBound(module, 0, src, n) || memOutOfBound(module, 0, dest, n))
throw new Exception("Out of bounds memory access")
else {
module.memory.head.copy(dest, src, n)
eval(rest, newStack, frame, kont, mkont, trail, h)
eval(rest, newStack, frame, kont, trail1, mkont, trail2, h)
}
case Const(n) => eval(rest, n :: stack, frame, kont, mkont, trail, h)
case Const(n) => eval(rest, n :: stack, frame, kont, trail1, mkont, trail2, h)
case Binary(op) =>
val v2 :: v1 :: newStack = stack
eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, mkont, trail, h)
eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, trail1, mkont, trail2, h)
case Unary(op) =>
val v :: newStack = stack
eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, mkont, trail, h)
eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, trail1, mkont, trail2, h)
case Compare(op) =>
val v2 :: v1 :: newStack = stack
eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, mkont, trail, h)
eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, trail1, mkont, trail2, h)
case Test(op) =>
val v :: newStack = stack
eval(rest, evalTestOp(op, v) :: newStack, frame, kont, mkont, trail, h)
eval(rest, evalTestOp(op, v) :: newStack, frame, kont, trail1, mkont, trail2, h)
case Store(StoreOp(align, offset, ty, None)) =>
val I32V(v) :: I32V(addr) :: newStack = stack
module.memory(0).storeInt(addr + offset, v)
eval(rest, newStack, frame, kont, mkont, trail, h)
eval(rest, newStack, frame, kont, trail1, mkont, trail2, h)
case Load(LoadOp(align, offset, ty, None, None)) =>
val I32V(addr) :: newStack = stack
val value = module.memory(0).loadInt(addr + offset)
eval(rest, I32V(value) :: newStack, frame, kont, mkont, trail, h)
eval(rest, I32V(value) :: newStack, frame, kont, trail1, mkont, trail2, h)
case Nop =>
eval(rest, stack, frame, kont, mkont, trail, h)
eval(rest, stack, frame, kont, trail1, mkont, trail2, h)
case Unreachable => throw Trap()
case Block(ty, inner) =>
val funcTy = getFuncType(ty)
val (inputs, restStack) = stack.splitAt(funcTy.inps.size)
val restK: Cont[Ans] = (retStack, mkont1) => {
val restK: Cont[Ans] = (retStack, trail1, mkont1) => {
// kont -> mkont -> mkont1
eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, mkont1, trail, h)
eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail1, mkont1, trail2, h)
}
eval(inner, inputs, frame, restK, mkont, restK :: trail, h)
eval(inner, inputs, frame, restK, trail1, mkont, restK :: trail2, h)
case Loop(ty, inner) =>
// We construct two continuations, one for the break (to the begining of the loop),
// and one for fall-through to the next instruction following the syntactic structure
// of the program.
val funcTy = getFuncType(ty)
val (inputs, restStack) = stack.splitAt(funcTy.inps.size)
val restK: Cont[Ans] = (retStack, mkont) =>
eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, mkont, trail, h)
def loop(retStack: List[Value], mkont: MCont[Ans]): Ans =
eval(inner, retStack.take(funcTy.inps.size), frame, restK, mkont, loop _ :: trail, h)
loop(inputs, mkont)
val restK: Cont[Ans] = (retStack, trail1, mkont) =>
eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail1, mkont, trail2, h)
def loop(retStack: List[Value], trail1: List[Cont[Ans]], mkont: MCont[Ans]): Ans =

eval(inner, retStack.take(funcTy.inps.size), frame, restK, trail1, mkont, (loop _ : Cont[Ans]):: trail2, h)
loop(inputs, trail1, mkont)
case If(ty, thn, els) =>
val funcTy = getFuncType(ty)
val I32V(cond) :: newStack = stack
val inner = if (cond != 0) thn else els
val (inputs, restStack) = newStack.splitAt(funcTy.inps.size)
val restK: Cont[Ans] = (retStack, mkont) =>
eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, mkont, trail, h)
eval(inner, inputs, frame, restK, mkont, restK :: trail, h)
val restK: Cont[Ans] = (retStack, trail1, mkont) =>
eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail1, mkont, trail2, h)
eval(inner, inputs, frame, restK, trail1, mkont, restK :: trail2, h)
case Br(label) =>
trail(label)(stack, mkont) // s => ().asInstanceOf[Ans]) //mkont)
trail2(label)(stack, trail1, mkont) // s => ().asInstanceOf[Ans]) //mkont)
case BrIf(label) =>
val I32V(cond) :: newStack = stack
if (cond != 0) trail(label)(newStack, mkont)
else eval(rest, newStack, frame, kont, mkont, trail, h)
if (cond != 0) trail2(label)(newStack, trail1, mkont)
else eval(rest, newStack, frame, kont, trail1, mkont, trail2, h)
case BrTable(labels, default) =>
val I32V(cond) :: newStack = stack
val goto = if (cond < labels.length) labels(cond) else default
trail(goto)(newStack, mkont)
case Return => trail.last(stack, mkont)
case Call(f) => evalCall(f, rest, stack, frame, kont, mkont, trail, h, false)
case ReturnCall(f) => evalCall(f, rest, stack, frame, kont, mkont, trail, h, true)
trail2(goto)(newStack, trail1, mkont)
case Return => trail2.last(stack, trail1, mkont)
case Call(f) => evalCall(f, rest, stack, frame, kont, trail1, mkont, trail2, h, false)
case ReturnCall(f) => evalCall(f, rest, stack, frame, kont, trail1, mkont, trail2, h, true)
case RefFunc(f) =>
// TODO: RefFuncV stores an applicable function, instead of a syntactic structure
eval(rest, RefFuncV(f) :: stack, frame, kont, mkont, trail, h)
eval(rest, RefFuncV(f) :: stack, frame, kont, trail1, mkont, trail2, h)
case CallRef(ty) =>
val RefFuncV(f) :: newStack = stack
evalCall(f, rest, newStack, frame, kont, mkont, trail, h, false)
evalCall(f, rest, newStack, frame, kont, trail1, mkont, trail2, h, false)

// resumable try-catch exception handling:
case TryCatch(es1, es2) =>
val join: MCont[Ans] = (newStack) => eval(rest, stack, frame, kont, mkont, trail, h)
// push trail1 to join point
val join: MCont[Ans] = (newStack) => eval(rest, stack, frame, kont, trail1, mkont, trail2, h)
// the `restK` for catch block (es2) is the join point
// the restK simply applies the meta-continuation, this is the same the [nil] case
// where we fall back to join point
val idK: Cont[Ans] = (s, m) => m(s)
val newHandler: Handler[Ans] = (newStack) => eval(es2, newStack, frame, idK, join, trail, h)
eval(es1, List(), frame, idK, join, trail, newHandler)
// val idK: Cont[Ans] = (s, m) => m(s)
val newHandler: Handler[Ans] = (newStack) => eval(es2, newStack, frame, init: Cont[Ans], List(), join, trail2, h)
eval(es1, List(), frame, init: Cont[Ans], List(), join, trail2, newHandler)
case Resume0() =>
val (resume: TCContV[Ans]) :: newStack = stack
val k: Cont[Ans] = (s, m) => eval(rest, newStack /*!*/, frame, kont, m, trail, h)
resume.k(List(), k, mkont)
val k: Cont[Ans] = (s, trail1, m) => eval(rest, newStack /*!*/, frame, kont, trail1, m, trail2, h)
resume.k(List(), k, trail1, mkont)
case Throw() =>
val err :: newStack = stack
// kont composed with k
// kont composed with k1 and trail1
// note that kr doesn't use the stack at all
// it only takes the err value
def kr(s: Stack, k1: Cont[Ans], m: MCont[Ans]): Ans = {
val kontK: Cont[Ans] = (s1, m1) => kont(s1, s2 => k1(s2, m1))
eval(rest, newStack /*!*/, frame, kontK, m /*vs mkont?*/, trail, h)
def kr(s: Stack, k1: Cont[Ans], newTrail1: List[Cont[Ans]], m: MCont[Ans]): Ans = {
eval(rest, newStack /*!*/, frame, kont, trail1 ++ List(k1) ++ newTrail1, m /*vs mkont?*/, trail2, h)
}
h(List(err, TCContV(kr)))

Expand All @@ -232,15 +250,14 @@ case class EvaluatorFX(module: ModuleInstance) {
val RefFuncV(f) :: newStack = stack
// should be similar to the contiuantion thrown by `throw`

// val k: Cont[Ans] = (s, mk) => evalCall(f, List(), s, frame, idK, mk, trail, h, false)
// val k: Cont[Ans] = (s, mk) => evalCall(f, List(), s, frame, idK, mk, trail2, h, false)
// TODO: where should kont go?
// TODO: this implementation is not right
def kr(s: Stack, k1: Cont[Ans], mk: MCont[Ans]): Ans = {
val kontK: Cont[Ans] = (s1, m1) => kont(s1, s2 => k1(s2, m1))
evalCall(f, List(), s, frame, kontK, mk, trail, h, false)
def kr(s: Stack, k1: Cont[Ans], trail1: List[Cont[Ans]], mk: MCont[Ans]): Ans = {
evalCall(f, List(), s, frame, k1, trail1, mk, trail2, h, false)
}

eval(rest, TCContV(kr) :: newStack, frame, kont, mkont, trail, h)
eval(rest, TCContV(kr) :: newStack, frame, kont, trail1, mkont, trail2, h)
// TODO: implement the following
case Suspend(tag_id) => {
// println(s"${RED}Unimplimented Suspending tag $tag_id")
Expand All @@ -258,8 +275,8 @@ case class EvaluatorFX(module: ModuleInstance) {
val (inputs, restStack) = newStack.splitAt(inps.size)

if (handler.length == 0) {
val k: Cont[Ans] = (s, m) => eval(rest, newStack, frame, kont, m, trail, h)
f.k(inputs, k, mkont)
val k: Cont[Ans] = (s, trail1, m) => eval(rest, newStack, frame, kont, trail1, m, trail2, h)
f.k(inputs, k, List(), mkont)
} else {
// TODO: attempt single tag first
throw new Exception("tags not supported")
Expand All @@ -283,11 +300,11 @@ case class EvaluatorFX(module: ModuleInstance) {

case CallRef(ty) =>
val RefFuncV(f) :: newStack = stack
evalCall(f, rest, newStack, frame, kont, mkont, trail, h, false)
evalCall(f, rest, newStack, frame, kont, trail1, mkont, trail2, h, false)

case CallRef(ty) =>
val RefFuncV(f) :: newStack = stack
evalCall(f, rest, newStack, frame, kont, mkont, trail, h, false)
evalCall(f, rest, newStack, frame, kont, trail1, mkont, trail2, h, false)

case _ =>
println(inst)
Expand Down Expand Up @@ -323,11 +340,11 @@ case class EvaluatorFX(module: ModuleInstance) {
}
if (instrs.isEmpty) println("Warning: nothing is executed")
val handler0: Handler[Ans] = stack => throw new Exception(s"Uncaught exception: $stack")
eval(instrs, List(), Frame(ArrayBuffer(I32V(0))), halt, mhalt, List(halt), handler0)
eval(instrs, List(), Frame(ArrayBuffer(I32V(0))), halt, List(), mhalt, List(halt), handler0)
}

def evalTop(m: ModuleInstance): Unit = {
val halt: Cont[Unit] = (stack, m) => m(stack)
val halt: Cont[Unit] = init
evalTop(halt, stack => ())
}
}
4 changes: 2 additions & 2 deletions src/main/scala/wasm/MiniWasmScript.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ sealed class ScriptRunner {
type Cont = evaluator.Cont[evaluator.Stack]
type MCont = evaluator.MCont[evaluator.Stack]
type Handler = evaluator.Handler[evaluator.Stack]
val k: Cont = (retStack, m) => m(retStack)
val k: Cont = evaluator.init;
val mk: MCont = (retStack) => retStack
val h0: Handler = stack => throw new Exception(s"Uncaught exception: $stack")
// TODO: change this back to Evaluator if we are just testing original stuff
val actual = evaluator.eval(instrs, List(), Frame(ArrayBuffer(args: _*)), k, mk, List(k), h0)
val actual = evaluator.eval(instrs, List(), Frame(ArrayBuffer(args: _*)), k, List(), mk, List(k), h0)
assert(actual == expect)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/genwasym/TestFx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TestFx extends FunSuite {
val evaluator = EvaluatorFX(ModuleInstance(module))
type Cont = evaluator.Cont[Unit]
type MCont = evaluator.MCont[Unit]
val haltK: Cont = (stack, m) => m(stack)
val haltK: Cont = evaluator.init;
val haltMK: MCont = (stack) => {
//println(s"halt cont: $stack")
expected match {
Expand Down

0 comments on commit ba44c23

Please sign in to comment.