Skip to content

Commit

Permalink
feat: add more test case
Browse files Browse the repository at this point in the history
  • Loading branch information
dapeng committed Dec 22, 2024
1 parent c39c87c commit cf740ee
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 37 deletions.
103 changes: 84 additions & 19 deletions config.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package gone

import (
"github.com/gone-io/gone/internal/json"
"os"
"reflect"
"strconv"
"time"

"github.com/gone-io/gone/internal/json"
)

const ConfigureName = "configure"
Expand Down Expand Up @@ -52,12 +53,20 @@ func (s *ConfigProvider) Provide(tagConf string, t reflect.Type) (any, error) {
defaultValue = m["default"] // Fallback to "default" key if no value
}

var getType = t
if t.Kind() == reflect.Ptr {
getType = t.Elem()
}

// Create new value of requested type and configure it
value := reflect.New(t)
value := reflect.New(getType)
err := s.configure.Get(key, value.Interface(), defaultValue)
if err != nil {
return nil, ToError(err)
}
if t.Kind() == reflect.Ptr {
return value.Interface(), nil
}
return value.Elem().Interface(), nil
}

Expand Down Expand Up @@ -93,62 +102,118 @@ func (s *EnvConfigure) Get(key string, v any, defaultVal string) error {
return NewInnerError("Value must be a pointer", ConfigError)
}

// Set default "0" for numeric and boolean types when env is empty
if env == "" {
switch v.(type) {
case *int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64,
*float32, *float64, *bool, *time.Duration:
env = "0"
}
}

// Type switch to handle different pointer types
switch ptr := v.(type) {
// String type
case *string:
// String type needs no conversion
*ptr = env

// Int types
case *int:
// Convert string to int
val, err := strconv.Atoi(env)
if err != nil {
return ToError(err)
}
*ptr = val
case *int64:
// Convert string to int64
val, err := strconv.ParseInt(env, 10, 64)
case *int8:
val, err := strconv.ParseInt(env, 10, 8)
if err != nil {
return ToError(err)
}
*ptr = val
case *float64:
// Convert string to float64
val, err := strconv.ParseFloat(env, 64)
*ptr = int8(val)
case *int16:
val, err := strconv.ParseInt(env, 10, 16)
if err != nil {
return ToError(err)
}
*ptr = val
case *bool:
// Convert string to bool
val, err := strconv.ParseBool(env)
*ptr = int16(val)
case *int32:
val, err := strconv.ParseInt(env, 10, 32)
if err != nil {
return ToError(err)
}
*ptr = int32(val)
case *int64:
val, err := strconv.ParseInt(env, 10, 64)
if err != nil {
return ToError(err)
}
*ptr = val

// Unsigned int types
case *uint:
// Convert string to uint
val, err := strconv.ParseUint(env, 10, 64)
if err != nil {
return ToError(err)
}
*ptr = uint(val)
case *uint8:
val, err := strconv.ParseUint(env, 10, 8)
if err != nil {
return ToError(err)
}
*ptr = uint8(val)
case *uint16:
val, err := strconv.ParseUint(env, 10, 16)
if err != nil {
return ToError(err)
}
*ptr = uint16(val)
case *uint32:
val, err := strconv.ParseUint(env, 10, 32)
if err != nil {
return ToError(err)
}
*ptr = uint32(val)
case *uint64:
// Convert string to uint64
val, err := strconv.ParseUint(env, 10, 64)
if err != nil {
return ToError(err)
}
*ptr = val

// Float types
case *float32:
val, err := strconv.ParseFloat(env, 32)
if err != nil {
return ToError(err)
}
*ptr = float32(val)
case *float64:
val, err := strconv.ParseFloat(env, 64)
if err != nil {
return ToError(err)
}
*ptr = val

// Boolean type
case *bool:
val, err := strconv.ParseBool(env)
if err != nil {
return ToError(err)
}
*ptr = val

// Time duration type
case *time.Duration:
// Convert string to time.Duration
val, err := time.ParseDuration(env)
if err != nil {
return ToError(err)
}
*ptr = val

// Struct and unsupported types
default:
// Handle struct types by JSON unmarshal
if rv.Elem().Kind() == reflect.Struct {
err := json.Unmarshal([]byte(env), v)
if err != nil {
Expand Down
19 changes: 8 additions & 11 deletions core.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,19 +216,16 @@ func (s *Core) Install() error {
return nil
}

func (s *Core) safeFillOne(coffin *coffin) error {
var err error
err = SafeExecute(func() {
err = s.fillOne(coffin)
func (s *Core) safeFillOne(coffin *coffin) (err error) {
return SafeExecute(func() error {
return s.fillOne(coffin)
})
return err
}

func (s *Core) safeInitOne(coffin *coffin) error {
var err error
err = SafeExecute(func() {
err = s.initOne(coffin)
return SafeExecute(func() error {
return s.initOne(coffin)
})
return err
}

func (s *Core) fillOne(coffin *coffin) error {
Expand Down Expand Up @@ -333,11 +330,11 @@ func (s *Core) InjectStruct(goner any) error {
co := &coffin{
goner: goner,
}
err := s.fillOne(co)
err := s.safeFillOne(co)
if err != nil {
return ToError(err)
}
return s.initOne(co)
return s.safeFillOne(co)
}

func (s *Core) GetGonerByName(name string) any {
Expand Down
158 changes: 158 additions & 0 deletions core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,37 @@ func (m *MockInitiator) Init() error {
return m.initError
}

type MockBeforeInitNoError struct {
Flag
beforeInitCalled bool
}

func (m *MockBeforeInitNoError) BeforeInit() {
m.beforeInitCalled = true
}

type MockStructFieldInjector struct {
Flag
}

func (m *MockStructFieldInjector) GonerName() string {
return "field-injector"
}

func (m *MockStructFieldInjector) Inject(conf string, field reflect.StructField, v reflect.Value) error {
if conf == "error" {
return fmt.Errorf("injection error")
}
v.Set(reflect.ValueOf("injected value"))
return nil
}

type StructWithUnexportedField struct {
Flag
dep *MockDependency `gone:"*"`
Public *MockDependency `gone:"*"`
}

func TestNewCore(t *testing.T) {
core := NewCore()

Expand Down Expand Up @@ -188,6 +219,60 @@ func TestCore_Fill(t *testing.T) {
},
wantErr: true,
},
{
name: "BeforeInitNoError implementation",
setup: func(core *Core) {
mock := &MockBeforeInitNoError{}
_ = core.Load(mock)
},
wantErr: false,
},
{
name: "Unexported field injection",
setup: func(core *Core) {
_ = core.Load(&MockDependency{})
_ = core.Load(&StructWithUnexportedField{})
},
wantErr: false,
},
{
name: "StructFieldInjector success",
setup: func(core *Core) {
_ = core.Load(&MockStructFieldInjector{})
type TestStruct struct {
Flag
Value string `gone:"field-injector"`
}
_ = core.Load(&TestStruct{})
},
wantErr: false,
},
{
name: "StructFieldInjector error",
setup: func(core *Core) {
_ = core.Load(&MockStructFieldInjector{})
type TestStruct struct {
Flag
Value string `gone:"field-injector-error"`
}
_ = core.Load(&TestStruct{})
},
wantErr: true,
},
{
name: "Provider with invalid return type",
setup: func(core *Core) {
_ = core.Load(&MockProvider{
returnVal: "invalid",
})
type TestStruct struct {
Flag
Value *MockDependency `gone:"mock-provider"`
}
_ = core.Load(&TestStruct{})
},
wantErr: true,
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -577,6 +662,31 @@ func TestCore_Provide(t *testing.T) {
want: false,
wantErr: true,
},
{
name: "Provider returns incompatible type",
setup: func(core *Core) {
_ = core.Load(&MockProvider{
returnVal: "invalid string instead of MockDependency",
})
},
typ: reflect.TypeOf(&MockDependency{}),
tagConf: "mock-provider",
want: false,
wantErr: true,
},
{
name: "Slice with mixed sources",
setup: func(core *Core) {
_ = core.Load(&MockDependency{})
_ = core.Load(&MockProvider{
returnVal: &MockDependency{},
})
},
typ: reflect.TypeOf([]*MockDependency{}),
tagConf: "",
want: true,
wantErr: false,
},
}

for _, tt := range tests {
Expand All @@ -595,3 +705,51 @@ func TestCore_Provide(t *testing.T) {
})
}
}

func TestCore_InjectStruct_EdgeCases(t *testing.T) {
tests := []struct {
name string
target interface{}
setup func(*Core)
wantErr bool
}{
{
name: "Nil pointer",
target: (*struct {
Flag
Dep *MockDependency `gone:"*"`
})(nil),
setup: func(core *Core) {},
wantErr: true,
},
{
name: "Non-struct pointer",
target: new(string),
setup: func(core *Core) {},
wantErr: true,
},
{
name: "Invalid tag configuration",
target: &struct {
Flag
Dep *MockDependency `gone:"invalid,config"`
}{},
setup: func(core *Core) {
_ = core.Load(&MockDependency{})
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
core := NewCore()
tt.setup(core)

err := core.InjectStruct(tt.target)
if (err != nil) != tt.wantErr {
t.Errorf("InjectStruct() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
Loading

0 comments on commit cf740ee

Please sign in to comment.