Skip to content

Commit

Permalink
Do not strip discriminator property in oneOf generation. (#268)
Browse files Browse the repository at this point in the history
* Do not strip discriminator property in oneOf generation.

* revert test and formatting changes

* correct test

* Change oneOf so that it defaults to the correct discriminator value
  • Loading branch information
cjbooms authored Mar 5, 2024
1 parent 1222719 commit 61af512
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ data class ClassSettings(
NONE,
SUPER,
SUB,
ONE_OF,
}
}

Expand Down Expand Up @@ -102,11 +103,16 @@ object PropertyUtils {
property.addAnnotation(JacksonMetadata.jacksonPropertyAnnotation(oasKey))
property.addValidationAnnotations(this, validationAnnotations)
}

ClassSettings.PolymorphyType.ONE_OF -> {
property.addAnnotation(JacksonMetadata.jacksonPropertyAnnotation(oasKey))
property.addValidationAnnotations(this, validationAnnotations)
}
}

if (isDiscriminatorFieldWithSingleKnownValue(classSettings, schemaName)) {
this as PropertyInfo.Field
if (classSettings.polymorphyType == ClassSettings.PolymorphyType.SUB) {
if (classSettings.polymorphyType in listOf(ClassSettings.PolymorphyType.SUB, ClassSettings.PolymorphyType.ONE_OF)) {
property.initializer(name)
property.addAnnotation(JacksonMetadata.jacksonParameterAnnotation(oasKey))
val constructorParameter: ParameterSpec.Builder = ParameterSpec.builder(name, wrappedType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import com.cjbooms.fabrikt.model.PropertyInfo.Companion.topLevelProperties
import com.cjbooms.fabrikt.model.SchemaInfo
import com.cjbooms.fabrikt.model.SourceApi
import com.cjbooms.fabrikt.model.toEnclosingSchemaInfo
import com.cjbooms.fabrikt.util.KaizenParserExtensions.findOneOfSuperInterface
import com.cjbooms.fabrikt.util.KaizenParserExtensions.getDiscriminatorForInLinedObjectUnderAllOf
import com.cjbooms.fabrikt.util.KaizenParserExtensions.getSchemaRefName
import com.cjbooms.fabrikt.util.KaizenParserExtensions.getSuperType
Expand Down Expand Up @@ -164,8 +165,11 @@ class JacksonModelGenerator(
.filterNot { it.schema.isSimpleType() }
.filterNot { it.schema.isOneOfPolymorphicTypes() }
.flatMap {
val properties = it.schema.topLevelProperties(HTTP_SETTINGS, it.schema)
if (properties.isNotEmpty() || it.typeInfo is KotlinTypeInfo.Enum) {
val properties = it.schema.topLevelProperties(HTTP_SETTINGS, api, it.schema)
if (properties.isNotEmpty() ||
it.typeInfo is KotlinTypeInfo.Enum ||
it.schema.findOneOfSuperInterface(schemas.map { it.schema }).isNotEmpty()
) {
val primaryModel = buildPrimaryModel(api, it, properties, schemas)
val inlinedModels = buildInLinedModels(properties, it.schema, it.schema.getDocumentUrl())
listOf(primaryModel) + inlinedModels
Expand Down Expand Up @@ -196,7 +200,7 @@ class JacksonModelGenerator(
schemaInfo.schema.discriminator,
allSchemas,
schemaInfo.schema.oneOfSchemas,
findOneOfSuperInterface(allSchemas, schemaInfo, options),
schemaInfo.schema.findOneOfSuperInterface(allSchemas.map { it.schema }),
)

schemaInfo.schema.isPolymorphicSuperType() && schemaInfo.schema.isPolymorphicSubType(api) ->
Expand All @@ -207,7 +211,7 @@ class JacksonModelGenerator(
checkNotNull(schemaInfo.schema.getDiscriminatorForInLinedObjectUnderAllOf()),
schemaInfo.schema.getSuperType(api)!!.let { SchemaInfo(it.name, it) },
schemaInfo.schema.extensions,
findOneOfSuperInterface(allSchemas, schemaInfo, options),
schemaInfo.schema.findOneOfSuperInterface(allSchemas.map { it.schema }),
allSchemas,
)

Expand All @@ -217,7 +221,7 @@ class JacksonModelGenerator(
properties,
schemaInfo.schema.discriminator,
schemaInfo.schema.extensions,
findOneOfSuperInterface(allSchemas, schemaInfo, options),
schemaInfo.schema.findOneOfSuperInterface(allSchemas.map { it.schema }),
allSchemas,
)

Expand All @@ -227,7 +231,7 @@ class JacksonModelGenerator(
properties,
schemaInfo.schema.getSuperType(api)!!.let { SchemaInfo(it.name, it) },
schemaInfo.schema.extensions,
findOneOfSuperInterface(allSchemas, schemaInfo, options),
schemaInfo.schema.findOneOfSuperInterface(allSchemas.map { it.schema }),
)

schemaInfo.typeInfo is KotlinTypeInfo.Enum -> buildEnumClass(schemaInfo.typeInfo)
Expand All @@ -236,40 +240,12 @@ class JacksonModelGenerator(
schemaName = schemaName,
properties = properties,
extensions = schemaInfo.schema.extensions,
oneOfInterfaces = findOneOfSuperInterface(allSchemas, schemaInfo, options),
oneOfInterfaces = schemaInfo.schema.findOneOfSuperInterface(allSchemas.map { it.schema }),
)
}
}

private fun findOneOfSuperInterface(
allSchemas: List<SchemaInfo>,
schema: SchemaInfo,
options: Set<ModelCodeGenOptionType>,
): Set<SchemaInfo> {
if (SEALED_INTERFACES_FOR_ONE_OF !in options) {
return emptySet()
}
return allSchemas
.filter { it.schema.discriminator != null && it.schema.oneOfSchemas.isNotEmpty() }
.mapNotNull { info ->
info.schema.discriminator.mappings
.toList()
.find { (_, ref) ->
ref.endsWith("/${schema.name}")
}
?.let { (key, _) ->
Pair(key!!, info)
}
}
.map { (_, parent) ->
val field = parent.schema.discriminator.propertyName!!
if (!schema.schema.properties.containsKey(field)) {
throw IllegalArgumentException("schema $schema did not have discriminator property")
}
parent
}
.toSet()
}


private fun buildInLinedModels(
topLevelProperties: Collection<PropertyInfo>,
Expand All @@ -285,7 +261,7 @@ class JacksonModelGenerator(
if (it.isInherited) {
emptySet() // Rely on the parent definition
} else {
val props = it.schema.topLevelProperties(HTTP_SETTINGS, enclosingSchema)
val props = it.schema.topLevelProperties(HTTP_SETTINGS, sourceApi.openApi3, enclosingSchema)
val currentModel = standardDataClass(
ModelNameRegistry.getOrRegister(it.schema, enclosingSchema.toEnclosingSchemaInfo()),
it.name,
Expand All @@ -308,7 +284,7 @@ class JacksonModelGenerator(
standardDataClass(
modelName = ModelNameRegistry.getOrRegister(it.schema, valueSuffix = it.schema.isInlinedTypedAdditionalProperties()),
schemaName = it.name,
properties = it.schema.topLevelProperties(HTTP_SETTINGS, enclosingSchema),
properties = it.schema.topLevelProperties(HTTP_SETTINGS, sourceApi.openApi3, enclosingSchema),
extensions = it.schema.extensions,
oneOfInterfaces = emptySet(),
),
Expand Down Expand Up @@ -347,7 +323,7 @@ class JacksonModelGenerator(
?: enclosingSchema.toEnclosingSchemaInfo()
when {
items.isInlinedObjectDefinition() ->
items.topLevelProperties(HTTP_SETTINGS, enclosingSchema).let { props ->
items.topLevelProperties(HTTP_SETTINGS, sourceApi.openApi3, enclosingSchema).let { props ->
buildInLinedModels(
topLevelProperties = props,
enclosingSchema = enclosingSchema,
Expand Down Expand Up @@ -454,7 +430,7 @@ class JacksonModelGenerator(
standardDataClass(
modelName = ModelNameRegistry.getOrRegister(schema, valueSuffix = schema.isInlinedTypedAdditionalProperties()),
schemaName = schema.safeName(),
properties = mapField.schema.additionalPropertiesSchema.topLevelProperties(HTTP_SETTINGS),
properties = mapField.schema.additionalPropertiesSchema.topLevelProperties(HTTP_SETTINGS, sourceApi.openApi3),
extensions = mapField.schema.extensions,
oneOfInterfaces = emptySet(),
)
Expand All @@ -467,24 +443,10 @@ class JacksonModelGenerator(
schemaName: String,
properties: Collection<PropertyInfo>,
extensions: Map<String, Any>,
oneOfInterfaces: Set<SchemaInfo>,
oneOfInterfaces: Set<Schema>,
): TypeSpec {
val filteredProperties = if (oneOfInterfaces.size == 1) {
val oneOfInterface = oneOfInterfaces.first()
val discriminatorProp = oneOfInterface.schema.discriminator?.propertyName
val mappingCount =
oneOfInterface.schema.discriminator?.mappings?.values?.count { it.endsWith("/$modelName") }
if (discriminatorProp != null && mappingCount == 1) {
properties.filterNot { it.name == discriminatorProp }
} else {
properties
}
} else {
properties
}

val name = generatedType(packages.base, modelName)
val generateObject = properties.isNotEmpty() && filteredProperties.isEmpty()
val generateObject = properties.isEmpty()
val builder =
if (generateObject) {
TypeSpec.objectBuilder(name)
Expand All @@ -499,15 +461,23 @@ class JacksonModelGenerator(
.addCompanionObject()
for (oneOfInterface in oneOfInterfaces) {
classBuilder
.addSuperinterface(generatedType(packages.base, ModelNameRegistry.getOrRegister(oneOfInterface.schema)))
.addSuperinterface(generatedType(packages.base, ModelNameRegistry.getOrRegister(oneOfInterface)))
}

if (!generateObject) {
filteredProperties.addToClass(
schemaName = schemaName,
classBuilder = classBuilder,
classType = ClassSettings(ClassSettings.PolymorphyType.NONE, extensions.hasJsonMergePatchExtension),
)
if (oneOfInterfaces.size == 1) {
properties.addToClass(
schemaName = schemaName,
classBuilder = classBuilder,
classType = ClassSettings(ClassSettings.PolymorphyType.ONE_OF, extensions.hasJsonMergePatchExtension),
)
} else {
properties.addToClass(
schemaName = schemaName,
classBuilder = classBuilder,
classType = ClassSettings(ClassSettings.PolymorphyType.NONE, extensions.hasJsonMergePatchExtension),
)
}
}
return classBuilder.build()
}
Expand All @@ -519,7 +489,7 @@ class JacksonModelGenerator(
discriminator: Discriminator,
superType: SchemaInfo,
extensions: Map<String, Any>,
oneOfSuperInterfaces: Set<SchemaInfo>,
oneOfSuperInterfaces: Set<Schema>,
allSchemas: List<SchemaInfo>,
): TypeSpec = with(FunSpec.constructorBuilder()) {
TypeSpec.classBuilder(generatedType(packages.base, modelName))
Expand Down Expand Up @@ -549,7 +519,7 @@ class JacksonModelGenerator(
discriminator: Discriminator,
allSchemas: List<SchemaInfo>,
members: List<Schema>,
oneOfSuperInterfaces: Set<SchemaInfo>,
oneOfSuperInterfaces: Set<Schema>,
): TypeSpec {
val interfaceBuilder = TypeSpec.interfaceBuilder(generatedType(packages.base, modelName))
.addModifiers(KModifier.SEALED)
Expand Down Expand Up @@ -589,7 +559,7 @@ class JacksonModelGenerator(
properties: Collection<PropertyInfo>,
discriminator: Discriminator,
extensions: Map<String, Any>,
oneOfSuperInterfaces: Set<SchemaInfo>,
oneOfSuperInterfaces: Set<Schema>,
allSchemas: List<SchemaInfo>,
): TypeSpec = TypeSpec.classBuilder(generatedType(packages.base, modelName))
.buildPolymorphicSuperType(
Expand All @@ -609,7 +579,7 @@ class JacksonModelGenerator(
properties: Collection<PropertyInfo>,
discriminator: Discriminator,
extensions: Map<String, Any>,
oneOfSuperInterfaces: Set<SchemaInfo>,
oneOfSuperInterfaces: Set<Schema>,
allSchemas: List<SchemaInfo>,
constructorBuilder: FunSpec.Builder = FunSpec.constructorBuilder(),
): TypeSpec.Builder {
Expand Down Expand Up @@ -661,7 +631,7 @@ class JacksonModelGenerator(
properties: Collection<PropertyInfo>,
superType: SchemaInfo,
extensions: Map<String, Any>,
oneOfSuperInterfaces: Set<SchemaInfo>,
oneOfSuperInterfaces: Set<Schema>,
): TypeSpec = TypeSpec.classBuilder(generatedType(packages.base, modelName))
.buildPolymorphicSubType(schemaName, properties, superType, extensions, oneOfSuperInterfaces).build()

Expand All @@ -670,7 +640,7 @@ class JacksonModelGenerator(
allProperties: Collection<PropertyInfo>,
superType: SchemaInfo,
extensions: Map<String, Any>,
oneOfSuperInterfaces: Set<SchemaInfo>,
oneOfSuperInterfaces: Set<Schema>,
constructorBuilder: FunSpec.Builder = FunSpec.constructorBuilder(),
): TypeSpec.Builder {
this.addSerializableInterface()
Expand Down
22 changes: 13 additions & 9 deletions src/main/kotlin/com/cjbooms/fabrikt/model/PropertyInfo.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import com.cjbooms.fabrikt.util.KaizenParserExtensions.safeName
import com.cjbooms.fabrikt.util.KaizenParserExtensions.safeType
import com.cjbooms.fabrikt.util.NormalisedString.camelCase
import com.cjbooms.fabrikt.util.NormalisedString.toEnumName
import com.reprezen.kaizen.oasparser.model3.OpenApi3
import com.reprezen.kaizen.oasparser.model3.Schema

sealed class PropertyInfo {
Expand All @@ -38,7 +39,7 @@ sealed class PropertyInfo {

val HTTP_SETTINGS = Settings()

fun Schema.topLevelProperties(settings: Settings, enclosingSchema: Schema? = null): Collection<PropertyInfo> {
fun Schema.topLevelProperties(settings: Settings, api: OpenApi3, enclosingSchema: Schema? = null): Collection<PropertyInfo> {
val results = mutableListOf<PropertyInfo>() +
allOfSchemas.flatMap {
it.topLevelProperties(
Expand All @@ -47,12 +48,13 @@ sealed class PropertyInfo {
enclosingSchema,
it
),
api,
this
)
} +
(if (oneOfSchemas.isEmpty()) emptyList() else listOf(OneOfAny(oneOfSchemas.first()))) +
anyOfSchemas.flatMap { it.topLevelProperties(settings.copy(markAllOptional = true), this) } +
getInLinedProperties(settings, enclosingSchema)
anyOfSchemas.flatMap { it.topLevelProperties(settings.copy(markAllOptional = true), api, this) } +
getInLinedProperties(settings, api, enclosingSchema)
return results.distinctBy { it.oasKey }
}

Expand All @@ -68,13 +70,14 @@ sealed class PropertyInfo {

private fun Schema.getInLinedProperties(
settings: Settings,
api: OpenApi3,
enclosingSchema: Schema? = null
): Collection<PropertyInfo> {
val mainProperties: List<PropertyInfo> = properties.map { property ->
when (property.value.safeType()) {
OasType.Array.type ->
ListField(
isRequired(property, settings.markReadWriteOnlyOptional, settings.markAllOptional),
isRequired(api, property, settings.markReadWriteOnlyOptional, settings.markAllOptional),
property.key,
property.value,
settings.markAsInherited,
Expand All @@ -87,6 +90,7 @@ sealed class PropertyInfo {
if (property.value.isSimpleMapDefinition() || property.value.isSchemaLess())
MapField(
isRequired = isRequired(
api,
property,
settings.markReadWriteOnlyOptional,
settings.markAllOptional
Expand All @@ -99,7 +103,7 @@ sealed class PropertyInfo {
else if (property.value.isInlinedObjectDefinition())
ObjectInlinedField(
isRequired = isRequired(
property, settings.markReadWriteOnlyOptional, settings.markAllOptional
api, property, settings.markReadWriteOnlyOptional, settings.markAllOptional
),
oasKey = property.key,
schema = property.value,
Expand All @@ -109,7 +113,7 @@ sealed class PropertyInfo {
)
else
ObjectRefField(
isRequired(property, settings.markReadWriteOnlyOptional, settings.markAllOptional),
isRequired(api, property, settings.markReadWriteOnlyOptional, settings.markAllOptional),
property.key,
property.value,
settings.markAsInherited,
Expand All @@ -120,13 +124,13 @@ sealed class PropertyInfo {
null
} else {
Field(
isRequired(property, settings.markReadWriteOnlyOptional, settings.markAllOptional),
isRequired(api, property, settings.markReadWriteOnlyOptional, settings.markAllOptional),
oasKey = property.key,
schema = property.value,
isInherited = settings.markAsInherited,
isPolymorphicDiscriminator = isDiscriminatorProperty(property),
isPolymorphicDiscriminator = isDiscriminatorProperty(api, property),
maybeDiscriminator = enclosingSchema?.let {
this.getKeyIfSingleDiscriminatorValue(property, it)
this.getKeyIfSingleDiscriminatorValue(api, property, it)
},
enclosingSchema = if (property.value.isInlinedEnumDefinition()) this else null
)
Expand Down
Loading

0 comments on commit 61af512

Please sign in to comment.