Skip to content

Commit

Permalink
Merge pull request #13 from pasqal-io/yoric/uint
Browse files Browse the repository at this point in the history
[FIX] Covariance problems with reflect-based API.
  • Loading branch information
David Teller authored Apr 10, 2024
2 parents 46c5a1e + 7d32324 commit 8ef2217
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 72 deletions.
175 changes: 113 additions & 62 deletions deserialize/deserialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,20 @@ type MapDeserializer[To any] interface {
// Deserialize a list of values from a list of values.
DeserializeList([]shared.Value) ([]To, error)
}
type MapReflectDeserializer interface {
// Deserialize a single value from a dict.
DeserializeDictTo(shared.Dict, *reflect.Value) error
}

// A deserializer from key, lists of values.
//
// Use this to deserialize e.g. query strings.
type KVListDeserializer[To any] interface {
DeserializeKVList(kvlist.KVList) (*To, error)
}
type KVListReflectDeserializer interface {
DeserializeKVListTo(kvlist.KVList, *reflect.Value) error
}

// Create a deserializer from Dict.
func MakeMapDeserializer[T any](options Options) (MapDeserializer[T], error) {
Expand All @@ -192,7 +199,7 @@ func MakeMapDeserializer[T any](options Options) (MapDeserializer[T], error) {
unmarshaler: options.Unmarshaler,
})
}
func MakeMapDeserializerFromReflect(options Options, typ reflect.Type) (MapDeserializer[any], error) {
func MakeMapDeserializerFromReflect(options Options, typ reflect.Type) (MapReflectDeserializer, error) {
tagName := options.MainTagName
if tagName == "" {
return nil, errors.New("missing option MainTagName")
Expand All @@ -205,28 +212,30 @@ func MakeMapDeserializerFromReflect(options Options, typ reflect.Type) (MapDeser
}

noTags := tags.Empty()
reflectDeserializer, err := makeFieldDeserializerFromReflect(options.RootPath, typ, staticOptions, &noTags, placeholder, false)
reflectDeserializer, err := makeFieldDeserializerFromReflect(options.RootPath, typ, staticOptions, &noTags, placeholder, false, false)

if err != nil {
return nil, err
}
return mapDeserializer[any]{
deserializer: func(value shared.Dict) (*any, error) {
out := reflect.New(typ).Elem()
input := value.AsValue()
err := reflectDeserializer(&out, input)
if err != nil {
return nil, err
}

result := out.Interface()
return &result, nil
},
options: staticOptions,
return mapReflectDeserializer{
reflectDeserializer: reflectDeserializer,
}, nil

}

type mapReflectDeserializer struct {
reflectDeserializer reflectDeserializer
}

func (mrd mapReflectDeserializer) DeserializeDictTo(dict shared.Dict, reflectOut *reflect.Value) error {
input := dict.AsValue()
err := mrd.reflectDeserializer(reflectOut, input)
if err != nil {
return err
}
return nil
}

// Create a deserializer from (key, value list).
func MakeKVListDeserializer[T any](options Options) (KVListDeserializer[T], error) {
tagName := options.MainTagName
Expand All @@ -242,20 +251,20 @@ func MakeKVListDeserializer[T any](options Options) (KVListDeserializer[T], erro
if err != nil {
return nil, err
}
deserializer := func(value kvlist.KVList) (*T, error) {
deserializer := func(value kvlist.KVList, out *T) error {
normalized := make(jsonPkg.JSON)
err := deListMap[T](normalized, value, innerOptions)
if err != nil {
return nil, fmt.Errorf("error attempting to deserialize from a list of entries:\n\t * %w", err)
return fmt.Errorf("error attempting to deserialize from a list of entries:\n\t * %w", err)
}
return wrapped.deserializer(normalized)
return wrapped.deserializer(normalized, out)
}
return kvListDeserializer[T]{
deserializer: deserializer,
options: innerOptions,
}, nil
}
func MakeKVDeserializerFromReflect(options Options, typ reflect.Type) (KVListDeserializer[any], error) {
func MakeKVDeserializerFromReflect(options Options, typ reflect.Type) (KVListReflectDeserializer, error) {
tagName := options.MainTagName
if tagName == "" {
return nil, errors.New("missing option MainTagName")
Expand All @@ -265,26 +274,40 @@ func MakeKVDeserializerFromReflect(options Options, typ reflect.Type) (KVListDes
allowNested: false,
unmarshaler: options.Unmarshaler,
}
var placeholder = reflect.New(typ).Elem().Interface()
wrapped, err := makeOuterStructDeserializerFromReflect[any](".", innerOptions, &placeholder, typ)
var placeholder = reflect.New(typ).Elem()
noTags := tags.Empty()
wrapped, err := makeFieldDeserializerFromReflect(".", typ, innerOptions, &noTags, placeholder, false, false)
if err != nil {
return nil, err
}

deserializer := func(value kvlist.KVList) (*any, error) {
normalized := make(jsonPkg.JSON)
err := deListMapReflect(typ, normalized, value, innerOptions)
if err != nil {
return nil, fmt.Errorf("error attempting to deserialize from a list of entries:\n\t * %w", err)
}
return wrapped.deserializer(normalized)
}
return kvListDeserializer[any]{
deserializer: deserializer,
options: innerOptions,
return kvReflectDeserializer{
reflectDeserializer: wrapped,
options: innerOptions,
typ: typ,
}, nil
}

type kvReflectDeserializer struct {
reflectDeserializer reflectDeserializer
options staticOptions
typ reflect.Type
}

func (kvrd kvReflectDeserializer) DeserializeKVListTo(value kvlist.KVList, reflectOut *reflect.Value) error {
normalized := make(jsonPkg.JSON)
err := deListMapReflect(kvrd.typ, normalized, value, kvrd.options)
if err != nil {
return err
}

err = kvrd.reflectDeserializer(reflectOut, normalized.AsValue())
if err != nil {
return err
}
return nil
}

// An error that arises because of a bug in a custom deserializer.
type CustomDeserializerError struct {
// The operation that failed, e.g. "initialize", "orMethod".
Expand Down Expand Up @@ -326,7 +349,7 @@ type staticOptions struct {

// A deserializer from (key, value) maps.
type mapDeserializer[T any] struct {
deserializer func(value shared.Dict) (*T, error)
deserializer func(value shared.Dict, out *T) error
options staticOptions
}

Expand All @@ -339,39 +362,50 @@ func (me mapDeserializer[T]) DeserializeBytes(source []byte) (*T, error) {
if !ok {
return nil, errors.New("failed to deserialize as a dictionary")
}
return me.deserializer(asDict)
return me.DeserializeDict(asDict)
}

func (me mapDeserializer[T]) DeserializeString(source string) (*T, error) {
return me.DeserializeBytes([]byte(source))
}

func (me mapDeserializer[T]) DeserializeDict(value shared.Dict) (*T, error) {
return me.deserializer(value)
out := new(T)
err := me.deserializer(value, out)
if err != nil {
return nil, err
}
return out, nil
}

func (me mapDeserializer[T]) DeserializeList(list []shared.Value) ([]T, error) {
result := []T{}
for i, entry := range list {
if dict, ok := entry.AsDict(); ok {
r, err := me.deserializer(dict)
out := new(T)
err := me.deserializer(dict, out)
if err != nil {
return []T{}, fmt.Errorf("failed to deserialize entry %d: \n\t * %w", i, err)
}
result = append(result, *r)
result = append(result, *out)
}
}
return result, nil
}

// A deserializer from (key, []string) maps.
type kvListDeserializer[T any] struct {
deserializer func(value kvlist.KVList) (*T, error)
deserializer func(value kvlist.KVList, out *T) error
options staticOptions
}

func (me kvListDeserializer[T]) DeserializeKVList(value kvlist.KVList) (*T, error) {
return me.deserializer(value)
out := new(T)
err := me.deserializer(value, out)
if err != nil {
return nil, err
}
return out, nil
}

// Convert a `map[string] []string` (as provided e.g. by the query parser) into a `Dict`
Expand Down Expand Up @@ -461,7 +495,7 @@ var errorInterface = reflect.TypeOf((*error)(nil)).Elem()

const JSON = "json"

func makeOuterStructDeserializerFromReflect[T any](path string, options staticOptions, container *T, typ reflect.Type) (*mapDeserializer[T], error) {
func makeOuterStructDeserializerFromReflect(path string, options staticOptions, container reflect.Value, typ reflect.Type) (*mapDeserializer[any], error) {
if options.unmarshaler == nil {
return nil, errors.New("please specify an unmarshaler")
}
Expand All @@ -479,18 +513,20 @@ func makeOuterStructDeserializerFromReflect[T any](path string, options staticOp

// The outer struct can't have any tags attached.
tags := tagsPkg.Empty()
reflectDeserializer, err := makeStructDeserializerFromReflect(path, typ, options, &tags, reflect.ValueOf(container), initializationMetadata.canInitializeSelf)
reflectDeserializer, err := makeStructDeserializerFromReflect(path, typ, options, &tags, container, initializationMetadata.canInitializeSelf)
if err != nil {
return nil, err
}

var result = mapDeserializer[T]{
deserializer: func(value shared.Dict) (*T, error) {
result := new(T)
var result = mapDeserializer[any]{
deserializer: func(value shared.Dict, out *any) error {
result := reflect.ValueOf(out)
if initializationMetadata.canInitializeSelf {
var resultAny any = result
initializer, ok := resultAny.(validation.Initializer)
initializer, ok := any(out).(validation.Initializer)
var err error
if !ok && out != nil {
initializer, ok = (*out).(validation.Initializer)
}
if !ok {
err = errors.New("we have already checked that the result can be converted to `Initializer` but conversion has failed")
panic(err)
Expand All @@ -499,21 +535,21 @@ func makeOuterStructDeserializerFromReflect[T any](path string, options staticOp
if err != nil {
err = fmt.Errorf("at %s, encountered an error while initializing optional fields:\n\t * %w", path, err)
slog.Error("internal error during deserialization", "error", err)
return nil, CustomDeserializerError{
return CustomDeserializerError{
Wrapped: err,
Operation: "initializer",
Structure: "outer",
}

}
}
resultSlot := reflect.ValueOf(result).Elem()
resultSlot := result.Elem()
input := value.AsValue()
err := reflectDeserializer(&resultSlot, input)
if err != nil {
return nil, err
return err
}
return result, nil
return nil
},
options: options,
}
Expand All @@ -532,7 +568,22 @@ func makeOuterStructDeserializer[T any](path string, options staticOptions) (*ma

// Pre-check if we're going to perform initialization.
typ := reflect.TypeOf(*container)
return makeOuterStructDeserializerFromReflect[T](path, options, container, typ)
deserializerAny, err := makeOuterStructDeserializerFromReflect(path, options, reflect.ValueOf(container), typ)
if err != nil {
return nil, err
}
return &mapDeserializer[T]{
deserializer: func(value shared.Dict, out *T) error {
resultAny := any(out)
err := deserializerAny.deserializer(value, &resultAny)
if err != nil {
return err
}
*out, _ = resultAny.(T)
return nil
},
options: options,
}, nil
}

// Construct a dynamically-typed deserializer for structs.
Expand Down Expand Up @@ -590,7 +641,7 @@ func makeStructDeserializerFromReflect(path string, typ reflect.Type, options st
fieldPath := fmt.Sprint(path, ".", *publicFieldName)

var fieldContentDeserializer reflectDeserializer
fieldContentDeserializer, err = makeFieldDeserializerFromReflect(fieldPath, fieldType, options, &tags, selfContainer, willPreinitialize)
fieldContentDeserializer, err = makeFieldDeserializerFromReflect(fieldPath, fieldType, options, &tags, selfContainer, willPreinitialize, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -769,7 +820,7 @@ func makeMapDeserializerFromReflect(path string, typ reflect.Type, options stati
subPath := path + "[]"
subTags := tagsPkg.Empty()
subTyp := typ.Elem()
contentDeserializer, err := makeFieldDeserializerFromReflect(subPath, subTyp, options, &subTags, selfContainer, initializationMetadata.willPreinitialize)
contentDeserializer, err := makeFieldDeserializerFromReflect(subPath, subTyp, options, &subTags, selfContainer, initializationMetadata.willPreinitialize, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -878,7 +929,7 @@ func makeSliceDeserializer(fieldPath string, fieldType reflect.Type, options sta

// Prepare a deserializer for elements in this slice.
childPreinitialized := wasPreinitialized || tags.IsPreinitialized()
elementDeserializer, err := makeFieldDeserializerFromReflect(arrayPath, fieldType.Elem(), options, &subTags, subContainer, childPreinitialized)
elementDeserializer, err := makeFieldDeserializerFromReflect(arrayPath, fieldType.Elem(), options, &subTags, subContainer, childPreinitialized, true)
if err != nil {
return nil, fmt.Errorf("failed to generate a deserializer for %s\n\t * %w", fieldPath, err)
}
Expand Down Expand Up @@ -955,13 +1006,13 @@ func makeSliceDeserializer(fieldPath string, fieldType reflect.Type, options sta
// - `fieldPath` the human-readable path into the data structure, used for error-reporting;
// - `fieldType` the dynamic type for the pointer being compiled;
// - `tags` the table of tags for this field.
func makePointerDeserializer(fieldPath string, fieldType reflect.Type, options staticOptions, tags *tagsPkg.Tags, container reflect.Value, wasPreinitialized bool) (reflectDeserializer, error) {
func makePointerDeserializer(fieldPath string, fieldType reflect.Type, options staticOptions, tags *tagsPkg.Tags, container reflect.Value, wasPreinitialized bool, wasNested bool) (reflectDeserializer, error) {
ptrPath := fmt.Sprint(fieldPath, "*")
elemType := fieldType.Elem()
subTags := tagsPkg.Empty()
subContainer := reflect.New(fieldType).Elem()
childPreinitialized := wasPreinitialized || tags.IsPreinitialized()
elementDeserializer, err := makeFieldDeserializerFromReflect(ptrPath, fieldType.Elem(), options, &subTags, subContainer, childPreinitialized)
elementDeserializer, err := makeFieldDeserializerFromReflect(ptrPath, fieldType.Elem(), options, &subTags, subContainer, childPreinitialized, wasNested)
if err != nil {
return nil, fmt.Errorf("failed to generate a deserializer for %s\n\t * %w", fieldPath, err)
}
Expand Down Expand Up @@ -1179,29 +1230,29 @@ func makeFlatFieldDeserializer(fieldPath string, fieldType reflect.Type, options
// - `typ` the dynamic type for the field being compiled;
// - `tagName` the name of tags to use for field renamings, e.g. `query`;
// - `tags` the table of tags for this field.
func makeFieldDeserializerFromReflect(fieldPath string, fieldType reflect.Type, options staticOptions, tags *tagsPkg.Tags, container reflect.Value, wasPreinitialized bool) (reflectDeserializer, error) {
func makeFieldDeserializerFromReflect(fieldPath string, fieldType reflect.Type, options staticOptions, tags *tagsPkg.Tags, container reflect.Value, wasPreinitialized bool, wasNested bool) (reflectDeserializer, error) {
var structured reflectDeserializer
var err error
var nestError error
switch fieldType.Kind() {
case reflect.Pointer:
structured, err = makePointerDeserializer(fieldPath, fieldType, options, tags, container, wasPreinitialized)
structured, err = makePointerDeserializer(fieldPath, fieldType, options, tags, container, wasPreinitialized, wasNested)
case reflect.Array:
fallthrough
case reflect.Slice:
if options.allowNested {
if options.allowNested || !wasNested {
structured, err = makeSliceDeserializer(fieldPath, fieldType, options, tags, container, wasPreinitialized)
} else {
nestError = errors.New("this type of extractor does not support arrays/slices")
}
case reflect.Struct:
if options.allowNested {
if options.allowNested || !wasNested {
structured, err = makeStructDeserializerFromReflect(fieldPath, fieldType, options, tags, container, wasPreinitialized)
} else {
nestError = errors.New("this type of extractor does not support nested structs")
}
case reflect.Map:
if options.allowNested {
if options.allowNested || !wasNested {
structured, err = makeMapDeserializerFromReflect(fieldPath, fieldType, options, tags, container, wasPreinitialized)
} else {
nestError = errors.New("this type of extractor does not support nested maps")
Expand Down Expand Up @@ -1305,7 +1356,7 @@ func canInterface(typ reflect.Type, interfaceType reflect.Type) (bool, error) {
if typ.Implements(interfaceType) {
return false, fmt.Errorf("type %s implements %s - it should be implemented by pointer type *%s instead", typ, interfaceType, typ)
}
if ptrTyp.ConvertibleTo(interfaceType) {
if ptrTyp.Implements(interfaceType) {
return true, nil
}
return false, nil
Expand Down
Loading

0 comments on commit 8ef2217

Please sign in to comment.