From c17c6aa72a014ddb9a8f7b8a9191e04e4fa52c3b Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Sat, 16 Dec 2023 18:44:09 +0100 Subject: [PATCH] cleanups and fixes --- codec_default_internal_test.go | 40 ++++++- schema.go | 25 +++-- schema_compatibility.go | 23 ++-- schema_compatibility_test.go | 44 +++++--- schema_internal_test.go | 198 +++++++++++++++++++++++++++++++++ 5 files changed, 289 insertions(+), 41 deletions(-) diff --git a/codec_default_internal_test.go b/codec_default_internal_test.go index 59c82130..9adf1b1a 100644 --- a/codec_default_internal_test.go +++ b/codec_default_internal_test.go @@ -57,6 +57,45 @@ func TestDecoder_InvalidDefault(t *testing.T) { require.Error(t, err) } +func TestDecoder_DrainField(t *testing.T) { + defer ConfigTeardown() + + // write schema + // `{ + // // "type": "record", + // // "name": "test", + // // "fields" : [ + // // {"name": "a", "type": "string"} + // // ] + // // }` + + // {"a": "foo"} + data := []byte{0x6, 0x66, 0x6f, 0x6f} + + schema := MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": "float", "default": 10.45} + ] + }`) + + schema.(*RecordSchema).Fields()[0].action = FieldDrain + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault + + type TestRecord struct { + A string `avro:"a"` + B float32 `avro:"b"` + } + + var got TestRecord + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: 10.45, A: ""}, got) +} + func TestDecoder_DefaultBool(t *testing.T) { defer ConfigTeardown() @@ -714,5 +753,4 @@ func TestDecoder_DefaultFixed(t *testing.T) { assert.Equal(t, big.NewRat(1734, 5), &got.B) assert.Equal(t, "foo", got.A) }) - } diff --git a/schema.go b/schema.go index e2658d43..50779d84 100644 --- a/schema.go +++ b/schema.go @@ -521,11 +521,11 @@ func (s *PrimitiveSchema) FingerprintUsing(typ FingerprintType) ([]byte, error) // CacheFingerprint returns a special fingerprint of the schema for caching purposes. func (s *PrimitiveSchema) CacheFingerprint() [32]byte { - data := []any{s.Fingerprint()} - if s.actual != "" { - data = append(data, s.actual) + if s.actual == "" { + return s.Fingerprint() } - return s.cacheFingerprinter.fingerprint(data) + + return s.cacheFingerprinter.fingerprint([]any{s.Fingerprint(), s.actual}) } // RecordSchema is an Avro record type schema. @@ -654,12 +654,17 @@ func (s *RecordSchema) FingerprintUsing(typ FingerprintType) ([]byte, error) { // CacheFingerprint returns a special fingerprint of the schema for caching purposes. func (s *RecordSchema) CacheFingerprint() [32]byte { - data := []any{s.Fingerprint()} + data := make([]any, 0) for _, field := range s.fields { if field.Default() != nil { data = append(data, field.Default()) } } + if len(data) == 0 { + return s.Fingerprint() + } + + data = append(data, s.Fingerprint()) return s.cacheFingerprinter.fingerprint(data) } @@ -679,7 +684,7 @@ type Field struct { action Action // encodedDef mainly used when decoding data that lack the field for schema evolution purposes. // Its value remains empty unless the field's encodeDefault function is called. - encodedDef []byte + encodedDef atomic.Value } type noDef struct{} @@ -764,8 +769,8 @@ func (f *Field) Default() any { } func (f *Field) encodeDefault(encode func(any) ([]byte, error)) ([]byte, error) { - if f.encodedDef != nil { - return f.encodedDef, nil + if v := f.encodedDef.Load(); v != nil { + return v.([]byte), nil } if !f.hasDef { return nil, fmt.Errorf("avro: '%s' field must have a non-empty default value", f.name) @@ -777,9 +782,9 @@ func (f *Field) encodeDefault(encode func(any) ([]byte, error)) ([]byte, error) if err != nil { return nil, err } - f.encodedDef = b + f.encodedDef.Store(b) - return f.encodedDef, nil + return b, nil } // Doc returns the documentation of a field. diff --git a/schema_compatibility.go b/schema_compatibility.go index 0985d276..57221753 100644 --- a/schema_compatibility.go +++ b/schema_compatibility.go @@ -180,11 +180,6 @@ func (c *SchemaCompatibility) checkSchemaName(reader, writer NamedSchema) error if c.contains(reader.Aliases(), writer.FullName()) { return nil } - // for _, alias := range reader.Aliases() { - // if alias == writer.FullName() { - // return nil - // } - // } return fmt.Errorf("reader schema %s and writer schema %s names do not match", reader.FullName(), writer.FullName()) } @@ -248,9 +243,6 @@ type getFieldOptions struct { func (c *SchemaCompatibility) getField(a []*Field, f *Field, optFns ...func(*getFieldOptions)) (*Field, bool) { opt := getFieldOptions{} for _, fn := range optFns { - if fn == nil { - continue - } fn(&opt) } for _, field := range a { @@ -277,13 +269,6 @@ func (c *SchemaCompatibility) getField(a []*Field, f *Field, optFns ...func(*get // // It fails if the writer and reader schemas are not compatible. func (c *SchemaCompatibility) Resolve(reader, writer Schema) (Schema, error) { - if reader.Type() == Ref { - reader = reader.(*RefSchema).Schema() - } - if writer.Type() == Ref { - writer = writer.(*RefSchema).Schema() - } - if err := c.compatible(reader, writer); err != nil { return nil, err } @@ -291,7 +276,15 @@ func (c *SchemaCompatibility) Resolve(reader, writer Schema) (Schema, error) { return c.resolve(reader, writer) } +// resolve requires the reader's schema to be already compatible with the writer's. func (c *SchemaCompatibility) resolve(reader, writer Schema) (Schema, error) { + if reader.Type() == Ref { + reader = reader.(*RefSchema).Schema() + } + if writer.Type() == Ref { + writer = writer.(*RefSchema).Schema() + } + if writer.Type() != reader.Type() { if reader.Type() == Union { for _, schema := range reader.(*UnionSchema).Types() { diff --git a/schema_compatibility_test.go b/schema_compatibility_test.go index 94f39fe2..4bbb119f 100644 --- a/schema_compatibility_test.go +++ b/schema_compatibility_test.go @@ -671,35 +671,49 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { "type": "record", "name": "parent", "namespace": "org.hamba.avro", - "fields": [{ - "name": "a", - "type": "int" - }, + "fields": [ { - "name": "b", + "name": "a", "type": { "type": "record", - "name": "test", + "name": "embed", + "namespace": "org.hamba.avro", "fields": [{ "name": "a", "type": "long" }] - }, - "default": {"a": 10} + } }, { - "name": "c", - "type": "test", + "name": "b", + "type": "embed", "default": {"a": 20} } ] }`, - writer: `{"type":"record", "name":"parent", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, - value: map[string]any{"a": 10}, + writer: `{ + "type": "record", + "name": "parent", + "namespace": "org.hamba.avro", + "fields": [ + { + "name": "a", + "type": { + "type": "record", + "name": "embed", + "namespace": "org.hamba.avro", + "fields": [{ + "name": "a", + "type": "long" + }] + } + } + ] + }`, + value: map[string]any{"a": map[string]any{"a": int64(10)}}, want: map[string]any{ - "a": 10, - "b": map[string]any{"a": int64(10)}, - "c": map[string]any{"a": int64(20)}, + "a": map[string]any{"a": int64(10)}, + "b": map[string]any{"a": int64(20)}, }, }, } diff --git a/schema_internal_test.go b/schema_internal_test.go index 7872eaf7..30708790 100644 --- a/schema_internal_test.go +++ b/schema_internal_test.go @@ -1,6 +1,7 @@ package avro import ( + "strconv" "testing" "github.com/stretchr/testify/assert" @@ -383,3 +384,200 @@ func TestSchema_FingerprintUsingCaches(t *testing.T) { assert.Equal(t, want, value) assert.Equal(t, want, got) } + +func TestSchema_IsPromotable(t *testing.T) { + tests := []struct { + typ Type + wantOk bool + }{ + { + typ: Int, + wantOk: true, + }, + { + typ: Long, + wantOk: true, + }, + { + typ: Float, + wantOk: true, + }, + { + typ: String, + wantOk: true, + }, + { + typ: Bytes, + wantOk: true, + }, + { + typ: Double, + wantOk: false, + }, + { + typ: Boolean, + wantOk: false, + }, + { + typ: Null, + wantOk: false, + }, + } + + for i, test := range tests { + test := test + t.Run(strconv.Itoa(i), func(t *testing.T) { + ok := isPromotable(test.typ) + assert.Equal(t, test.wantOk, ok) + }) + } +} + +func TestSchema_IsNative(t *testing.T) { + tests := []struct { + typ Type + wantOk bool + }{ + { + typ: Null, + wantOk: true, + }, + { + typ: Boolean, + wantOk: true, + }, + { + typ: Int, + wantOk: true, + }, + { + typ: Long, + wantOk: true, + }, + + { + typ: Float, + wantOk: true, + }, + { + typ: Double, + wantOk: true, + }, + + { + typ: Bytes, + wantOk: true, + }, + { + typ: String, + wantOk: true, + }, + { + typ: Record, + wantOk: false, + }, + { + typ: Array, + wantOk: false, + }, + { + typ: Map, + wantOk: false, + }, + { + typ: Fixed, + wantOk: false, + }, + { + typ: Enum, + wantOk: false, + }, + { + typ: Union, + wantOk: false, + }, + } + + for i, test := range tests { + test := test + t.Run(strconv.Itoa(i), func(t *testing.T) { + ok := isNative(test.typ) + assert.Equal(t, test.wantOk, ok) + }) + } +} + +func TestSchema_FieldEncodeDefault(t *testing.T) { + schema := MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string", "default": "bar"}, + {"name": "b", "type": "boolean"} + ] + }`).(*RecordSchema) + + fooEncoder := func(a any) ([]byte, error) { + return []byte("foo"), nil + } + barEncoder := func(a any) ([]byte, error) { + return []byte("bar"), nil + } + + assert.Equal(t, nil, schema.fields[0].encodedDef.Load()) + + _, err := schema.fields[0].encodeDefault(nil) + assert.Error(t, err) + + _, err = schema.fields[1].encodeDefault(fooEncoder) + assert.Error(t, err) + + def, err := schema.fields[0].encodeDefault(fooEncoder) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), def) + + def, err = schema.fields[0].encodeDefault(barEncoder) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), def) +} + +func TestSchema_CacheFingerprint(t *testing.T) { + t.Run("invalid", func(t *testing.T) { + cacheFingerprint := cacheFingerprinter{} + assert.Panics(t, func() { + cacheFingerprint.fingerprint([]any{func() {}}) + }) + }) + + t.Run("promoted", func(t *testing.T) { + schema := NewPrimitiveSchema(Long, nil) + assert.Equal(t, schema.Fingerprint(), schema.CacheFingerprint()) + + schema = NewPrimitiveSchema(Long, nil) + schema.actual = Int + assert.NotEqual(t, schema.Fingerprint(), schema.CacheFingerprint()) + }) + + t.Run("record", func(t *testing.T) { + schema1 := MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": "boolean"} + ] + }`).(*RecordSchema) + + schema2 := MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string", "default": "bar"}, + {"name": "b", "type": "boolean", "default": false} + ] + }`).(*RecordSchema) + + assert.Equal(t, schema1.Fingerprint(), schema1.CacheFingerprint()) + assert.NotEqual(t, schema1.CacheFingerprint(), schema2.CacheFingerprint()) + }) +}