Skip to content

Commit

Permalink
Merge pull request #33 from zappolowski/bug/19-enum-default-value-han…
Browse files Browse the repository at this point in the history
…dling

Fix input type code generation for enum fields with default values
  • Loading branch information
paulbakker authored Feb 17, 2021
2 parents ef07ca1 + 5abae8a commit 03692e0
Show file tree
Hide file tree
Showing 6 changed files with 428 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ import com.netflix.graphql.dgs.codegen.shouldSkip
import com.squareup.javapoet.*
import graphql.language.*
import graphql.language.TypeName
import java.time.LocalDate
import java.time.LocalDateTime
import java.time.LocalTime
import java.time.OffsetDateTime
import javax.lang.model.element.Modifier

class DataTypeGenerator(config: CodeGenConfig) : BaseDataTypeGenerator(config.packageNameTypes, config) {
Expand Down Expand Up @@ -58,26 +54,33 @@ class InputTypeGenerator(config: CodeGenConfig) : BaseDataTypeGenerator(config.p
val name = definition.name

val fieldDefinitions = definition.inputValueDefinitions.map {
var defaultValue: Any
if (it.defaultValue != null) {
defaultValue = when (it.defaultValue) {
is BooleanValue -> (it.defaultValue as BooleanValue).isValue
is IntValue -> (it.defaultValue as graphql.language.IntValue).value
is StringValue -> (it.defaultValue as graphql.language.StringValue).value
is FloatValue -> (it.defaultValue as graphql.language.FloatValue).value
else -> it.defaultValue
}
Field(it.name, typeUtils.findReturnType(it.type), defaultValue)
} else {
Field(it.name, typeUtils.findReturnType(it.type))
val defaultValue = it.defaultValue?.let { defVal ->
when (defVal) {
is BooleanValue -> CodeBlock.of("\$L", defVal.isValue)
is IntValue -> CodeBlock.of("\$L", defVal.value)
is StringValue -> CodeBlock.of("\$S", defVal.value)
is FloatValue -> CodeBlock.of("\$L", defVal.value)
is EnumValue -> CodeBlock.of("\$T.\$N", typeUtils.findReturnType(it.type), defVal.name)
is ArrayValue -> if(defVal.values.isEmpty()) CodeBlock.of("java.util.Collections.emptyList()") else CodeBlock.of("java.util.Arrays.asList(\$L)", defVal.values.map { v ->
when(v) {
is BooleanValue -> CodeBlock.of("\$L", v.isValue)
is IntValue -> CodeBlock.of("\$L", v.value)
is StringValue -> CodeBlock.of("\$S", v.value)
is FloatValue -> CodeBlock.of("\$L", v.value)
is EnumValue -> CodeBlock.of("\$L.\$N", ((it.type as ListType).type as TypeName).name, v.name)
else -> ""
}
}.joinToString())
else -> CodeBlock.of("\$L", defVal)
}
}

Field(it.name, typeUtils.findReturnType(it.type), defaultValue)
}.plus(extensions.flatMap { it.inputValueDefinitions }.map { Field(it.name, typeUtils.findReturnType(it.type)) })
return generate(name, emptyList(), fieldDefinitions, true)
}
}

internal data class Field(val name: String, val type: com.squareup.javapoet.TypeName, val initialValue: Any? = null)
internal data class Field(val name: String, val type: com.squareup.javapoet.TypeName, val initialValue: CodeBlock? = null)

abstract class BaseDataTypeGenerator(internal val packageName: String, config: CodeGenConfig) {
internal val typeUtils = TypeUtils(packageName, config)
Expand Down Expand Up @@ -290,20 +293,12 @@ abstract class BaseDataTypeGenerator(internal val packageName: String, config: C
}

private fun addFieldWithGetterAndSetter(returnType: com.squareup.javapoet.TypeName?, fieldDefinition: Field, javaType: TypeSpec.Builder) {
if (fieldDefinition.initialValue != null) {
var initializerBlock = if (fieldDefinition.type.toString().contains("String")) {
"\"${fieldDefinition.initialValue}\""
} else {
"${fieldDefinition.initialValue}"
}
val field = FieldSpec.builder(fieldDefinition.type, fieldDefinition.name).addModifiers(Modifier.PRIVATE)
.initializer(initializerBlock)
.build()
javaType.addField(field)
val field = if (fieldDefinition.initialValue != null) {
FieldSpec.builder(fieldDefinition.type, fieldDefinition.name).addModifiers(Modifier.PRIVATE).initializer(fieldDefinition.initialValue).build()
} else {
val field = FieldSpec.builder(returnType, ReservedKeywordSanitizer.sanitize(fieldDefinition.name)).addModifiers(Modifier.PRIVATE).build()
javaType.addField(field)
FieldSpec.builder(returnType, ReservedKeywordSanitizer.sanitize(fieldDefinition.name)).addModifiers(Modifier.PRIVATE).build()
}
javaType.addField(field)

val getterName = "get${fieldDefinition.name[0].toUpperCase()}${fieldDefinition.name.substring(1)}"
javaType.addMethod(MethodSpec.methodBuilder(getterName).addModifiers(Modifier.PUBLIC).returns(returnType).addStatement("return \$N", ReservedKeywordSanitizer.sanitize(fieldDefinition.name)).build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package com.netflix.graphql.dgs.codegen.generators.java
import com.netflix.graphql.dgs.codegen.CodeGenConfig
import com.netflix.graphql.dgs.codegen.CodeGenResult
import com.squareup.javapoet.ClassName
import com.squareup.javapoet.CodeBlock
import graphql.language.*


Expand Down Expand Up @@ -48,7 +49,7 @@ class EntitiesRepresentationTypeGenerator(val config: CodeGenConfig): BaseDataTy
}
var result = CodeGenResult()
// generate representations of entity types that have @key, including the __typename field, and the key fields
val typeName = Field("__typename", ClassName.get(String::class.java), definition.name)
val typeName = Field("__typename", ClassName.get(String::class.java), CodeBlock.of("\$S", definition.name))
val fieldDefinitions = definition.fieldDefinitions
.filter {
keyFields.containsKey(it.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import com.netflix.graphql.dgs.codegen.KotlinCodeGenResult
import com.netflix.graphql.dgs.codegen.filterSkipped
import com.netflix.graphql.dgs.codegen.shouldSkip
import com.squareup.kotlinpoet.*
import com.squareup.kotlinpoet.TypeName
import graphql.language.*


class KotlinDataTypeGenerator(private val config: CodeGenConfig, private val document: Document): AbstractKotlinDataTypeGenerator(config.packageNameTypes, config) {
fun generate(definition: ObjectTypeDefinition, extensions: List<ObjectTypeExtensionDefinition>): KotlinCodeGenResult {
if(definition.shouldSkip()) {
Expand Down Expand Up @@ -53,33 +53,42 @@ class KotlinInputTypeGenerator(private val config: CodeGenConfig, private val do
fun generate(definition: InputObjectTypeDefinition, extensions: List<InputObjectTypeExtensionDefinition>): KotlinCodeGenResult {

val fields = definition.inputValueDefinitions
.filter(ReservedKeywordFilter.filterInvalidNames)
.map {
val defaultValue: Any
if (it.defaultValue != null) {
defaultValue = when (it.defaultValue) {
is BooleanValue -> (it.defaultValue as BooleanValue).isValue
is IntValue -> (it.defaultValue as IntValue).value
is StringValue -> (it.defaultValue as StringValue).value
is FloatValue -> (it.defaultValue as FloatValue).value
else -> it.defaultValue
}

Field(it.name, typeUtils.findReturnType(it.type), typeUtils.isNullable(it.type), defaultValue)
} else {
Field(it.name, typeUtils.findReturnType(it.type), typeUtils.isNullable(it.type))
}
}.plus(extensions.flatMap { it.inputValueDefinitions }.map { Field(it.name, typeUtils.findReturnType(it.type), typeUtils.isNullable(it.type)) })
.filter(ReservedKeywordFilter.filterInvalidNames)
.map {
val type = typeUtils.findReturnType(it.type)
val defaultValue = it.defaultValue?.let { value -> generateCode(value, type) }
Field(it.name, type, typeUtils.isNullable(it.type), defaultValue)
}.plus(extensions.flatMap { it.inputValueDefinitions }.map { Field(it.name, typeUtils.findReturnType(it.type), typeUtils.isNullable(it.type)) })
val interfaces = emptyList<Type<*>>()
return generate(definition.name, fields, interfaces, true, document)
}

private fun generateCode(value: Value<Value<*>>, type: TypeName): CodeBlock =
when (value) {
is BooleanValue -> CodeBlock.of("%L", value.isValue)
is IntValue -> CodeBlock.of("%L", value.value)
is StringValue -> CodeBlock.of("%S", value.value)
is FloatValue -> CodeBlock.of("%L", value.value)
is EnumValue -> CodeBlock.of("%M", MemberName(type.className, value.name))
is ArrayValue ->
if (value.values.isEmpty()) CodeBlock.of("emptyList()")
else CodeBlock.of("listOf(%L)", value.values.joinToString { v -> generateCode(v, type).toString() })
else -> CodeBlock.of("%L", value)
}

private val TypeName.className: ClassName
get() = when (this) {
is ClassName -> this
is ParameterizedTypeName -> typeArguments[0].className
else -> TODO()
}

override fun getPackageName(): String {
return config.packageNameTypes
}
}

internal data class Field(val name: String, val type: com.squareup.kotlinpoet.TypeName, val nullable: Boolean, val default: Any? = null)
internal data class Field(val name: String, val type: com.squareup.kotlinpoet.TypeName, val nullable: Boolean, val default: CodeBlock? = null)

abstract class AbstractKotlinDataTypeGenerator(private val packageName: String, private val config: CodeGenConfig) {
protected val typeUtils = KotlinTypeUtils(packageName, config)
Expand All @@ -104,12 +113,7 @@ abstract class AbstractKotlinDataTypeGenerator(private val packageName: String,
.addAnnotation(jsonPropertyAnnotation(field.name))

if (field.default != null) {
val initializerBlock = if (field.type.toString().contains("String")) {
"\"${field.default}\""
} else {
"${field.default}"
}
parameterSpec.defaultValue(initializerBlock)
parameterSpec.defaultValue(field.default)
} else {
when (returnType) {
STRING -> if (field.nullable) parameterSpec.defaultValue("null")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package com.netflix.graphql.dgs.codegen.generators.kotlin

import com.netflix.graphql.dgs.codegen.CodeGenConfig
import com.netflix.graphql.dgs.codegen.KotlinCodeGenResult
import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.LIST
import com.squareup.kotlinpoet.ParameterizedTypeName
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
Expand Down Expand Up @@ -47,7 +48,7 @@ class KotlinEntitiesRepresentationTypeGenerator(private val config: CodeGenConfi

var result = KotlinCodeGenResult()
// generate representations of entity types that have @key, including the __typename field, and the key fields
val typeName = Field("__typename", STRING, false, definition.name)
val typeName = Field("__typename", STRING, false, CodeBlock.of("%S", definition.name))
val fieldDefinitions= definition.fieldDefinitions
.filter {
keyFields.containsKey(it.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,151 @@ class CodeGenTest {
assertCompiles(dataTypes)
}

@Test
fun generateInputWithDefaultValueForEnum() {
val schema = """
enum Color {
red
}
input ColorFilter {
color: Color = red
}
""".trimIndent()

val (dataTypes, _, enumTypes) = CodeGen(CodeGenConfig(schemas = setOf(schema), packageName = basePackageName)).generate() as CodeGenResult
assertThat(dataTypes).hasSize(1)

val data = dataTypes[0]
assertThat(data.packageName).isEqualTo(typesPackageName)

val type = data.typeSpec
assertThat(type.name).isEqualTo("ColorFilter")

val fields = type.fieldSpecs
assertThat(fields).hasSize(1)

val colorField = fields[0]
assertThat(colorField.name).isEqualTo("color")
assertThat(colorField.type.toString()).isEqualTo("$typesPackageName.Color")
assertThat(colorField.initializer.toString()).isEqualTo("$typesPackageName.Color.red")

assertCompiles(enumTypes + dataTypes)
}

@Test
fun generateInputWithDefaultValueForArray() {
val schema = """
input SomeType {
names: [String] = []
}
""".trimIndent()

val (dataTypes) = CodeGen(CodeGenConfig(schemas = setOf(schema), packageName = basePackageName)).generate() as CodeGenResult
assertThat(dataTypes).hasSize(1)

val data = dataTypes[0]
assertThat(data.packageName).isEqualTo(typesPackageName)

val type = data.typeSpec
assertThat(type.name).isEqualTo("SomeType")

val fields = type.fieldSpecs
assertThat(fields).hasSize(1)

val colorField = fields[0]
assertThat(colorField.name).isEqualTo("names")
assertThat(colorField.initializer.toString()).isEqualTo("java.util.Collections.emptyList()")

assertCompiles(dataTypes)
}

@Test
fun generateInputWithDefaultStringValueForArray() {
val schema = """
input SomeType {
names: [String] = ["A", "B"]
}
""".trimIndent()

val (dataTypes) = CodeGen(CodeGenConfig(schemas = setOf(schema), packageName = basePackageName)).generate() as CodeGenResult
assertThat(dataTypes).hasSize(1)

val data = dataTypes[0]
assertThat(data.packageName).isEqualTo(typesPackageName)

val type = data.typeSpec
assertThat(type.name).isEqualTo("SomeType")

val fields = type.fieldSpecs
assertThat(fields).hasSize(1)

val colorField = fields[0]
assertThat(colorField.name).isEqualTo("names")
assertThat(colorField.initializer.toString()).isEqualTo("""java.util.Arrays.asList("A", "B")""")

assertCompiles(dataTypes)
}

@Test
fun generateInputWithDefaultIntValueForArray() {
val schema = """
input SomeType {
numbers: [Int] = [1, 2, 3]
}
""".trimIndent()

val (dataTypes) = CodeGen(CodeGenConfig(schemas = setOf(schema), packageName = basePackageName)).generate() as CodeGenResult
assertThat(dataTypes).hasSize(1)

val data = dataTypes[0]
assertThat(data.packageName).isEqualTo(typesPackageName)

val type = data.typeSpec
assertThat(type.name).isEqualTo("SomeType")

val fields = type.fieldSpecs
assertThat(fields).hasSize(1)

val colorField = fields[0]
assertThat(colorField.name).isEqualTo("numbers")
assertThat(colorField.initializer.toString()).isEqualTo("""java.util.Arrays.asList(1, 2, 3)""")

assertCompiles(dataTypes)
}

@Test
fun generateInputWithDefaultEnumValueForArray() {
val schema = """
input SomeType {
colors: [Color] = [red]
}
enum Color {
red,
blue
}
""".trimIndent()

val (dataTypes,_, enumTypes) = CodeGen(CodeGenConfig(schemas = setOf(schema), packageName = basePackageName, writeToFiles = true)).generate() as CodeGenResult
assertThat(dataTypes).hasSize(1)

val data = dataTypes[0]
assertThat(data.packageName).isEqualTo(typesPackageName)

val type = data.typeSpec
assertThat(type.name).isEqualTo("SomeType")

val fields = type.fieldSpecs
assertThat(fields).hasSize(1)

val colorField = fields[0]
assertThat(colorField.name).isEqualTo("colors")
assertThat(colorField.initializer.toString()).isEqualTo("""java.util.Arrays.asList(Color.red)""")

assertCompiles(dataTypes + enumTypes)
}

@Test
fun generateToInputStringMethodForInputTypes() {

Expand Down
Loading

0 comments on commit 03692e0

Please sign in to comment.