Skip to content

Commit

Permalink
auto anchor
Browse files Browse the repository at this point in the history
  • Loading branch information
goccy committed Dec 23, 2024
1 parent 9cbf5d4 commit ceb94a1
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 5 deletions.
16 changes: 11 additions & 5 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type Decoder struct {
reader io.Reader
referenceReaders []io.Reader
anchorNodeMap map[string]ast.Node
aliasValueMap map[*ast.AliasNode]any
aliasValueMap map[string]any
anchorValueMap map[string]reflect.Value
customUnmarshalerMap map[reflect.Type]func(interface{}, []byte) error
toCommentMap CommentMap
Expand All @@ -51,7 +51,7 @@ func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder {
return &Decoder{
reader: r,
anchorNodeMap: map[string]ast.Node{},
aliasValueMap: make(map[*ast.AliasNode]any),
aliasValueMap: make(map[string]any),
anchorValueMap: map[string]reflect.Value{},
customUnmarshalerMap: map[reflect.Type]func(interface{}, []byte) error{},
opts: opts,
Expand Down Expand Up @@ -447,13 +447,18 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
return nil, err
}
d.anchorNodeMap[anchorName] = n.Value
d.anchorValueMap[anchorName] = reflect.ValueOf(anchorValue)
return anchorValue, nil
case *ast.AliasNode:
if v, exists := d.aliasValueMap[n]; exists {
if v, exists := d.anchorValueMap[n.Value.String()]; exists {
return v.Interface(), nil
}
text := n.String()
if v, exists := d.aliasValueMap[text]; exists {
return v, nil
}
// To handle the case where alias is processed recursively, the result of alias can be set to nil in advance.
d.aliasValueMap[n] = nil
d.aliasValueMap[text] = nil

aliasName := n.Value.GetToken().Value
node, exists := d.anchorNodeMap[aliasName]
Expand All @@ -465,7 +470,7 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
return nil, err
}
// once the correct alias value is obtained, overwrite with that value.
d.aliasValueMap[n] = aliasValue
d.aliasValueMap[text] = aliasValue
return aliasValue, nil
case *ast.LiteralNode:
return n.Value.GetValue(), nil
Expand Down Expand Up @@ -1985,6 +1990,7 @@ func (d *Decoder) decodeInit() error {

func (d *Decoder) decode(ctx context.Context, v reflect.Value) error {
d.decodeDepth = 0
d.aliasValueMap = make(map[string]any)
if len(d.parsedFile.Docs) <= d.streamIndex {
return io.EOF
}
Expand Down
53 changes: 53 additions & 0 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ type Encoder struct {
isFlowStyle bool
isJSONStyle bool
useJSONMarshaler bool
useAutoAnchor bool
anchorCallback func(*ast.AnchorNode, interface{}) error
anchorPtrToNameMap map[uintptr]string
anchorNameRefMap map[string]struct{}
customMarshalerMap map[reflect.Type]func(interface{}) ([]byte, error)
useLiteralStyleIfMultiline bool
commentMap map[*Path][]*Comment
Expand All @@ -56,6 +58,7 @@ func NewEncoder(w io.Writer, opts ...EncodeOption) *Encoder {
opts: opts,
indent: DefaultIndentSpaces,
anchorPtrToNameMap: map[uintptr]string{},
anchorNameRefMap: make(map[string]struct{}),
customMarshalerMap: map[reflect.Type]func(interface{}) ([]byte, error){},
line: 1,
column: 1,
Expand Down Expand Up @@ -111,6 +114,10 @@ func (e *Encoder) EncodeToNodeContext(ctx context.Context, v interface{}) (ast.N
return nil, err
}
}
if _, err := e.encodeValue(ctx, reflect.ValueOf(v), 1); err != nil {
return nil, err
}
e.anchorPtrToNameMap = make(map[uintptr]string)
node, err := e.encodeValue(ctx, reflect.ValueOf(v), 1)
if err != nil {
return nil, err
Expand Down Expand Up @@ -448,6 +455,7 @@ func (e *Encoder) encodeValue(ctx context.Context, v reflect.Value, column int)
case reflect.Ptr:
anchorName := e.anchorPtrToNameMap[v.Pointer()]
if anchorName != "" {
e.anchorNameRefMap[anchorName] = struct{}{}
aliasName := anchorName
alias := ast.Alias(token.New("*", "*", e.pos(column)))
alias.Value = ast.String(token.New(aliasName, aliasName, e.pos(column)))
Expand All @@ -464,6 +472,14 @@ func (e *Encoder) encodeValue(ctx context.Context, v reflect.Value, column int)
if mapSlice, ok := v.Interface().(MapSlice); ok {
return e.encodeMapSlice(ctx, mapSlice, column)
}
anchorName := e.anchorPtrToNameMap[v.Pointer()]
if anchorName != "" {
e.anchorNameRefMap[anchorName] = struct{}{}
aliasName := anchorName
alias := ast.Alias(token.New("*", "*", e.pos(column)))
alias.Value = ast.String(token.New(aliasName, aliasName, e.pos(column)))
return alias, nil
}
return e.encodeSlice(ctx, v)
case reflect.Array:
return e.encodeArray(ctx, v)
Expand All @@ -478,6 +494,13 @@ func (e *Encoder) encodeValue(ctx context.Context, v reflect.Value, column int)
}
return e.encodeStruct(ctx, v, column)
case reflect.Map:
anchorName := e.anchorPtrToNameMap[v.Pointer()]
if anchorName != "" {
aliasName := anchorName
alias := ast.Alias(token.New("*", "*", e.pos(column)))
alias.Value = ast.String(token.New(aliasName, aliasName, e.pos(column)))
return alias, nil
}
return e.encodeMap(ctx, v, column), nil
default:
return nil, fmt.Errorf("unknown value type %s", v.Type().String())
Expand Down Expand Up @@ -662,11 +685,21 @@ func (e *Encoder) encodeMap(ctx context.Context, value reflect.Value, column int
if e.isMapNode(value) {
value.AddColumn(e.indent)
}
if _, exists := e.anchorNameRefMap[fmt.Sprint(key)]; exists {
anchorName := fmt.Sprint(key)
anchorNode := ast.Anchor(token.New("&", "&", e.pos(column)))
anchorNode.Name = ast.String(token.New(anchorName, anchorName, e.pos(column)))
anchorNode.Value = value
value = anchorNode
}
node.Values = append(node.Values, ast.MappingValue(
nil,
e.encodeString(fmt.Sprint(key), column),
value,
))
if ptr := e.toPointer(v); ptr != 0 {
e.anchorPtrToNameMap[ptr] = fmt.Sprint(key)
}
}
return node
}
Expand Down Expand Up @@ -868,3 +901,23 @@ func (e *Encoder) encodeStruct(ctx context.Context, value reflect.Value, column
}
return node, nil
}

func (e *Encoder) toPointer(v reflect.Value) uintptr {
if e.isInvalidValue(v) {
return 0
}

switch v.Type().Kind() {
case reflect.Ptr:
return v.Pointer()
case reflect.Interface:
return e.toPointer(v.Elem())
case reflect.Slice:
return v.Pointer()
case reflect.Array:
return v.Pointer()
case reflect.Map:
return v.Pointer()
}
return 0
}
7 changes: 7 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ func Flow(isFlowStyle bool) EncodeOption {
}
}

func UseAutoAnchor() EncodeOption {
return func(e *Encoder) error {
e.useAutoAnchor = true
return nil
}
}

// UseLiteralStyleIfMultiline causes encoding multiline strings with a literal syntax,
// no matter what characters they include
func UseLiteralStyleIfMultiline(useLiteralStyleIfMultiline bool) EncodeOption {
Expand Down

0 comments on commit ceb94a1

Please sign in to comment.