Skip to content

Commit

Permalink
Change variant visitor interface
Browse files Browse the repository at this point in the history
  • Loading branch information
aihuaxu committed Feb 22, 2025
1 parent 0598dbb commit 4351e8a
Show file tree
Hide file tree
Showing 13 changed files with 50 additions and 43 deletions.
4 changes: 4 additions & 0 deletions api/src/main/java/org/apache/iceberg/variants/Variant.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@

/** A variant metadata and value pair. */
public interface Variant {
String METADATA = "metadata";
String VALUE = "value";
String TYPED_VALUE = "typed_value";

/** Returns the metadata for all values in the variant. */
VariantMetadata metadata();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public Schema map(Schema map, Schema value) {
}

@Override
public Schema variant(Schema variant, List<Schema> fields) {
public Schema variant(Schema variant, Schema metadata, Schema value) {
return variant;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.avro.Schema;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.variants.Variant;

public abstract class AvroSchemaVisitor<T> {
public static <T> T visit(Schema schema, AvroSchemaVisitor<T> visitor) {
Expand All @@ -49,7 +50,12 @@ public static <T> T visit(Schema schema, AvroSchemaVisitor<T> visitor) {
if (schema.getLogicalType() instanceof VariantLogicalType) {
Preconditions.checkArgument(
AvroSchemaUtil.isVariantSchema(schema), "Invalid variant record: %s", schema);
return visitor.variant(schema, results);

boolean isMetadataFirst = names.get(0).equals(Variant.METADATA);
return visitor.variant(
schema,
isMetadataFirst ? results.get(0) : results.get(1),
isMetadataFirst ? results.get(1) : results.get(0));
} else {
return visitor.record(schema, names, results);
}
Expand Down Expand Up @@ -109,7 +115,7 @@ public T map(Schema map, T value) {
return null;
}

public T variant(Schema variant, List<T> fields) {
public T variant(Schema variant, T metadataResult, T valueResult) {
throw new UnsupportedOperationException("Unsupported type: variant");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ public ValueWriter<?> map(Schema map, ValueWriter<?> valueWriter) {
return ValueWriters.map(ValueWriters.strings(), valueWriter);
}

@Override
public ValueWriter<?> variant(Schema variant, List<ValueWriter<?>> fields) {
return createRecordWriter(fields);
}

@Override
public ValueWriter<?> primitive(Schema primitive) {
LogicalType logicalType = primitive.getLogicalType();
Expand Down
15 changes: 6 additions & 9 deletions core/src/main/java/org/apache/iceberg/avro/PruneColumns.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ Schema rootSchema(Schema record) {
}

@Override
@SuppressWarnings("checkstyle:CyclomaticComplexity")
public Schema record(Schema record, List<String> names, List<Schema> fields) {
// Then this should access the record's fields by name
List<Schema.Field> filteredFields = Lists.newArrayListWithExpectedSize(fields.size());
Expand Down Expand Up @@ -93,8 +92,7 @@ public Schema record(Schema record, List<String> names, List<Schema> fields) {
hasChange = true; // Sub-fields may be different
filteredFields.add(copyField(field, fieldSchema, fieldId));
} else {
if (isRecord(field.schema())
&& field.schema().getLogicalType() != VariantLogicalType.get()) {
if (isRecord(field.schema())) {
hasChange = true; // Sub-fields are now empty
filteredFields.add(copyField(field, makeEmptyCopy(field.schema()), fieldId));
} else {
Expand Down Expand Up @@ -262,7 +260,7 @@ private Schema mapWithIds(Schema map, Integer keyId, Integer valueId) {
}

@Override
public Schema variant(Schema variant, List<Schema> fields) {
public Schema variant(Schema variant, Schema metadata, Schema value) {
return null;
}

Expand All @@ -284,12 +282,11 @@ private static Schema copyRecord(Schema record, List<Schema.Field> newFields) {
return copy;
}

/* Check the schema is a record but not a Variant type */
private boolean isRecord(Schema field) {
if (AvroSchemaUtil.isOptionSchema(field)) {
return AvroSchemaUtil.fromOption(field).getType().equals(Schema.Type.RECORD);
} else {
return field.getType().equals(Schema.Type.RECORD);
}
Schema schema = AvroSchemaUtil.isOptionSchema(field) ? AvroSchemaUtil.fromOption(field) : field;
return schema.getType().equals(Schema.Type.RECORD)
&& !(schema.getLogicalType() instanceof VariantLogicalType);
}

private static Schema makeEmptyCopy(Schema field) {
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/org/apache/iceberg/avro/RemoveIds.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public Schema array(Schema array, Schema element) {
}

@Override
public Schema variant(Schema variant, List<Schema> fields) {
public Schema variant(Schema variant, Schema metadata, Schema value) {
return variant;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ public Type map(Schema map, Type valueType) {
}

@Override
public Type variant(Schema variant, List<Type> fieldTypes) {
public Type variant(Schema variant, Type metadataType, Type valueType) {
return Types.VariantType.get();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public Schema struct(Types.StructType struct, List<Schema> fieldSchemas) {
Integer fieldId = fieldIds.peek();
String recordName = namesFunction.apply(fieldId, struct);
if (recordName == null) {
recordName = fieldId != null ? "r" + fieldId : "table";
recordName = "r" + fieldId;
}

Schema recordSchema = lookupSchema(struct, recordName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.iceberg.types.TypeUtil;
import org.apache.iceberg.util.DecimalUtil;
import org.apache.iceberg.util.UUIDUtil;
import org.checkerframework.checker.units.qual.K;

public class ValueWriters {
private ValueWriters() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ public void testInferredMapping() throws IOException {

@Test
public void testVariantNameMapping() {
org.apache.iceberg.Schema icebergSchema =
new org.apache.iceberg.Schema(
Schema icebergSchema =
new Schema(
Types.NestedField.required(0, "id", Types.LongType.get()),
Types.NestedField.required(1, "var", Types.VariantType.get()));

Expand All @@ -356,7 +356,7 @@ public void testVariantNameMapping() {
NameMapping.of(
MappedField.of(0, ImmutableList.of("id")), MappedField.of(1, ImmutableList.of("var")));
org.apache.avro.Schema mappedSchema = AvroSchemaUtil.applyNameMapping(avroSchema, nameMapping);
assertThat(mappedSchema).isEqualTo(AvroSchemaUtil.convert(icebergSchema.asStruct()));
assertThat(mappedSchema).isEqualTo(AvroSchemaUtil.convert(icebergSchema.asStruct(), "table"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ public class TestPruneColumns {
Types.NestedField.required(6, "lat", Types.FloatType.get()),
Types.NestedField.optional(7, "long", Types.FloatType.get()))),
Types.NestedField.required(
8, "types", Types.ListType.ofRequired(9, Types.StringType.get())),
Types.NestedField.required(10, "payload", Types.VariantType.get()))
8, "tags", Types.ListType.ofRequired(9, Types.StringType.get())),
Types.NestedField.optional(10, "payload", Types.VariantType.get()))
.asStruct());

@Test
Expand Down Expand Up @@ -97,7 +97,7 @@ public void testSelectList(int selectedId) {
Schema expected =
new Schema(
Types.NestedField.required(
8, "types", Types.ListType.ofRequired(9, Types.StringType.get())));
8, "tags", Types.ListType.ofRequired(9, Types.StringType.get())));
org.apache.avro.Schema prunedSchema =
AvroSchemaUtil.pruneColumns(TEST_SCHEMA, Sets.newHashSet(selectedId));
assertThat(prunedSchema).isEqualTo(AvroSchemaUtil.convert(expected.asStruct()));
Expand All @@ -106,7 +106,7 @@ public void testSelectList(int selectedId) {
@Test
public void testSelectVariant() {
Schema expected =
new Schema(Types.NestedField.required(10, "payload", Types.VariantType.get()));
new Schema(Types.NestedField.optional(10, "payload", Types.VariantType.get()));
org.apache.avro.Schema prunedSchema =
AvroSchemaUtil.pruneColumns(TEST_SCHEMA, Sets.newHashSet(10));
assertThat(prunedSchema).isEqualTo(AvroSchemaUtil.convert(expected.asStruct()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,14 @@
import java.util.List;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.variants.Variant;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.schema.LogicalTypeAnnotation.ListLogicalTypeAnnotation;
import org.apache.parquet.schema.PrimitiveType;
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName;
import org.apache.parquet.schema.Type;

public abstract class ParquetVariantVisitor<R> {
static final String METADATA = "metadata";
static final String VALUE = "value";
static final String TYPED_VALUE = "typed_value";

/**
* Handles the root variant column group.
*
Expand Down Expand Up @@ -164,9 +161,11 @@ public void afterField(Type type) {}

public static <R> R visit(GroupType type, ParquetVariantVisitor<R> visitor) {
Preconditions.checkArgument(
ParquetSchemaUtil.hasField(type, METADATA), "Invalid variant, missing metadata: %s", type);
ParquetSchemaUtil.hasField(type, Variant.METADATA),
"Invalid variant, missing metadata: %s",
type);

Type metadataType = type.getType(METADATA);
Type metadataType = type.getType(Variant.METADATA);
Preconditions.checkArgument(
isBinary(metadataType), "Invalid variant metadata, expecting BINARY: %s", metadataType);

Expand All @@ -180,8 +179,8 @@ public static <R> R visit(GroupType type, ParquetVariantVisitor<R> visitor) {

private static <R> R visitValue(GroupType valueGroup, ParquetVariantVisitor<R> visitor) {
R valueResult;
if (ParquetSchemaUtil.hasField(valueGroup, VALUE)) {
Type valueType = valueGroup.getType(VALUE);
if (ParquetSchemaUtil.hasField(valueGroup, Variant.VALUE)) {
Type valueType = valueGroup.getType(Variant.VALUE);
Preconditions.checkArgument(
isBinary(valueType), "Invalid variant value, expecting BINARY: %s", valueType);

Expand All @@ -190,15 +189,15 @@ private static <R> R visitValue(GroupType valueGroup, ParquetVariantVisitor<R> v
() -> visitor.serialized(valueType.asPrimitiveType()), valueType, visitor);
} else {
Preconditions.checkArgument(
ParquetSchemaUtil.hasField(valueGroup, TYPED_VALUE),
ParquetSchemaUtil.hasField(valueGroup, Variant.TYPED_VALUE),
"Invalid variant, missing both value and typed_value: %s",
valueGroup);

valueResult = null;
}

if (ParquetSchemaUtil.hasField(valueGroup, TYPED_VALUE)) {
Type typedValueType = valueGroup.getType(TYPED_VALUE);
if (ParquetSchemaUtil.hasField(valueGroup, Variant.TYPED_VALUE)) {
Type typedValueType = valueGroup.getType(Variant.TYPED_VALUE);

if (typedValueType.isPrimitive()) {
R typedResult =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Streams;
import org.apache.iceberg.variants.PhysicalType;
import org.apache.iceberg.variants.Variant;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.schema.LogicalTypeAnnotation.DateLogicalTypeAnnotation;
Expand Down Expand Up @@ -129,10 +130,12 @@ public VariantValueReader primitive(PrimitiveType primitive) {
public VariantValueReader value(
GroupType group, ParquetValueReader<?> valueReader, ParquetValueReader<?> typedReader) {
int valueDL =
valueReader != null ? schema.getMaxDefinitionLevel(path(VALUE)) - 1 : Integer.MAX_VALUE;
valueReader != null
? schema.getMaxDefinitionLevel(path(Variant.VALUE)) - 1
: Integer.MAX_VALUE;
int typedDL =
typedReader != null
? schema.getMaxDefinitionLevel(path(TYPED_VALUE)) - 1
? schema.getMaxDefinitionLevel(path(Variant.TYPED_VALUE)) - 1
: Integer.MAX_VALUE;
return ParquetVariantReaders.shredded(valueDL, valueReader, typedDL, typedReader);
}
Expand All @@ -143,11 +146,13 @@ public VariantValueReader object(
ParquetValueReader<?> valueReader,
List<ParquetValueReader<?>> fieldResults) {
int valueDL =
valueReader != null ? schema.getMaxDefinitionLevel(path(VALUE)) - 1 : Integer.MAX_VALUE;
int fieldsDL = schema.getMaxDefinitionLevel(path(TYPED_VALUE)) - 1;
valueReader != null
? schema.getMaxDefinitionLevel(path(Variant.VALUE)) - 1
: Integer.MAX_VALUE;
int fieldsDL = schema.getMaxDefinitionLevel(path(Variant.TYPED_VALUE)) - 1;

List<String> shreddedFieldNames =
group.getType(TYPED_VALUE).asGroupType().getFields().stream()
group.getType(Variant.TYPED_VALUE).asGroupType().getFields().stream()
.map(Type::getName)
.collect(Collectors.toList());
List<VariantValueReader> fieldReaders =
Expand Down

0 comments on commit 4351e8a

Please sign in to comment.