diff --git a/deserialize/deserialize.go b/deserialize/deserialize.go index 47467f0..87e0d1b 100644 --- a/deserialize/deserialize.go +++ b/deserialize/deserialize.go @@ -488,17 +488,19 @@ func makeStructDeserializerFromReflect(path string, typ reflect.Type, options st return nil, fmt.Errorf("struct %s contains a field \"%s\" that has both a `default` and a `orMethod` declaration. Please specify only one", path, fieldNativeName) } + willPreinitialize := initializationData.willPreinitialize || wasPreInitialized || tags.IsPreinitialized() + // By Go convention, a field with lower-case name or with a publicFieldName of "-" is private and // should not be parsed. isPublic := (*publicFieldName != "-") && fieldNativeExported - if !isPublic && !initializationData.willPreinitialize { - return nil, fmt.Errorf("struct %s contains a field \"%s\" that is not public, you should either make it public or specify an initializer with `Initializer` or `UnmarshalJSON`", path, fieldNativeName) + if !isPublic && !willPreinitialize { + return nil, fmt.Errorf("struct %s contains a field \"%s\" that is not public and not pre-initialized, you should either make it public or specify an initializer with `Initializer` or `UnmarshalJSON`", path, fieldNativeName) } fieldPath := fmt.Sprint(path, ".", *publicFieldName) var fieldContentDeserializer reflectDeserializer - fieldContentDeserializer, err = makeFieldDeserializerFromReflect(fieldPath, fieldType, options, &tags, selfContainer, initializationData.willPreinitialize) + fieldContentDeserializer, err = makeFieldDeserializerFromReflect(fieldPath, fieldType, options, &tags, selfContainer, willPreinitialize) if err != nil { return nil, err } @@ -785,7 +787,7 @@ func makeSliceDeserializer(fieldPath string, fieldType reflect.Type, options sta subContainer := reflect.New(fieldType).Elem() // Prepare a deserializer for elements in this slice. - childPreinitialized := tags.IsPreinitialized() + childPreinitialized := wasPreinitialized || tags.IsPreinitialized() elementDeserializer, err := makeFieldDeserializerFromReflect(arrayPath, fieldType.Elem(), options, &subTags, subContainer, childPreinitialized) if err != nil { return nil, fmt.Errorf("failed to generate a deserializer for %s\n\t * %w", fieldPath, err) @@ -801,7 +803,7 @@ func makeSliceDeserializer(fieldPath string, fieldType reflect.Type, options sta // Simply deserialize. var ok bool if input, ok = inValue.AsSlice(); !ok { - return fmt.Errorf("error while deserializing %s[]: expected an array", fieldType) + return fmt.Errorf("error while deserializing %s: expected an array", fieldType) } case isEmptyDefault: // Nothing to deserialize, but we are allowed to default to an empty array. @@ -868,7 +870,7 @@ func makePointerDeserializer(fieldPath string, fieldType reflect.Type, options s elemType := fieldType.Elem() subTags := tagsPkg.Empty() subContainer := reflect.New(fieldType).Elem() - childPreinitialized := tags.IsPreinitialized() + childPreinitialized := wasPreinitialized || tags.IsPreinitialized() elementDeserializer, err := makeFieldDeserializerFromReflect(ptrPath, fieldType.Elem(), options, &subTags, subContainer, childPreinitialized) if err != nil { return nil, fmt.Errorf("failed to generate a deserializer for %s\n\t * %w", fieldPath, err) @@ -946,13 +948,14 @@ func makeFlatFieldDeserializer(fieldPath string, fieldType reflect.Type, options var unmarshaler *func(any) (any, error) if options.unmarshaler.ShouldUnmarshal(fieldType) { u := func(source any) (any, error) { - result := reflect.New(fieldType).Interface() - err := options.unmarshaler.Unmarshal(source, &result) + ptrResult := reflect.New(fieldType) + anyResult := ptrResult.Interface() + err := options.unmarshaler.Unmarshal(source, &anyResult) if err != nil { err = fmt.Errorf("invalid data at, expected to be able to parse a %s:\n\t * %w", typeName, err) return nil, err } - return result, nil + return ptrResult.Elem().Interface(), nil } unmarshaler = &u } @@ -992,7 +995,12 @@ func makeFlatFieldDeserializer(fieldPath string, fieldType reflect.Type, options // We have all the data we need, proceed. input = inValue.Interface() case wasPreinitialized: - input = outPtr.Interface() + if outPtr.CanInterface() { + input = outPtr.Interface() + } else { + // This is a private field that was already initialized, nothing to do here. + return nil + } case defaultValue != nil: input = defaultValue case orMethod != nil: @@ -1060,41 +1068,70 @@ func makeFlatFieldDeserializer(fieldPath string, fieldType reflect.Type, options // - `tagName` the name of tags to use for field renamings, e.g. `query`; // - `tags` the table of tags for this field. func makeFieldDeserializerFromReflect(fieldPath string, fieldType reflect.Type, options staticOptions, tags *tagsPkg.Tags, container reflect.Value, wasPreinitialized bool) (reflectDeserializer, error) { - var result reflectDeserializer + var structured reflectDeserializer var err error switch fieldType.Kind() { case reflect.Pointer: - result, err = makePointerDeserializer(fieldPath, fieldType, options, tags, container, wasPreinitialized) + structured, err = makePointerDeserializer(fieldPath, fieldType, options, tags, container, wasPreinitialized) case reflect.Array: fallthrough case reflect.Slice: if options.allowNested { - result, err = makeSliceDeserializer(fieldPath, fieldType, options, tags, container, wasPreinitialized) + structured, err = makeSliceDeserializer(fieldPath, fieldType, options, tags, container, wasPreinitialized) } else { return nil, fmt.Errorf("this type of extractor does not support arrays/slices") } case reflect.Struct: if options.allowNested { - result, err = makeStructDeserializerFromReflect(fieldPath, fieldType, options, tags, container, wasPreinitialized) + structured, err = makeStructDeserializerFromReflect(fieldPath, fieldType, options, tags, container, wasPreinitialized) } else { return nil, fmt.Errorf("this type of extractor does not support nested structs") } case reflect.Map: if options.allowNested { - result, err = makeMapDeserializerFromReflect(fieldPath, fieldType, options, tags, container, wasPreinitialized) + structured, err = makeMapDeserializerFromReflect(fieldPath, fieldType, options, tags, container, wasPreinitialized) } else { return nil, fmt.Errorf("this type of extractor does not support nested maps") } default: - // If it's not a struct, an array, a slice or a pointer, well, it's probably something flat. - // - // We'll let `makeFlatFieldDeserializer` detect whether we can generate a deserializer for it. - result, err = makeFlatFieldDeserializer(fieldPath, fieldType, options, tags, container, wasPreinitialized) + // We'll have to try with a flat field deserializer (see below). } if err != nil { return nil, fmt.Errorf("could not generate a deserializer for %s with type %s:\n\t * %w", fieldPath, typeName(fieldType), err) } - return result, nil + + // Case 1: We already have a deserializer, but for some reason, we could end up with, say, a string + // instead of the data structure we hope for (that's what happens with `uuid.UUID`). + // + // Case 2: We don't have a deserializer yet, because the data is flat (string, int, etc.) + // + // In either case, prepare a flat deserializer. + flat, err2 := makeFlatFieldDeserializer(fieldPath, fieldType, options, tags, container, wasPreinitialized) + if structured == nil { + if err2 == nil { + // Alright, we have a flat field deserializer and that's the only way we can deserialize this structure. + return flat, nil + } + // Neither structured deserializer nor flat field deserializer, we can't deserialize at all. + return nil, fmt.Errorf("could not generate a deserializer for %s with type %s:\n\t * %w", fieldPath, typeName(fieldType), err2) + } + if err2 != nil { + // We have a structured deserializer and that's the only way we can deserialize this structure. + return structured, nil + } + // We have both a flat and a structured deserializer. Need to try both! + var combined reflectDeserializer = func(slot *reflect.Value, data shared.Value) error { + err := structured(slot, data) + if err == nil { + return nil + } + err2 := flat(slot, data) + if err2 == nil { + return nil + } + return err + } + return combined, nil } // Return a (mostly) human-readable type name for a Go type. diff --git a/deserialize/deserialize_test.go b/deserialize/deserialize_test.go index b87ec06..506df51 100644 --- a/deserialize/deserialize_test.go +++ b/deserialize/deserialize_test.go @@ -2,12 +2,16 @@ package deserialize_test import ( + "encoding" "encoding/json" "errors" "fmt" "strings" "testing" + "time" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/uuid" "github.com/pasqal-io/godasse/deserialize" jsonPkg "github.com/pasqal-io/godasse/deserialize/json" "github.com/pasqal-io/godasse/deserialize/kvlist" @@ -1199,3 +1203,65 @@ func TestSupportBothUnmarshalerAndDictInitializer(t *testing.T) { } // ----- + +// ----- Test that we can deserialize a struct with a field that should not be deserializable if we have some kind of pre-initializer. + +type StructThatCannotBeDeserialized struct { + private bool +} + +type StructThatCannotBeDeserialized2 struct { + private bool +} + +func (*StructThatCannotBeDeserialized2) Initialize() error { + return nil +} + +type StructWithTime struct { + Field StructThatCannotBeDeserialized `initialized:"true"` + Field2 StructThatCannotBeDeserialized2 + Field3 time.Time +} + +func TestDeserializingWithPreinitializer(t *testing.T) { + date := time.Date(2000, 01, 01, 01, 01, 01, 01, time.UTC) + sample := StructWithTime{ + Field: StructThatCannotBeDeserialized{private: false}, + Field2: StructThatCannotBeDeserialized2{private: false}, + Field3: date, + } + result, err := twoWays[StructWithTime](t, sample) + assert.NilError(t, err) + assert.DeepEqual(t, result, &sample, cmpopts.IgnoreUnexported(StructThatCannotBeDeserialized{}, StructThatCannotBeDeserialized2{})) +} + +// ------ + +// ------ Test that we can deserialize uuid + +type TextUnmarshalerUUID uuid.UUID + +func (t *TextUnmarshalerUUID) UnmarshalText(source []byte) error { + result, err := uuid.Parse(string(source)) + if err != nil { + return err //nolint:wrapcheck + } + *t = TextUnmarshalerUUID(result) + return nil +} + +var _ encoding.TextUnmarshaler = &TextUnmarshalerUUID{} + +type StructWithUUID struct { + Field TextUnmarshalerUUID +} + +func TestDeserializeUUID(t *testing.T) { + sample := StructWithUUID{ + Field: TextUnmarshalerUUID(uuid.New()), + } + result, err := twoWays[StructWithUUID](t, sample) + assert.NilError(t, err) + assert.DeepEqual(t, result, &sample) +} diff --git a/deserialize/json/json.go b/deserialize/json/json.go index ef45561..aea2797 100644 --- a/deserialize/json/json.go +++ b/deserialize/json/json.go @@ -2,6 +2,7 @@ package json import ( + "encoding" "encoding/json" "errors" "fmt" @@ -80,6 +81,7 @@ var dictionary = reflect.TypeOf(make(JSON, 0)) // The interface for `json.Unmarshaler`. var unmarshaler = reflect.TypeOf(new(json.Unmarshaler)).Elem() +var textUnmarshaler = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem() // Determine whether we should call the driver to unmarshal values // of this type from []byte. @@ -93,10 +95,8 @@ func (u Driver) ShouldUnmarshal(typ reflect.Type) bool { if typ.ConvertibleTo(dictionary) { return true } - if reflect.PointerTo(typ).ConvertibleTo(unmarshaler) { - return true - } - return false + ptr := reflect.PointerTo(typ) + return ptr.ConvertibleTo(unmarshaler) || ptr.ConvertibleTo(textUnmarshaler) } // Perform unmarshaling. @@ -141,13 +141,26 @@ func (u Driver) Unmarshal(in any, out *any) (err error) { // Attempt to deserialize as a `json.Unmarshaler`. if unmarshal, ok := (*out).(json.Unmarshaler); ok { - return unmarshal.UnmarshalJSON(buf) //nolint:wrapcheck + err = unmarshal.UnmarshalJSON(buf) + } else { + err = json.Unmarshal(buf, out) + } + if err == nil { + // Basic JSON decoding worked, let's go with it. + return nil } - err = json.Unmarshal(buf, out) - if err != nil { - return fmt.Errorf("failed to unmarshal '%s': \n\t * %w", buf, err) + // But sometimes, things aren't that nice. For instance, time.Time serializes + // itself as an unencoded string, but its UnmarshalJSON expects an encoded string. + // Just in case, let's try again with UnmarshalText. + if textUnmarshaler, ok := (*out).(encoding.TextUnmarshaler); ok { + err2 := textUnmarshaler.UnmarshalText(buf) + if err2 == nil { + // Success! Let's use that result. + return nil + } + return fmt.Errorf("failed to unmarshal '%s' either from JSON or from text: \n\t * %w\n\t * and %w", buf, err, err2) } - return nil + return fmt.Errorf("failed to unmarshal '%s': \n\t * %w", buf, err) } func (u Driver) WrapValue(wrapped any) shared.Value { diff --git a/go.mod b/go.mod index ac9da4b..5fd10ee 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,8 @@ module github.com/pasqal-io/godasse go 1.21 -require gotest.tools/v3 v3.5.1 - -require github.com/google/go-cmp v0.5.9 // indirect +require ( + github.com/google/go-cmp v0.5.9 + github.com/google/uuid v1.6.0 + gotest.tools/v3 v3.5.1 +) diff --git a/go.sum b/go.sum index 7dd4ab5..e5ea833 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= diff --git a/validation/validation.go b/validation/validation.go index ba2009f..a394d5c 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -62,13 +62,15 @@ func (v Error) Error() string { for cursor != nil { switch cursor.kind { case kindField: - buf = append(buf, fmt.Sprint(cursor.entry)) + buf = append(buf, fmt.Sprint(".", cursor.entry)) case kindIndex: buf = append(buf, fmt.Sprintf("[%d]", cursor.entry)) case kindKey: buf = append(buf, fmt.Sprintf("[>> %v <<]", cursor.entry)) case kindValue: buf = append(buf, fmt.Sprintf("[%v]", cursor.entry)) + case kindRoot: + buf = append(buf, fmt.Sprint(cursor.entry)) case kindInterface: // Keep buf unchanged. case kindDereference: @@ -80,9 +82,6 @@ func (v Error) Error() string { for i := len(buf) - 1; i >= 0; i-- { serialized += buf[i] } - if serialized == "" { - serialized = "root" //nolint:ineffassign - } return fmt.Sprintf("validation error at %s:\n\t * %s", buf, v.wrapped.Error()) } @@ -97,6 +96,8 @@ func (v Error) Unwrap() error { type entryKind string const ( + kindRoot entryKind = "ROOT" + // Visiting a field. kindField entryKind = "FIELD" @@ -221,6 +222,11 @@ func validateReflect(path *path, value reflect.Value) error { return nil } func Validate[T any](value *T) error { + root := path{ + prev: nil, + kind: kindRoot, + entry: fmt.Sprintf("%T", *value), + } reflected := reflect.ValueOf(value) - return validateReflect(nil, reflected) + return validateReflect(&root, reflected) }