diff --git a/src/main/kotlin/com/cjbooms/fabrikt/cli/CodeGenOptions.kt b/src/main/kotlin/com/cjbooms/fabrikt/cli/CodeGenOptions.kt index 232464e5..7a9aefe8 100644 --- a/src/main/kotlin/com/cjbooms/fabrikt/cli/CodeGenOptions.kt +++ b/src/main/kotlin/com/cjbooms/fabrikt/cli/CodeGenOptions.kt @@ -33,7 +33,9 @@ enum class ModelCodeGenOptionType(val description: String) { QUARKUS_REFLECTION("This option adds @RegisterForReflection to the generated models. Requires dependency \"'io.quarkus:quarkus-core:+\""), MICRONAUT_INTROSPECTION("This option adds @Introspected to the generated models. Requires dependency \"'io.micronaut:micronaut-core:+\""), MICRONAUT_REFLECTION("This option adds @ReflectiveAccess to the generated models. Requires dependency \"'io.micronaut:micronaut-core:+\""), - INCLUDE_COMPANION_OBJECT("This option adds a companion object to the generated models."); + INCLUDE_COMPANION_OBJECT("This option adds a companion object to the generated models."), + SEALED_INTERFACES_FOR_ONE_OF("This option enables the generation of interfaces for discriminated oneOf types"), + ; override fun toString() = "`${super.toString()}` - $description" } diff --git a/src/main/kotlin/com/cjbooms/fabrikt/generators/model/JacksonModelGenerator.kt b/src/main/kotlin/com/cjbooms/fabrikt/generators/model/JacksonModelGenerator.kt index 56ad9f51..712e0e8e 100644 --- a/src/main/kotlin/com/cjbooms/fabrikt/generators/model/JacksonModelGenerator.kt +++ b/src/main/kotlin/com/cjbooms/fabrikt/generators/model/JacksonModelGenerator.kt @@ -1,6 +1,7 @@ package com.cjbooms.fabrikt.generators.model import com.cjbooms.fabrikt.cli.ModelCodeGenOptionType +import com.cjbooms.fabrikt.cli.ModelCodeGenOptionType.SEALED_INTERFACES_FOR_ONE_OF import com.cjbooms.fabrikt.configurations.Packages import com.cjbooms.fabrikt.generators.ClassSettings import com.cjbooms.fabrikt.generators.GeneratorUtils.toClassName @@ -33,6 +34,7 @@ import com.cjbooms.fabrikt.util.KaizenParserExtensions.isInlinedEnumDefinition import com.cjbooms.fabrikt.util.KaizenParserExtensions.isInlinedObjectDefinition import com.cjbooms.fabrikt.util.KaizenParserExtensions.isInlinedTypedAdditionalProperties import com.cjbooms.fabrikt.util.KaizenParserExtensions.isOneOfPolymorphicTypes +import com.cjbooms.fabrikt.util.KaizenParserExtensions.isOneOfSuperInterface import com.cjbooms.fabrikt.util.KaizenParserExtensions.isPolymorphicSubType import com.cjbooms.fabrikt.util.KaizenParserExtensions.isPolymorphicSuperType import com.cjbooms.fabrikt.util.KaizenParserExtensions.isSimpleType @@ -179,6 +181,13 @@ class JacksonModelGenerator( val modelName = schemaInfo.name.toModelClassName() val schemaName = schemaInfo.name return when { + schemaInfo.schema.isOneOfSuperInterface() && SEALED_INTERFACES_FOR_ONE_OF in options -> oneOfSuperInterface( + modelName, + schemaInfo.schema.discriminator, + allSchemas, + schemaInfo.schema.oneOfSchemas, + findOneOfSuperInterface(allSchemas, schemaInfo, options), + ) schemaInfo.schema.isPolymorphicSuperType() && schemaInfo.schema.isPolymorphicSubType(api) -> polymorphicSuperSubType( modelName, @@ -187,6 +196,7 @@ class JacksonModelGenerator( checkNotNull(schemaInfo.schema.getDiscriminatorForInLinedObjectUnderAllOf()), schemaInfo.schema.getSuperType(api)!!.let { SchemaInfo(it.name, it) }, schemaInfo.schema.extensions, + findOneOfSuperInterface(allSchemas, schemaInfo, options), allSchemas, ) @@ -196,6 +206,7 @@ class JacksonModelGenerator( properties, schemaInfo.schema.discriminator, schemaInfo.schema.extensions, + findOneOfSuperInterface(allSchemas, schemaInfo, options), allSchemas, ) @@ -205,13 +216,44 @@ class JacksonModelGenerator( properties, schemaInfo.schema.getSuperType(api)!!.let { SchemaInfo(it.name, it) }, schemaInfo.schema.extensions, + findOneOfSuperInterface(allSchemas, schemaInfo, options), ) schemaInfo.typeInfo is KotlinTypeInfo.Enum -> buildEnumClass(schemaInfo.typeInfo) - else -> standardDataClass(modelName, schemaName, properties, schemaInfo.schema.extensions) + else -> standardDataClass(modelName, schemaName, properties, schemaInfo.schema.extensions, findOneOfSuperInterface(allSchemas, schemaInfo, options)) } } + private fun findOneOfSuperInterface( + allSchemas: List, + schema: SchemaInfo, + options: Set, + ): Set { + if (SEALED_INTERFACES_FOR_ONE_OF !in options) { + return emptySet() + } + return allSchemas + .filter { it.schema.discriminator != null && it.schema.oneOfSchemas.isNotEmpty() } + .mapNotNull { info -> + info.schema.discriminator.mappings + .toList() + .find { (_, ref) -> + ref.endsWith("/${schema.name}") + } + ?.let { (key, _) -> + Pair(key!!, info) + } + } + .map { (_, parent) -> + val field = parent.schema.discriminator.propertyName!! + if (!schema.schema.properties.containsKey(field)) { + throw IllegalArgumentException("schema $schema did not have discriminator property") + } + parent + } + .toSet() + } + private fun buildInLinedModels( topLevelProperties: Collection, enclosingSchema: Schema, @@ -230,6 +272,7 @@ class JacksonModelGenerator( it.name, props, it.schema.extensions, + oneOfInterfaces = emptySet(), ) val inlinedModels = buildInLinedModels(props, enclosingSchema, apiDocUrl) inlinedModels + currentModel @@ -247,6 +290,7 @@ class JacksonModelGenerator( schemaName = it.name, properties = it.schema.topLevelProperties(HTTP_SETTINGS, enclosingSchema), extensions = it.schema.extensions, + oneOfInterfaces = emptySet(), ), ) } else { @@ -274,6 +318,7 @@ class JacksonModelGenerator( schemaName = it.name, properties = props, extensions = it.schema.extensions, + oneOfInterfaces = emptySet(), ) } @@ -377,6 +422,7 @@ class JacksonModelGenerator( schemaName = schema.safeName(), properties = mapField.schema.additionalPropertiesSchema.topLevelProperties(HTTP_SETTINGS), extensions = mapField.schema.extensions, + oneOfInterfaces = emptySet(), ) } else { null @@ -387,19 +433,41 @@ class JacksonModelGenerator( schemaName: String, properties: Collection, extensions: Map, + oneOfInterfaces: Set, ): TypeSpec { - val classBuilder = TypeSpec.classBuilder(generatedType(packages.base, modelName)) + val filteredProperties = if (oneOfInterfaces.size == 1) { + val oneOfInterface = oneOfInterfaces.first() + val discriminatorProp = oneOfInterface.schema.discriminator?.propertyName + val mappingCount = oneOfInterface.schema.discriminator?.mappings?.values?.count { it.endsWith("/$modelName") } + if (discriminatorProp != null && mappingCount == 1) { + properties.filterNot { it.name == discriminatorProp } + } else properties + } else properties + + val name = generatedType(packages.base, modelName) + val generateObject = properties.isNotEmpty() && filteredProperties.isEmpty() + val builder = + if (generateObject) TypeSpec.objectBuilder(name) + else TypeSpec.classBuilder(name) + val classBuilder = builder .addSerializableInterface() .addQuarkusReflectionAnnotation() .addMicronautIntrospectedAnnotation() .addMicronautReflectionAnnotation() .addCompanionObject() - properties.addToClass( + for (oneOfInterface in oneOfInterfaces) { + classBuilder + .addSuperinterface(generatedType(packages.base, oneOfInterface.schema.toModelClassName())) + } + + if (!generateObject) { + filteredProperties.addToClass( modelName = modelName, schemaName = schemaName, classBuilder = classBuilder, classType = ClassSettings(ClassSettings.PolymorphyType.NONE, extensions.hasJsonMergePatchExtension), ) + } return classBuilder.build() } @@ -410,6 +478,7 @@ class JacksonModelGenerator( discriminator: Discriminator, superType: SchemaInfo, extensions: Map, + oneOfSuperInterfaces: Set, allSchemas: List, ): TypeSpec = with(FunSpec.constructorBuilder()) { TypeSpec.classBuilder(generatedType(packages.base, modelName)) @@ -419,6 +488,7 @@ class JacksonModelGenerator( properties.filter(PropertyInfo::isInherited), superType, extensions, + oneOfSuperInterfaces, this, ) .buildPolymorphicSuperType( @@ -427,21 +497,62 @@ class JacksonModelGenerator( properties.filterNot(PropertyInfo::isInherited), discriminator, extensions, + oneOfSuperInterfaces, allSchemas, this, ) .build() } + private fun oneOfSuperInterface( + modelName: String, + discriminator: Discriminator, + allSchemas: List, + members: List, + oneOfSuperInterfaces: Set, + ): TypeSpec { + val interfaceBuilder = TypeSpec.interfaceBuilder(generatedType(packages.base, modelName)) + .addModifiers(KModifier.SEALED) + .addAnnotation(basePolymorphicType(discriminator.propertyName)) + + for (oneOfSuperInterface in oneOfSuperInterfaces) { + interfaceBuilder.addSuperinterface(generatedType(packages.base, oneOfSuperInterface.name)) + } + + val membersAndMappingsConsistent = members.all { member -> + discriminator.mappings.any { (_, ref) -> ref.endsWith("/${member.name}") } + } + + if (!membersAndMappingsConsistent) { + throw IllegalArgumentException("members and mappings are not consistent for oneOf super interface $modelName!") + } + + val mappings = discriminator.mappings + .mapValues { (_, value) -> + allSchemas.find { value.endsWith("/${it.name}") }!! + } + .mapValues { (_, value) -> + toModelType(packages.base, KotlinTypeInfo.from(value.schema, value.name)) + } + + interfaceBuilder.addAnnotation(polymorphicSubTypes(mappings, enumDiscriminator = null)) + .addQuarkusReflectionAnnotation() + .addMicronautIntrospectedAnnotation() + .addMicronautReflectionAnnotation() + + return interfaceBuilder.build() + } + private fun polymorphicSuperType( modelName: String, schemaName: String, properties: Collection, discriminator: Discriminator, extensions: Map, + oneOfSuperInterfaces: Set, allSchemas: List, ): TypeSpec = TypeSpec.classBuilder(generatedType(packages.base, modelName)) - .buildPolymorphicSuperType(modelName, schemaName, properties, discriminator, extensions, allSchemas) + .buildPolymorphicSuperType(modelName, schemaName, properties, discriminator, extensions, oneOfSuperInterfaces, allSchemas) .build() private fun TypeSpec.Builder.buildPolymorphicSuperType( @@ -450,6 +561,7 @@ class JacksonModelGenerator( properties: Collection, discriminator: Discriminator, extensions: Map, + oneOfSuperInterfaces: Set, allSchemas: List, constructorBuilder: FunSpec.Builder = FunSpec.constructorBuilder(), ): TypeSpec.Builder { @@ -457,6 +569,10 @@ class JacksonModelGenerator( .addAnnotation(basePolymorphicType(discriminator.propertyName)) .modifiers.remove(KModifier.DATA) + for (oneOfSuperInterface in oneOfSuperInterfaces) { + this.addSuperinterface(generatedType(packages.base, oneOfSuperInterface.name)) + } + val subTypes = allSchemas .filter { model -> model.schema.allOfSchemas.any { allOfRef -> @@ -498,8 +614,9 @@ class JacksonModelGenerator( properties: Collection, superType: SchemaInfo, extensions: Map, + oneOfSuperInterfaces: Set, ): TypeSpec = TypeSpec.classBuilder(generatedType(packages.base, modelName)) - .buildPolymorphicSubType(modelName, schemaName, properties, superType, extensions).build() + .buildPolymorphicSubType(modelName, schemaName, properties, superType, extensions, oneOfSuperInterfaces).build() private fun TypeSpec.Builder.buildPolymorphicSubType( modelName: String, @@ -507,6 +624,7 @@ class JacksonModelGenerator( allProperties: Collection, superType: SchemaInfo, extensions: Map, + oneOfSuperInterfaces: Set, constructorBuilder: FunSpec.Builder = FunSpec.constructorBuilder(), ): TypeSpec.Builder { this.addSerializableInterface() @@ -518,6 +636,10 @@ class JacksonModelGenerator( toModelType(packages.base, KotlinTypeInfo.from(superType.schema, superType.name)), ) + for (oneOfSuperInterface in oneOfSuperInterfaces) { + this.addSuperinterface(generatedType(packages.base, oneOfSuperInterface.name)) + } + val properties = superType.schema.getDiscriminatorForInLinedObjectUnderAllOf()?.let { discriminator -> allProperties.filterNot { when (it) { diff --git a/src/main/kotlin/com/cjbooms/fabrikt/model/KotlinTypeInfo.kt b/src/main/kotlin/com/cjbooms/fabrikt/model/KotlinTypeInfo.kt index 95128aa6..72279514 100644 --- a/src/main/kotlin/com/cjbooms/fabrikt/model/KotlinTypeInfo.kt +++ b/src/main/kotlin/com/cjbooms/fabrikt/model/KotlinTypeInfo.kt @@ -6,6 +6,7 @@ import com.cjbooms.fabrikt.model.OasType.Companion.toOasType import com.cjbooms.fabrikt.util.KaizenParserExtensions.getEnumValues import com.cjbooms.fabrikt.util.KaizenParserExtensions.isInlinedTypedAdditionalProperties import com.cjbooms.fabrikt.util.KaizenParserExtensions.isNotDefined +import com.cjbooms.fabrikt.util.KaizenParserExtensions.isOneOfSuperInterface import com.cjbooms.fabrikt.util.KaizenParserExtensions.toMapValueClassName import com.cjbooms.fabrikt.util.KaizenParserExtensions.toModelClassName import com.cjbooms.fabrikt.util.NormalisedString.toModelClassName @@ -92,7 +93,9 @@ sealed class KotlinTypeInfo(val modelKClass: KClass<*>, val generatedModelClassN from(schema.additionalPropertiesSchema, "", enclosingName) ) OasType.Any -> AnyType - OasType.OneOfAny -> AnyType + OasType.OneOfAny -> + if (schema.isOneOfSuperInterface()) Object(schema.toModelClassName(enclosingName.toModelClassName())) + else AnyType } private fun getOverridableDateTimeType(): KotlinTypeInfo { diff --git a/src/main/kotlin/com/cjbooms/fabrikt/util/KaizenParserExtensions.kt b/src/main/kotlin/com/cjbooms/fabrikt/util/KaizenParserExtensions.kt index ac864a36..4924c24a 100644 --- a/src/main/kotlin/com/cjbooms/fabrikt/util/KaizenParserExtensions.kt +++ b/src/main/kotlin/com/cjbooms/fabrikt/util/KaizenParserExtensions.kt @@ -102,7 +102,7 @@ object KaizenParserExtensions { "additionalProperties" && properties?.isEmpty() != true && !isSimpleType() fun Schema.isSimpleType(): Boolean = - (simpleTypes.contains(type) && !isEnumDefinition()) || isSimpleMapDefinition() || isSimpleOneOfAnyDefinition() + !isOneOfSuperInterface() && ((simpleTypes.contains(type) && !isEnumDefinition()) || isSimpleMapDefinition() || isSimpleOneOfAnyDefinition()) private fun Schema.isObjectType() = OasType.Object.type == type @@ -210,6 +210,9 @@ object KaizenParserExtensions { fun Schema.isOneOfPolymorphicTypes() = this.oneOfSchemas?.firstOrNull()?.allOfSchemas?.firstOrNull() != null + fun Schema.isOneOfSuperInterface() = + discriminator != null && discriminator.propertyName != null && oneOfSchemas.isNotEmpty() + fun OpenApi3.basePath(): String = servers .firstOrNull() diff --git a/src/test/kotlin/com/cjbooms/fabrikt/generators/ModelGeneratorTest.kt b/src/test/kotlin/com/cjbooms/fabrikt/generators/ModelGeneratorTest.kt index 780adea9..553fbd74 100644 --- a/src/test/kotlin/com/cjbooms/fabrikt/generators/ModelGeneratorTest.kt +++ b/src/test/kotlin/com/cjbooms/fabrikt/generators/ModelGeneratorTest.kt @@ -51,6 +51,8 @@ class ModelGeneratorTest { "responsesSchema", "webhook", "instantDateTime", + "singleAllOf", + "discriminatedOneOf", ) @BeforeEach @@ -79,6 +81,7 @@ class ModelGeneratorTest { val models = JacksonModelGenerator( Packages(basePackage), sourceApi, + setOf(ModelCodeGenOptionType.SEALED_INTERFACES_FOR_ONE_OF), ).generate().toSingleFile() assertThat(models).isEqualTo(expectedModels) diff --git a/src/test/resources/examples/discriminatedOneOf/api.yaml b/src/test/resources/examples/discriminatedOneOf/api.yaml new file mode 100644 index 00000000..ab2e0a23 --- /dev/null +++ b/src/test/resources/examples/discriminatedOneOf/api.yaml @@ -0,0 +1,40 @@ +openapi: 3.0.0 +info: +paths: +components: + schemas: + SomeObj: + type: object + required: + - state + properties: + state: + $ref: '#/components/schemas/State' + State: + oneOf: + - $ref: '#/components/schemas/StateA' + - $ref: '#/components/schemas/StateB' + discriminator: + propertyName: status + mapping: + a: '#/components/schemas/StateA' + b: '#/components/schemas/StateB' + Status: + type: string + enum: + - a + - b + StateA: + type: object + required: + - status + properties: + status: + $ref: '#/components/schemas/Status' + StateB: + type: object + required: + - status + properties: + status: + $ref: '#/components/schemas/Status' diff --git a/src/test/resources/examples/discriminatedOneOf/models/Models.kt b/src/test/resources/examples/discriminatedOneOf/models/Models.kt new file mode 100644 index 00000000..d1e4fb1f --- /dev/null +++ b/src/test/resources/examples/discriminatedOneOf/models/Models.kt @@ -0,0 +1,53 @@ +package examples.discriminatedOneOf.models + +import com.fasterxml.jackson.annotation.JsonProperty +import com.fasterxml.jackson.annotation.JsonSubTypes +import com.fasterxml.jackson.annotation.JsonTypeInfo +import com.fasterxml.jackson.annotation.JsonValue +import javax.validation.Valid +import javax.validation.constraints.NotNull +import kotlin.String +import kotlin.collections.Map + +data class SomeObj( + @param:JsonProperty("state") + @get:JsonProperty("state") + @get:NotNull + @get:Valid + val state: State +) + +@JsonTypeInfo( + use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.EXISTING_PROPERTY, + property = "status", + visible = true +) +@JsonSubTypes( + JsonSubTypes.Type(value = StateA::class, name = "a"), + JsonSubTypes.Type( + value = + StateB::class, + name = "b" + ) +) +sealed interface State + +object StateA : State + +object StateB : State + +enum class Status( + @JsonValue + val value: String +) { + A("a"), + + B("b"); + + companion object { + private val mapping: Map = values().associateBy(Status::value) + + fun fromValue(value: String): Status? = mapping[value] + } +}