Skip to content

Commit

Permalink
feat: Add transitions for fix-sized arrays (#192)
Browse files Browse the repository at this point in the history
Fixes #180
  • Loading branch information
markehammons authored May 21, 2023
1 parent a6a73fc commit 50f141d
Show file tree
Hide file tree
Showing 13 changed files with 225 additions and 49 deletions.
4 changes: 4 additions & 0 deletions core/src/fr/hammons/slinc/SetSizeArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class SetSizeArray[A, B <: Int] private[slinc] (private val array: Array[A])
value: A
)(using 0 <= C =:= true, C < B =:= true): Unit = array(constValue[C]) = value

def zip[C](oArr: SetSizeArray[C, B]): SetSizeArray[(A, C), B] =
new SetSizeArray[(A, C), B](array.zip(oArr.array))
def foreach(fn: A => Unit) = array.foreach(fn)

object SetSizeArray:
class SetSizeArrayBuilderUnsafe[B <: Int]:
def apply[A](array: Array[A]): SetSizeArray[A, B] = new SetSizeArray(array)
Expand Down
15 changes: 13 additions & 2 deletions core/src/fr/hammons/slinc/TypeDescriptor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,20 @@ case class SetSizeArrayDescriptor(
override val argumentTransition
: (TransitionModule, ReadWriteModule, Allocator) ?=> ArgumentTransition[
Inner
] = ???
] = arg =>
val mem = summon[Allocator].allocate(this, 1)
summon[ReadWriteModule].write(
mem,
Bytes(0),
this,
arg
)
mem.asAddress

override val returnTransition
: (TransitionModule, ReadWriteModule) ?=> ReturnTransition[Inner] = ???
: (TransitionModule, ReadWriteModule) ?=> ReturnTransition[Inner] =
obj =>
val mem = summon[TransitionModule].addressReturn(obj)
summon[ReadWriteModule].read(mem, Bytes(0), this)

type Inner = SetSizeArray[contained.Inner, ?]
9 changes: 9 additions & 0 deletions core/test/resources/native/test.c
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,12 @@ EXPORTED struct_issue_175 i175_test(struct_issue_175 a, char left) {
}
return a;
}

EXPORTED int* i180_test(int my_array[5]) {
int i = 0;
while(i < 5) {
my_array[i] = my_array[i] * 2;
i++;
}
return my_array;
}
12 changes: 12 additions & 0 deletions core/test/src/fr/hammons/slinc/BindingSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ trait BindingSpec(val slinc: Slinc) extends ScalaCheckSuite:
left: CChar
): I175_Struct

def i180_test(
input: SetSizeArray[CInt, 5]
): SetSizeArray[CInt, 5]

test("int_identity") {
val test = FSet.instance[TestLib]

Expand Down Expand Up @@ -186,3 +190,11 @@ trait BindingSpec(val slinc: Slinc) extends ScalaCheckSuite:
union.set(double)
val res = test.i175_test(I175_Struct(union), 0)
assertEquals(res.union.get[CDouble], double / 2)

property("issue 180 - can send and receive set size arrays to C functions"):
val test = FSet.instance[TestLib]
forAll(Gen.listOfN(5, Arbitrary.arbitrary[CInt])): list =>
val arr = SetSizeArray.fromArrayUnsafe[5](list.toArray)
val retArr = test.i180_test(arr)

retArr.zip(arr.map(_ * 2)).foreach(assertEquals(_, _))
121 changes: 106 additions & 15 deletions core/test/src/fr/hammons/slinc/TransferSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scala.concurrent.ExecutionContext.Implicits.global
import scala.reflect.ClassTag
import scala.util.chaining.*

trait TransferSpec[ThreadException <: Throwable](val slinc: Slinc)(using
ClassTag[ThreadException]
Expand All @@ -28,7 +29,7 @@ trait TransferSpec[ThreadException <: Throwable](val slinc: Slinc)(using

case class F(u: CUnion[(CInt, CFloat)]) derives Struct

case class G(arr: SetSizeArray[CLong, 2]) derives Struct
case class G(long: CLong, arr: SetSizeArray[CLong, 2]) derives Struct

test("can read and write jvm ints") {
Scope.global {
Expand Down Expand Up @@ -162,20 +163,16 @@ trait TransferSpec[ThreadException <: Throwable](val slinc: Slinc)(using
}
}

test("varargs can be sent and retrieved"):
test("varargs can receive primitive types"):
Scope.confined {
val vaListForVaList = VarArgsBuilder(4).build
val vaList = VarArgsBuilder(
4.toByte,
5.toShort,
6,
7.toLong,
2f,
3d,
Null[Int],
A(1, 2),
CLong(4: Byte),
vaListForVaList
Null[Int]
).build

assertEquals(vaList.get[Byte], 4.toByte, "byte assert")
Expand All @@ -185,24 +182,118 @@ trait TransferSpec[ThreadException <: Throwable](val slinc: Slinc)(using
assertEquals(vaList.get[Float], 2f, "float assert")
assertEquals(vaList.get[Double], 3d, "double assert")
assertEquals(
vaList.get[Ptr[Int]].mem.asAddress,
Null[Int].mem.asAddress,
vaList.get[Ptr[Int]],
Null[Int],
"ptr assert"
)
}

test("varargs can receive complex types".ignore):
Scope.confined {
val vaListForVaList = VarArgsBuilder(4).build
val vaList = VarArgsBuilder(
A(1, 2),
CLong(4),
A(3, 4),
SetSizeArray(1, 2, 3, 4),
vaListForVaList,
CUnion[(CInt, CFloat)].tap(_.set(5)),
// Null[Int],
A(3, 4)
).build

assertEquals(vaList.get[A], A(1, 2), "struct assert")
assertEquals(vaList.get[CLong], CLong(4: Byte), "alias assert")
assertEquals(vaList.get[VarArgs].get[CInt], 4)
assertEquals(vaList.get[A], A(3, 4))
assertEquals(
vaList.get[SetSizeArray[CInt, 4]].toSeq,
Seq(1, 2, 3, 4),
"set size array assert"
)
assertEquals(
vaListForVaList.get[VarArgs].get[Int],
4
)
assertEquals(
vaList.get[CUnion[(CLongLong, CFloat)]].get[CLongLong],
5L,
"cunion assert"
)
// assertEquals(
// vaList.get[Ptr[Int]],
// Null[Int]
// )
assertEquals(
vaList.get[A],
A(3, 4),
"struct assert 2"
)
}

test("varargs can be skipped"):
test("varargs can skip primitive types"):
Scope.confined {
val vaList = VarArgsBuilder(
4.toByte,
2f
4: Byte,
5: Short,
6,
7L,
2f,
3d,
Null[Int]
).build

val vaList2 = vaList.copy()

vaList.skip[Byte]
assertEquals(vaList.get[Float], 2f)
assertEquals(vaList.get[Short], 5: Short)
vaList.skip[Int]
assertEquals(vaList.get[Long], 7L)
vaList.skip[Float]
assertEquals(vaList.get[Double], 3d)
vaList.skip[Ptr[Int]]

assertEquals(vaList2.get[Byte], 4: Byte)
vaList2.skip[Short]
assertEquals(vaList2.get[Int], 6)
vaList2.skip[Long]
assertEquals(vaList2.get[Float], 2f)
vaList2.skip[Double]
assertEquals(vaList2.get[Ptr[Int]], Null[Int])
}

test("varargs can skip complex types".ignore):
Scope.confined {
val vaListForVaList = VarArgsBuilder(4, 5, 6).build
val vaList = VarArgsBuilder(
A(1, 2),
CLong(4),
vaListForVaList,
CUnion[(CInt, CFloat)].tap(_.set(5)),
SetSizeArray(1, 2, 3, 4)
).build

val vaList2 = vaList.copy()

assertEquals(vaList.get[A], A(1, 2), "struct assert")
vaList.skip[CLong]
val vaList3 = vaList.get[VarArgs]
assertEquals(
List(vaList3.get[Int], vaList3.get[Int], vaList3.get[Int]),
List(4, 5, 6),
"varargs assert"
)
vaList.skip[CUnion[(CInt, CFloat)]]
assertEquals(
vaList.get[SetSizeArray[Int, 4]].toSeq,
Seq(1, 2, 3, 4),
"set size array assert"
)

vaList2.skip[A]
assertEquals(vaList2.get[CLong], CLong(4))
vaList2.skip[VarArgs]
assertEquals(vaList2.get[CUnion[(CInt, CFloat)]].get[Int], 5)
vaList2.skip[SetSizeArray[Int, 4]]
}

test("varargs can be copied and reread"):
Expand Down Expand Up @@ -373,7 +464,7 @@ trait TransferSpec[ThreadException <: Throwable](val slinc: Slinc)(using
}

test("can copy G to native memory and back"):
val g = G(SetSizeArray(CLong(1), CLong(2)))
val g = G(CLong(5), SetSizeArray(CLong(1), CLong(2)))

Scope.confined {
val ptr = Ptr.copy(g)
Expand Down
9 changes: 9 additions & 0 deletions j17/src/fr/hammons/slinc/Allocator17.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ object Allocator17:
case ms: MemorySegment => ms
case _ => throw Error("base of mem was not J17 MemorySegment!!")
)
case (ssad: SetSizeArrayDescriptor, s: SetSizeArray[?, ?]) =>
LinkageModule17.tempScope(alloc ?=>
builder.vargFromAddress(
C_POINTER,
transitionModule17
.methodArgument(ssad, s, alloc)
.asInstanceOf[Addressable]
)
)
case (a, d) =>
throw Error(
s"Unsupported type descriptor/data pairing for VarArgs: $a - $d"
Expand Down
38 changes: 24 additions & 14 deletions j17/src/fr/hammons/slinc/VarArgs17.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package fr.hammons.slinc
import jdk.incubator.foreign.CLinker.VaList
import jdk.incubator.foreign.CLinker.{C_INT, C_LONG_LONG, C_DOUBLE, C_POINTER}
import jdk.incubator.foreign.SegmentAllocator
import jdk.incubator.foreign.GroupLayout
import fr.hammons.slinc.modules.{
LinkageModule17,
descriptorModule17,
Expand All @@ -20,7 +21,8 @@ class VarArgs17(args: VaList) extends VarArgs:
case LongDescriptor => Long.box(args.vargAsLong(C_LONG_LONG))
case FloatDescriptor => Float.box(args.vargAsDouble(C_DOUBLE).toFloat)
case DoubleDescriptor => Double.box(args.vargAsDouble(C_DOUBLE))
case PtrDescriptor => args.vargAsAddress(C_POINTER).nn
case PtrDescriptor | _: SetSizeArrayDescriptor | VaListDescriptor =>
args.vargAsAddress(C_POINTER).nn
case sd: StructDescriptor =>
LinkageModule17.tempScope(alloc ?=>
args
Expand All @@ -30,26 +32,34 @@ class VarArgs17(args: VaList) extends VarArgs:
)
.nn
)
case AliasDescriptor(real) => get(real)
case VaListDescriptor => args.vargAsAddress(C_POINTER).nn
case CUnionDescriptor(possibleTypes) => get(possibleTypes.maxBy(_.size))
case AliasDescriptor(real) => get(real)
case cud: CUnionDescriptor =>
LinkageModule17.tempScope(alloc ?=>
args
.vargAsSegment(
descriptorModule17.toMemoryLayout(cud).asInstanceOf[GroupLayout],
alloc.base.asInstanceOf[SegmentAllocator]
)
.nn
)
def get[A](using d: DescriptorOf[A]): A =
transitionModule17.methodReturn[A](d.descriptor, get(d.descriptor))

private def skip(td: TypeDescriptor): Unit =
td match
case ByteDescriptor => args.skip(C_INT)
case ShortDescriptor => args.skip(C_INT)
case IntDescriptor => args.skip(C_INT)
case LongDescriptor => args.skip(C_LONG_LONG)
case FloatDescriptor => args.skip(C_DOUBLE)
case DoubleDescriptor => args.skip(C_DOUBLE)
case PtrDescriptor => args.skip(C_POINTER)
case ByteDescriptor => args.skip(C_INT)
case ShortDescriptor => args.skip(C_INT)
case IntDescriptor => args.skip(C_INT)
case LongDescriptor => args.skip(C_LONG_LONG)
case FloatDescriptor => args.skip(C_DOUBLE)
case DoubleDescriptor => args.skip(C_DOUBLE)
case PtrDescriptor | _: SetSizeArrayDescriptor => args.skip(C_POINTER)
case sd: StructDescriptor =>
args.skip(descriptorModule17.toGroupLayout(sd))
case AliasDescriptor(real) => skip(real)
case VaListDescriptor => args.skip(C_POINTER)
case CUnionDescriptor(possibleTypes) => skip(possibleTypes.maxBy(_.size))
case AliasDescriptor(real) => skip(real)
case VaListDescriptor => args.skip(C_POINTER)
case cud: CUnionDescriptor =>
args.skip(descriptorModule17.toMemoryLayout(cud))

def skip[A](using dO: DescriptorOf[A]): Unit = skip(dO.descriptor)

Expand Down
15 changes: 10 additions & 5 deletions j17/src/fr/hammons/slinc/modules/DescriptorModule17.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import jdk.incubator.foreign.{
MemorySegment,
GroupLayout,
CLinker,
ValueLayout
ValueLayout,
SequenceLayout
}, CLinker.C_POINTER
import scala.collection.concurrent.TrieMap
import fr.hammons.slinc.types.{arch, os, OS, Arch}
Expand All @@ -26,11 +27,10 @@ given descriptorModule17: DescriptorModule with
case FloatDescriptor => classOf[Float]
case DoubleDescriptor => classOf[Double]
case PtrDescriptor => classOf[MemoryAddress]
case _: StructDescriptor | _: CUnionDescriptor |
_: SetSizeArrayDescriptor =>
case _: StructDescriptor | _: CUnionDescriptor =>
classOf[MemorySegment]
case VaListDescriptor => classOf[MemoryAddress]
case ad: AliasDescriptor[?] => toCarrierType(ad.real)
case VaListDescriptor | _: SetSizeArrayDescriptor => classOf[MemoryAddress]
case ad: AliasDescriptor[?] => toCarrierType(ad.real)

def genLayoutList(
layouts: Seq[MemoryLayout],
Expand Down Expand Up @@ -123,6 +123,11 @@ given descriptorModule17: DescriptorModule with
case CUnionDescriptor(possibleTypes) =>
MemoryLayout.unionLayout(possibleTypes.map(toMemoryLayout).toSeq*).nn

def toDowncallLayout(td: TypeDescriptor): MemoryLayout = toMemoryLayout(
td
) match
case _: SequenceLayout => C_POINTER.nn
case o => o
def toMemoryLayout(smd: StructMemberDescriptor): MemoryLayout =
toMemoryLayout(smd.descriptor).withName(smd.name).nn

Expand Down
4 changes: 2 additions & 2 deletions j17/src/fr/hammons/slinc/modules/LinkageModule17.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ object LinkageModule17 extends LinkageModule:
varArgs.view.map(_.use[DescriptorOf](d ?=> _ => d.descriptor))
val fdConstructor = descriptor.returnDescriptor match
case None => FunctionDescriptor.ofVoid(_*)
case Some(value) => FunctionDescriptor.of(toMemoryLayout(value), _*)
case Some(value) => FunctionDescriptor.of(toDowncallLayout(value), _*)

val fd = fdConstructor(
descriptor.inputDescriptors.view
.map(toMemoryLayout)
.map(toDowncallLayout)
.concat(variadicDescriptors.map(toMemoryLayout).map(CLinker.asVarArg))
.toSeq
)
Expand Down
10 changes: 10 additions & 0 deletions j19/src/fr/hammons/slinc/Allocator19.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ class Allocator19(
case ms: MemorySegment => ms
case _ => throw Error("Illegal datatype")
)

case (ssad: SetSizeArrayDescriptor, s: SetSizeArray[?, ?]) =>
LinkageModule19.tempScope(alloc ?=>
builder.addVarg(
ValueLayout.ADDRESS,
transitionModule19
.methodArgument(ssad, s, alloc)
.asInstanceOf[Addressable]
)
)
case (td, d) =>
throw Error(s"Unsupported datatype for $td - $d")

Expand Down
Loading

0 comments on commit 50f141d

Please sign in to comment.