Skip to content

Commit

Permalink
Add remaining visitor implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
aihuaxu committed Feb 20, 2025
1 parent 7f39dae commit 0d732ae
Show file tree
Hide file tree
Showing 18 changed files with 93 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ public Schema map(Schema map, Schema value) {
return map;
}

@Override
public Schema variant(Schema variant, List<Schema> fields) {
return variant;
}

@Override
public Schema primitive(Schema primitive) {
return primitive;
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/org/apache/iceberg/avro/Avro.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ private enum Codec {

static {
LogicalTypes.register(LogicalMap.NAME, schema -> LogicalMap.get());
LogicalTypes.register(Variant.NAME, schema -> Variant.get());
LogicalTypes.register(VariantLogicalType.NAME, schema -> VariantLogicalType.get());
DEFAULT_MODEL.addLogicalTypeConversion(new Conversions.DecimalConversion());
DEFAULT_MODEL.addLogicalTypeConversion(new UUIDConversion());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ public static <T, F> T visit(Schema schema, AvroCustomOrderSchemaVisitor<T, F> v

visitor.recordLevels.push(name);

if (schema.getLogicalType() instanceof Variant) {
return visitor.variant(schema);
}

List<Schema.Field> fields = schema.getFields();
List<String> names = Lists.newArrayListWithExpectedSize(fields.size());
List<Supplier<F>> results = Lists.newArrayListWithExpectedSize(fields.size());
Expand All @@ -51,7 +47,13 @@ public static <T, F> T visit(Schema schema, AvroCustomOrderSchemaVisitor<T, F> v

visitor.recordLevels.pop();

return visitor.record(schema, names, Iterables.transform(results, Supplier::get));
if (schema.getLogicalType() instanceof VariantLogicalType) {
Preconditions.checkArgument(
AvroSchemaUtil.isVariantSchema(schema), "Invalid variant record: %s", schema);
return visitor.variant(schema);
} else {
return visitor.record(schema, names, Iterables.transform(results, Supplier::get));
}

case UNION:
List<Schema> types = schema.getTypes();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private AvroEncoderUtil() {}

static {
LogicalTypes.register(LogicalMap.NAME, schema -> LogicalMap.get());
LogicalTypes.register(Variant.NAME, schema -> Variant.get());
LogicalTypes.register(VariantLogicalType.NAME, schema -> VariantLogicalType.get());
}

private static final byte[] MAGIC_BYTES = new byte[] {(byte) 0xC2, (byte) 0x01};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ static boolean isVariantSchema(Schema schema) {
return schema.getType() == RECORD
&& schema.getFields().size() == 2
&& schema.getField("metadata") != null
&& schema.getField("value") != null;
&& schema.getField("metadata").schema().getType() == Schema.Type.BYTES
&& schema.getField("value") != null
&& schema.getField("value").schema().getType() == Schema.Type.BYTES;
}

static Schema createMap(int keyId, Schema keySchema, int valueId, Schema valueSchema) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ public abstract class AvroSchemaVisitor<T> {
public static <T> T visit(Schema schema, AvroSchemaVisitor<T> visitor) {
switch (schema.getType()) {
case RECORD:
if (schema.getLogicalType() instanceof Variant) {
return visitor.variant(schema);
}

// check to make sure this hasn't been visited before
String name = schema.getFullName();
Preconditions.checkState(
Expand All @@ -50,7 +46,13 @@ public static <T> T visit(Schema schema, AvroSchemaVisitor<T> visitor) {

visitor.recordLevels.pop();

return visitor.record(schema, names, results);
if (schema.getLogicalType() instanceof VariantLogicalType) {
Preconditions.checkArgument(
AvroSchemaUtil.isVariantSchema(schema), "Invalid variant record: %s", schema);
return visitor.variant(schema, results);
} else {
return visitor.record(schema, names, results);
}

case UNION:
List<Schema> types = schema.getTypes();
Expand Down Expand Up @@ -107,7 +109,7 @@ public T map(Schema map, T value) {
return null;
}

public T variant(Schema variant) {
public T variant(Schema variant, List<T> fields) {
throw new UnsupportedOperationException("Unsupported type: variant");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ public ValueWriter<?> map(Schema map, ValueWriter<?> valueWriter) {
return ValueWriters.map(ValueWriters.strings(), valueWriter);
}

@Override
public ValueWriter<?> variant(Schema variant, List<ValueWriter<?>> fields) {
return createRecordWriter(fields);
}

@Override
public ValueWriter<?> primitive(Schema primitive) {
LogicalType logicalType = primitive.getLogicalType();
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/java/org/apache/iceberg/avro/HasIds.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ public Boolean union(Schema union, Iterable<Boolean> options) {
return Iterables.any(options, Boolean.TRUE::equals);
}

@Override
public Boolean variant(Schema variant) {
return false;
}

@Override
public Boolean primitive(Schema primitive) {
return false;
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/java/org/apache/iceberg/avro/MissingIds.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ public Boolean union(Schema union, Iterable<Boolean> options) {
return Iterables.any(options, Boolean.TRUE::equals);
}

@Override
public Boolean variant(Schema variant) {
return false;
}

@Override
public Boolean primitive(Schema primitive) {
// primitive node cannot be missing ID as Iceberg do not assign primitive node IDs in the first
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/java/org/apache/iceberg/avro/RemoveIds.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ public Schema array(Schema array, Schema element) {
return result;
}

@Override
public Schema variant(Schema variant, List<Schema> fields) {
return variant;
}

@Override
public Schema primitive(Schema primitive) {
return Schema.create(primitive.getType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ public Type map(Schema map, Type valueType) {
}

@Override
public Type variant(Schema variant) {
public Type variant(Schema variant, List<Type> fieldTypes) {
return Types.VariantType.get();
}

Expand Down
8 changes: 4 additions & 4 deletions core/src/main/java/org/apache/iceberg/avro/TypeToSchema.java
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public Schema struct(Types.StructType struct, List<Schema> fieldSchemas) {
Integer fieldId = fieldIds.peek();
String recordName = namesFunction.apply(fieldId, struct);
if (recordName == null) {
recordName = "r" + fieldId;
recordName = fieldId != null ? "r" + fieldId : "table";
}

Schema recordSchema = lookupSchema(struct, recordName);
Expand Down Expand Up @@ -188,8 +188,8 @@ public Schema map(Types.MapType map, Schema keySchema, Schema valueSchema) {
}

@Override
public Schema variant() {
String recordName = "r" + fieldIds.peek();
public Schema variant(Types.VariantType variant) {
String recordName = fieldIds.peek() != null ? "r" + fieldIds.peek() : "variant";
Schema schema =
Schema.createRecord(
recordName,
Expand All @@ -199,7 +199,7 @@ public Schema variant() {
List.of(
new Schema.Field("metadata", BINARY_SCHEMA),
new Schema.Field("value", BINARY_SCHEMA)));
return Variant.get().addToSchema(schema);
return VariantLogicalType.get().addToSchema(schema);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
import org.apache.avro.Schema;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;

public class Variant extends LogicalType {
public class VariantLogicalType extends LogicalType {
static final String NAME = "variant";
private static final Variant INSTANCE = new Variant();
private static final VariantLogicalType INSTANCE = new VariantLogicalType();

static Variant get() {
static VariantLogicalType get() {
return INSTANCE;
}

private Variant() {
private VariantLogicalType() {
super(NAME);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ static Schema variant(String name) {
name,
new Schema.Field("metadata", Schema.create(Schema.Type.BYTES), null, null),
new Schema.Field("value", Schema.create(Schema.Type.BYTES), null, null));
return Variant.get().addToSchema(schema);
return VariantLogicalType.get().addToSchema(schema);
}

static Schema.Field addId(int id, Schema.Field field) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,23 @@ public void testInferredMapping() throws IOException {
assertThat(projected).isEqualTo(record);
}

@Test
public void testVariantNameMapping() {
org.apache.iceberg.Schema icebergSchema =
new org.apache.iceberg.Schema(
Types.NestedField.required(0, "id", Types.LongType.get()),
Types.NestedField.required(1, "var", Types.VariantType.get()));

org.apache.avro.Schema avroSchema = RemoveIds.removeIds(icebergSchema);
assertThat(AvroSchemaUtil.hasIds(avroSchema)).as("Expect schema has no ids").isFalse();

NameMapping nameMapping =
NameMapping.of(
MappedField.of(0, ImmutableList.of("id")), MappedField.of(1, ImmutableList.of("var")));
org.apache.avro.Schema mappedSchema = AvroSchemaUtil.applyNameMapping(avroSchema, nameMapping);
assertThat(mappedSchema).isEqualTo(AvroSchemaUtil.convert(icebergSchema.asStruct()));
}

@Test
@Override
public void testAvroArrayAsLogicalMap() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Collections;
import org.apache.avro.SchemaBuilder;
import org.apache.iceberg.Schema;
import org.apache.iceberg.types.Types;
import org.junit.jupiter.api.Test;

public class TestAvroSchemaProjection {
Expand Down Expand Up @@ -152,56 +153,20 @@ public void projectWithMapSchemaChanged() {
}

@Test
public void projectWithVariantSchemaChanged() {
final org.apache.avro.Schema currentAvroSchema =
SchemaBuilder.record("myrecord")
.fields()
.name("f11")
.type()
.nullable()
.intType()
.noDefault()
.endRecord();

final org.apache.avro.Schema variantSchema =
SchemaBuilder.record("v")
.fields()
.name("metadata")
.type()
.bytesType()
.noDefault()
.name("value")
.type()
.bytesType()
.noDefault()
.endRecord();
Variant.get().addToSchema(variantSchema);
public void projectWithVariantType() {
Schema icebergSchema =
new Schema(
Types.NestedField.required(0, "id", Types.LongType.get()),
Types.NestedField.required(1, "data", Types.VariantType.get()));

final org.apache.avro.Schema updatedAvroSchema =
SchemaBuilder.record("myrecord")
.fields()
.name("f11")
.type()
.nullable()
.intType()
.noDefault()
.name("f12")
.type(variantSchema)
.noDefault()
.endRecord();

final Schema currentIcebergSchema = AvroSchemaUtil.toIceberg(currentAvroSchema);

// Getting the node ID in updatedAvroSchema allocated by converting into iceberg schema and back
final org.apache.avro.Schema idAllocatedUpdatedAvroSchema =
AvroSchemaUtil.convert(AvroSchemaUtil.toIceberg(updatedAvroSchema).asStruct());

final org.apache.avro.Schema projectedAvroSchema =
org.apache.avro.Schema projectedSchema =
AvroSchemaUtil.buildAvroProjection(
idAllocatedUpdatedAvroSchema, currentIcebergSchema, Collections.emptyMap());

assertThat(AvroSchemaUtil.missingIds(projectedAvroSchema))
.as("Result of buildAvroProjection is missing some IDs")
.isFalse();
AvroSchemaUtil.convert(icebergSchema.asStruct()),
icebergSchema.select("data"),
Collections.emptyMap());
assertThat(projectedSchema.getField("id")).isNull();
org.apache.avro.Schema variantSchema = projectedSchema.getField("data").schema();
assertThat(variantSchema.getLogicalType()).isEqualTo(VariantLogicalType.get());
assertThat(AvroSchemaUtil.isVariantSchema(variantSchema)).isTrue();
}
}
7 changes: 5 additions & 2 deletions core/src/test/java/org/apache/iceberg/avro/TestHasIds.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

public class TestHasIds {
@Test
public void test() {
public void testRemoveIdsHasIds() {
Schema schema =
new Schema(
Types.NestedField.required(0, "id", Types.LongType.get()),
Expand All @@ -39,7 +39,10 @@ public void test() {
Types.StringType.get(),
Types.StructType.of(
Types.NestedField.required(1, "lat", Types.FloatType.get()),
Types.NestedField.optional(2, "long", Types.FloatType.get())))));
Types.NestedField.optional(2, "long", Types.FloatType.get())))),
Types.NestedField.required(
8, "types", Types.ListType.ofRequired(9, Types.StringType.get())),
Types.NestedField.required(10, "data", Types.VariantType.get()));

org.apache.avro.Schema avroSchema = RemoveIds.removeIds(schema);
assertThat(AvroSchemaUtil.hasIds(avroSchema)).as("Avro schema should not have IDs").isFalse();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void testPrimitiveTypes() {
Schema.create(Schema.Type.BYTES),
LogicalTypes.decimal(9, 4)
.addToSchema(Schema.createFixed("decimal_9_4", null, null, 4)),
variant("rnull"));
variant("variant"));

for (int i = 0; i < primitives.size(); i += 1) {
Type type = primitives.get(i);
Expand Down Expand Up @@ -386,6 +386,7 @@ public void testVariantConversion() {

for (int id : Lists.newArrayList(1, 2)) {
org.apache.avro.Schema variantSchema = avroSchema.getField("variantCol" + id).schema();
assertThat(variantSchema.getName()).isEqualTo("r" + id);
assertThat(variantSchema.getType()).isEqualTo(org.apache.avro.Schema.Type.RECORD);
assertThat(variantSchema.getFields().size()).isEqualTo(2);
assertThat(variantSchema.getField("metadata").schema().getType())
Expand Down

0 comments on commit 0d732ae

Please sign in to comment.