From 3dee020719df47246f8bd26648d1125db1e5ddee Mon Sep 17 00:00:00 2001 From: Chris Baker <1675087+cgbaker@users.noreply.github.com> Date: Mon, 30 Nov 2020 23:04:01 +0000 Subject: [PATCH] more tests for nested variables, fix issue where they were overwriting --- go.mod | 1 + helper/kvflag.go | 47 ++++++++++++---------------- helper/kvflag_test.go | 71 +++++++++++++++++++++++++++++++++---------- 3 files changed, 76 insertions(+), 43 deletions(-) diff --git a/go.mod b/go.mod index b99822f88..55b2b73e8 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/hashicorp/consul/api v1.7.0 github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-hclog v0.14.1 // indirect + github.com/hashicorp/go-multierror v1.1.0 github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/hashicorp/hcl/v2 v2.7.1 github.com/hashicorp/nomad v0.12.5-0.20201123213618-289d91df2e1c diff --git a/helper/kvflag.go b/helper/kvflag.go index b2f36acc1..7e960314b 100644 --- a/helper/kvflag.go +++ b/helper/kvflag.go @@ -16,10 +16,11 @@ func (v *Flag) String() string { // Set takes a flag variable argument and pulls the correct key and value to // create or add to a map. func (v *Flag) Set(raw string) error { - idx := strings.Index(raw, "=") - if idx == -1 { + split := strings.SplitN(raw, "=", 2) + if len(split) != 2 { return fmt.Errorf("no '=' value in arg: %s", raw) } + keyRaw, value := split[0], split[1] if *v == nil { *v = make(map[string]interface{}) @@ -27,33 +28,25 @@ func (v *Flag) Set(raw string) error { // Split the variable key based on the nested delimiter to get a list of // nested keys. - keys := strings.Split(raw[0:idx], ".") + keys := strings.Split(keyRaw, ".") - // If we only have a single key, then we are not dealing with a nested set - // meaning we can update the variable mapping and exit. - if len(keys) == 1 { - (*v)[keys[0]] = raw[idx+1:] - return nil + lastKeyIdx := len(keys) - 1 + // Find the nested map where this value belongs + // create missing maps as we go + target := *v + for i := 0; i < lastKeyIdx; i++ { + raw, ok := target[keys[i]] + if !ok { + raw = make(map[string]interface{}) + target[keys[i]] = raw + } + var newTarget Flag + if newTarget, ok = raw.(map[string]interface{}); !ok { + return fmt.Errorf("simple value already exists at key %q", strings.Join(keys[:i+1], ".")) + } + target = newTarget } - - // Identify the index max of the list for easy use. - nestedLen := len(keys) - 1 - - // The end map is the only thing we concretely know which contains our - // final key:value pair. - endEntry := map[string]interface{}{keys[nestedLen]: raw[idx+1:]} - - // Track the root of the nested map structure so we can continue to iterate - // the nested keys below. - root := endEntry - - // Iterate the nested keys backwards. Set a new root map containing the - // previous root as its value. Do not iterate backwards fully to the end, - // instead save the first key for the entry into Flag. - for i := nestedLen - 1; i > 0; i-- { - root = map[string]interface{}{keys[i]: root} - } - (*v)[keys[0]] = root + target[keys[lastKeyIdx]] = value return nil } diff --git a/helper/kvflag_test.go b/helper/kvflag_test.go index be9c5ce51..2a7c101df 100644 --- a/helper/kvflag_test.go +++ b/helper/kvflag_test.go @@ -3,51 +3,90 @@ package helper import ( "reflect" "testing" + + "github.com/hashicorp/go-multierror" + "github.com/stretchr/testify/require" ) func TestHelper_Set(t *testing.T) { cases := []struct { - Input string + Label string + Inputs []string Output map[string]interface{} Error bool }{ { - "key=value", + "simple value", + []string{"key=value"}, map[string]interface{}{"key": "value"}, false, }, { - "nested.key=value", + "nested replaces simple", + []string{"key=1", "key.nested=2"}, + nil, + true, + }, + { + "simple replaces nested", + []string{"key.nested=2", "key=1"}, + map[string]interface{}{"key": "1"}, + false, + }, + { + "nested siblings", + []string{"nested.a=1", "nested.b=2"}, + map[string]interface{}{"nested": map[string]interface{}{"a": "1", "b": "2"}}, + false, + }, + { + "nested singleton", + []string{"nested.key=value"}, map[string]interface{}{"nested": map[string]interface{}{"key": "value"}}, false, }, { - "key=", + "nested with parent", + []string{"root=a", "nested.key=value"}, + map[string]interface{}{"root": "a", "nested": map[string]interface{}{"key": "value"}}, + false, + }, + { + "empty value", + []string{"key="}, map[string]interface{}{"key": ""}, false, }, { - "key=foo=bar", + "value contains equal sign", + []string{"key=foo=bar"}, map[string]interface{}{"key": "foo=bar"}, false, }, { - "key", + "missing equal sign", + []string{"key"}, nil, true, }, } for _, tc := range cases { - f := new(Flag) - err := f.Set(tc.Input) - if (err != nil) != tc.Error { - t.Fatalf("bad error. Input: %#v", tc.Input) - } - - actual := map[string]interface{}(*f) - if !reflect.DeepEqual(actual, tc.Output) { - t.Fatalf("bad: %#v", actual) - } + t.Run(tc.Label, func(t *testing.T) { + f := new(Flag) + mErr := multierror.Error{} + for _, input := range tc.Inputs { + err := f.Set(input) + if err != nil { + mErr.Errors = append(mErr.Errors, err) + } + } + if tc.Error { + require.Error(t, mErr.ErrorOrNil()) + } else { + actual := map[string]interface{}(*f) + require.True(t, reflect.DeepEqual(actual, tc.Output)) + } + }) } }