From b739d4ee3448cd628583dbe7563541075da5de8c Mon Sep 17 00:00:00 2001
From: Phillip Schichtel <pschichtel@users.noreply.github.com>
Date: Mon, 24 Apr 2023 11:53:53 +0200
Subject: [PATCH] Add support for disciminated oneOf's (#125)

---
 .../com/cjbooms/fabrikt/cli/CodeGenOptions.kt |   4 +-
 .../generators/model/JacksonModelGenerator.kt | 132 +++++++++++++++++-
 .../cjbooms/fabrikt/model/KotlinTypeInfo.kt   |   5 +-
 .../fabrikt/util/KaizenParserExtensions.kt    |   5 +-
 .../fabrikt/generators/ModelGeneratorTest.kt  |   3 +
 .../examples/discriminatedOneOf/api.yaml      |  40 ++++++
 .../discriminatedOneOf/models/Models.kt       |  53 +++++++
 7 files changed, 234 insertions(+), 8 deletions(-)
 create mode 100644 src/test/resources/examples/discriminatedOneOf/api.yaml
 create mode 100644 src/test/resources/examples/discriminatedOneOf/models/Models.kt

diff --git a/src/main/kotlin/com/cjbooms/fabrikt/cli/CodeGenOptions.kt b/src/main/kotlin/com/cjbooms/fabrikt/cli/CodeGenOptions.kt
index 232464e5..7a9aefe8 100644
--- a/src/main/kotlin/com/cjbooms/fabrikt/cli/CodeGenOptions.kt
+++ b/src/main/kotlin/com/cjbooms/fabrikt/cli/CodeGenOptions.kt
@@ -33,7 +33,9 @@ enum class ModelCodeGenOptionType(val description: String) {
     QUARKUS_REFLECTION("This option adds @RegisterForReflection to the generated models. Requires dependency \"'io.quarkus:quarkus-core:+\""),
     MICRONAUT_INTROSPECTION("This option adds @Introspected to the generated models. Requires dependency \"'io.micronaut:micronaut-core:+\""),
     MICRONAUT_REFLECTION("This option adds @ReflectiveAccess to the generated models. Requires dependency \"'io.micronaut:micronaut-core:+\""),
-    INCLUDE_COMPANION_OBJECT("This option adds a companion object to the generated models.");
+    INCLUDE_COMPANION_OBJECT("This option adds a companion object to the generated models."),
+    SEALED_INTERFACES_FOR_ONE_OF("This option enables the generation of interfaces for discriminated oneOf types"),
+    ;
 
     override fun toString() = "`${super.toString()}` - $description"
 }
diff --git a/src/main/kotlin/com/cjbooms/fabrikt/generators/model/JacksonModelGenerator.kt b/src/main/kotlin/com/cjbooms/fabrikt/generators/model/JacksonModelGenerator.kt
index 56ad9f51..712e0e8e 100644
--- a/src/main/kotlin/com/cjbooms/fabrikt/generators/model/JacksonModelGenerator.kt
+++ b/src/main/kotlin/com/cjbooms/fabrikt/generators/model/JacksonModelGenerator.kt
@@ -1,6 +1,7 @@
 package com.cjbooms.fabrikt.generators.model
 
 import com.cjbooms.fabrikt.cli.ModelCodeGenOptionType
+import com.cjbooms.fabrikt.cli.ModelCodeGenOptionType.SEALED_INTERFACES_FOR_ONE_OF
 import com.cjbooms.fabrikt.configurations.Packages
 import com.cjbooms.fabrikt.generators.ClassSettings
 import com.cjbooms.fabrikt.generators.GeneratorUtils.toClassName
@@ -33,6 +34,7 @@ import com.cjbooms.fabrikt.util.KaizenParserExtensions.isInlinedEnumDefinition
 import com.cjbooms.fabrikt.util.KaizenParserExtensions.isInlinedObjectDefinition
 import com.cjbooms.fabrikt.util.KaizenParserExtensions.isInlinedTypedAdditionalProperties
 import com.cjbooms.fabrikt.util.KaizenParserExtensions.isOneOfPolymorphicTypes
+import com.cjbooms.fabrikt.util.KaizenParserExtensions.isOneOfSuperInterface
 import com.cjbooms.fabrikt.util.KaizenParserExtensions.isPolymorphicSubType
 import com.cjbooms.fabrikt.util.KaizenParserExtensions.isPolymorphicSuperType
 import com.cjbooms.fabrikt.util.KaizenParserExtensions.isSimpleType
@@ -179,6 +181,13 @@ class JacksonModelGenerator(
         val modelName = schemaInfo.name.toModelClassName()
         val schemaName = schemaInfo.name
         return when {
+            schemaInfo.schema.isOneOfSuperInterface() && SEALED_INTERFACES_FOR_ONE_OF in options -> oneOfSuperInterface(
+                modelName,
+                schemaInfo.schema.discriminator,
+                allSchemas,
+                schemaInfo.schema.oneOfSchemas,
+                findOneOfSuperInterface(allSchemas, schemaInfo, options),
+            )
             schemaInfo.schema.isPolymorphicSuperType() && schemaInfo.schema.isPolymorphicSubType(api) ->
                 polymorphicSuperSubType(
                     modelName,
@@ -187,6 +196,7 @@ class JacksonModelGenerator(
                     checkNotNull(schemaInfo.schema.getDiscriminatorForInLinedObjectUnderAllOf()),
                     schemaInfo.schema.getSuperType(api)!!.let { SchemaInfo(it.name, it) },
                     schemaInfo.schema.extensions,
+                    findOneOfSuperInterface(allSchemas, schemaInfo, options),
                     allSchemas,
                 )
 
@@ -196,6 +206,7 @@ class JacksonModelGenerator(
                 properties,
                 schemaInfo.schema.discriminator,
                 schemaInfo.schema.extensions,
+                findOneOfSuperInterface(allSchemas, schemaInfo, options),
                 allSchemas,
             )
 
@@ -205,13 +216,44 @@ class JacksonModelGenerator(
                 properties,
                 schemaInfo.schema.getSuperType(api)!!.let { SchemaInfo(it.name, it) },
                 schemaInfo.schema.extensions,
+                findOneOfSuperInterface(allSchemas, schemaInfo, options),
             )
 
             schemaInfo.typeInfo is KotlinTypeInfo.Enum -> buildEnumClass(schemaInfo.typeInfo)
-            else -> standardDataClass(modelName, schemaName, properties, schemaInfo.schema.extensions)
+            else -> standardDataClass(modelName, schemaName, properties, schemaInfo.schema.extensions, findOneOfSuperInterface(allSchemas, schemaInfo, options))
         }
     }
 
+    private fun findOneOfSuperInterface(
+        allSchemas: List<SchemaInfo>,
+        schema: SchemaInfo,
+        options: Set<ModelCodeGenOptionType>,
+    ): Set<SchemaInfo> {
+        if (SEALED_INTERFACES_FOR_ONE_OF !in options) {
+            return emptySet()
+        }
+        return allSchemas
+            .filter { it.schema.discriminator != null && it.schema.oneOfSchemas.isNotEmpty() }
+            .mapNotNull { info ->
+                info.schema.discriminator.mappings
+                    .toList()
+                    .find { (_, ref) ->
+                        ref.endsWith("/${schema.name}")
+                    }
+                    ?.let { (key, _) ->
+                        Pair(key!!, info)
+                    }
+            }
+            .map { (_, parent) ->
+                val field = parent.schema.discriminator.propertyName!!
+                if (!schema.schema.properties.containsKey(field)) {
+                    throw IllegalArgumentException("schema $schema did not have discriminator property")
+                }
+                parent
+            }
+            .toSet()
+    }
+
     private fun buildInLinedModels(
         topLevelProperties: Collection<PropertyInfo>,
         enclosingSchema: Schema,
@@ -230,6 +272,7 @@ class JacksonModelGenerator(
                         it.name,
                         props,
                         it.schema.extensions,
+                        oneOfInterfaces = emptySet(),
                     )
                     val inlinedModels = buildInLinedModels(props, enclosingSchema, apiDocUrl)
                     inlinedModels + currentModel
@@ -247,6 +290,7 @@ class JacksonModelGenerator(
                                 schemaName = it.name,
                                 properties = it.schema.topLevelProperties(HTTP_SETTINGS, enclosingSchema),
                                 extensions = it.schema.extensions,
+                                oneOfInterfaces = emptySet(),
                             ),
                         )
                     } else {
@@ -274,6 +318,7 @@ class JacksonModelGenerator(
                                         schemaName = it.name,
                                         properties = props,
                                         extensions = it.schema.extensions,
+                                        oneOfInterfaces = emptySet(),
                                     )
                                 }
 
@@ -377,6 +422,7 @@ class JacksonModelGenerator(
                 schemaName = schema.safeName(),
                 properties = mapField.schema.additionalPropertiesSchema.topLevelProperties(HTTP_SETTINGS),
                 extensions = mapField.schema.extensions,
+                oneOfInterfaces = emptySet(),
             )
         } else {
             null
@@ -387,19 +433,41 @@ class JacksonModelGenerator(
         schemaName: String,
         properties: Collection<PropertyInfo>,
         extensions: Map<String, Any>,
+        oneOfInterfaces: Set<SchemaInfo>,
     ): TypeSpec {
-        val classBuilder = TypeSpec.classBuilder(generatedType(packages.base, modelName))
+        val filteredProperties = if (oneOfInterfaces.size == 1) {
+            val oneOfInterface = oneOfInterfaces.first()
+            val discriminatorProp = oneOfInterface.schema.discriminator?.propertyName
+            val mappingCount = oneOfInterface.schema.discriminator?.mappings?.values?.count { it.endsWith("/$modelName") }
+            if (discriminatorProp != null && mappingCount == 1) {
+                properties.filterNot { it.name == discriminatorProp }
+            } else properties
+        } else properties
+
+        val name = generatedType(packages.base, modelName)
+        val generateObject = properties.isNotEmpty() && filteredProperties.isEmpty()
+        val builder =
+            if (generateObject) TypeSpec.objectBuilder(name)
+            else TypeSpec.classBuilder(name)
+        val classBuilder = builder
             .addSerializableInterface()
             .addQuarkusReflectionAnnotation()
             .addMicronautIntrospectedAnnotation()
             .addMicronautReflectionAnnotation()
             .addCompanionObject()
-        properties.addToClass(
+        for (oneOfInterface in oneOfInterfaces) {
+            classBuilder
+                .addSuperinterface(generatedType(packages.base, oneOfInterface.schema.toModelClassName()))
+        }
+
+        if (!generateObject) {
+            filteredProperties.addToClass(
             modelName = modelName,
             schemaName = schemaName,
             classBuilder = classBuilder,
             classType = ClassSettings(ClassSettings.PolymorphyType.NONE, extensions.hasJsonMergePatchExtension),
         )
+        }
         return classBuilder.build()
     }
 
@@ -410,6 +478,7 @@ class JacksonModelGenerator(
         discriminator: Discriminator,
         superType: SchemaInfo,
         extensions: Map<String, Any>,
+        oneOfSuperInterfaces: Set<SchemaInfo>,
         allSchemas: List<SchemaInfo>,
     ): TypeSpec = with(FunSpec.constructorBuilder()) {
         TypeSpec.classBuilder(generatedType(packages.base, modelName))
@@ -419,6 +488,7 @@ class JacksonModelGenerator(
                 properties.filter(PropertyInfo::isInherited),
                 superType,
                 extensions,
+                oneOfSuperInterfaces,
                 this,
             )
             .buildPolymorphicSuperType(
@@ -427,21 +497,62 @@ class JacksonModelGenerator(
                 properties.filterNot(PropertyInfo::isInherited),
                 discriminator,
                 extensions,
+                oneOfSuperInterfaces,
                 allSchemas,
                 this,
             )
             .build()
     }
 
+    private fun oneOfSuperInterface(
+        modelName: String,
+        discriminator: Discriminator,
+        allSchemas: List<SchemaInfo>,
+        members: List<Schema>,
+        oneOfSuperInterfaces: Set<SchemaInfo>,
+    ): TypeSpec {
+        val interfaceBuilder = TypeSpec.interfaceBuilder(generatedType(packages.base, modelName))
+            .addModifiers(KModifier.SEALED)
+            .addAnnotation(basePolymorphicType(discriminator.propertyName))
+
+        for (oneOfSuperInterface in oneOfSuperInterfaces) {
+            interfaceBuilder.addSuperinterface(generatedType(packages.base, oneOfSuperInterface.name))
+        }
+
+        val membersAndMappingsConsistent = members.all { member ->
+            discriminator.mappings.any { (_, ref) -> ref.endsWith("/${member.name}") }
+        }
+
+        if (!membersAndMappingsConsistent) {
+            throw IllegalArgumentException("members and mappings are not consistent for oneOf super interface $modelName!")
+        }
+
+        val mappings = discriminator.mappings
+            .mapValues { (_, value) ->
+                allSchemas.find { value.endsWith("/${it.name}") }!!
+            }
+            .mapValues { (_, value) ->
+                toModelType(packages.base, KotlinTypeInfo.from(value.schema, value.name))
+            }
+
+        interfaceBuilder.addAnnotation(polymorphicSubTypes(mappings, enumDiscriminator = null))
+            .addQuarkusReflectionAnnotation()
+            .addMicronautIntrospectedAnnotation()
+            .addMicronautReflectionAnnotation()
+
+        return interfaceBuilder.build()
+    }
+
     private fun polymorphicSuperType(
         modelName: String,
         schemaName: String,
         properties: Collection<PropertyInfo>,
         discriminator: Discriminator,
         extensions: Map<String, Any>,
+        oneOfSuperInterfaces: Set<SchemaInfo>,
         allSchemas: List<SchemaInfo>,
     ): TypeSpec = TypeSpec.classBuilder(generatedType(packages.base, modelName))
-        .buildPolymorphicSuperType(modelName, schemaName, properties, discriminator, extensions, allSchemas)
+        .buildPolymorphicSuperType(modelName, schemaName, properties, discriminator, extensions, oneOfSuperInterfaces, allSchemas)
         .build()
 
     private fun TypeSpec.Builder.buildPolymorphicSuperType(
@@ -450,6 +561,7 @@ class JacksonModelGenerator(
         properties: Collection<PropertyInfo>,
         discriminator: Discriminator,
         extensions: Map<String, Any>,
+        oneOfSuperInterfaces: Set<SchemaInfo>,
         allSchemas: List<SchemaInfo>,
         constructorBuilder: FunSpec.Builder = FunSpec.constructorBuilder(),
     ): TypeSpec.Builder {
@@ -457,6 +569,10 @@ class JacksonModelGenerator(
             .addAnnotation(basePolymorphicType(discriminator.propertyName))
             .modifiers.remove(KModifier.DATA)
 
+        for (oneOfSuperInterface in oneOfSuperInterfaces) {
+            this.addSuperinterface(generatedType(packages.base, oneOfSuperInterface.name))
+        }
+
         val subTypes = allSchemas
             .filter { model ->
                 model.schema.allOfSchemas.any { allOfRef ->
@@ -498,8 +614,9 @@ class JacksonModelGenerator(
         properties: Collection<PropertyInfo>,
         superType: SchemaInfo,
         extensions: Map<String, Any>,
+        oneOfSuperInterfaces: Set<SchemaInfo>,
     ): TypeSpec = TypeSpec.classBuilder(generatedType(packages.base, modelName))
-        .buildPolymorphicSubType(modelName, schemaName, properties, superType, extensions).build()
+        .buildPolymorphicSubType(modelName, schemaName, properties, superType, extensions, oneOfSuperInterfaces).build()
 
     private fun TypeSpec.Builder.buildPolymorphicSubType(
         modelName: String,
@@ -507,6 +624,7 @@ class JacksonModelGenerator(
         allProperties: Collection<PropertyInfo>,
         superType: SchemaInfo,
         extensions: Map<String, Any>,
+        oneOfSuperInterfaces: Set<SchemaInfo>,
         constructorBuilder: FunSpec.Builder = FunSpec.constructorBuilder(),
     ): TypeSpec.Builder {
         this.addSerializableInterface()
@@ -518,6 +636,10 @@ class JacksonModelGenerator(
                 toModelType(packages.base, KotlinTypeInfo.from(superType.schema, superType.name)),
             )
 
+        for (oneOfSuperInterface in oneOfSuperInterfaces) {
+            this.addSuperinterface(generatedType(packages.base, oneOfSuperInterface.name))
+        }
+
         val properties = superType.schema.getDiscriminatorForInLinedObjectUnderAllOf()?.let { discriminator ->
             allProperties.filterNot {
                 when (it) {
diff --git a/src/main/kotlin/com/cjbooms/fabrikt/model/KotlinTypeInfo.kt b/src/main/kotlin/com/cjbooms/fabrikt/model/KotlinTypeInfo.kt
index 95128aa6..72279514 100644
--- a/src/main/kotlin/com/cjbooms/fabrikt/model/KotlinTypeInfo.kt
+++ b/src/main/kotlin/com/cjbooms/fabrikt/model/KotlinTypeInfo.kt
@@ -6,6 +6,7 @@ import com.cjbooms.fabrikt.model.OasType.Companion.toOasType
 import com.cjbooms.fabrikt.util.KaizenParserExtensions.getEnumValues
 import com.cjbooms.fabrikt.util.KaizenParserExtensions.isInlinedTypedAdditionalProperties
 import com.cjbooms.fabrikt.util.KaizenParserExtensions.isNotDefined
+import com.cjbooms.fabrikt.util.KaizenParserExtensions.isOneOfSuperInterface
 import com.cjbooms.fabrikt.util.KaizenParserExtensions.toMapValueClassName
 import com.cjbooms.fabrikt.util.KaizenParserExtensions.toModelClassName
 import com.cjbooms.fabrikt.util.NormalisedString.toModelClassName
@@ -92,7 +93,9 @@ sealed class KotlinTypeInfo(val modelKClass: KClass<*>, val generatedModelClassN
                         from(schema.additionalPropertiesSchema, "", enclosingName)
                     )
                 OasType.Any -> AnyType
-                OasType.OneOfAny -> AnyType
+                OasType.OneOfAny ->
+                    if (schema.isOneOfSuperInterface()) Object(schema.toModelClassName(enclosingName.toModelClassName()))
+                    else AnyType
             }
 
         private fun getOverridableDateTimeType(): KotlinTypeInfo {
diff --git a/src/main/kotlin/com/cjbooms/fabrikt/util/KaizenParserExtensions.kt b/src/main/kotlin/com/cjbooms/fabrikt/util/KaizenParserExtensions.kt
index ac864a36..4924c24a 100644
--- a/src/main/kotlin/com/cjbooms/fabrikt/util/KaizenParserExtensions.kt
+++ b/src/main/kotlin/com/cjbooms/fabrikt/util/KaizenParserExtensions.kt
@@ -102,7 +102,7 @@ object KaizenParserExtensions {
         "additionalProperties" && properties?.isEmpty() != true && !isSimpleType()
 
     fun Schema.isSimpleType(): Boolean =
-        (simpleTypes.contains(type) && !isEnumDefinition()) || isSimpleMapDefinition() || isSimpleOneOfAnyDefinition()
+        !isOneOfSuperInterface() && ((simpleTypes.contains(type) && !isEnumDefinition()) || isSimpleMapDefinition() || isSimpleOneOfAnyDefinition())
 
     private fun Schema.isObjectType() = OasType.Object.type == type
 
@@ -210,6 +210,9 @@ object KaizenParserExtensions {
     fun Schema.isOneOfPolymorphicTypes() =
         this.oneOfSchemas?.firstOrNull()?.allOfSchemas?.firstOrNull() != null
 
+    fun Schema.isOneOfSuperInterface() =
+        discriminator != null && discriminator.propertyName != null && oneOfSchemas.isNotEmpty()
+
     fun OpenApi3.basePath(): String =
         servers
             .firstOrNull()
diff --git a/src/test/kotlin/com/cjbooms/fabrikt/generators/ModelGeneratorTest.kt b/src/test/kotlin/com/cjbooms/fabrikt/generators/ModelGeneratorTest.kt
index 780adea9..553fbd74 100644
--- a/src/test/kotlin/com/cjbooms/fabrikt/generators/ModelGeneratorTest.kt
+++ b/src/test/kotlin/com/cjbooms/fabrikt/generators/ModelGeneratorTest.kt
@@ -51,6 +51,8 @@ class ModelGeneratorTest {
         "responsesSchema",
         "webhook",
         "instantDateTime",
+        "singleAllOf",
+        "discriminatedOneOf",
     )
 
     @BeforeEach
@@ -79,6 +81,7 @@ class ModelGeneratorTest {
         val models = JacksonModelGenerator(
             Packages(basePackage),
             sourceApi,
+            setOf(ModelCodeGenOptionType.SEALED_INTERFACES_FOR_ONE_OF),
         ).generate().toSingleFile()
 
         assertThat(models).isEqualTo(expectedModels)
diff --git a/src/test/resources/examples/discriminatedOneOf/api.yaml b/src/test/resources/examples/discriminatedOneOf/api.yaml
new file mode 100644
index 00000000..ab2e0a23
--- /dev/null
+++ b/src/test/resources/examples/discriminatedOneOf/api.yaml
@@ -0,0 +1,40 @@
+openapi: 3.0.0
+info:
+paths:
+components:
+  schemas:
+    SomeObj:
+      type: object
+      required:
+        - state
+      properties:
+        state:
+          $ref: '#/components/schemas/State'
+    State:
+      oneOf:
+        - $ref: '#/components/schemas/StateA'
+        - $ref: '#/components/schemas/StateB'
+      discriminator:
+        propertyName: status
+        mapping:
+          a: '#/components/schemas/StateA'
+          b: '#/components/schemas/StateB'
+    Status:
+      type: string
+      enum:
+        - a
+        - b
+    StateA:
+      type: object
+      required:
+        - status
+      properties:
+        status:
+          $ref: '#/components/schemas/Status'
+    StateB:
+      type: object
+      required:
+        - status
+      properties:
+        status:
+          $ref: '#/components/schemas/Status'
diff --git a/src/test/resources/examples/discriminatedOneOf/models/Models.kt b/src/test/resources/examples/discriminatedOneOf/models/Models.kt
new file mode 100644
index 00000000..d1e4fb1f
--- /dev/null
+++ b/src/test/resources/examples/discriminatedOneOf/models/Models.kt
@@ -0,0 +1,53 @@
+package examples.discriminatedOneOf.models
+
+import com.fasterxml.jackson.annotation.JsonProperty
+import com.fasterxml.jackson.annotation.JsonSubTypes
+import com.fasterxml.jackson.annotation.JsonTypeInfo
+import com.fasterxml.jackson.annotation.JsonValue
+import javax.validation.Valid
+import javax.validation.constraints.NotNull
+import kotlin.String
+import kotlin.collections.Map
+
+data class SomeObj(
+    @param:JsonProperty("state")
+    @get:JsonProperty("state")
+    @get:NotNull
+    @get:Valid
+    val state: State
+)
+
+@JsonTypeInfo(
+    use = JsonTypeInfo.Id.NAME,
+    include = JsonTypeInfo.As.EXISTING_PROPERTY,
+    property = "status",
+    visible = true
+)
+@JsonSubTypes(
+    JsonSubTypes.Type(value = StateA::class, name = "a"),
+    JsonSubTypes.Type(
+        value =
+        StateB::class,
+        name = "b"
+    )
+)
+sealed interface State
+
+object StateA : State
+
+object StateB : State
+
+enum class Status(
+    @JsonValue
+    val value: String
+) {
+    A("a"),
+
+    B("b");
+
+    companion object {
+        private val mapping: Map<String, Status> = values().associateBy(Status::value)
+
+        fun fromValue(value: String): Status? = mapping[value]
+    }
+}