From 61af5121e7f397dc1879b5f518e2e9e2c6b432af Mon Sep 17 00:00:00 2001 From: Conor Gallagher Date: Tue, 5 Mar 2024 15:29:34 +0000 Subject: [PATCH] Do not strip discriminator property in oneOf generation. (#268) * Do not strip discriminator property in oneOf generation. * revert test and formatting changes * correct test * Change oneOf so that it defaults to the correct discriminator value --- .../fabrikt/generators/PropertyUtils.kt | 8 +- .../generators/model/JacksonModelGenerator.kt | 106 +++++++----------- .../com/cjbooms/fabrikt/model/PropertyInfo.kt | 22 ++-- .../fabrikt/util/KaizenParserExtensions.kt | 53 ++++++++- .../fabrikt/generators/ModelGeneratorTest.kt | 3 + .../discriminatedOneOf/models/Models.kt | 14 ++- 6 files changed, 122 insertions(+), 84 deletions(-) diff --git a/src/main/kotlin/com/cjbooms/fabrikt/generators/PropertyUtils.kt b/src/main/kotlin/com/cjbooms/fabrikt/generators/PropertyUtils.kt index f23b6721..1a313d6a 100644 --- a/src/main/kotlin/com/cjbooms/fabrikt/generators/PropertyUtils.kt +++ b/src/main/kotlin/com/cjbooms/fabrikt/generators/PropertyUtils.kt @@ -21,6 +21,7 @@ data class ClassSettings( NONE, SUPER, SUB, + ONE_OF, } } @@ -102,11 +103,16 @@ object PropertyUtils { property.addAnnotation(JacksonMetadata.jacksonPropertyAnnotation(oasKey)) property.addValidationAnnotations(this, validationAnnotations) } + + ClassSettings.PolymorphyType.ONE_OF -> { + property.addAnnotation(JacksonMetadata.jacksonPropertyAnnotation(oasKey)) + property.addValidationAnnotations(this, validationAnnotations) + } } if (isDiscriminatorFieldWithSingleKnownValue(classSettings, schemaName)) { this as PropertyInfo.Field - if (classSettings.polymorphyType == ClassSettings.PolymorphyType.SUB) { + if (classSettings.polymorphyType in listOf(ClassSettings.PolymorphyType.SUB, ClassSettings.PolymorphyType.ONE_OF)) { property.initializer(name) property.addAnnotation(JacksonMetadata.jacksonParameterAnnotation(oasKey)) val constructorParameter: ParameterSpec.Builder = ParameterSpec.builder(name, wrappedType) diff --git a/src/main/kotlin/com/cjbooms/fabrikt/generators/model/JacksonModelGenerator.kt b/src/main/kotlin/com/cjbooms/fabrikt/generators/model/JacksonModelGenerator.kt index aff2a0f9..2fc6953d 100644 --- a/src/main/kotlin/com/cjbooms/fabrikt/generators/model/JacksonModelGenerator.kt +++ b/src/main/kotlin/com/cjbooms/fabrikt/generators/model/JacksonModelGenerator.kt @@ -29,6 +29,7 @@ import com.cjbooms.fabrikt.model.PropertyInfo.Companion.topLevelProperties import com.cjbooms.fabrikt.model.SchemaInfo import com.cjbooms.fabrikt.model.SourceApi import com.cjbooms.fabrikt.model.toEnclosingSchemaInfo +import com.cjbooms.fabrikt.util.KaizenParserExtensions.findOneOfSuperInterface import com.cjbooms.fabrikt.util.KaizenParserExtensions.getDiscriminatorForInLinedObjectUnderAllOf import com.cjbooms.fabrikt.util.KaizenParserExtensions.getSchemaRefName import com.cjbooms.fabrikt.util.KaizenParserExtensions.getSuperType @@ -164,8 +165,11 @@ class JacksonModelGenerator( .filterNot { it.schema.isSimpleType() } .filterNot { it.schema.isOneOfPolymorphicTypes() } .flatMap { - val properties = it.schema.topLevelProperties(HTTP_SETTINGS, it.schema) - if (properties.isNotEmpty() || it.typeInfo is KotlinTypeInfo.Enum) { + val properties = it.schema.topLevelProperties(HTTP_SETTINGS, api, it.schema) + if (properties.isNotEmpty() || + it.typeInfo is KotlinTypeInfo.Enum || + it.schema.findOneOfSuperInterface(schemas.map { it.schema }).isNotEmpty() + ) { val primaryModel = buildPrimaryModel(api, it, properties, schemas) val inlinedModels = buildInLinedModels(properties, it.schema, it.schema.getDocumentUrl()) listOf(primaryModel) + inlinedModels @@ -196,7 +200,7 @@ class JacksonModelGenerator( schemaInfo.schema.discriminator, allSchemas, schemaInfo.schema.oneOfSchemas, - findOneOfSuperInterface(allSchemas, schemaInfo, options), + schemaInfo.schema.findOneOfSuperInterface(allSchemas.map { it.schema }), ) schemaInfo.schema.isPolymorphicSuperType() && schemaInfo.schema.isPolymorphicSubType(api) -> @@ -207,7 +211,7 @@ class JacksonModelGenerator( checkNotNull(schemaInfo.schema.getDiscriminatorForInLinedObjectUnderAllOf()), schemaInfo.schema.getSuperType(api)!!.let { SchemaInfo(it.name, it) }, schemaInfo.schema.extensions, - findOneOfSuperInterface(allSchemas, schemaInfo, options), + schemaInfo.schema.findOneOfSuperInterface(allSchemas.map { it.schema }), allSchemas, ) @@ -217,7 +221,7 @@ class JacksonModelGenerator( properties, schemaInfo.schema.discriminator, schemaInfo.schema.extensions, - findOneOfSuperInterface(allSchemas, schemaInfo, options), + schemaInfo.schema.findOneOfSuperInterface(allSchemas.map { it.schema }), allSchemas, ) @@ -227,7 +231,7 @@ class JacksonModelGenerator( properties, schemaInfo.schema.getSuperType(api)!!.let { SchemaInfo(it.name, it) }, schemaInfo.schema.extensions, - findOneOfSuperInterface(allSchemas, schemaInfo, options), + schemaInfo.schema.findOneOfSuperInterface(allSchemas.map { it.schema }), ) schemaInfo.typeInfo is KotlinTypeInfo.Enum -> buildEnumClass(schemaInfo.typeInfo) @@ -236,40 +240,12 @@ class JacksonModelGenerator( schemaName = schemaName, properties = properties, extensions = schemaInfo.schema.extensions, - oneOfInterfaces = findOneOfSuperInterface(allSchemas, schemaInfo, options), + oneOfInterfaces = schemaInfo.schema.findOneOfSuperInterface(allSchemas.map { it.schema }), ) } } - private fun findOneOfSuperInterface( - allSchemas: List, - schema: SchemaInfo, - options: Set, - ): Set { - if (SEALED_INTERFACES_FOR_ONE_OF !in options) { - return emptySet() - } - return allSchemas - .filter { it.schema.discriminator != null && it.schema.oneOfSchemas.isNotEmpty() } - .mapNotNull { info -> - info.schema.discriminator.mappings - .toList() - .find { (_, ref) -> - ref.endsWith("/${schema.name}") - } - ?.let { (key, _) -> - Pair(key!!, info) - } - } - .map { (_, parent) -> - val field = parent.schema.discriminator.propertyName!! - if (!schema.schema.properties.containsKey(field)) { - throw IllegalArgumentException("schema $schema did not have discriminator property") - } - parent - } - .toSet() - } + private fun buildInLinedModels( topLevelProperties: Collection, @@ -285,7 +261,7 @@ class JacksonModelGenerator( if (it.isInherited) { emptySet() // Rely on the parent definition } else { - val props = it.schema.topLevelProperties(HTTP_SETTINGS, enclosingSchema) + val props = it.schema.topLevelProperties(HTTP_SETTINGS, sourceApi.openApi3, enclosingSchema) val currentModel = standardDataClass( ModelNameRegistry.getOrRegister(it.schema, enclosingSchema.toEnclosingSchemaInfo()), it.name, @@ -308,7 +284,7 @@ class JacksonModelGenerator( standardDataClass( modelName = ModelNameRegistry.getOrRegister(it.schema, valueSuffix = it.schema.isInlinedTypedAdditionalProperties()), schemaName = it.name, - properties = it.schema.topLevelProperties(HTTP_SETTINGS, enclosingSchema), + properties = it.schema.topLevelProperties(HTTP_SETTINGS, sourceApi.openApi3, enclosingSchema), extensions = it.schema.extensions, oneOfInterfaces = emptySet(), ), @@ -347,7 +323,7 @@ class JacksonModelGenerator( ?: enclosingSchema.toEnclosingSchemaInfo() when { items.isInlinedObjectDefinition() -> - items.topLevelProperties(HTTP_SETTINGS, enclosingSchema).let { props -> + items.topLevelProperties(HTTP_SETTINGS, sourceApi.openApi3, enclosingSchema).let { props -> buildInLinedModels( topLevelProperties = props, enclosingSchema = enclosingSchema, @@ -454,7 +430,7 @@ class JacksonModelGenerator( standardDataClass( modelName = ModelNameRegistry.getOrRegister(schema, valueSuffix = schema.isInlinedTypedAdditionalProperties()), schemaName = schema.safeName(), - properties = mapField.schema.additionalPropertiesSchema.topLevelProperties(HTTP_SETTINGS), + properties = mapField.schema.additionalPropertiesSchema.topLevelProperties(HTTP_SETTINGS, sourceApi.openApi3), extensions = mapField.schema.extensions, oneOfInterfaces = emptySet(), ) @@ -467,24 +443,10 @@ class JacksonModelGenerator( schemaName: String, properties: Collection, extensions: Map, - oneOfInterfaces: Set, + oneOfInterfaces: Set, ): TypeSpec { - val filteredProperties = if (oneOfInterfaces.size == 1) { - val oneOfInterface = oneOfInterfaces.first() - val discriminatorProp = oneOfInterface.schema.discriminator?.propertyName - val mappingCount = - oneOfInterface.schema.discriminator?.mappings?.values?.count { it.endsWith("/$modelName") } - if (discriminatorProp != null && mappingCount == 1) { - properties.filterNot { it.name == discriminatorProp } - } else { - properties - } - } else { - properties - } - val name = generatedType(packages.base, modelName) - val generateObject = properties.isNotEmpty() && filteredProperties.isEmpty() + val generateObject = properties.isEmpty() val builder = if (generateObject) { TypeSpec.objectBuilder(name) @@ -499,15 +461,23 @@ class JacksonModelGenerator( .addCompanionObject() for (oneOfInterface in oneOfInterfaces) { classBuilder - .addSuperinterface(generatedType(packages.base, ModelNameRegistry.getOrRegister(oneOfInterface.schema))) + .addSuperinterface(generatedType(packages.base, ModelNameRegistry.getOrRegister(oneOfInterface))) } if (!generateObject) { - filteredProperties.addToClass( - schemaName = schemaName, - classBuilder = classBuilder, - classType = ClassSettings(ClassSettings.PolymorphyType.NONE, extensions.hasJsonMergePatchExtension), - ) + if (oneOfInterfaces.size == 1) { + properties.addToClass( + schemaName = schemaName, + classBuilder = classBuilder, + classType = ClassSettings(ClassSettings.PolymorphyType.ONE_OF, extensions.hasJsonMergePatchExtension), + ) + } else { + properties.addToClass( + schemaName = schemaName, + classBuilder = classBuilder, + classType = ClassSettings(ClassSettings.PolymorphyType.NONE, extensions.hasJsonMergePatchExtension), + ) + } } return classBuilder.build() } @@ -519,7 +489,7 @@ class JacksonModelGenerator( discriminator: Discriminator, superType: SchemaInfo, extensions: Map, - oneOfSuperInterfaces: Set, + oneOfSuperInterfaces: Set, allSchemas: List, ): TypeSpec = with(FunSpec.constructorBuilder()) { TypeSpec.classBuilder(generatedType(packages.base, modelName)) @@ -549,7 +519,7 @@ class JacksonModelGenerator( discriminator: Discriminator, allSchemas: List, members: List, - oneOfSuperInterfaces: Set, + oneOfSuperInterfaces: Set, ): TypeSpec { val interfaceBuilder = TypeSpec.interfaceBuilder(generatedType(packages.base, modelName)) .addModifiers(KModifier.SEALED) @@ -589,7 +559,7 @@ class JacksonModelGenerator( properties: Collection, discriminator: Discriminator, extensions: Map, - oneOfSuperInterfaces: Set, + oneOfSuperInterfaces: Set, allSchemas: List, ): TypeSpec = TypeSpec.classBuilder(generatedType(packages.base, modelName)) .buildPolymorphicSuperType( @@ -609,7 +579,7 @@ class JacksonModelGenerator( properties: Collection, discriminator: Discriminator, extensions: Map, - oneOfSuperInterfaces: Set, + oneOfSuperInterfaces: Set, allSchemas: List, constructorBuilder: FunSpec.Builder = FunSpec.constructorBuilder(), ): TypeSpec.Builder { @@ -661,7 +631,7 @@ class JacksonModelGenerator( properties: Collection, superType: SchemaInfo, extensions: Map, - oneOfSuperInterfaces: Set, + oneOfSuperInterfaces: Set, ): TypeSpec = TypeSpec.classBuilder(generatedType(packages.base, modelName)) .buildPolymorphicSubType(schemaName, properties, superType, extensions, oneOfSuperInterfaces).build() @@ -670,7 +640,7 @@ class JacksonModelGenerator( allProperties: Collection, superType: SchemaInfo, extensions: Map, - oneOfSuperInterfaces: Set, + oneOfSuperInterfaces: Set, constructorBuilder: FunSpec.Builder = FunSpec.constructorBuilder(), ): TypeSpec.Builder { this.addSerializableInterface() diff --git a/src/main/kotlin/com/cjbooms/fabrikt/model/PropertyInfo.kt b/src/main/kotlin/com/cjbooms/fabrikt/model/PropertyInfo.kt index b59827f9..8315e5c7 100644 --- a/src/main/kotlin/com/cjbooms/fabrikt/model/PropertyInfo.kt +++ b/src/main/kotlin/com/cjbooms/fabrikt/model/PropertyInfo.kt @@ -15,6 +15,7 @@ import com.cjbooms.fabrikt.util.KaizenParserExtensions.safeName import com.cjbooms.fabrikt.util.KaizenParserExtensions.safeType import com.cjbooms.fabrikt.util.NormalisedString.camelCase import com.cjbooms.fabrikt.util.NormalisedString.toEnumName +import com.reprezen.kaizen.oasparser.model3.OpenApi3 import com.reprezen.kaizen.oasparser.model3.Schema sealed class PropertyInfo { @@ -38,7 +39,7 @@ sealed class PropertyInfo { val HTTP_SETTINGS = Settings() - fun Schema.topLevelProperties(settings: Settings, enclosingSchema: Schema? = null): Collection { + fun Schema.topLevelProperties(settings: Settings, api: OpenApi3, enclosingSchema: Schema? = null): Collection { val results = mutableListOf() + allOfSchemas.flatMap { it.topLevelProperties( @@ -47,12 +48,13 @@ sealed class PropertyInfo { enclosingSchema, it ), + api, this ) } + (if (oneOfSchemas.isEmpty()) emptyList() else listOf(OneOfAny(oneOfSchemas.first()))) + - anyOfSchemas.flatMap { it.topLevelProperties(settings.copy(markAllOptional = true), this) } + - getInLinedProperties(settings, enclosingSchema) + anyOfSchemas.flatMap { it.topLevelProperties(settings.copy(markAllOptional = true), api, this) } + + getInLinedProperties(settings, api, enclosingSchema) return results.distinctBy { it.oasKey } } @@ -68,13 +70,14 @@ sealed class PropertyInfo { private fun Schema.getInLinedProperties( settings: Settings, + api: OpenApi3, enclosingSchema: Schema? = null ): Collection { val mainProperties: List = properties.map { property -> when (property.value.safeType()) { OasType.Array.type -> ListField( - isRequired(property, settings.markReadWriteOnlyOptional, settings.markAllOptional), + isRequired(api, property, settings.markReadWriteOnlyOptional, settings.markAllOptional), property.key, property.value, settings.markAsInherited, @@ -87,6 +90,7 @@ sealed class PropertyInfo { if (property.value.isSimpleMapDefinition() || property.value.isSchemaLess()) MapField( isRequired = isRequired( + api, property, settings.markReadWriteOnlyOptional, settings.markAllOptional @@ -99,7 +103,7 @@ sealed class PropertyInfo { else if (property.value.isInlinedObjectDefinition()) ObjectInlinedField( isRequired = isRequired( - property, settings.markReadWriteOnlyOptional, settings.markAllOptional + api, property, settings.markReadWriteOnlyOptional, settings.markAllOptional ), oasKey = property.key, schema = property.value, @@ -109,7 +113,7 @@ sealed class PropertyInfo { ) else ObjectRefField( - isRequired(property, settings.markReadWriteOnlyOptional, settings.markAllOptional), + isRequired(api, property, settings.markReadWriteOnlyOptional, settings.markAllOptional), property.key, property.value, settings.markAsInherited, @@ -120,13 +124,13 @@ sealed class PropertyInfo { null } else { Field( - isRequired(property, settings.markReadWriteOnlyOptional, settings.markAllOptional), + isRequired(api, property, settings.markReadWriteOnlyOptional, settings.markAllOptional), oasKey = property.key, schema = property.value, isInherited = settings.markAsInherited, - isPolymorphicDiscriminator = isDiscriminatorProperty(property), + isPolymorphicDiscriminator = isDiscriminatorProperty(api, property), maybeDiscriminator = enclosingSchema?.let { - this.getKeyIfSingleDiscriminatorValue(property, it) + this.getKeyIfSingleDiscriminatorValue(api, property, it) }, enclosingSchema = if (property.value.isInlinedEnumDefinition()) this else null ) diff --git a/src/main/kotlin/com/cjbooms/fabrikt/util/KaizenParserExtensions.kt b/src/main/kotlin/com/cjbooms/fabrikt/util/KaizenParserExtensions.kt index 26d8fb08..6cab6c81 100644 --- a/src/main/kotlin/com/cjbooms/fabrikt/util/KaizenParserExtensions.kt +++ b/src/main/kotlin/com/cjbooms/fabrikt/util/KaizenParserExtensions.kt @@ -138,6 +138,7 @@ object KaizenParserExtensions { api.schemas.values.firstOrNull { it.name == safeName() } fun Schema.isRequired( + api: OpenApi3, prop: Map.Entry, markReadWriteOnlyOptional: Boolean, markAllOptional: Boolean, @@ -145,19 +146,50 @@ object KaizenParserExtensions { if (markAllOptional || (prop.value.isReadOnly && markReadWriteOnlyOptional) || (prop.value.isWriteOnly && markReadWriteOnlyOptional)) { false } else { - requiredFields.contains(prop.key) || isDiscriminatorProperty(prop) // A discriminator property should be required + requiredFields.contains(prop.key) || isDiscriminatorProperty(api, prop) // A discriminator property should be required } fun Schema.getSchemaRefName() = Overlay.of(this).jsonReference.split("/").last() - fun Schema.isDiscriminatorProperty(prop: Map.Entry): Boolean = - discriminator?.propertyName == prop.key + fun Schema.isDiscriminatorProperty(api: OpenApi3, prop: Map.Entry): Boolean = + discriminator?.propertyName == prop.key || + findOneOfSuperInterface(api.schemas.values.toList()).any { oneOf -> + oneOf.discriminator?.mappings?.values?.any { it.endsWith("/$name") } ?: false + } + + fun Schema.findOneOfSuperInterface(allSchemas: List): Set { + if (ModelCodeGenOptionType.SEALED_INTERFACES_FOR_ONE_OF !in MutableSettings.modelOptions()) { + return emptySet() + } + return allSchemas + .filter { it.discriminator != null && it.oneOfSchemas.isNotEmpty() } + .mapNotNull { schema -> + schema.discriminator.mappings + .toList() + .find { (_, ref) -> + ref.endsWith("/${name}") + } + ?.let { (key, _) -> + Pair(key!!, schema) + } + } + .map { (_, parent) -> + val field = parent.discriminator.propertyName!! + if (!properties.containsKey(field)) { + throw IllegalArgumentException("schema $name did not have discriminator property") + } + parent + } + .toSet() + } fun Schema.getKeyIfSingleDiscriminatorValue( + api: OpenApi3, prop: Map.Entry, enclosingSchema: Schema, ): Map? = - if (isDiscriminatorProperty(prop) && discriminator.mappingKeys(enclosingSchema).isNotEmpty()) { + if (isDiscriminatorProperty(api, prop)) { + val discriminator = findDiscriminator(api) discriminator.mappingKeys(enclosingSchema).map { if (prop.value.isEnumDefinition()) { it.key to PropertyInfo.DiscriminatorKey.EnumKey(it.key, it.value) @@ -169,6 +201,18 @@ object KaizenParserExtensions { null } + private fun Schema.findDiscriminator(api: OpenApi3): Discriminator { + val bestDiscriminator = if (this.hasDiscriminator()) { + this.discriminator + } else { + val oneOfDiscriminator = findOneOfSuperInterface(api.schemas.values.toList()).firstOrNull { oneOfInterface -> + oneOfInterface.hasDiscriminator() + }?.discriminator + oneOfDiscriminator ?: this.discriminator + } + return bestDiscriminator + } + fun Discriminator.mappingKeys(enclosingSchema: Schema): Map { val discriminatorMappings = mappings?.map { it.key to it.value.split("/").last() }?.toMap() return if (discriminatorMappings.isNullOrEmpty()) { @@ -187,6 +231,7 @@ object KaizenParserExtensions { } fun Schema.hasNoDiscriminator(): Boolean = this.discriminator.propertyName == null + fun Schema.hasDiscriminator(): Boolean = !hasNoDiscriminator() fun Schema.safeName(): String = when { diff --git a/src/test/kotlin/com/cjbooms/fabrikt/generators/ModelGeneratorTest.kt b/src/test/kotlin/com/cjbooms/fabrikt/generators/ModelGeneratorTest.kt index 2b26c8c3..941b8980 100644 --- a/src/test/kotlin/com/cjbooms/fabrikt/generators/ModelGeneratorTest.kt +++ b/src/test/kotlin/com/cjbooms/fabrikt/generators/ModelGeneratorTest.kt @@ -76,6 +76,9 @@ class ModelGeneratorTest { if (testCaseName == "instantDateTime") { MutableSettings.addOption(CodeGenTypeOverride.DATETIME_AS_INSTANT) } + if (testCaseName == "discriminatedOneOf") { + MutableSettings.addOption(ModelCodeGenOptionType.SEALED_INTERFACES_FOR_ONE_OF) + } val basePackage = "examples.${testCaseName.replace("/", ".")}" val apiLocation = javaClass.getResource("/examples/$testCaseName/api.yaml")!! val sourceApi = SourceApi(apiLocation.readText(), baseDir = Paths.get(apiLocation.toURI())) diff --git a/src/test/resources/examples/discriminatedOneOf/models/Models.kt b/src/test/resources/examples/discriminatedOneOf/models/Models.kt index 7b73d7f4..3b2ffdc4 100644 --- a/src/test/resources/examples/discriminatedOneOf/models/Models.kt +++ b/src/test/resources/examples/discriminatedOneOf/models/Models.kt @@ -33,9 +33,19 @@ public data class SomeObj( ) public sealed interface State -public object StateA : State +public data class StateA( + @get:JsonProperty("status") + @get:NotNull + @param:JsonProperty("status") + public val status: Status = Status.A, +) : State -public object StateB : State +public data class StateB( + @get:JsonProperty("status") + @get:NotNull + @param:JsonProperty("status") + public val status: Status = Status.B, +) : State public enum class Status( @JsonValue