diff --git a/cli/src/main/scala/TestSuites.scala b/cli/src/main/scala/TestSuites.scala index dfca3849..0b4611b8 100644 --- a/cli/src/main/scala/TestSuites.scala +++ b/cli/src/main/scala/TestSuites.scala @@ -9,6 +9,7 @@ object TestSuites { TestSuite("testsuite.core.InterfaceCall"), TestSuite("testsuite.core.AsInstanceOfTest"), TestSuite("testsuite.core.ClassOfTest"), + TestSuite("testsuite.core.ClosureTest"), TestSuite("testsuite.core.GetClassTest"), TestSuite("testsuite.core.JSInteropTest"), TestSuite("testsuite.core.HijackedClassesDispatchTest"), diff --git a/loader.mjs b/loader.mjs index 26a85592..1a41b08d 100644 --- a/loader.mjs +++ b/loader.mjs @@ -63,6 +63,9 @@ const scalaJSHelpers = { // Closure closure: (f, data) => f.bind(void 0, data), + closureThis: (f, data) => function(...args) { return f(this, data, ...args); }, + closureRest: (f, data, n) => ((...args) => f(data, ...args.slice(0, n), args.slice(n))), + closureThisRest: (f, data, n) => function(...args) { return f(this, data, ...args.slice(0, n), args.slice(n)); }, // Strings emptyString: () => "", diff --git a/test-suite/src/main/scala/testsuite/core/ClosureTest.scala b/test-suite/src/main/scala/testsuite/core/ClosureTest.scala new file mode 100644 index 00000000..15691fa9 --- /dev/null +++ b/test-suite/src/main/scala/testsuite/core/ClosureTest.scala @@ -0,0 +1,48 @@ +package testsuite.core + +import scala.scalajs.js + +import testsuite.Assert.ok + +object ClosureTest { + def main(): Unit = { + testClosure() + testClosureThis() + + // TODO We cannot test closures with ...rest params yet because they need Seq's + + testGiveToActualJSCode() + } + + def testClosure(): Unit = { + def makeClosure(x: Int, y: String): js.Function2[Boolean, String, String] = + (z, w) => s"$x $y $z $w" + + val f = makeClosure(5, "foo") + ok(f(true, "bar") == "5 foo true bar") + } + + def testClosureThis(): Unit = { + def makeClosure(x: Int, y: String): js.ThisFunction2[Any, Boolean, String, String] = + (ths, z, w) => s"$ths $x $y $z $w" + + val f = makeClosure(5, "foo") + ok(f(new Obj, true, "bar") == "Obj 5 foo true bar") + } + + def testGiveToActualJSCode(): Unit = { + val arr = js.Array(2, 3, 5, 7, 11) + val f: js.Function1[Int, Int] = x => x * 2 + val result = arr.asInstanceOf[js.Dynamic].map(f).asInstanceOf[js.Array[Int]] + ok(result.length == 5) + ok(result(0) == 4) + ok(result(1) == 6) + ok(result(2) == 10) + ok(result(3) == 14) + ok(result(4) == 22) + } + + class Obj { + override def toString(): String = "Obj" + } +} diff --git a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala index 4e827aa1..d0e2c814 100644 --- a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala +++ b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala @@ -141,6 +141,7 @@ private class WasmExpressionBuilder private ( case t: IRTrees.JSGlobalRef => genJSGlobalRef(t) case t: IRTrees.JSTypeOfGlobalRef => genJSTypeOfGlobalRef(t) case t: IRTrees.JSLinkingInfo => genJSLinkingInfo(t) + case t: IRTrees.Closure => genClosure(t) case _ => println(tree) @@ -156,7 +157,6 @@ private class WasmExpressionBuilder private ( // case IRTrees.NewArray(pos) => // case IRTrees.Match(tpe) => // case IRTrees.Throw(pos) => - // case IRTrees.Closure(pos) => // case IRTrees.RecordSelect(tpe) => // case IRTrees.TryFinally(pos) => // case IRTrees.JSImportMeta(pos) => @@ -1578,4 +1578,76 @@ private class WasmExpressionBuilder private ( instrs += CALL(FuncIdx(WasmFunctionName.jsLinkingInfo)) IRTypes.AnyType } + + private def genClosure(tree: IRTrees.Closure): IRTypes.Type = { + implicit val ctx = this.ctx + + val hasThis = !tree.arrow + val dataStructType = ctx.getClosureDataStructType(tree.captureParams.map(_.ptpe)) + + // Define the function where captures are reified as a `__captureData` argument. + val closureFuncName = fctx.genInnerFuncName() + locally { + val receiverParam = + if (!hasThis) None + else Some(WasmLocal(WasmLocalName.receiver, Types.WasmAnyRef, isParameter = true)) + + val captureDataParam = WasmLocal( + WasmLocalName("__captureData"), + Types.WasmRefType(Types.WasmHeapType.Type(dataStructType.name)), + isParameter = true + ) + + val paramLocals = (tree.params ::: tree.restParam.toList).map { param => + val typ = TypeTransformer.transformType(param.ptpe) + WasmLocal(WasmLocalName.fromIR(param.name.name), typ, isParameter = true) + } + val resultTyps = TypeTransformer.transformResultType(IRTypes.AnyType) + + implicit val fctx = WasmFunctionContext( + enclosingClassName = None, + closureFuncName, + receiverParam, + captureDataParam :: paramLocals, + resultTyps + ) + + val captureDataLocalIdx = fctx.paramIndices.head + + // Extract the fields of captureData in individual locals + for ((captureParam, index) <- tree.captureParams.zipWithIndex) { + val local = fctx.addLocal( + captureParam.name.name, + TypeTransformer.transformType(captureParam.ptpe) + ) + fctx.instrs += LOCAL_GET(captureDataLocalIdx) + fctx.instrs += STRUCT_GET(TypeIdx(dataStructType.name), StructFieldIdx(index)) + fctx.instrs += LOCAL_SET(local) + } + + // Now transform the body + WasmExpressionBuilder.generateIRBody(tree.body, IRTypes.AnyType) + + fctx.buildAndAddToContext() + } + + // Put a reference to the function on the stack + instrs += ctx.refFuncWithDeclaration(closureFuncName) + + // Evaluate the capture values and instantiate the capture data struct + for ((param, value) <- tree.captureParams.zip(tree.captureValues)) + genTree(value, param.ptpe) + instrs += STRUCT_NEW(TypeIdx(dataStructType.name)) + + // Call the appropriate helper + val helper = (hasThis, tree.restParam.isDefined) match { + case (false, false) => WasmFunctionName.closure + case (true, false) => WasmFunctionName.closureThis + case (false, true) => WasmFunctionName.closureRest + case (true, true) => WasmFunctionName.closureThisRest + } + instrs += CALL(FuncIdx(helper)) + + IRTypes.AnyType + } } diff --git a/wasm/src/main/scala/wasm4s/Names.scala b/wasm/src/main/scala/wasm4s/Names.scala index 42ce0585..7628230f 100644 --- a/wasm/src/main/scala/wasm4s/Names.scala +++ b/wasm/src/main/scala/wasm4s/Names.scala @@ -156,6 +156,9 @@ object Names { def typeTest(primRef: IRTypes.PrimRef): WasmFunctionName = helper("t" + primRef.charCode) val closure = helper("closure") + val closureThis = helper("closureThis") + val closureRest = helper("closureRest") + val closureThisRest = helper("closureThisRest") val emptyString = helper("emptyString") val stringLength = helper("stringLength") @@ -246,6 +249,9 @@ object Names { def apply(name: WasmTypeName.WasmITableTypeName) = new WasmFieldName(name.name) def apply(name: IRNames.MethodName) = new WasmFieldName(name.nameString) def apply(name: WasmFunctionName) = new WasmFieldName(name.name) + + def captureParam(i: Int): WasmFieldName = new WasmFieldName("c" + i) + val vtable = new WasmFieldName("vtable") val itable = new WasmFieldName("itable") val itables = new WasmFieldName("itables") @@ -329,6 +335,8 @@ object Names { object WasmStructTypeName { def apply(name: IRNames.ClassName) = new WasmStructTypeName(name.nameString) + def captureData(index: Int): WasmStructTypeName = new WasmStructTypeName("captureData__" + index) + val typeData = new WasmStructTypeName("typeData") } diff --git a/wasm/src/main/scala/wasm4s/WasmContext.scala b/wasm/src/main/scala/wasm4s/WasmContext.scala index bf83a79f..c542f023 100644 --- a/wasm/src/main/scala/wasm4s/WasmContext.scala +++ b/wasm/src/main/scala/wasm4s/WasmContext.scala @@ -111,9 +111,12 @@ trait ReadOnlyWasmContext { trait FunctionTypeWriterWasmContext extends ReadOnlyWasmContext { this: WasmContext => protected val functionSignatures = LinkedHashMap.empty[WasmFunctionSignature, Int] protected val constantStringGlobals = LinkedHashMap.empty[String, WasmGlobalName] + protected val closureDataTypes = LinkedHashMap.empty[List[IRTypes.Type], WasmStructType] private var nextConstantStringIndex: Int = 1 + private var nextClosureDataTypeIndex: Int = 1 + def addFunction(fun: WasmFunction): Unit protected def addGlobal(g: WasmGlobal): Unit protected def addFuncDeclaration(name: WasmFunctionName): Unit @@ -159,6 +162,19 @@ trait FunctionTypeWriterWasmContext extends ReadOnlyWasmContext { this: WasmCont WasmInstr.GLOBAL_GET(WasmImmediate.GlobalIdx(globalName)) } + def getClosureDataStructType(captureParamTypes: List[IRTypes.Type]): WasmStructType = { + closureDataTypes.getOrElse(captureParamTypes, { + val fields: List[WasmStructField] = + for ((tpe, i) <- captureParamTypes.zipWithIndex) yield + WasmStructField(WasmFieldName.captureParam(i), TypeTransformer.transformType(tpe)(this), isMutable = false) + val structTypeName = WasmStructTypeName.captureData(nextClosureDataTypeIndex) + nextClosureDataTypeIndex += 1 + val structType = WasmStructType(structTypeName, fields, superType = None) + addGCType(structType) + structType + }) + } + def refFuncWithDeclaration(name: WasmFunctionName): WasmInstr.REF_FUNC = { addFuncDeclaration(name) WasmInstr.REF_FUNC(WasmImmediate.FuncIdx(name)) @@ -227,6 +243,21 @@ class WasmContext(val module: WasmModule) extends FunctionTypeWriterWasmContext List(WasmRefType(WasmHeapType.Simple.Func), WasmAnyRef), List(WasmRefType.any) ) + addHelperImport( + WasmFunctionName.closureThis, + List(WasmRefType(WasmHeapType.Simple.Func), WasmAnyRef), + List(WasmRefType.any) + ) + addHelperImport( + WasmFunctionName.closureRest, + List(WasmRefType(WasmHeapType.Simple.Func), WasmAnyRef), + List(WasmRefType.any) + ) + addHelperImport( + WasmFunctionName.closureThisRest, + List(WasmRefType(WasmHeapType.Simple.Func), WasmAnyRef), + List(WasmRefType.any) + ) addHelperImport(WasmFunctionName.emptyString, List(), List(WasmRefType.any)) addHelperImport(WasmFunctionName.stringLength, List(WasmRefType.any), List(WasmInt32)) diff --git a/wasm/src/main/scala/wasm4s/WasmFunctionContext.scala b/wasm/src/main/scala/wasm4s/WasmFunctionContext.scala index c072e1e0..81f480e2 100644 --- a/wasm/src/main/scala/wasm4s/WasmFunctionContext.scala +++ b/wasm/src/main/scala/wasm4s/WasmFunctionContext.scala @@ -14,7 +14,7 @@ import wasm.wasm4s.WasmInstr._ import wasm.ir2wasm.TypeTransformer class WasmFunctionContext private ( - ctx: WasmContext, + ctx: FunctionTypeWriterWasmContext, val enclosingClassName: Option[IRNames.ClassName], val functionName: WasmFunctionName, _receiver: Option[WasmLocal], @@ -23,6 +23,7 @@ class WasmFunctionContext private ( ) { private var cnt = 0 private var labelIdx = 0 + private var innerFuncIdx = 0 val locals = new WasmSymbolTable[WasmLocalName, WasmLocal]() @@ -80,6 +81,15 @@ class WasmFunctionContext private ( def addSyntheticLocal(typ: WasmType): LocalIdx = addLocal(genSyntheticLocalName(), typ) + def genInnerFuncName(): WasmFunctionName = { + val innerName = WasmFunctionName( + functionName.namespace, + functionName.simpleName + "__c" + innerFuncIdx + ) + innerFuncIdx += 1 + innerName + } + // Helpers to build structured control flow def ifThenElse(blockType: BlockType)(thenp: => Unit)(elsep: => Unit): Unit = { @@ -177,7 +187,7 @@ object WasmFunctionContext { receiver: Option[WasmLocal], params: List[WasmLocal], resultTypes: List[WasmType] - )(implicit ctx: WasmContext): WasmFunctionContext = { + )(implicit ctx: FunctionTypeWriterWasmContext): WasmFunctionContext = { new WasmFunctionContext(ctx, enclosingClassName, name, receiver, params, resultTypes) } @@ -187,7 +197,7 @@ object WasmFunctionContext { receiverTyp: Option[WasmType], paramDefs: List[IRTrees.ParamDef], resultType: IRTypes.Type - )(implicit ctx: WasmContext): WasmFunctionContext = { + )(implicit ctx: FunctionTypeWriterWasmContext): WasmFunctionContext = { val receiver = receiverTyp.map { typ => WasmLocal(WasmLocalName.receiver, typ, isParameter = true) } @@ -205,7 +215,7 @@ object WasmFunctionContext { name: WasmFunctionName, params: List[(String, WasmType)], resultTypes: List[WasmType] - )(implicit ctx: WasmContext): WasmFunctionContext = { + )(implicit ctx: FunctionTypeWriterWasmContext): WasmFunctionContext = { val paramLocals = params.map { param => WasmLocal(WasmLocalName.fromStr(param._1), param._2, isParameter = true) }