Skip to content

Commit

Permalink
Add support for disciminated oneOf's (#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
pschichtel authored Apr 24, 2023
1 parent 9bafc02 commit b739d4e
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 8 deletions.
4 changes: 3 additions & 1 deletion src/main/kotlin/com/cjbooms/fabrikt/cli/CodeGenOptions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ enum class ModelCodeGenOptionType(val description: String) {
QUARKUS_REFLECTION("This option adds @RegisterForReflection to the generated models. Requires dependency \"'io.quarkus:quarkus-core:+\""),
MICRONAUT_INTROSPECTION("This option adds @Introspected to the generated models. Requires dependency \"'io.micronaut:micronaut-core:+\""),
MICRONAUT_REFLECTION("This option adds @ReflectiveAccess to the generated models. Requires dependency \"'io.micronaut:micronaut-core:+\""),
INCLUDE_COMPANION_OBJECT("This option adds a companion object to the generated models.");
INCLUDE_COMPANION_OBJECT("This option adds a companion object to the generated models."),
SEALED_INTERFACES_FOR_ONE_OF("This option enables the generation of interfaces for discriminated oneOf types"),
;

override fun toString() = "`${super.toString()}` - $description"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.cjbooms.fabrikt.generators.model

import com.cjbooms.fabrikt.cli.ModelCodeGenOptionType
import com.cjbooms.fabrikt.cli.ModelCodeGenOptionType.SEALED_INTERFACES_FOR_ONE_OF
import com.cjbooms.fabrikt.configurations.Packages
import com.cjbooms.fabrikt.generators.ClassSettings
import com.cjbooms.fabrikt.generators.GeneratorUtils.toClassName
Expand Down Expand Up @@ -33,6 +34,7 @@ import com.cjbooms.fabrikt.util.KaizenParserExtensions.isInlinedEnumDefinition
import com.cjbooms.fabrikt.util.KaizenParserExtensions.isInlinedObjectDefinition
import com.cjbooms.fabrikt.util.KaizenParserExtensions.isInlinedTypedAdditionalProperties
import com.cjbooms.fabrikt.util.KaizenParserExtensions.isOneOfPolymorphicTypes
import com.cjbooms.fabrikt.util.KaizenParserExtensions.isOneOfSuperInterface
import com.cjbooms.fabrikt.util.KaizenParserExtensions.isPolymorphicSubType
import com.cjbooms.fabrikt.util.KaizenParserExtensions.isPolymorphicSuperType
import com.cjbooms.fabrikt.util.KaizenParserExtensions.isSimpleType
Expand Down Expand Up @@ -179,6 +181,13 @@ class JacksonModelGenerator(
val modelName = schemaInfo.name.toModelClassName()
val schemaName = schemaInfo.name
return when {
schemaInfo.schema.isOneOfSuperInterface() && SEALED_INTERFACES_FOR_ONE_OF in options -> oneOfSuperInterface(
modelName,
schemaInfo.schema.discriminator,
allSchemas,
schemaInfo.schema.oneOfSchemas,
findOneOfSuperInterface(allSchemas, schemaInfo, options),
)
schemaInfo.schema.isPolymorphicSuperType() && schemaInfo.schema.isPolymorphicSubType(api) ->
polymorphicSuperSubType(
modelName,
Expand All @@ -187,6 +196,7 @@ class JacksonModelGenerator(
checkNotNull(schemaInfo.schema.getDiscriminatorForInLinedObjectUnderAllOf()),
schemaInfo.schema.getSuperType(api)!!.let { SchemaInfo(it.name, it) },
schemaInfo.schema.extensions,
findOneOfSuperInterface(allSchemas, schemaInfo, options),
allSchemas,
)

Expand All @@ -196,6 +206,7 @@ class JacksonModelGenerator(
properties,
schemaInfo.schema.discriminator,
schemaInfo.schema.extensions,
findOneOfSuperInterface(allSchemas, schemaInfo, options),
allSchemas,
)

Expand All @@ -205,13 +216,44 @@ class JacksonModelGenerator(
properties,
schemaInfo.schema.getSuperType(api)!!.let { SchemaInfo(it.name, it) },
schemaInfo.schema.extensions,
findOneOfSuperInterface(allSchemas, schemaInfo, options),
)

schemaInfo.typeInfo is KotlinTypeInfo.Enum -> buildEnumClass(schemaInfo.typeInfo)
else -> standardDataClass(modelName, schemaName, properties, schemaInfo.schema.extensions)
else -> standardDataClass(modelName, schemaName, properties, schemaInfo.schema.extensions, findOneOfSuperInterface(allSchemas, schemaInfo, options))
}
}

private fun findOneOfSuperInterface(
allSchemas: List<SchemaInfo>,
schema: SchemaInfo,
options: Set<ModelCodeGenOptionType>,
): Set<SchemaInfo> {
if (SEALED_INTERFACES_FOR_ONE_OF !in options) {
return emptySet()
}
return allSchemas
.filter { it.schema.discriminator != null && it.schema.oneOfSchemas.isNotEmpty() }
.mapNotNull { info ->
info.schema.discriminator.mappings
.toList()
.find { (_, ref) ->
ref.endsWith("/${schema.name}")
}
?.let { (key, _) ->
Pair(key!!, info)
}
}
.map { (_, parent) ->
val field = parent.schema.discriminator.propertyName!!
if (!schema.schema.properties.containsKey(field)) {
throw IllegalArgumentException("schema $schema did not have discriminator property")
}
parent
}
.toSet()
}

private fun buildInLinedModels(
topLevelProperties: Collection<PropertyInfo>,
enclosingSchema: Schema,
Expand All @@ -230,6 +272,7 @@ class JacksonModelGenerator(
it.name,
props,
it.schema.extensions,
oneOfInterfaces = emptySet(),
)
val inlinedModels = buildInLinedModels(props, enclosingSchema, apiDocUrl)
inlinedModels + currentModel
Expand All @@ -247,6 +290,7 @@ class JacksonModelGenerator(
schemaName = it.name,
properties = it.schema.topLevelProperties(HTTP_SETTINGS, enclosingSchema),
extensions = it.schema.extensions,
oneOfInterfaces = emptySet(),
),
)
} else {
Expand Down Expand Up @@ -274,6 +318,7 @@ class JacksonModelGenerator(
schemaName = it.name,
properties = props,
extensions = it.schema.extensions,
oneOfInterfaces = emptySet(),
)
}

Expand Down Expand Up @@ -377,6 +422,7 @@ class JacksonModelGenerator(
schemaName = schema.safeName(),
properties = mapField.schema.additionalPropertiesSchema.topLevelProperties(HTTP_SETTINGS),
extensions = mapField.schema.extensions,
oneOfInterfaces = emptySet(),
)
} else {
null
Expand All @@ -387,19 +433,41 @@ class JacksonModelGenerator(
schemaName: String,
properties: Collection<PropertyInfo>,
extensions: Map<String, Any>,
oneOfInterfaces: Set<SchemaInfo>,
): TypeSpec {
val classBuilder = TypeSpec.classBuilder(generatedType(packages.base, modelName))
val filteredProperties = if (oneOfInterfaces.size == 1) {
val oneOfInterface = oneOfInterfaces.first()
val discriminatorProp = oneOfInterface.schema.discriminator?.propertyName
val mappingCount = oneOfInterface.schema.discriminator?.mappings?.values?.count { it.endsWith("/$modelName") }
if (discriminatorProp != null && mappingCount == 1) {
properties.filterNot { it.name == discriminatorProp }
} else properties
} else properties

val name = generatedType(packages.base, modelName)
val generateObject = properties.isNotEmpty() && filteredProperties.isEmpty()
val builder =
if (generateObject) TypeSpec.objectBuilder(name)
else TypeSpec.classBuilder(name)
val classBuilder = builder
.addSerializableInterface()
.addQuarkusReflectionAnnotation()
.addMicronautIntrospectedAnnotation()
.addMicronautReflectionAnnotation()
.addCompanionObject()
properties.addToClass(
for (oneOfInterface in oneOfInterfaces) {
classBuilder
.addSuperinterface(generatedType(packages.base, oneOfInterface.schema.toModelClassName()))
}

if (!generateObject) {
filteredProperties.addToClass(
modelName = modelName,
schemaName = schemaName,
classBuilder = classBuilder,
classType = ClassSettings(ClassSettings.PolymorphyType.NONE, extensions.hasJsonMergePatchExtension),
)
}
return classBuilder.build()
}

Expand All @@ -410,6 +478,7 @@ class JacksonModelGenerator(
discriminator: Discriminator,
superType: SchemaInfo,
extensions: Map<String, Any>,
oneOfSuperInterfaces: Set<SchemaInfo>,
allSchemas: List<SchemaInfo>,
): TypeSpec = with(FunSpec.constructorBuilder()) {
TypeSpec.classBuilder(generatedType(packages.base, modelName))
Expand All @@ -419,6 +488,7 @@ class JacksonModelGenerator(
properties.filter(PropertyInfo::isInherited),
superType,
extensions,
oneOfSuperInterfaces,
this,
)
.buildPolymorphicSuperType(
Expand All @@ -427,21 +497,62 @@ class JacksonModelGenerator(
properties.filterNot(PropertyInfo::isInherited),
discriminator,
extensions,
oneOfSuperInterfaces,
allSchemas,
this,
)
.build()
}

private fun oneOfSuperInterface(
modelName: String,
discriminator: Discriminator,
allSchemas: List<SchemaInfo>,
members: List<Schema>,
oneOfSuperInterfaces: Set<SchemaInfo>,
): TypeSpec {
val interfaceBuilder = TypeSpec.interfaceBuilder(generatedType(packages.base, modelName))
.addModifiers(KModifier.SEALED)
.addAnnotation(basePolymorphicType(discriminator.propertyName))

for (oneOfSuperInterface in oneOfSuperInterfaces) {
interfaceBuilder.addSuperinterface(generatedType(packages.base, oneOfSuperInterface.name))
}

val membersAndMappingsConsistent = members.all { member ->
discriminator.mappings.any { (_, ref) -> ref.endsWith("/${member.name}") }
}

if (!membersAndMappingsConsistent) {
throw IllegalArgumentException("members and mappings are not consistent for oneOf super interface $modelName!")
}

val mappings = discriminator.mappings
.mapValues { (_, value) ->
allSchemas.find { value.endsWith("/${it.name}") }!!
}
.mapValues { (_, value) ->
toModelType(packages.base, KotlinTypeInfo.from(value.schema, value.name))
}

interfaceBuilder.addAnnotation(polymorphicSubTypes(mappings, enumDiscriminator = null))
.addQuarkusReflectionAnnotation()
.addMicronautIntrospectedAnnotation()
.addMicronautReflectionAnnotation()

return interfaceBuilder.build()
}

private fun polymorphicSuperType(
modelName: String,
schemaName: String,
properties: Collection<PropertyInfo>,
discriminator: Discriminator,
extensions: Map<String, Any>,
oneOfSuperInterfaces: Set<SchemaInfo>,
allSchemas: List<SchemaInfo>,
): TypeSpec = TypeSpec.classBuilder(generatedType(packages.base, modelName))
.buildPolymorphicSuperType(modelName, schemaName, properties, discriminator, extensions, allSchemas)
.buildPolymorphicSuperType(modelName, schemaName, properties, discriminator, extensions, oneOfSuperInterfaces, allSchemas)
.build()

private fun TypeSpec.Builder.buildPolymorphicSuperType(
Expand All @@ -450,13 +561,18 @@ class JacksonModelGenerator(
properties: Collection<PropertyInfo>,
discriminator: Discriminator,
extensions: Map<String, Any>,
oneOfSuperInterfaces: Set<SchemaInfo>,
allSchemas: List<SchemaInfo>,
constructorBuilder: FunSpec.Builder = FunSpec.constructorBuilder(),
): TypeSpec.Builder {
this.addModifiers(KModifier.SEALED)
.addAnnotation(basePolymorphicType(discriminator.propertyName))
.modifiers.remove(KModifier.DATA)

for (oneOfSuperInterface in oneOfSuperInterfaces) {
this.addSuperinterface(generatedType(packages.base, oneOfSuperInterface.name))
}

val subTypes = allSchemas
.filter { model ->
model.schema.allOfSchemas.any { allOfRef ->
Expand Down Expand Up @@ -498,15 +614,17 @@ class JacksonModelGenerator(
properties: Collection<PropertyInfo>,
superType: SchemaInfo,
extensions: Map<String, Any>,
oneOfSuperInterfaces: Set<SchemaInfo>,
): TypeSpec = TypeSpec.classBuilder(generatedType(packages.base, modelName))
.buildPolymorphicSubType(modelName, schemaName, properties, superType, extensions).build()
.buildPolymorphicSubType(modelName, schemaName, properties, superType, extensions, oneOfSuperInterfaces).build()

private fun TypeSpec.Builder.buildPolymorphicSubType(
modelName: String,
schemaName: String,
allProperties: Collection<PropertyInfo>,
superType: SchemaInfo,
extensions: Map<String, Any>,
oneOfSuperInterfaces: Set<SchemaInfo>,
constructorBuilder: FunSpec.Builder = FunSpec.constructorBuilder(),
): TypeSpec.Builder {
this.addSerializableInterface()
Expand All @@ -518,6 +636,10 @@ class JacksonModelGenerator(
toModelType(packages.base, KotlinTypeInfo.from(superType.schema, superType.name)),
)

for (oneOfSuperInterface in oneOfSuperInterfaces) {
this.addSuperinterface(generatedType(packages.base, oneOfSuperInterface.name))
}

val properties = superType.schema.getDiscriminatorForInLinedObjectUnderAllOf()?.let { discriminator ->
allProperties.filterNot {
when (it) {
Expand Down
5 changes: 4 additions & 1 deletion src/main/kotlin/com/cjbooms/fabrikt/model/KotlinTypeInfo.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.cjbooms.fabrikt.model.OasType.Companion.toOasType
import com.cjbooms.fabrikt.util.KaizenParserExtensions.getEnumValues
import com.cjbooms.fabrikt.util.KaizenParserExtensions.isInlinedTypedAdditionalProperties
import com.cjbooms.fabrikt.util.KaizenParserExtensions.isNotDefined
import com.cjbooms.fabrikt.util.KaizenParserExtensions.isOneOfSuperInterface
import com.cjbooms.fabrikt.util.KaizenParserExtensions.toMapValueClassName
import com.cjbooms.fabrikt.util.KaizenParserExtensions.toModelClassName
import com.cjbooms.fabrikt.util.NormalisedString.toModelClassName
Expand Down Expand Up @@ -92,7 +93,9 @@ sealed class KotlinTypeInfo(val modelKClass: KClass<*>, val generatedModelClassN
from(schema.additionalPropertiesSchema, "", enclosingName)
)
OasType.Any -> AnyType
OasType.OneOfAny -> AnyType
OasType.OneOfAny ->
if (schema.isOneOfSuperInterface()) Object(schema.toModelClassName(enclosingName.toModelClassName()))
else AnyType
}

private fun getOverridableDateTimeType(): KotlinTypeInfo {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ object KaizenParserExtensions {
"additionalProperties" && properties?.isEmpty() != true && !isSimpleType()

fun Schema.isSimpleType(): Boolean =
(simpleTypes.contains(type) && !isEnumDefinition()) || isSimpleMapDefinition() || isSimpleOneOfAnyDefinition()
!isOneOfSuperInterface() && ((simpleTypes.contains(type) && !isEnumDefinition()) || isSimpleMapDefinition() || isSimpleOneOfAnyDefinition())

private fun Schema.isObjectType() = OasType.Object.type == type

Expand Down Expand Up @@ -210,6 +210,9 @@ object KaizenParserExtensions {
fun Schema.isOneOfPolymorphicTypes() =
this.oneOfSchemas?.firstOrNull()?.allOfSchemas?.firstOrNull() != null

fun Schema.isOneOfSuperInterface() =
discriminator != null && discriminator.propertyName != null && oneOfSchemas.isNotEmpty()

fun OpenApi3.basePath(): String =
servers
.firstOrNull()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class ModelGeneratorTest {
"responsesSchema",
"webhook",
"instantDateTime",
"singleAllOf",
"discriminatedOneOf",
)

@BeforeEach
Expand Down Expand Up @@ -79,6 +81,7 @@ class ModelGeneratorTest {
val models = JacksonModelGenerator(
Packages(basePackage),
sourceApi,
setOf(ModelCodeGenOptionType.SEALED_INTERFACES_FOR_ONE_OF),
).generate().toSingleFile()

assertThat(models).isEqualTo(expectedModels)
Expand Down
Loading

0 comments on commit b739d4e

Please sign in to comment.