Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Commit

Permalink
Implement support for Closures.
Browse files Browse the repository at this point in the history
  • Loading branch information
sjrd committed Mar 18, 2024
1 parent 4a318ee commit f872b40
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 5 deletions.
1 change: 1 addition & 0 deletions cli/src/main/scala/TestSuites.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ object TestSuites {
TestSuite("testsuite.core.InterfaceCall"),
TestSuite("testsuite.core.AsInstanceOfTest"),
TestSuite("testsuite.core.ClassOfTest"),
TestSuite("testsuite.core.ClosureTest"),
TestSuite("testsuite.core.GetClassTest"),
TestSuite("testsuite.core.JSInteropTest"),
TestSuite("testsuite.core.HijackedClassesDispatchTest"),
Expand Down
3 changes: 3 additions & 0 deletions loader.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ const scalaJSHelpers = {

// Closure
closure: (f, data) => f.bind(void 0, data),
closureThis: (f, data) => function(...args) { return f(this, data, ...args); },
closureRest: (f, data, n) => ((...args) => f(data, ...args.slice(0, n), args.slice(n))),
closureThisRest: (f, data, n) => function(...args) { return f(this, data, ...args.slice(0, n), args.slice(n)); },

// Strings
emptyString: () => "",
Expand Down
48 changes: 48 additions & 0 deletions test-suite/src/main/scala/testsuite/core/ClosureTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package testsuite.core

import scala.scalajs.js

import testsuite.Assert.ok

object ClosureTest {
def main(): Unit = {
testClosure()
testClosureThis()

// TODO We cannot test closures with ...rest params yet because they need Seq's

testGiveToActualJSCode()
}

def testClosure(): Unit = {
def makeClosure(x: Int, y: String): js.Function2[Boolean, String, String] =
(z, w) => s"$x $y $z $w"

val f = makeClosure(5, "foo")
ok(f(true, "bar") == "5 foo true bar")
}

def testClosureThis(): Unit = {
def makeClosure(x: Int, y: String): js.ThisFunction2[Any, Boolean, String, String] =
(ths, z, w) => s"$ths $x $y $z $w"

val f = makeClosure(5, "foo")
ok(f(new Obj, true, "bar") == "Obj 5 foo true bar")
}

def testGiveToActualJSCode(): Unit = {
val arr = js.Array(2, 3, 5, 7, 11)
val f: js.Function1[Int, Int] = x => x * 2
val result = arr.asInstanceOf[js.Dynamic].map(f).asInstanceOf[js.Array[Int]]
ok(result.length == 5)
ok(result(0) == 4)
ok(result(1) == 6)
ok(result(2) == 10)
ok(result(3) == 14)
ok(result(4) == 22)
}

class Obj {
override def toString(): String = "Obj"
}
}
74 changes: 73 additions & 1 deletion wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ private class WasmExpressionBuilder private (
case t: IRTrees.JSGlobalRef => genJSGlobalRef(t)
case t: IRTrees.JSTypeOfGlobalRef => genJSTypeOfGlobalRef(t)
case t: IRTrees.JSLinkingInfo => genJSLinkingInfo(t)
case t: IRTrees.Closure => genClosure(t)

case _ =>
println(tree)
Expand All @@ -156,7 +157,6 @@ private class WasmExpressionBuilder private (
// case IRTrees.NewArray(pos) =>
// case IRTrees.Match(tpe) =>
// case IRTrees.Throw(pos) =>
// case IRTrees.Closure(pos) =>
// case IRTrees.RecordSelect(tpe) =>
// case IRTrees.TryFinally(pos) =>
// case IRTrees.JSImportMeta(pos) =>
Expand Down Expand Up @@ -1578,4 +1578,76 @@ private class WasmExpressionBuilder private (
instrs += CALL(FuncIdx(WasmFunctionName.jsLinkingInfo))
IRTypes.AnyType
}

private def genClosure(tree: IRTrees.Closure): IRTypes.Type = {
implicit val ctx = this.ctx

val hasThis = !tree.arrow
val dataStructType = ctx.getClosureDataStructType(tree.captureParams.map(_.ptpe))

// Define the function where captures are reified as a `__captureData` argument.
val closureFuncName = fctx.genInnerFuncName()
locally {
val receiverParam =
if (!hasThis) None
else Some(WasmLocal(WasmLocalName.receiver, Types.WasmAnyRef, isParameter = true))

val captureDataParam = WasmLocal(
WasmLocalName("__captureData"),
Types.WasmRefType(Types.WasmHeapType.Type(dataStructType.name)),
isParameter = true
)

val paramLocals = (tree.params ::: tree.restParam.toList).map { param =>
val typ = TypeTransformer.transformType(param.ptpe)
WasmLocal(WasmLocalName.fromIR(param.name.name), typ, isParameter = true)
}
val resultTyps = TypeTransformer.transformResultType(IRTypes.AnyType)

implicit val fctx = WasmFunctionContext(
enclosingClassName = None,
closureFuncName,
receiverParam,
captureDataParam :: paramLocals,
resultTyps
)

val captureDataLocalIdx = fctx.paramIndices.head

// Extract the fields of captureData in individual locals
for ((captureParam, index) <- tree.captureParams.zipWithIndex) {
val local = fctx.addLocal(
captureParam.name.name,
TypeTransformer.transformType(captureParam.ptpe)
)
fctx.instrs += LOCAL_GET(captureDataLocalIdx)
fctx.instrs += STRUCT_GET(TypeIdx(dataStructType.name), StructFieldIdx(index))
fctx.instrs += LOCAL_SET(local)
}

// Now transform the body
WasmExpressionBuilder.generateIRBody(tree.body, IRTypes.AnyType)

fctx.buildAndAddToContext()
}

// Put a reference to the function on the stack
instrs += ctx.refFuncWithDeclaration(closureFuncName)

// Evaluate the capture values and instantiate the capture data struct
for ((param, value) <- tree.captureParams.zip(tree.captureValues))
genTree(value, param.ptpe)
instrs += STRUCT_NEW(TypeIdx(dataStructType.name))

// Call the appropriate helper
val helper = (hasThis, tree.restParam.isDefined) match {
case (false, false) => WasmFunctionName.closure
case (true, false) => WasmFunctionName.closureThis
case (false, true) => WasmFunctionName.closureRest
case (true, true) => WasmFunctionName.closureThisRest
}
instrs += CALL(FuncIdx(helper))

IRTypes.AnyType
}
}
8 changes: 8 additions & 0 deletions wasm/src/main/scala/wasm4s/Names.scala
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ object Names {
def typeTest(primRef: IRTypes.PrimRef): WasmFunctionName = helper("t" + primRef.charCode)

val closure = helper("closure")
val closureThis = helper("closureThis")
val closureRest = helper("closureRest")
val closureThisRest = helper("closureThisRest")

val emptyString = helper("emptyString")
val stringLength = helper("stringLength")
Expand Down Expand Up @@ -246,6 +249,9 @@ object Names {
def apply(name: WasmTypeName.WasmITableTypeName) = new WasmFieldName(name.name)
def apply(name: IRNames.MethodName) = new WasmFieldName(name.nameString)
def apply(name: WasmFunctionName) = new WasmFieldName(name.name)

def captureParam(i: Int): WasmFieldName = new WasmFieldName("c" + i)

val vtable = new WasmFieldName("vtable")
val itable = new WasmFieldName("itable")
val itables = new WasmFieldName("itables")
Expand Down Expand Up @@ -329,6 +335,8 @@ object Names {
object WasmStructTypeName {
def apply(name: IRNames.ClassName) = new WasmStructTypeName(name.nameString)

def captureData(index: Int): WasmStructTypeName = new WasmStructTypeName("captureData__" + index)

val typeData = new WasmStructTypeName("typeData")
}

Expand Down
31 changes: 31 additions & 0 deletions wasm/src/main/scala/wasm4s/WasmContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,12 @@ trait ReadOnlyWasmContext {
trait FunctionTypeWriterWasmContext extends ReadOnlyWasmContext { this: WasmContext =>
protected val functionSignatures = LinkedHashMap.empty[WasmFunctionSignature, Int]
protected val constantStringGlobals = LinkedHashMap.empty[String, WasmGlobalName]
protected val closureDataTypes = LinkedHashMap.empty[List[IRTypes.Type], WasmStructType]

private var nextConstantStringIndex: Int = 1
private var nextClosureDataTypeIndex: Int = 1

def addFunction(fun: WasmFunction): Unit
protected def addGlobal(g: WasmGlobal): Unit
protected def addFuncDeclaration(name: WasmFunctionName): Unit

Expand Down Expand Up @@ -159,6 +162,19 @@ trait FunctionTypeWriterWasmContext extends ReadOnlyWasmContext { this: WasmCont
WasmInstr.GLOBAL_GET(WasmImmediate.GlobalIdx(globalName))
}

def getClosureDataStructType(captureParamTypes: List[IRTypes.Type]): WasmStructType = {
closureDataTypes.getOrElse(captureParamTypes, {
val fields: List[WasmStructField] =
for ((tpe, i) <- captureParamTypes.zipWithIndex) yield
WasmStructField(WasmFieldName.captureParam(i), TypeTransformer.transformType(tpe)(this), isMutable = false)
val structTypeName = WasmStructTypeName.captureData(nextClosureDataTypeIndex)
nextClosureDataTypeIndex += 1
val structType = WasmStructType(structTypeName, fields, superType = None)
addGCType(structType)
structType
})
}

def refFuncWithDeclaration(name: WasmFunctionName): WasmInstr.REF_FUNC = {
addFuncDeclaration(name)
WasmInstr.REF_FUNC(WasmImmediate.FuncIdx(name))
Expand Down Expand Up @@ -227,6 +243,21 @@ class WasmContext(val module: WasmModule) extends FunctionTypeWriterWasmContext
List(WasmRefType(WasmHeapType.Simple.Func), WasmAnyRef),
List(WasmRefType.any)
)
addHelperImport(
WasmFunctionName.closureThis,
List(WasmRefType(WasmHeapType.Simple.Func), WasmAnyRef),
List(WasmRefType.any)
)
addHelperImport(
WasmFunctionName.closureRest,
List(WasmRefType(WasmHeapType.Simple.Func), WasmAnyRef),
List(WasmRefType.any)
)
addHelperImport(
WasmFunctionName.closureThisRest,
List(WasmRefType(WasmHeapType.Simple.Func), WasmAnyRef),
List(WasmRefType.any)
)

addHelperImport(WasmFunctionName.emptyString, List(), List(WasmRefType.any))
addHelperImport(WasmFunctionName.stringLength, List(WasmRefType.any), List(WasmInt32))
Expand Down
18 changes: 14 additions & 4 deletions wasm/src/main/scala/wasm4s/WasmFunctionContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import wasm.wasm4s.WasmInstr._
import wasm.ir2wasm.TypeTransformer

class WasmFunctionContext private (
ctx: WasmContext,
ctx: FunctionTypeWriterWasmContext,
val enclosingClassName: Option[IRNames.ClassName],
val functionName: WasmFunctionName,
_receiver: Option[WasmLocal],
Expand All @@ -23,6 +23,7 @@ class WasmFunctionContext private (
) {
private var cnt = 0
private var labelIdx = 0
private var innerFuncIdx = 0

val locals = new WasmSymbolTable[WasmLocalName, WasmLocal]()

Expand Down Expand Up @@ -80,6 +81,15 @@ class WasmFunctionContext private (
def addSyntheticLocal(typ: WasmType): LocalIdx =
addLocal(genSyntheticLocalName(), typ)

def genInnerFuncName(): WasmFunctionName = {
val innerName = WasmFunctionName(
functionName.namespace,
functionName.simpleName + "__c" + innerFuncIdx
)
innerFuncIdx += 1
innerName
}

// Helpers to build structured control flow

def ifThenElse(blockType: BlockType)(thenp: => Unit)(elsep: => Unit): Unit = {
Expand Down Expand Up @@ -177,7 +187,7 @@ object WasmFunctionContext {
receiver: Option[WasmLocal],
params: List[WasmLocal],
resultTypes: List[WasmType]
)(implicit ctx: WasmContext): WasmFunctionContext = {
)(implicit ctx: FunctionTypeWriterWasmContext): WasmFunctionContext = {
new WasmFunctionContext(ctx, enclosingClassName, name, receiver, params, resultTypes)
}

Expand All @@ -187,7 +197,7 @@ object WasmFunctionContext {
receiverTyp: Option[WasmType],
paramDefs: List[IRTrees.ParamDef],
resultType: IRTypes.Type
)(implicit ctx: WasmContext): WasmFunctionContext = {
)(implicit ctx: FunctionTypeWriterWasmContext): WasmFunctionContext = {
val receiver = receiverTyp.map { typ =>
WasmLocal(WasmLocalName.receiver, typ, isParameter = true)
}
Expand All @@ -205,7 +215,7 @@ object WasmFunctionContext {
name: WasmFunctionName,
params: List[(String, WasmType)],
resultTypes: List[WasmType]
)(implicit ctx: WasmContext): WasmFunctionContext = {
)(implicit ctx: FunctionTypeWriterWasmContext): WasmFunctionContext = {
val paramLocals = params.map { param =>
WasmLocal(WasmLocalName.fromStr(param._1), param._2, isParameter = true)
}
Expand Down

0 comments on commit f872b40

Please sign in to comment.