Skip to content

Commit

Permalink
Updated RecordEncoderGenerator for function returns
Browse files Browse the repository at this point in the history
  • Loading branch information
sksamuel committed Apr 28, 2024
1 parent 505ccb1 commit 1bab255
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,39 +26,40 @@ class RecordEncoderGenerator {
appendLine("class ${kclass.java.simpleName}Encoder(schema: Schema) : Encoder<${kclass.java.simpleName}> {")
appendLine()
kclass.declaredMemberProperties.forEach { property ->
appendLine(" private val ${property.name}Encoder = ${encoderVal(property)}")
appendLine(" private val ${property.name}Schema = schema.getField(\"${property.name}\").schema()")
appendLine(" private val ${property.name}Pos = schema.getField(\"${property.name}\").pos()")
appendLine(" private val ${property.name}Schema = schema.getField(\"${property.name}\").schema()")
appendLine(" private val ${property.name}Pos = schema.getField(\"${property.name}\").pos()")
appendLine(" private val ${property.name}Encode = ${encode(property)}")
}
appendLine()
appendLine(" override fun encode(schema: Schema, value: ${kclass.java.simpleName}): GenericRecord {")
appendLine(" val record = GenericData.Record(schema)")
appendLine(" override fun encode(schema: Schema): (${kclass.java.simpleName}) -> GenericRecord {")
appendLine(" return { value ->")
appendLine(" val record = GenericData.Record(schema)")
kclass.declaredMemberProperties.forEach { property ->
appendLine(" record.put(${property.name}Pos, ${encoderInvocation(property)})")
appendLine(" record.put(${property.name}Pos, ${encoderInvocation(property)})")
}
appendLine(" return record")
appendLine(" record")
appendLine(" }")
appendLine(" }")
appendLine("}")
}
}

private fun encoderVal(property: KProperty1<out Any, *>): String {
private fun encode(property: KProperty1<out Any, *>): String {
val baseEncoder = encoderFor(property.returnType)
return if (property.returnType.isMarkedNullable) "NullEncoder($baseEncoder)" else baseEncoder
val wrapped = if (property.returnType.isMarkedNullable) "NullEncoder($baseEncoder)" else baseEncoder
return "$wrapped.encode(${property.name}Schema)"
}

private fun encoderInvocation(property: KProperty1<out Any, *>): String {
val getSchema = "${property.name}Schema"
val getValue = "value.${property.name}"

return when (property.returnType.classifier) {
Boolean::class -> getValue
Double::class -> getValue
Float::class -> getValue
Int::class -> getValue
Long::class ->getValue
Long::class -> getValue
String::class -> getValue
else -> "${property.name}Encoder.encode($getSchema, $getValue)"
else -> "${property.name}Encode.invoke($getValue)"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class Serde<T : Any>(
companion object {

/**
* Creates a [Schema] reflectively from the given [kclass] using a [ReflectionSchemaBuilder].
* Creates a [Schema], [Encoder] and [Decoder] reflectively from the given [kclass]
* using a [ReflectionSchemaBuilder].
*/
operator fun <T : Any> invoke(
kclass: KClass<T>,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.sksamuel.centurion.avro.encoders

import io.kotest.core.spec.style.FunSpec
import io.kotest.matchers.shouldBe
import org.apache.avro.Schema
import org.apache.avro.util.Utf8

class ArrayEncoderTest : FunSpec({

test("encoding list of strings") {
val schema = Schema.createArray(Schema.create(Schema.Type.STRING))
ArrayEncoder(StringEncoder).encode(schema).invoke(arrayOf("foo", "bar")) shouldBe listOf(Utf8("foo"), Utf8("bar"))
}

})
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ import com.sksamuel.centurion.avro.encoders.Wine
import io.kotest.core.spec.style.FunSpec
import io.kotest.matchers.shouldBe

class RecordEncoderGeneratorTest : FunSpec({
data class MyFoo(
val b: Boolean,
val s: String?,
val c: Long,
val d: Double,
val i: Int,
val f: Float,
val sets: Set<String>,
val lists: List<Int>,
val maps: Map<String, Double>,
val wine: Wine?,
)

data class MyFoo(
val b: Boolean,
val s: String?,
val c: Long,
val d: Double,
val i: Int,
val f: Float,
val sets: Set<String>,
val lists: List<Int>,
val maps: Map<String, Double>,
val wine: Wine?,
)
class RecordEncoderGeneratorTest : FunSpec({

test("simple encoder") {
RecordEncoderGenerator().generate(MyFoo::class).trim() shouldBe """
Expand All @@ -34,50 +34,52 @@ import org.apache.avro.generic.GenericRecord
*/
class MyFooEncoder(schema: Schema) : Encoder<MyFoo> {
private val bEncoder = BooleanEncoder
private val bSchema = schema.getField("b").schema()
private val bPos = schema.getField("b").pos()
private val cEncoder = LongEncoder
private val cSchema = schema.getField("c").schema()
private val cPos = schema.getField("c").pos()
private val dEncoder = DoubleEncoder
private val dSchema = schema.getField("d").schema()
private val dPos = schema.getField("d").pos()
private val fEncoder = FloatEncoder
private val fSchema = schema.getField("f").schema()
private val fPos = schema.getField("f").pos()
private val iEncoder = IntEncoder
private val iSchema = schema.getField("i").schema()
private val iPos = schema.getField("i").pos()
private val listsEncoder = ListEncoder(IntEncoder)
private val listsSchema = schema.getField("lists").schema()
private val listsPos = schema.getField("lists").pos()
private val mapsEncoder = MapEncoder(StringEncoder, DoubleEncoder)
private val mapsSchema = schema.getField("maps").schema()
private val mapsPos = schema.getField("maps").pos()
private val sEncoder = NullEncoder(StringEncoder)
private val sSchema = schema.getField("s").schema()
private val sPos = schema.getField("s").pos()
private val setsEncoder = SetEncoder(StringEncoder)
private val setsSchema = schema.getField("sets").schema()
private val setsPos = schema.getField("sets").pos()
private val wineEncoder = NullEncoder(EnumEncoder())
private val wineSchema = schema.getField("wine").schema()
private val winePos = schema.getField("wine").pos()
private val bSchema = schema.getField("b").schema()
private val bPos = schema.getField("b").pos()
private val bEncode = BooleanEncoder.encode(bSchema)
private val cSchema = schema.getField("c").schema()
private val cPos = schema.getField("c").pos()
private val cEncode = LongEncoder.encode(cSchema)
private val dSchema = schema.getField("d").schema()
private val dPos = schema.getField("d").pos()
private val dEncode = DoubleEncoder.encode(dSchema)
private val fSchema = schema.getField("f").schema()
private val fPos = schema.getField("f").pos()
private val fEncode = FloatEncoder.encode(fSchema)
private val iSchema = schema.getField("i").schema()
private val iPos = schema.getField("i").pos()
private val iEncode = IntEncoder.encode(iSchema)
private val listsSchema = schema.getField("lists").schema()
private val listsPos = schema.getField("lists").pos()
private val listsEncode = ListEncoder(IntEncoder).encode(listsSchema)
private val mapsSchema = schema.getField("maps").schema()
private val mapsPos = schema.getField("maps").pos()
private val mapsEncode = MapEncoder(StringEncoder, DoubleEncoder).encode(mapsSchema)
private val sSchema = schema.getField("s").schema()
private val sPos = schema.getField("s").pos()
private val sEncode = NullEncoder(StringEncoder).encode(sSchema)
private val setsSchema = schema.getField("sets").schema()
private val setsPos = schema.getField("sets").pos()
private val setsEncode = SetEncoder(StringEncoder).encode(setsSchema)
private val wineSchema = schema.getField("wine").schema()
private val winePos = schema.getField("wine").pos()
private val wineEncode = NullEncoder(EnumEncoder()).encode(wineSchema)
override fun encode(schema: Schema, value: MyFoo): GenericRecord {
val record = GenericData.Record(schema)
record.put(bPos, value.b)
record.put(cPos, value.c)
record.put(dPos, value.d)
record.put(fPos, value.f)
record.put(iPos, value.i)
record.put(listsPos, listsEncoder.encode(listsSchema, value.lists))
record.put(mapsPos, mapsEncoder.encode(mapsSchema, value.maps))
record.put(sPos, value.s)
record.put(setsPos, setsEncoder.encode(setsSchema, value.sets))
record.put(winePos, wineEncoder.encode(wineSchema, value.wine))
return record
override fun encode(schema: Schema): (MyFoo) -> GenericRecord {
return { value ->
val record = GenericData.Record(schema)
record.put(bPos, value.b)
record.put(cPos, value.c)
record.put(dPos, value.d)
record.put(fPos, value.f)
record.put(iPos, value.i)
record.put(listsPos, listsEncode.invoke(value.lists))
record.put(mapsPos, mapsEncode.invoke(value.maps))
record.put(sPos, value.s)
record.put(setsPos, setsEncode.invoke(value.sets))
record.put(winePos, wineEncode.invoke(value.wine))
record
}
}
}
""".trim()
Expand Down

0 comments on commit 1bab255

Please sign in to comment.