From 2b3323f0bfe7bba1cf41a1c07c5eab680451be40 Mon Sep 17 00:00:00 2001 From: dapeng Date: Wed, 5 Jun 2024 21:46:04 +0800 Subject: [PATCH] refactor: http inject and test --- cemetery.go | 63 ++++++++ example/web/go.mod | 4 +- example/web/main.go | 15 +- goner/gin/http-injector.go | 247 +++++++++++++++++--------------- goner/gin/http-injector_test.go | 21 +-- goner/gin/interface.go | 9 +- goner/gin/proxy.go | 108 +++++++++++--- goner/gin/proxy2_test.go | 78 ---------- goner/gin/proxy_test.go | 193 ++++++++++--------------- help.go | 24 +++- interface.go | 2 + 11 files changed, 413 insertions(+), 351 deletions(-) delete mode 100644 goner/gin/proxy2_test.go diff --git a/cemetery.go b/cemetery.go index e339468..b8fd056 100644 --- a/cemetery.go +++ b/cemetery.go @@ -460,3 +460,66 @@ func (c *cemetery) GetTomById(id GonerId) Tomb { func (c *cemetery) GetTomByType(t reflect.Type) (tombs []Tomb) { return Tombs(c.tombs).GetTomByType(t) } + +func (c *cemetery) InjectFuncParameters(fn any, injectBefore func(pt reflect.Type, i int) any, injectAfter func(pt reflect.Type, i int, obj any)) (args []any, err error) { + ft := reflect.TypeOf(fn) + if ft.Kind() != reflect.Func { + return nil, NewInnerError("fn must be a function", NotCompatible) + } + + in := ft.NumIn() + + getOnlyOne := func(pt reflect.Type, i int) Goner { + tombs := c.GetTomByType(pt) + if len(tombs) > 0 { + var container Tomb + + for _, t := range tombs { + if t.IsDefault() { + container = t + break + } + } + if container == nil { + container = tombs[0] + if len(tombs) > 1 { + c.Warnf(fmt.Sprintf("injected function %s %d parameter more than one goner was found and no default, used the first!", ft.Name(), i)) + } + } + return container.GetGoner() + } + return nil + } + + for i := 0; i < in; i++ { + pt := ft.In(0) + x := injectBefore(pt, i) + if x != nil { + args = append(args, x) + continue + } + + x = getOnlyOne(pt, i+1) + if x != nil { + args = append(args, x) + continue + } + + if pt.Kind() != reflect.Struct { + err = NewInnerError(fmt.Sprintf("injected function %s %d parameter must be a struct", ft.Name(), i), NotCompatible) + return + } + + parameter := reflect.New(pt) + goner := parameter.Interface() + _, err = c.ReviveOne(goner) + if err != nil { + return + } + x = parameter.Elem().Interface() + + args = append(args, x) + injectAfter(pt, i, x) + } + return +} diff --git a/example/web/go.mod b/example/web/go.mod index 3edb3c2..f0a22ee 100644 --- a/example/web/go.mod +++ b/example/web/go.mod @@ -18,7 +18,7 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.17.0 // indirect - github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect + github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/gomodule/redigo v1.8.9 // indirect @@ -58,6 +58,8 @@ require ( golang.org/x/sys v0.19.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.20.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de // indirect + google.golang.org/grpc v1.63.2 // indirect google.golang.org/protobuf v1.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978 // indirect diff --git a/example/web/main.go b/example/web/main.go index f0403d8..d6ed851 100644 --- a/example/web/main.go +++ b/example/web/main.go @@ -13,12 +13,21 @@ type controller struct { // Mount use for mounting the router of gin framework func (ctr *controller) Mount() gin.MountError { - ctr.router.GET("/ping", func(c *gin.Context) (any, error) { - return "hello", nil - }) + //ctr.router.GET("/ping", func(c *gin.Context) (any, error) { + // return "hello", nil + //}) + ctr.router.GET("/hello", ctr.hello) return nil } +func (ctr *controller) hello(in struct { + name string `gone:"http,query"` +}) (any, error) { + defer gone.TimeStat("hello")() + + return "hello, " + in.name, nil +} + func NewController() gone.Goner { return &controller{} } diff --git a/goner/gin/http-injector.go b/goner/gin/http-injector.go index d475700..da697ec 100644 --- a/goner/gin/http-injector.go +++ b/goner/gin/http-injector.go @@ -9,20 +9,19 @@ import ( "reflect" "strconv" "strings" + "unsafe" ) func NewHttInjector() (gone.Goner, gone.GonerId, gone.GonerOption) { return &httpInjector{ - bindFuncs: make([]BindFunc, 0), + bindFuncs: make([]BindFieldFunc, 0), }, gone.IdHttpInjector, gone.IsDefault(true) } -type BindFunc func(context *gin.Context) error - type httpInjector struct { gone.Flag - bindFuncs []BindFunc + bindFuncs []BindFieldFunc isInjectedBody bool } @@ -36,23 +35,41 @@ func parseConfKeyValue(conf string) (key, value string) { } func (s *httpInjector) StartCollectBindFuncs() { - s.bindFuncs = make([]BindFunc, 0) + s.bindFuncs = make([]BindFieldFunc, 0) s.isInjectedBody = false } -func (s *httpInjector) CollectBindFuncs() []BindFunc { +func (s *httpInjector) CollectBindFuncs() []BindFieldFunc { return s.bindFuncs } +func (s *httpInjector) BindFuncs() BindStructFunc { + funcs := s.CollectBindFuncs() + return func(context *gin.Context, arg any, T reflect.Type) (reflect.Value, error) { + v := reflect.ValueOf(&arg).Elem() + v = reflect.NewAt(T, unsafe.Pointer(v.UnsafeAddr())).Elem() + + for _, fn := range funcs { + err := fn(context, v) + if err != nil { + return v, err + } + } + return v, nil + } +} + func (s *httpInjector) Suck(conf string, v reflect.Value, field reflect.StructField) error { kind, key := parseConfKeyValue(conf) if key == "" { key = field.Name } - fn, err := s.inject(kind, key, v, field.Name) + fn, err := s.inject(kind, key, field) if err != nil { return err } + + //index := field.Index() s.bindFuncs = append(s.bindFuncs, fn) return nil } @@ -80,20 +97,16 @@ func cannotInjectBodyMoreThanOnce(fieldName string) error { return NewInnerError(fmt.Sprintf("cannot inject %s,http body inject only support inject once; ref doc: https://goner.fun/en/references/http-inject.md", fieldName), gone.InjectError) } -func (s *httpInjector) inject(kind string, key string, v reflect.Value, fieldName string) (fn BindFunc, err error) { +func (s *httpInjector) inject(kind string, key string, field reflect.StructField) (fn BindFieldFunc, err error) { if kind == "" { - return s.injectByType(v, fieldName) + return s.injectByType(field) } - return s.injectByKind(kind, key, v, fieldName) + return s.injectByKind(kind, key, field) } -var ctxPtr *gin.Context -var ctxPointType = reflect.TypeOf(ctxPtr) -var ctxType = ctxPointType.Elem() - -var goneContextPtr *gone.Context -var goneContextPointType = reflect.TypeOf(goneContextPtr) -var goneContextType = goneContextPointType.Elem() +//var ctxPtr *gin.Context +//var ctxPointType = reflect.TypeOf(ctxPtr) +//var ctxType = ctxPointType.Elem() var requestPtr *http.Request var requestType = reflect.TypeOf(requestPtr) @@ -109,82 +122,93 @@ var headerType = reflect.TypeOf(header) var writerPtr *gin.ResponseWriter var writerType = reflect.TypeOf(writerPtr).Elem() -func (s *httpInjector) injectByType(v reflect.Value, fieldName string) (fn BindFunc, err error) { - t := v.Type() +func (s *httpInjector) injectByType(field reflect.StructField) (fn BindFieldFunc, err error) { + t := field.Type switch t { case ctxPointType: - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) v.Set(reflect.ValueOf(ctx)) return nil }, nil case ctxType: - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) v.Set(reflect.ValueOf(ctx).Elem()) return nil }, nil case goneContextPointType: - return func(context *gin.Context) error { - v.Set(reflect.ValueOf(&gone.Context{Context: context})) + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) + v.Set(reflect.ValueOf(&gone.Context{Context: ctx})) return nil }, nil case goneContextType: - return func(context *gin.Context) error { - v.Set(reflect.ValueOf(gone.Context{Context: context})) + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) + v.Set(reflect.ValueOf(gone.Context{Context: ctx})) return nil }, nil case requestType: - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) v.Set(reflect.ValueOf(ctx.Request)) return nil }, nil case requestPointType: - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) v.Set(reflect.ValueOf(ctx.Request).Elem()) return nil }, nil case urlType: - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) v.Set(reflect.ValueOf(ctx.Request.URL)) return nil }, nil case urlPointType: - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) v.Set(reflect.ValueOf(ctx.Request.URL).Elem()) return nil }, nil case headerType: - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) v.Set(reflect.ValueOf(ctx.Request.Header)) return nil }, nil case writerType: - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) v.Set(reflect.ValueOf(ctx.Writer)) return nil }, nil default: - return nil, unsupportedAttributeType(fieldName) + return nil, unsupportedAttributeType(field.Name) } } -func (s *httpInjector) injectBody(v reflect.Value, fieldName string) (fn BindFunc, err error) { +func (s *httpInjector) injectBody(kind, key string, field reflect.StructField) (fn BindFieldFunc, err error) { if s.isInjectedBody { - return nil, cannotInjectBodyMoreThanOnce(fieldName) + return nil, cannotInjectBodyMoreThanOnce(field.Name) } - t := v.Type() + t := field.Type switch t.Kind() { case reflect.Struct, reflect.Map, reflect.Slice: - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) body := reflect.New(t).Interface() if err := ctx.ShouldBind(body); err != nil { @@ -194,68 +218,42 @@ func (s *httpInjector) injectBody(v reflect.Value, fieldName string) (fn BindFun return nil }, nil case reflect.Pointer: - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } - - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } if err := ctx.ShouldBind(v.Interface()); err != nil { return NewParameterError(err.Error()) } return nil }, nil default: - return nil, unsupportedAttributeType(fieldName) + return nil, unsupportedAttributeType(field.Name) } - - //if !(t.Kind() == reflect.Struct || t.Kind() == reflect.Pointer && t.Elem().Kind() == reflect.Struct) { - // return nil, unsupportedAttributeType(fieldName) - //} - // - //if t.Kind() == reflect.Pointer { - // if v.IsNil() { - // v.Set(reflect.New(v.Type().Elem())) - // } - // - // return func(ctx *gin.Context) error { - // if err := ctx.ShouldBind(v.Interface()); err != nil { - // return NewParameterError(err.Error()) - // } - // return nil - // }, nil - //} else { - // return func(ctx *gin.Context) error { - // body := reflect.New(t).Interface() - // - // if err := ctx.ShouldBind(body); err != nil { - // return NewParameterError(err.Error()) - // } - // v.Set(reflect.ValueOf(body).Elem()) - // return nil - // }, nil - //} } -func (s *httpInjector) injectByKind(kind, key string, v reflect.Value, fieldName string) (fn BindFunc, err error) { +func (s *httpInjector) injectByKind(kind, key string, field reflect.StructField) (fn BindFieldFunc, err error) { switch kind { case keyHeader, keyParam, keyCookie: - return s.parseStringValueAndInject(v, fieldName, kind, key) + return s.parseStringValueAndInject(kind, key, field) case keyQuery: - return s.injectQuery(v, fieldName, key) + return s.injectQuery(kind, key, field) case keyBody: - return s.injectBody(v, fieldName) + return s.injectBody(kind, key, field) default: - return nil, unsupportedKindConfigure(fieldName) + return nil, unsupportedKindConfigure(field.Name) } } var queryMapType = reflect.TypeOf(map[string]string{}) -func (s *httpInjector) injectQuery(v reflect.Value, fieldName string, key string) (fn BindFunc, err error) { - t := v.Type() +func (s *httpInjector) injectQuery(kind, key string, field reflect.StructField) (fn BindFieldFunc, err error) { + t := field.Type switch t.Kind() { case reflect.Struct: - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) body := reflect.New(t).Interface() if err := ctx.ShouldBindQuery(body); err != nil { return NewParameterError(err.Error()) @@ -265,11 +263,11 @@ func (s *httpInjector) injectQuery(v reflect.Value, fieldName string, key string }, nil case reflect.Pointer: - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } - - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } if err := ctx.ShouldBindQuery(v.Interface()); err != nil { return NewParameterError(err.Error()) } @@ -278,37 +276,41 @@ func (s *httpInjector) injectQuery(v reflect.Value, fieldName string, key string case reflect.Map: if t == queryMapType { - return func(ctx *gin.Context) error { - dicts := ctx.QueryMap(key) - v.Set(reflect.ValueOf(dicts)) + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) + dict := ctx.QueryMap(key) + v.Set(reflect.ValueOf(dict)) return nil }, nil } - return nil, unsupportedAttributeType(fieldName) + return nil, unsupportedAttributeType(field.Name) case reflect.Slice: - return s.injectQueryArray(key, v, fieldName) + return s.injectQueryArray(kind, key, field) default: - return s.parseStringValueAndInject(v, fieldName, keyQuery, key) + return s.parseStringValueAndInject(kind, key, field) } } -func (s *httpInjector) injectQueryArray(key string, v reflect.Value, fieldName string) (fn BindFunc, err error) { - el := v.Type().Elem() +func (s *httpInjector) injectQueryArray(k, key string, field reflect.StructField) (fn BindFieldFunc, err error) { + el := field.Type.Elem() + kind := el.Kind() bits := bitSize(kind) switch kind { case reflect.String: - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) values := ctx.QueryArray(key) v.Set(reflect.ValueOf(values)) return nil }, nil case reflect.Bool: - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) values := ctx.QueryArray(key) for _, value := range values { v.Set(reflect.Append(v, reflect.ValueOf(stringToBool(value)))) @@ -317,7 +319,8 @@ func (s *httpInjector) injectQueryArray(key string, v reflect.Value, fieldName s }, nil case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) values := ctx.QueryArray(key) for _, value := range values { def, err := strconv.ParseInt(value, 10, bits) @@ -330,7 +333,8 @@ func (s *httpInjector) injectQueryArray(key string, v reflect.Value, fieldName s }, nil case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) values := ctx.QueryArray(key) for _, value := range values { def, err := strconv.ParseUint(value, 10, bits) @@ -343,7 +347,8 @@ func (s *httpInjector) injectQueryArray(key string, v reflect.Value, fieldName s }, nil case reflect.Float64, reflect.Float32: - return func(ctx *gin.Context) error { + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) values := ctx.QueryArray(key) for _, value := range values { def, err := strconv.ParseFloat(value, bits) @@ -355,13 +360,13 @@ func (s *httpInjector) injectQueryArray(key string, v reflect.Value, fieldName s return nil }, nil default: - return nil, unsupportedAttributeType(fieldName) + return nil, unsupportedAttributeType(field.Name) } } -func (s *httpInjector) parseStringValueAndInject(v reflect.Value, fieldName string, kind string, key string) (fn BindFunc, err error) { +func (s *httpInjector) parseStringValueAndInject(kind, key string, field reflect.StructField) (fn BindFieldFunc, err error) { var parser func(context *gin.Context) (string, error) - t := v.Type() + t := field.Type switch kind { case keyHeader: @@ -385,15 +390,16 @@ func (s *httpInjector) parseStringValueAndInject(v reflect.Value, fieldName stri return context.Query(key), nil } default: - return nil, unsupportedKindConfigure(fieldName) + return nil, unsupportedKindConfigure(field.Name) } bits := bitSize(t.Kind()) switch t.Kind() { case reflect.String: - return func(context *gin.Context) error { - value, err := parser(context) + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) + value, err := parser(ctx) if err != nil { return err } @@ -402,8 +408,9 @@ func (s *httpInjector) parseStringValueAndInject(v reflect.Value, fieldName stri }, nil case reflect.Bool: - return func(context *gin.Context) error { - value, err := parser(context) + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) + value, err := parser(ctx) if err != nil { return err } @@ -411,8 +418,9 @@ func (s *httpInjector) parseStringValueAndInject(v reflect.Value, fieldName stri return nil }, nil case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - return func(context *gin.Context) error { - value, err := parser(context) + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) + value, err := parser(ctx) if err != nil { return err } @@ -427,8 +435,9 @@ func (s *httpInjector) parseStringValueAndInject(v reflect.Value, fieldName stri }, nil case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - return func(context *gin.Context) error { - value, err := parser(context) + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) + value, err := parser(ctx) if err != nil { return err } @@ -443,8 +452,9 @@ func (s *httpInjector) parseStringValueAndInject(v reflect.Value, fieldName stri }, nil case reflect.Float64, reflect.Float32: - return func(context *gin.Context) error { - value, err := parser(context) + return func(ctx *gin.Context, structVale reflect.Value) error { + v := fieldByIndexFromStructValue(structVale, field.Index, field.IsExported(), field.Type) + value, err := parser(ctx) if err != nil { return err } @@ -459,7 +469,7 @@ func (s *httpInjector) parseStringValueAndInject(v reflect.Value, fieldName stri }, nil default: - return nil, unsupportedAttributeType(fieldName) + return nil, unsupportedAttributeType(field.Name) } } @@ -482,11 +492,12 @@ func stringToBool(value string) bool { return value != "" && value != "0" && value != "false" } -// -//var m = map[reflect.Kind]func(number string) reflect.Value{ -// reflect.Int: reflect.ValueOf(int(0)).Convert, -//} -// -//func xxx() func() reflect.Value { -// -//} +func fieldByIndexFromStructValue(structValue reflect.Value, index []int, isExported bool, fieldType reflect.Type) reflect.Value { + v := structValue.FieldByIndex(index) + + if !isExported { + //黑魔法:让非导出字段可以访问 + v = reflect.NewAt(fieldType, unsafe.Pointer(v.UnsafeAddr())).Elem() + } + return v +} diff --git a/goner/gin/http-injector_test.go b/goner/gin/http-injector_test.go index 83d91bb..3860782 100644 --- a/goner/gin/http-injector_test.go +++ b/goner/gin/http-injector_test.go @@ -3,6 +3,7 @@ package gin import ( "bytes" "encoding/json" + "fmt" "github.com/gin-gonic/gin" "github.com/golang/mock/gomock" "github.com/gone-io/gone" @@ -146,7 +147,7 @@ func Test_httpInjector_inject(t *testing.T) { kind string key string - wantFn BindFunc + wantFn BindFieldFunc wantErr assert.ErrorAssertionFunc ctx *gin.Context bindErr func(t assert.TestingT, err error) @@ -1032,20 +1033,22 @@ func Test_httpInjector_inject(t *testing.T) { }, } for _, tt := range tests { - if tt.before != nil { - tt.before() - } - elem := reflect.ValueOf(req).Elem() t.Run(tt.name, func(t *testing.T) { + if tt.before != nil { + tt.before() + } + elemT := reflect.TypeOf(req).Elem() + ttx, b := elemT.FieldByName(tt.fieldName) + assert.True(t, b) + fmt.Printf("%v", ttx) + fn, err := injector.inject( tt.kind, tt.key, - elem.FieldByName(tt.fieldName), - tt.fieldName, + ttx, ) - if tt.wantErr(t, err) { - tt.bindErr(t, fn(tt.ctx)) + tt.bindErr(t, fn(tt.ctx, reflect.ValueOf(req).Elem())) } }) } diff --git a/goner/gin/interface.go b/goner/gin/interface.go index 100809d..5483468 100644 --- a/goner/gin/interface.go +++ b/goner/gin/interface.go @@ -3,6 +3,7 @@ package gin import ( "github.com/gin-gonic/gin" "github.com/gone-io/gone" + "reflect" ) //go:generate sh -c "mockgen -package=gin github.com/gin-gonic/gin ResponseWriter > gin_ResponseWriter_mock_test.go" @@ -120,11 +121,11 @@ type Responser interface { // allowing the same interface to have the ability to return different business codes and business data in special cases type BusinessError = gone.BusinessError -//type keepContext interface { -//SetContext(context *Context) (any, error) -//} +type BindFieldFunc func(context *gin.Context, structVale reflect.Value) error +type BindStructFunc func(*gin.Context, any, reflect.Type) (reflect.Value, error) type HttInjector interface { StartCollectBindFuncs() - CollectBindFuncs() []BindFunc + CollectBindFuncs() []BindFieldFunc + BindFuncs() BindStructFunc } diff --git a/goner/gin/proxy.go b/goner/gin/proxy.go index 20b87d7..91ee3f2 100644 --- a/goner/gin/proxy.go +++ b/goner/gin/proxy.go @@ -3,6 +3,7 @@ package gin import ( "github.com/gin-gonic/gin" "github.com/gone-io/gone" + "reflect" ) // NewGinProxy 新建代理器 @@ -36,42 +37,117 @@ func (p *proxy) ProxyForMiddleware(handlers ...HandlerFunc) (arr []gin.HandlerFu return arr } +var ctxPtr *gin.Context +var ctxPointType = reflect.TypeOf(ctxPtr) +var ctxType = ctxPointType.Elem() + +var goneContextPtr *gone.Context +var goneContextPointType = reflect.TypeOf(goneContextPtr) +var goneContextType = goneContextPointType.Elem() + +type placeholder struct { + Type reflect.Type +} + +type BindStructFuncAndType struct { + Fn BindStructFunc + Type reflect.Type +} + func (p *proxy) proxyOne(x HandlerFunc, last bool) gin.HandlerFunc { + funcName := gone.GetFuncName(x) switch x.(type) { case func(*Context) (any, error): + f := x.(func(*Context) (any, error)) return func(context *gin.Context) { - data, err := x.(func(*Context) (any, error))(&Context{Context: context}) - p.responser.ProcessResults(context, context.Writer, last, gone.GetFuncName(x), data, err) + data, err := f(&Context{Context: context}) + p.responser.ProcessResults(context, context.Writer, last, funcName, data, err) } case func(*Context) error: + f := x.(func(*Context) error) return func(context *gin.Context) { - err := x.(func(*Context) error)(&Context{Context: context}) - p.responser.ProcessResults(context, context.Writer, last, gone.GetFuncName(x), err) + err := f(&Context{Context: context}) + p.responser.ProcessResults(context, context.Writer, last, funcName, err) } case func(*Context): + f := x.(func(*Context)) return func(context *gin.Context) { - x.(func(*Context))(&Context{Context: context}) + f(&Context{Context: context}) } default: - p.injector.StartCollectBindFuncs() - fn, err := gone.InjectWrapFn(p.cemetery, x) - if err != nil { - panic(err) - } - funcs := p.injector.CollectBindFuncs() + return p.buildProxyFn(x, funcName, last) + } +} - return func(context *gin.Context) { - for _, f := range funcs { - err := f(context) +func (p *proxy) buildProxyFn(x HandlerFunc, funcName string, last bool) gin.HandlerFunc { + m := make(map[int]*BindStructFuncAndType) + args, err := p.cemetery.InjectFuncParameters( + x, + func(pt reflect.Type, i int) any { + switch pt { + case ctxPointType, ctxType, goneContextPointType, goneContextType: + return &placeholder{ + Type: pt, + } + } + p.injector.StartCollectBindFuncs() + return nil + }, + func(pt reflect.Type, i int, obj any) { + m[i] = &BindStructFuncAndType{ + Fn: p.injector.BindFuncs(), + Type: pt, + } + }, + ) + + if err != nil { + panic(err) + } + + fv := reflect.ValueOf(x) + return func(context *gin.Context) { + parameters := make([]reflect.Value, 0, len(args)) + for i, arg := range args { + if holder, ok := arg.(*placeholder); ok { + switch holder.Type { + case ctxPointType: + parameters = append(parameters, reflect.ValueOf(context)) + case ctxType: + parameters = append(parameters, reflect.ValueOf(context).Elem()) + case goneContextPointType: + parameters = append(parameters, reflect.ValueOf(&Context{Context: context})) + case goneContextType: + parameters = append(parameters, reflect.ValueOf(Context{Context: context})) + } + continue + } + + if f, ok := m[i]; ok { + parameter, err := f.Fn(context, arg, f.Type) if err != nil { p.responser.Failed(context, err) return } + parameters = append(parameters, parameter) + continue } + parameters = append(parameters, reflect.ValueOf(arg)) + } - results := gone.ExecuteInjectWrapFn(fn) - p.responser.ProcessResults(context, context.Writer, last, gone.GetFuncName(x), results...) + //call the func x + values := fv.Call(parameters) + + var results []any + for i := 0; i < len(values); i++ { + arg := values[i] + if arg.Type() == ctxPointType && !arg.IsNil() { + results = append(results, nil) + } else { + results = append(results, arg.Interface()) + } } + p.responser.ProcessResults(context, context.Writer, last, funcName, results...) } } diff --git a/goner/gin/proxy2_test.go b/goner/gin/proxy2_test.go deleted file mode 100644 index 4933e96..0000000 --- a/goner/gin/proxy2_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package gin - -import ( - "github.com/gin-gonic/gin" - "github.com/golang/mock/gomock" - "github.com/gone-io/gone" - "github.com/gone-io/gone/goner/config" - "github.com/gone-io/gone/goner/logrus" - "github.com/gone-io/gone/goner/tracer" - "github.com/stretchr/testify/assert" - "net/http" - "net/url" - "testing" -) - -func Test_proxy_proxyOne1(t *testing.T) { - gone. - Prepare(func(cemetery gone.Cemetery) error { - _ = config.Priest(cemetery) - _ = logrus.Priest(cemetery) - _ = tracer.Priest(cemetery) - cemetery.Bury(NewGinProxy()) - cemetery.Bury(NewHttInjector()) - cemetery.Bury(NewGinResponser()) - return nil - }). - Test(func(in struct { - proxy HandleProxyToGin `gone:"*"` - }) { - controller := gomock.NewController(t) - defer controller.Finish() - writer := NewMockResponseWriter(controller) - writer.EXPECT().Written().AnyTimes() - writer.EXPECT().WriteHeader(gomock.Any()).AnyTimes() - writer.EXPECT().Header().Return(http.Header{}).AnyTimes() - writer.EXPECT().Write(gomock.Any()).AnyTimes() - - Url, _ := url.Parse("https://goner.fun/zh/?page=1&pageSize=10&arr=1&arr=2&arr=3") - - context := gin.Context{ - Writer: writer, - Request: &http.Request{ - URL: Url, - }, - } - - t.Run("ctx inject", func(t *testing.T) { - executedCounter := 0 - proxyFn := in.proxy.Proxy(func(in struct { - ctx gin.Context `gone:"http"` - ctxPtr *gin.Context `gone:"http"` - }) { - assert.Equal(t, in.ctxPtr, &context) - assert.NotNil(t, in.ctx.Writer, context.Writer) - executedCounter++ - return - })[0] - proxyFn(&context) - assert.Equal(t, executedCounter, 1) - }) - - t.Run("request inject", func(t *testing.T) { - executedCounter := 0 - proxyFn := in.proxy.Proxy(func(in struct { - req http.Request `gone:"http"` - reqPtr *http.Request `gone:"http"` - }) { - assert.Equal(t, in.req.URL, context.Request.URL) - assert.NotNil(t, in.reqPtr, context.Request) - executedCounter++ - return - })[0] - proxyFn(&context) - assert.Equal(t, executedCounter, 1) - }) - - }) -} diff --git a/goner/gin/proxy_test.go b/goner/gin/proxy_test.go index 77475aa..4933e96 100644 --- a/goner/gin/proxy_test.go +++ b/goner/gin/proxy_test.go @@ -1,123 +1,78 @@ package gin -// -//func Test_proxy_proxyOne(t *testing.T) { -// type fields struct { -// Flag gone.Flag -// Logger gone.Logger -// cemetery gone.Cemetery -// responser Responser -// tracer gone.Tracer -// inject func(logger gone.Logger, cemetery gone.Cemetery, responser Responser, x HandlerFunc, context *gin.Context) (results []any) -// } -// -// controller := gomock.NewController(t) -// defer controller.Finish() -// -// mockResponser := NewMockResponser(controller) -// mockResponser.EXPECT().ProcessResults(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() -// mockResponser.EXPECT().ProcessResults(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() -// mockResponser.EXPECT().ProcessResults(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() -// -// gone.Prepare().AfterStart(func(in struct { -// cemetery gone.Cemetery `gone:"gone-cemetery"` -// }) { -// -// type args struct { -// x HandlerFunc -// last bool -// } -// tests := []struct { -// name string -// fields fields -// args args -// wantFunc func(want gin.HandlerFunc) bool -// }{ -// { -// name: "func(*Context) (any, error)", -// fields: fields{ -// responser: mockResponser, -// }, -// args: args{ -// x: func(ctx *Context) (any, error) { return nil, nil }, -// last: false, -// }, -// wantFunc: func(want gin.HandlerFunc) bool { -// want(&gin.Context{}) -// return true -// }, -// }, -// { -// name: "func(*Context) error", -// fields: fields{ -// responser: mockResponser, -// }, -// args: args{ -// x: func(*Context) error { return nil }, -// last: false, -// }, -// wantFunc: func(want gin.HandlerFunc) bool { -// want(&gin.Context{}) -// return true -// }, -// }, -// { -// name: "func(*Context)", -// fields: fields{ -// responser: mockResponser, -// }, -// args: args{ -// x: func(*Context) {}, -// last: false, -// }, -// wantFunc: func(want gin.HandlerFunc) bool { -// want(&gin.Context{}) -// return true -// }, -// }, -// { -// name: "other", -// fields: fields{ -// responser: mockResponser, -// cemetery: in.cemetery, -// }, -// args: args{ -// x: func(in struct{}) {}, -// last: false, -// }, -// wantFunc: func(want gin.HandlerFunc) bool { -// want(&gin.Context{}) -// return true -// }, -// }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// p := &proxy{ -// Flag: tt.fields.Flag, -// Logger: tt.fields.Logger, -// cemetery: tt.fields.cemetery, -// responser: tt.fields.responser, -// tracer: tt.fields.tracer, -// } -// one := p.proxyOne(tt.args.x, tt.args.last) -// -// assert.Truef(t, tt.wantFunc(one), "proxyOne(%v, %v)", tt.args.x, tt.args.last) -// }) -// } -// }).Run() -//} +import ( + "github.com/gin-gonic/gin" + "github.com/golang/mock/gomock" + "github.com/gone-io/gone" + "github.com/gone-io/gone/goner/config" + "github.com/gone-io/gone/goner/logrus" + "github.com/gone-io/gone/goner/tracer" + "github.com/stretchr/testify/assert" + "net/http" + "net/url" + "testing" +) -//func Test_proxy_ProxyForMiddleware(t *testing.T) { -// ginProxy, _, _ := NewGinProxy() -// p := ginProxy.(HandleProxyToGin) -// funcs := p.ProxyForMiddleware(func(ctx *gin.Context) {}, func() {}) -// assert.Equal(t, 2, len(funcs)) -//} +func Test_proxy_proxyOne1(t *testing.T) { + gone. + Prepare(func(cemetery gone.Cemetery) error { + _ = config.Priest(cemetery) + _ = logrus.Priest(cemetery) + _ = tracer.Priest(cemetery) + cemetery.Bury(NewGinProxy()) + cemetery.Bury(NewHttInjector()) + cemetery.Bury(NewGinResponser()) + return nil + }). + Test(func(in struct { + proxy HandleProxyToGin `gone:"*"` + }) { + controller := gomock.NewController(t) + defer controller.Finish() + writer := NewMockResponseWriter(controller) + writer.EXPECT().Written().AnyTimes() + writer.EXPECT().WriteHeader(gomock.Any()).AnyTimes() + writer.EXPECT().Header().Return(http.Header{}).AnyTimes() + writer.EXPECT().Write(gomock.Any()).AnyTimes() -//func Test_proxy_Proxy(t *testing.T) { -// ginProxy, _, _ := NewGinProxy() -// p := ginProxy.(HandleProxyToGin) -// funcs := p.Proxy(func(ctx *gin.Context) {}, func() {}) -// assert.Equal(t, 2, len(funcs)) -//} + Url, _ := url.Parse("https://goner.fun/zh/?page=1&pageSize=10&arr=1&arr=2&arr=3") + + context := gin.Context{ + Writer: writer, + Request: &http.Request{ + URL: Url, + }, + } + + t.Run("ctx inject", func(t *testing.T) { + executedCounter := 0 + proxyFn := in.proxy.Proxy(func(in struct { + ctx gin.Context `gone:"http"` + ctxPtr *gin.Context `gone:"http"` + }) { + assert.Equal(t, in.ctxPtr, &context) + assert.NotNil(t, in.ctx.Writer, context.Writer) + executedCounter++ + return + })[0] + proxyFn(&context) + assert.Equal(t, executedCounter, 1) + }) + + t.Run("request inject", func(t *testing.T) { + executedCounter := 0 + proxyFn := in.proxy.Proxy(func(in struct { + req http.Request `gone:"http"` + reqPtr *http.Request `gone:"http"` + }) { + assert.Equal(t, in.req.URL, context.Request.URL) + assert.NotNil(t, in.reqPtr, context.Request) + executedCounter++ + return + })[0] + proxyFn(&context) + assert.Equal(t, executedCounter, 1) + }) + + }) +} diff --git a/help.go b/help.go index 1795940..9e0cc52 100644 --- a/help.go +++ b/help.go @@ -6,6 +6,7 @@ import ( "reflect" "runtime" "strings" + "time" ) // PanicTrace 用于获取调用者的堆栈信息 @@ -39,15 +40,18 @@ func PanicTrace(kb int, skip int) []byte { // GetFuncName 获取某个函数的名字 func GetFuncName(f any) string { - return strings.Trim(runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name(), "-fm") + return strings.TrimSuffix(runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name(), "-fm") } // GetInterfaceType 获取接口的类型 func GetInterfaceType[T any](t *T) reflect.Type { return reflect.TypeOf(t).Elem() } - func InjectWrapFn(cemetery Cemetery, fn any) (*reflect.Value, error) { + return InjectWrapFnWithHook(cemetery, fn, nil, nil) +} + +func InjectWrapFnWithHook(cemetery Cemetery, fn any, before func([]reflect.Value), after func([]reflect.Value)) (*reflect.Value, error) { ft := reflect.TypeOf(fn) fv := reflect.ValueOf(fn) if ft.Kind() != reflect.Func { @@ -88,7 +92,14 @@ func InjectWrapFn(cemetery Cemetery, fn any) (*reflect.Value, error) { } makeFunc := reflect.MakeFunc(reflect.FuncOf(nil, outList, false), func([]reflect.Value) (results []reflect.Value) { - return fv.Call(args) + if before != nil { + before(args) + } + results = fv.Call(args) + if after != nil { + after(results) + } + return }) return &makeFunc, nil } @@ -150,3 +161,10 @@ func (c *cemetery) setFieldValue(v reflect.Value, ref any) { } return } + +func TimeStat(name string) func() { + start := time.Now() + return func() { + fmt.Printf("%s use %v\n", name, time.Since(start)) + } +} diff --git a/interface.go b/interface.go index dae0f50..e18f779 100644 --- a/interface.go +++ b/interface.go @@ -75,6 +75,8 @@ type Cemetery interface { //GetTomByType return the Tombs by the GonerType GetTomByType(reflect.Type) []Tomb + + InjectFuncParameters(fn any, injectBefore func(pt reflect.Type, i int) any, injectAfter func(pt reflect.Type, i int, obj any)) (args []any, err error) } // Priest A function which has A Cemetery parameter, and return an error. use for Burying Goner