diff --git a/.github/actions/setup-go/action.yml b/.github/actions/setup-go/action.yml index e01ef74..03ce258 100644 --- a/.github/actions/setup-go/action.yml +++ b/.github/actions/setup-go/action.yml @@ -5,7 +5,7 @@ description: | inputs: go-version: description: Used Go version - default: '1.19' + default: '1.20' runs: using: "composite" diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 815fd76..79c4266 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -23,4 +23,4 @@ jobs: - name: Run golangci-lint uses: golangci/golangci-lint-action@v3 with: - version: v1.50.1 + version: v1.51.1 diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e1a415..13abe67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ How to release a new version: ## [Unreleased] +## [0.6.0] - 2023-03-03 +### Added +- package `http/signature` to simplify defining http handler functions +- package `http/param` to simplify parsing http path and query parameters + ## [0.5.0] - 2022-01-20 ### Added - `ErrorResponseOptions` contains public error message. @@ -40,7 +45,8 @@ How to release a new version: ### Added - Added Changelog. -[Unreleased]: https://github.com/strvcom/strv-backend-go-net/compare/v0.5.0...HEAD +[Unreleased]: https://github.com/strvcom/strv-backend-go-net/compare/v0.6.0...HEAD +[0.6.0]: https://github.com/strvcom/strv-backend-go-net/compare/v0.5.0...v0.6.0 [0.5.0]: https://github.com/strvcom/strv-backend-go-net/compare/v0.4.0...v0.5.0 [0.4.0]: https://github.com/strvcom/strv-backend-go-net/compare/v0.3.0...v0.4.0 [0.3.0]: https://github.com/strvcom/strv-backend-go-net/compare/v0.2.0...v0.3.0 diff --git a/go.mod b/go.mod index 580f56c..b978113 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,9 @@ module go.strv.io/net -go 1.19 +go 1.20 require ( + github.com/go-chi/chi/v5 v5.0.8 github.com/google/uuid v1.3.0 github.com/stretchr/testify v1.8.0 go.strv.io/time v0.2.0 diff --git a/go.sum b/go.sum index fe40e00..db22115 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-chi/chi/v5 v5.0.8 h1:lD+NLqFcAi1ovnVZpsnObHGW4xb4J8lNmoYVfECH1Y0= +github.com/go-chi/chi/v5 v5.0.8/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= diff --git a/http/encode.go b/http/encode.go index 5610160..161ef06 100644 --- a/http/encode.go +++ b/http/encode.go @@ -20,7 +20,7 @@ func WithEncodeFunc(fn EncodeFunc) ResponseOption { } } -// DecodeJSON decodes data using JSON marshalling into the type of parameter v. +// DecodeJSON decodes data using JSON marshaling into the type of parameter v. func DecodeJSON(data any, v any) error { b, err := json.Marshal(data) if err != nil { diff --git a/http/param/README.md b/http/param/README.md new file mode 100644 index 0000000..1044f72 --- /dev/null +++ b/http/param/README.md @@ -0,0 +1,16 @@ +Package for parsing path and query parameters from http request into struct, similar to parsing body as json to struct. + +``` +type MyInputStruct struct { + UserID int `param:"path=id"` + SomeFlag *bool `param:"query=flag"` +} +``` + +Then a request like `http://somewhere.com/users/9?flag=true` can be parsed as follows. +In this example, using chi to access path parameters that has a `{id}` wildcard in configured chi router + +``` + parsedInput := MyInputStruct{} + param.DefaultParser().PathParamFunc(chi.URLParam).Parse(request, &parsedInput) +``` diff --git a/http/param/param.go b/http/param/param.go new file mode 100644 index 0000000..7d162cd --- /dev/null +++ b/http/param/param.go @@ -0,0 +1,234 @@ +package param + +import ( + "encoding" + "fmt" + "net/http" + "reflect" + "strconv" + "strings" +) + +// TagResolver is a function that decides from a field type what key of http parameter should be searched. +// Second return value should return whether the key should be searched in http parameter at all. +type TagResolver func(fieldTag reflect.StructTag) (string, bool) + +// FixedTagNameParamTagResolver returns a TagResolver, that matches struct params by specific tag. +// Example: FixedTagNameParamTagResolver("mytag") matches a field tagged with `mytag:"param_name"` +func FixedTagNameParamTagResolver(tagName string) TagResolver { + return func(fieldTag reflect.StructTag) (string, bool) { + taggedParamName := fieldTag.Get(tagName) + return taggedParamName, taggedParamName != "" + } +} + +// TagWithModifierTagResolver returns a TagResolver, that matches struct params by specific tag and +// by a value before a '=' separator. +// Example: FixedTagNameParamTagResolver("mytag", "mymodifier") matches a field tagged with `mytag:"mymodifier=param_name"` +func TagWithModifierTagResolver(tagName string, tagModifier string) TagResolver { + return func(fieldTag reflect.StructTag) (string, bool) { + tagValue := fieldTag.Get(tagName) + if tagValue == "" { + return "", false + } + splits := strings.Split(tagValue, "=") + //nolint:gomnd // 2 not really that magic number - one value before '=', one after + if len(splits) != 2 { + return "", false + } + if splits[0] == tagModifier { + return splits[1], true + } + return "", false + } +} + +// PathParamFunc is a function that returns value of specified http path parameter +type PathParamFunc func(r *http.Request, key string) string + +// Parser can Parse query and path parameters from http.Request into a struct. +// Fields struct have to be tagged such that either QueryParamTagResolver or PathParamTagResolver returns +// valid parameter name from the provided tag. +// +// PathParamFunc is for getting path parameter from http.Request, as each http router handles it in different way (if at all). +// For example for chi, use WithPathParamFunc(chi.URLParam) to be able to use tags for path parameters. +type Parser struct { + QueryParamTagResolver TagResolver + PathParamTagResolver TagResolver + PathParamFunc PathParamFunc +} + +// DefaultParser returns query and path parameter Parser with intended struct tags +// `param:"query=param_name"` for query parameters and `param:"path=param_name"` for path parameters +func DefaultParser() Parser { + return Parser{ + QueryParamTagResolver: TagWithModifierTagResolver("param", "query"), + PathParamTagResolver: TagWithModifierTagResolver("param", "path"), + PathParamFunc: nil, // keep nil, as there is no sensible default of how to get value of path parameter + } +} + +// WithPathParamFunc returns a copy of Parser with set function for getting path parameters from http.Request. +// For more see Parser description. +func (p Parser) WithPathParamFunc(f PathParamFunc) Parser { + p.PathParamFunc = f + return p +} + +// Parse accepts the request and a pointer to struct that is tagged with appropriate tags set in Parser. +// All such tagged fields are assigned the respective parameter from the actual request. +// +// Fields are assigned their zero value if the field was tagged but request did not contain such parameter. +// +// Supported tagged field types are: +// - primitive types - bool, all ints, all uints, both floats, and string +// - pointer to any supported type +// - slice of non-slice supported type (only for query parameters) +// - any type that implements encoding.TextUnmarshaler +// +// For query parameters, the tagged type can be a slice. This means that a query like /endpoint?key=val1&key=val2 +// is allowed, and in such case the slice field will be assigned []T{"val1", "val2"} . +// Otherwise, only single query parameter is allowed in request. +func (p Parser) Parse(r *http.Request, dest any) error { + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Pointer { + return fmt.Errorf("cannot set non-pointer value of type %s", v.Type().Name()) + } + v = v.Elem() + + if v.Kind() != reflect.Struct { + return fmt.Errorf("can only parse into struct, but got %s", v.Type().Name()) + } + + for i := 0; i < v.NumField(); i++ { + typeField := v.Type().Field(i) + if !typeField.IsExported() { + continue + } + valueField := v.Field(i) + // Zero the value, even if it would not be set by following path or query parameter. + // This will cause potential partial result from previous parser (e.g. json.Unmarshal) to be discarded on + // fields that are tagged for path or query parameter. + valueField.Set(reflect.Zero(typeField.Type)) + tag := typeField.Tag + err := p.parseQueryParam(r, tag, valueField) + if err != nil { + return err + } + err = p.parsePathParam(r, tag, valueField) + if err != nil { + return err + } + } + return nil +} + +func (p Parser) parsePathParam(r *http.Request, tag reflect.StructTag, v reflect.Value) error { + paramName, ok := p.PathParamTagResolver(tag) + if !ok { + return nil + } + if p.PathParamFunc == nil { + return fmt.Errorf("struct's field was tagged for parsing the path parameter (%s) but PathParamFunc to get value of path parameter is not defined", paramName) + } + paramValue := p.PathParamFunc(r, paramName) + if paramValue != "" { + err := unmarshalValue(paramValue, v) + if err != nil { + return fmt.Errorf("unmarshaling path parameter %s: %w", paramName, err) + } + } + return nil +} + +func (p Parser) parseQueryParam(r *http.Request, tag reflect.StructTag, v reflect.Value) error { + paramName, ok := p.QueryParamTagResolver(tag) + if !ok { + return nil + } + query := r.URL.Query() + if values, ok := query[paramName]; ok && len(values) > 0 { + err := unmarshalValueOrSlice(values, v) + if err != nil { + return fmt.Errorf("unmarshaling query parameter %s: %w", paramName, err) + } + } + return nil +} + +func unmarshalValueOrSlice(texts []string, dest reflect.Value) error { + if unmarshaler, ok := dest.Addr().Interface().(encoding.TextUnmarshaler); ok { + if len(texts) != 1 { + return fmt.Errorf("too many parameters unmarshaling to %s, expected up to 1 value", dest.Type().Name()) + } + return unmarshaler.UnmarshalText([]byte(texts[0])) + } + t := dest.Type() + if t.Kind() == reflect.Pointer { + ptrValue := reflect.New(t.Elem()) + dest.Set(ptrValue) + return unmarshalValueOrSlice(texts, dest.Elem()) + } + if t.Kind() == reflect.Slice { + sliceValue := reflect.MakeSlice(t, len(texts), len(texts)) + for i, text := range texts { + if err := unmarshalValue(text, sliceValue.Index(i)); err != nil { + return fmt.Errorf("unmarshaling %dth element: %w", i, err) + } + } + dest.Set(sliceValue) + return nil + } + if len(texts) != 1 { + return fmt.Errorf("too many parameters unmarshaling to %s, expected up to 1 value", dest.Type().Name()) + } + return unmarshalPrimitiveValue(texts[0], dest) +} + +func unmarshalValue(text string, dest reflect.Value) error { + if unmarshaler, ok := dest.Addr().Interface().(encoding.TextUnmarshaler); ok { + return unmarshaler.UnmarshalText([]byte(text)) + } + t := dest.Type() + if t.Kind() == reflect.Pointer { + ptrValue := reflect.New(t.Elem()) + dest.Set(ptrValue) + return unmarshalValue(text, dest.Elem()) + } + return unmarshalPrimitiveValue(text, dest) +} + +func unmarshalPrimitiveValue(text string, dest reflect.Value) error { + //nolint:exhaustive + switch dest.Kind() { + case reflect.Bool: + v, err := strconv.ParseBool(text) + if err != nil { + return fmt.Errorf("parsing into field of type %s: %w", dest.Type().Name(), err) + } + dest.SetBool(v) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v, err := strconv.ParseInt(text, 10, dest.Type().Bits()) + if err != nil { + return fmt.Errorf("parsing into field of type %s: %w", dest.Type().Name(), err) + } + dest.SetInt(v) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v, err := strconv.ParseUint(text, 10, dest.Type().Bits()) + if err != nil { + return fmt.Errorf("parsing into field of type %s: %w", dest.Type().Name(), err) + } + dest.SetUint(v) + case reflect.Float32, reflect.Float64: + v, err := strconv.ParseFloat(text, dest.Type().Bits()) + if err != nil { + return fmt.Errorf("parsing into field of type %s: %w", dest.Type().Name(), err) + } + dest.SetFloat(v) + case reflect.String: + dest.SetString(text) + default: + return fmt.Errorf("unsupported field type %s", dest.Type().Name()) + } + return nil +} diff --git a/http/param/param_test.go b/http/param/param_test.go new file mode 100644 index 0000000..89c3a47 --- /dev/null +++ b/http/param/param_test.go @@ -0,0 +1,546 @@ +package param + +import ( + "net/http" + "net/http/httptest" + "reflect" + "strconv" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type myString string + +type myComplicatedType struct { + Value string +} + +func (m *myComplicatedType) UnmarshalText(text []byte) error { + // differ from simple assignment to underlying (string) type to be sure this was called + m.Value = "my" + string(text) + return nil +} + +type structWithSlice struct { + SlicePrimitiveField []string `param:"query=a"` + SliceCustomField []myString `param:"query=b"` + SliceCustomUnmarshalerField []myComplicatedType `param:"query=c"` + OtherField string `param:"query=d"` +} + +func TestParser_Parse_QueryParam_Slice(t *testing.T) { + testCases := []struct { + name string + query string + expected structWithSlice + }{ + { + name: "multiple items", + query: "https://test.com/hello?a=vala1&a=vala2&b=valb1&b=valb2&c=valc1&c=valc2&d=vald", + expected: structWithSlice{ + SlicePrimitiveField: []string{"vala1", "vala2"}, + SliceCustomField: []myString{"valb1", "valb2"}, + SliceCustomUnmarshalerField: []myComplicatedType{{"myvalc1"}, {"myvalc2"}}, + OtherField: "vald", + }, + }, + { + name: "single item", + query: "https://test.com/hello?a=vala1&b=valb1&c=valc1&d=vald", + expected: structWithSlice{ + SlicePrimitiveField: []string{"vala1"}, + SliceCustomField: []myString{"valb1"}, + SliceCustomUnmarshalerField: []myComplicatedType{{"myvalc1"}}, + OtherField: "vald", + }, + }, + { + name: "no items", + query: "https://test.com/hello?something_else=hmm", + expected: structWithSlice{ + SlicePrimitiveField: nil, + SliceCustomField: nil, + SliceCustomUnmarshalerField: nil, + OtherField: "", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + parser := DefaultParser() + result := structWithSlice{ + SlicePrimitiveField: []string{"existing data should be overwritten in all cases"}, + OtherField: "in all tagged fields", + } + req := httptest.NewRequest(http.MethodGet, tc.query, nil) + err := parser.Parse(req, &result) + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + }) + } +} + +type structWithPrimitiveTypes struct { + Bool bool `param:"query=b"` + Int int `param:"query=i0"` + Int8 int8 `param:"query=i1"` + Int16 int16 `param:"query=i2"` + Int32 int32 `param:"query=i3"` + Int64 int64 `param:"query=i4"` + Uint uint `param:"query=u0"` + Uint8 uint8 `param:"query=u1"` + Uint16 uint16 `param:"query=u2"` + Uint32 uint32 `param:"query=u3"` + Uint64 uint64 `param:"query=u4"` + Float32 float32 `param:"query=f1"` + Float64 float64 `param:"query=f2"` + String string `param:"query=s"` + // nolint:unused + ignoredUnexported string `param:"query=ignored"` +} + +func TestParser_Parse_QueryParam_PrimitiveTypes(t *testing.T) { + query := "https://test.com/hello?b=true&i0=-32768&i1=-127&i2=-32768&i3=-2147483648&i4=-9223372036854775808&u0=65535&u1=255&u2=65535&u3=4294967295&u4=18446744073709551615&f1=3e38&f2=1e308&s=hello%20world%5C\"&ignored=hello" + expected := structWithPrimitiveTypes{ + Bool: true, + // chosen edge of range numbers most that are most likely to cause problems + Int: -32768, // assumes it's at least 16 bits :) + Int8: -127, + Int16: -32768, + Int32: -2147483648, + Int64: -9223372036854775808, + Uint: 65535, + Uint8: 255, + Uint16: 65535, + Uint32: 4294967295, + Uint64: 18446744073709551615, + Float32: 3e38, + Float64: 1e308, + String: "hello world\\\"", + } + + parser := DefaultParser() + result := structWithPrimitiveTypes{} + req := httptest.NewRequest(http.MethodGet, query, nil) + err := parser.Parse(req, &result) + assert.NoError(t, err) + assert.Equal(t, expected, result) +} + +type structWithPointers struct { + BoolPtr *bool `param:"query=b"` + IntPtr *int `param:"query=i"` + StrPtr *string `param:"query=s"` + Str2Ptr **string `param:"query=sp"` + UnmarshalerPtr *myComplicatedType `param:"query=c"` +} + +func TestParser_Parse_QueryParam_Pointers(t *testing.T) { + testCases := []struct { + name string + query string + expected structWithPointers + }{ + { + name: "filled", + query: "https://test.com/hello?b=true&i=42&s=somestring&sp=pointers&c=wow", + expected: structWithPointers{ + BoolPtr: ptr(true), + IntPtr: ptr(42), + StrPtr: ptr("somestring"), + Str2Ptr: ptr(ptr("pointers")), + UnmarshalerPtr: &myComplicatedType{"mywow"}, + }, + }, + { + name: "no params", + query: "https://test.com/hello", + expected: structWithPointers{ + BoolPtr: nil, + IntPtr: nil, + StrPtr: nil, + Str2Ptr: nil, + UnmarshalerPtr: nil, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + parser := DefaultParser() + result := structWithPointers{} + req := httptest.NewRequest(http.MethodGet, tc.query, nil) + err := parser.Parse(req, &result) + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + }) + } +} + +type valueReceiverUnmarshaler struct{} + +var valueReceiverResult string + +func (s valueReceiverUnmarshaler) UnmarshalText(bytes []byte) error { + valueReceiverResult = string(bytes) + return nil +} + +type StructWithValueReceiverUnmarshal struct { + Data valueReceiverUnmarshaler `param:"query=s"` +} + +func TestParser_Parse_QueryParam_ValueReceiverUnmarshaler(t *testing.T) { + query := "https://test.com/hello?s=changed" + valueReceiverResult = "orig" + parser := DefaultParser() + theStruct := StructWithValueReceiverUnmarshal{ + valueReceiverUnmarshaler{}, + } + req := httptest.NewRequest(http.MethodGet, query, nil) + err := parser.Parse(req, &theStruct) + assert.NoError(t, err) + assert.Equal(t, "changed", valueReceiverResult) +} + +func TestParser_Parse_QueryParam_MultipleToNonSlice(t *testing.T) { + testCases := []struct { + name string + query string + resultStruct any + }{ + { + name: "primitive type", + query: "https://test.com/hello?b=true&b=true", + resultStruct: &structWithPrimitiveTypes{}, + }, + { + name: "text unmarshaler", + query: "https://test.com/hello?c=yes&c=no", + resultStruct: &structWithPointers{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + parser := DefaultParser() + req := httptest.NewRequest(http.MethodGet, tc.query, nil) + err := parser.Parse(req, tc.resultStruct) + assert.Error(t, err) + }) + } +} + +func TestParser_Parse_QueryParam_InvalidType(t *testing.T) { + var str string + testCases := []struct { + name string + query string + resultStruct any + }{ + { + name: "not a pointer", + query: "https://test.com/hello?b=true", + resultStruct: structWithPrimitiveTypes{}, + }, + { + name: "pointer to not struct", + query: "https://test.com/hello", + resultStruct: &str, + }, + { + name: "map", + query: "https://test.com/hello?map=something", + resultStruct: &struct { + Map map[string]any `param:"query=map"` + }{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + parser := DefaultParser() + req := httptest.NewRequest(http.MethodGet, tc.query, nil) + err := parser.Parse(req, tc.resultStruct) + assert.Error(t, err) + }) + } +} + +func TestParser_Parse_QueryParam_CannotBeParsed(t *testing.T) { + testCases := []struct { + name string + query string + resultStruct any + errorTarget error + }{ + { + name: "invalid bool", + query: "https://test.com/hello?b=frue", + resultStruct: &structWithPrimitiveTypes{}, + errorTarget: strconv.ErrSyntax, + }, + { + name: "invalid int", + query: "https://test.com/hello?i0=18446744073709551615", + resultStruct: &structWithPrimitiveTypes{}, + errorTarget: strconv.ErrRange, + }, + { + name: "invalid int8", + query: "https://test.com/hello?i1=128", + resultStruct: &structWithPrimitiveTypes{}, + errorTarget: strconv.ErrRange, + }, + { + name: "invalid int16", + query: "https://test.com/hello?i2=32768", + resultStruct: &structWithPrimitiveTypes{}, + errorTarget: strconv.ErrRange, + }, + { + name: "invalid int32", + query: "https://test.com/hello?i3=2147483648", + resultStruct: &structWithPrimitiveTypes{}, + errorTarget: strconv.ErrRange, + }, + { + name: "invalid int64", + query: "https://test.com/hello?i4=18446744073709551615", + resultStruct: &structWithPrimitiveTypes{}, + errorTarget: strconv.ErrRange, + }, + { + name: "invalid uint", + query: "https://test.com/hello?u0=-1", + resultStruct: &structWithPrimitiveTypes{}, + errorTarget: strconv.ErrSyntax, + }, + { + name: "invalid uint8", + query: "https://test.com/hello?u1=-1", + resultStruct: &structWithPrimitiveTypes{}, + errorTarget: strconv.ErrSyntax, + }, + { + name: "invalid uint16", + query: "https://test.com/hello?u2=-1", + resultStruct: &structWithPrimitiveTypes{}, + errorTarget: strconv.ErrSyntax, + }, + { + name: "invalid uint32", + query: "https://test.com/hello?u3=-1", + resultStruct: &structWithPrimitiveTypes{}, + errorTarget: strconv.ErrSyntax, + }, + { + name: "invalid uint64", + query: "https://test.com/hello?u4=-1", + resultStruct: &structWithPrimitiveTypes{}, + errorTarget: strconv.ErrSyntax, + }, + { + name: "invalid float32", + query: "https://test.com/hello?f1=4e38", + resultStruct: &structWithPrimitiveTypes{}, + errorTarget: strconv.ErrRange, + }, + { + name: "invalid float64", + query: "https://test.com/hello?f2=1e309", + resultStruct: &structWithPrimitiveTypes{}, + errorTarget: strconv.ErrRange, + }, + { + name: "invalid int8 in slice", + query: "https://test.com/hello?x=127&x=128", + resultStruct: &struct { + Slice []int8 `param:"query=x"` + }{}, + errorTarget: strconv.ErrRange, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + parser := DefaultParser() + req := httptest.NewRequest(http.MethodGet, tc.query, nil) + err := parser.Parse(req, tc.resultStruct) + assert.ErrorIs(t, err, tc.errorTarget) + }) + } +} + +type maybeShinyObject struct { + IsShiny bool + Object string +} + +func (m *maybeShinyObject) UnmarshalText(text []byte) error { + if strings.HasPrefix(string(text), "shiny-") { + m.IsShiny = true + m.Object = string(text[6:]) + return nil + } + m.Object = string(text) + return nil +} + +type structWithPathParams struct { + Subject string `param:"path=subject"` + Amount *int `param:"path=amount"` + Object *maybeShinyObject `param:"path=object"` + Nothing string `param:"path=nothing"` +} + +func TestParser_Parse_PathParam(t *testing.T) { + r := chi.NewRouter() + p := DefaultParser().WithPathParamFunc(chi.URLParam) + result := structWithPathParams{Nothing: "should be replaced"} + expected := structWithPathParams{ + Subject: "world", + Amount: ptr(69), + Object: &maybeShinyObject{ + IsShiny: true, + Object: "apples", + }, + Nothing: "", + } + var parseError error + r.Get("/hello/{subject}/i/have/{amount}/{object}", func(w http.ResponseWriter, r *http.Request) { + parseError = p.Parse(r, &result) + }) + + req := httptest.NewRequest(http.MethodGet, "https://test.com/hello/world/i/have/69/shiny-apples", nil) + r.ServeHTTP(httptest.NewRecorder(), req) + + assert.NoError(t, parseError) + assert.Equal(t, expected, result) +} + +type simpleStringPathParamStruct struct { + Param int `param:"path=param"` +} + +func TestParser_Parse_PathParam_ParseError(t *testing.T) { + r := chi.NewRouter() + p := DefaultParser().WithPathParamFunc(chi.URLParam) + var parseError error + r.Get("/hello/{param}", func(w http.ResponseWriter, r *http.Request) { + parseError = p.Parse(r, &simpleStringPathParamStruct{}) + }) + + req := httptest.NewRequest(http.MethodGet, "https://test.com/hello/not-a-number", nil) + r.ServeHTTP(httptest.NewRecorder(), req) + + assert.Error(t, parseError) +} + +func TestParser_Parse_PathParam_FuncNotDefinedError(t *testing.T) { + p := DefaultParser() + req := httptest.NewRequest(http.MethodGet, "https://test.com/hello/not-a-number", nil) + + err := p.Parse(req, &simpleStringPathParamStruct{}) + + assert.Error(t, err) +} + +type variousTagsStruct struct { + A string `key:"location=val"` + B string `key:"location=val=excessive"` + C string `key:"no-equal-sign"` + D string `another:"location=val"` + E string `key:"another=val"` +} + +func TestTagWithModifierTagResolver(t *testing.T) { + const correctKey = "key" + const correctLocation = "location" + + testCases := []struct { + fieldName string + expectedParam string + expectedOk bool + }{ + { + fieldName: "A", + expectedParam: "val", + expectedOk: true, + }, + { + fieldName: "B", + expectedParam: "", + expectedOk: false, + }, + { + fieldName: "C", + expectedParam: "", + expectedOk: false, + }, + { + fieldName: "D", + expectedParam: "", + expectedOk: false, + }, + { + fieldName: "E", + expectedParam: "", + expectedOk: false, + }, + } + for _, tc := range testCases { + t.Run(tc.fieldName, func(t *testing.T) { + tagResolver := TagWithModifierTagResolver(correctKey, correctLocation) + structField, found := reflect.TypeOf(variousTagsStruct{}).FieldByName(tc.fieldName) + require.True(t, found) + + paramName, ok := tagResolver(structField.Tag) + + assert.Equal(t, tc.expectedParam, paramName) + assert.Equal(t, tc.expectedOk, ok) + }) + } +} + +func TestFixedTagNameParamTagResolver(t *testing.T) { + const correctKey = "key" + + testCases := []struct { + fieldName string + expectedParam string + expectedOk bool + }{ + { + fieldName: "A", + expectedParam: "location=val", + expectedOk: true, + }, + { + fieldName: "D", + expectedParam: "", + expectedOk: false, + }, + } + for _, tc := range testCases { + t.Run(tc.fieldName, func(t *testing.T) { + tagResolver := FixedTagNameParamTagResolver(correctKey) + structField, found := reflect.TypeOf(variousTagsStruct{}).FieldByName(tc.fieldName) + require.True(t, found) + + paramName, ok := tagResolver(structField.Tag) + + assert.Equal(t, tc.expectedParam, paramName) + assert.Equal(t, tc.expectedOk, ok) + }) + } +} + +func ptr[T any](x T) *T { + return &x +} diff --git a/http/signature/README.md b/http/signature/README.md new file mode 100644 index 0000000..6db3d21 --- /dev/null +++ b/http/signature/README.md @@ -0,0 +1,51 @@ +This package is intended to reduce duplication of common steps in http handlers implemented as `http.HandlerFunc`. +Those common steps are parsing an input from request, unmarshaling the result into `http.ResponseWriter` +and handling errors that can occur in any of those steps or inside the handler logic itself. + +It does this by allowing handlers to be defined with a new function signature, that can include an input type +as parameter, a response type in return values, and always has `error` as last return value. + +The handlers with enhanced signature can be than wrapped using function like `signature.WrapHandler` so it can be +used as a `http.HandlerFunc` type + +Example: +``` +main() { + w := signature.DefaultWrapper() + + r := chi.NewRouter() + r.Get("/endpoint1", signature.WrapHandler(w, handleEndpoint1)) +} + +func handleEndpoint1(w http.ResponseWriter, r *http.Request, input MyInputStruct) (MyResponseStruct, error){ + // access the input variable of type MyInputStruct, do handler logic, return MyResponseStruct or error + return theLogic(input) +} +``` + +Instead of using the repetetive http.HandlerFunc: +``` +main() { + r := chi.NewRouter() + r.Get("/endpoint1", handleEndpoint1) +} + +func handleEndpoint1(w http.ResponseWriter, r *http.Request) { + if err := parseIntoMyStruct(&MyInputStruct{}); err != nil { + writeError(w, err) + return + } + + // in this case, the actual logic is still just one line + result, err := theLogic(input) + if err != nil { + writeError(w, err) + return + } + + if err := writeResult(result); err != nil { + writeError(w, err) + return + } +} +``` diff --git a/http/signature/signature.go b/http/signature/signature.go new file mode 100644 index 0000000..e31cd9e --- /dev/null +++ b/http/signature/signature.go @@ -0,0 +1,236 @@ +package signature + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + httpx "go.strv.io/net/http" +) + +var ( + // ErrInputGet is passed to ErrorHandlerFunc when WrapHandler (or derived) fails in the first step (parsing input) + ErrInputGet = errors.New("parsing input") + // ErrInnerHandler is passed to ErrorHandlerFunc when WrapHandler (or derived) fails in the second step (inner handler) + ErrInnerHandler = errors.New("inner handler") + // ErrResponseMarshal is passed to ErrorHandlerFunc when WrapHandler (or derived) fails in the third step (marshaling response object) + ErrResponseMarshal = errors.New("marshaling response") +) + +// InputGetterFunc is a function that is used in WrapHandler and WrapHandlerInput to parse request into declared input type. +// Before calling the inner handler, the InputGetterFunc is called to fill the struct that is then passed to the inner handler. +// If inner handler does not declare an input type (i.e. WrapHandlerResponse and WrapHandlerError), this function is not called at all. +type InputGetterFunc func(r *http.Request, dest any) error + +// ResponseMarshalerFunc is a function that is used in WrapHandler and related functions to marshal declared response type. +// After the inner handler succeeds, the ResponseMarshalerFunc receives http.ResponseWriter and http.Request of handled request, +// and a type that an inner handler function declared as its first (non-error) return value. +// If the inner handler does not declare such return value (i.e. for WrapHandlerInput and WrapHandlerError), +// the ResponseMarshalerFunc receives http.NoBody as the src parameter. +type ResponseMarshalerFunc func(w http.ResponseWriter, r *http.Request, src any) error + +// ErrorHandlerFunc is a function that is used in WrapHandler and related functions if any of the steps fail. +// The passed err is wrapped in one of ErrInputGet, ErrInnerHandler or ErrResponseMarshal to distinguish the +// step that failed. +// +// Note that if the error occurs on unmarshaling response with still valid http.ResponseWriter, +// and that step already wrote into the writer, the unmarshaled response (including e.g. http headers) +// may be inconsistent if error handler also writes. +type ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error) + +// Wrapper needs to be passed to WrapHandler and related functions. It contains the common handling of parsing http.Request +// to needed type, marshaling the response of needed type, and handling the errors that occur in any of those steps or +// in the inner handler (with modified signature) +type Wrapper struct { + inputGetter InputGetterFunc + responseMarshaler ResponseMarshalerFunc + errorHandler ErrorHandlerFunc +} + +// DefaultWrapper Creates a Wrapper with default functions for each needed step. +// +// Input is parsed only from http.Request body, using JSON unmarshal. +// A custom InputGetterFunc is needed to parse also the query and path parameters, but param package can be used to do most. +// +// Response is marshaled using a WriteResponse wrapper in parent package, which uses JSON marshal. +// +// Error handler also uses a WriteErrorResponse of parent package. +// It is recommended to replace this to implement any custom error handling (matching any application errors). +// Default handler only returns http code 400 on unmarshal error and 500 otherwise. +func DefaultWrapper() Wrapper { + return Wrapper{ + inputGetter: UnmarshalRequestBody, + responseMarshaler: DefaultResponseMarshal, + errorHandler: InputGetErrorHandle, + } +} + +// WithInputGetter returns a copy of Wrapper with new InputGetterFunc +func (w Wrapper) WithInputGetter(f InputGetterFunc) Wrapper { + w.inputGetter = f + return w +} + +// WithResponseMarshaler returns a copy of Wrapper with new ResponseMarshalerFunc +func (w Wrapper) WithResponseMarshaler(f ResponseMarshalerFunc) Wrapper { + w.responseMarshaler = f + return w +} + +// WithErrorHandler returns a copy of Wrapper with new ErrorHandlerFunc +func (w Wrapper) WithErrorHandler(f ErrorHandlerFunc) Wrapper { + w.errorHandler = f + return w +} + +func inputErrorWithType(target any, innerError error) error { + return fmt.Errorf("%w into type %T: %w", ErrInputGet, target, innerError) +} + +func responseErrorWithType(src any, innerError error) error { + if src == nil { + return fmt.Errorf("%w without response object: %w", ErrResponseMarshal, innerError) + } + return fmt.Errorf("%w from type %T: %w", ErrResponseMarshal, src, innerError) +} + +func wrapInnerHandlerError(innerError error) error { + return fmt.Errorf("%w: %w", ErrInnerHandler, innerError) +} + +// WrapHandler enables a handler with signature of second parameter to be used as a http.HandlerFunc. +// 1. Before calling such inner handler, the http.request is used +// to get the input parameter of type TInput for the handler, using InputGetterFunc in Wrapper. +// 2. Then the inner handler is called with such created TInput. +// 3. If the handler succeeds (returns nil error), The first return value +// (of type TResponse) is passed to ResponseMarshalerFunc of Wrapper. +// If any of the above steps returns error, the ErrorHandlerFunc is called with that error. +func WrapHandler[TInput any, TResponse any](wrapper Wrapper, handler func(http.ResponseWriter, *http.Request, TInput) (TResponse, error)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var input TInput + err := wrapper.inputGetter(r, &input) + if err != nil { + wrapper.errorHandler(w, r, inputErrorWithType(input, err)) + return + } + response, err := handler(w, r, input) + if err != nil { + wrapper.errorHandler(w, r, wrapInnerHandlerError(err)) + return + } + err = wrapper.responseMarshaler(w, r, response) + if err != nil { + wrapper.errorHandler(w, r, responseErrorWithType(response, err)) + return + } + } +} + +// WrapHandlerResponse enables a handler with signature of second parameter to be used as a http.HandlerFunc. +// See WrapHandler for general idea. +// Compared to WrapHandler, the first step is skipped (no parsed input for inner handler is provided) +func WrapHandlerResponse[TResponse any](wrapper Wrapper, handler func(http.ResponseWriter, *http.Request) (TResponse, error)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + response, err := handler(w, r) + if err != nil { + wrapper.errorHandler(w, r, wrapInnerHandlerError(err)) + return + } + err = wrapper.responseMarshaler(w, r, response) + if err != nil { + wrapper.errorHandler(w, r, responseErrorWithType(response, err)) + return + } + } +} + +// WrapHandlerInput enables a handler with signature of second parameter to be used as a http.HandlerFunc. +// See WrapHandler for general idea. +// Compared to WrapHandler, in the last step, the ResponseMarshalerFunc receives http.NoBody as a response object +// (and as such, the ResponseMarshalerFunc should handle the http.NoBody value gracefully) +func WrapHandlerInput[TInput any](wrapper Wrapper, handler func(http.ResponseWriter, *http.Request, TInput) error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var input TInput + err := wrapper.inputGetter(r, &input) + if err != nil { + wrapper.errorHandler(w, r, inputErrorWithType(input, err)) + return + } + err = handler(w, r, input) + if err != nil { + wrapper.errorHandler(w, r, wrapInnerHandlerError(err)) + return + } + err = wrapper.responseMarshaler(w, r, http.NoBody) + if err != nil { + wrapper.errorHandler(w, r, responseErrorWithType(nil, err)) + return + } + } +} + +// WrapHandlerError enables a handler with signature of second parameter to be used as a http.HandlerFunc. +// See WrapHandler for general idea. +// Compared to WrapHandler, the first step is skipped (no parsed input for inner handler is provided), +// and in the last step, the ResponseMarshalerFunc receives http.NoBody as a response object +// (and as such, the ResponseMarshalerFunc should handle the http.NoBody value gracefully) +func WrapHandlerError(wrapper Wrapper, handler func(http.ResponseWriter, *http.Request) error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + err := handler(w, r) + if err != nil { + wrapper.errorHandler(w, r, wrapInnerHandlerError(err)) + return + } + err = wrapper.responseMarshaler(w, r, http.NoBody) + if err != nil { + wrapper.errorHandler(w, r, responseErrorWithType(nil, err)) + return + } + } +} + +// UnmarshalRequestBody decodes a body into a struct. +// This function expects the request body to be a JSON object and target to be a pointer to expected struct. +// If the request body is invalid, it returns an error. +func UnmarshalRequestBody(r *http.Request, target any) error { + if err := json.NewDecoder(r.Body).Decode(target); err != nil { + return err + } + return nil +} + +// FixedResponseCodeMarshal returns a ResponseMarshalerFunc that always writes provided http status code on success. +func FixedResponseCodeMarshal(statusCode int) ResponseMarshalerFunc { + return func(w http.ResponseWriter, _ *http.Request, obj any) error { + return httpx.WriteResponse(w, obj, statusCode) + } +} + +// DefaultResponseMarshal is a ResponseMarshalerFunc that writes 200 OK http status code with JSON marshaled object. +// 204 No Content http status code is returned if no response object is provided (i.e. when using WrapHandlerInput or WrapHandlerError) +func DefaultResponseMarshal(w http.ResponseWriter, _ *http.Request, src any) error { + if src == http.NoBody { + return httpx.WriteResponse(w, src, http.StatusNoContent) + } + return httpx.WriteResponse(w, src, http.StatusOK) +} + +// AlwaysInternalErrorHandle is a function usable as ErrorHandlerFunc. +// It writes 500 http status code on error. +// Error message not returned in response and is lost. +func AlwaysInternalErrorHandle(w http.ResponseWriter, _ *http.Request, _ error) { + _ = httpx.WriteErrorResponse(w, http.StatusInternalServerError) +} + +// InputGetErrorHandle is a function usable as ErrorHandlerFunc. +// It writes a 400 Bad Request http status code to http.ResponseWriter if the error is from parsing input. +// Otherwise, writes 500 Internal Server Error http status code on error. +// In either case, error message is not returned in response and is lost +func InputGetErrorHandle(w http.ResponseWriter, r *http.Request, err error) { + if errors.Is(err, ErrInputGet) { + _ = httpx.WriteErrorResponse(w, http.StatusBadRequest) + return + } + AlwaysInternalErrorHandle(w, r, err) +} diff --git a/http/signature/signature_test.go b/http/signature/signature_test.go new file mode 100644 index 0000000..8ff42bd --- /dev/null +++ b/http/signature/signature_test.go @@ -0,0 +1,339 @@ +package signature_test + +import ( + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + + httpparam "go.strv.io/net/http/param" + "go.strv.io/net/http/signature" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" +) + +type User struct { + UserName string `json:"user_name"` + Group int `json:"group"` +} + +type ListUsersInput struct { + Group int `param:"path=group"` + Page int `param:"query=page"` + PerPage int `param:"query=per_page"` +} + +type CreateUserInput struct { + UserName string `json:"user_name"` + Group int `json:"group"` +} + +func hasStructJSONTag(obj any) bool { + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Pointer { + v = v.Elem() + } + if v.Kind() != reflect.Struct { + return false + } + t := v.Type() + for i := 0; i < t.NumField(); i++ { + _, exists := t.Field(i).Tag.Lookup("json") + if exists { + return true + } + } + return false +} + +func parseInputFunc(r *http.Request, dest any) error { + // Don't call json unmarshal if dest has no json tag, which means request body may be empty, + // as only expected input are path and query parameters. + // causes error if json.Unmarshal is called on empty body (EOF) + if hasStructJSONTag(dest) { + if err := signature.UnmarshalRequestBody(r, dest); err != nil { + return err + } + } + + // After UnmarshalRequestBody, as it possibly fills all fields, even those tagged as query or path parameter. + // This way, such filled fields will be reassigned in httpparam.Parser. + return httpparam.DefaultParser().WithPathParamFunc(chi.URLParam).Parse(r, dest) +} + +func TestWrapper(t *testing.T) { + testCases := []struct { + method string + url string + inputBody string + expectedBody string + expectedStatus int + }{ + { + method: http.MethodGet, + url: "https://test.com/healthcheck", + inputBody: "", + expectedBody: "", + expectedStatus: http.StatusNoContent, + }, + { + method: http.MethodGet, + url: "https://test.com/dependency-check", + inputBody: "", + expectedBody: `{"payment-provider":"ready","company-registry":"unreachable"}`, + expectedStatus: http.StatusOK, + }, + { + method: http.MethodGet, + url: "https://test.com/group/55/users", + inputBody: "", + expectedBody: `[{"user_name":"Testowic","group":55}]`, + expectedStatus: http.StatusOK, + }, + { + method: http.MethodGet, + url: "https://test.com/users", + inputBody: "", + expectedBody: `[{"user_name":"Testowic","group":0}]`, + expectedStatus: http.StatusOK, + }, + { + method: http.MethodPost, + url: "https://test.com/users", + inputBody: `{"user_name":"NewUser","group":5}`, + expectedBody: "", + expectedStatus: http.StatusCreated, + }, + } + for _, tc := range testCases { + t.Run(tc.method+" "+tc.url, func(t *testing.T) { + w := signature.DefaultWrapper(). + WithInputGetter(parseInputFunc) + + r := chi.NewRouter() + + r.Get("/healthcheck", signature.WrapHandlerError(w, healthcheckHandler)) + r.Get("/dependency-check", signature.WrapHandlerResponse(w, dependencyCheckHandler)) + r.Get("/group/{group}/users", signature.WrapHandler(w, listUsersHandler)) + r.Route("/users", func(r chi.Router) { + r.Get("/", signature.WrapHandler(w, listUsersHandler)) + r.Post("/", signature.WrapHandlerInput( + w.WithResponseMarshaler(signature.FixedResponseCodeMarshal(http.StatusCreated)), + createUserHandler, + )) + }) + + var body io.Reader + if tc.inputBody != "" { + body = strings.NewReader(tc.inputBody) + } + req := httptest.NewRequest(tc.method, tc.url, body) + rec := httptest.NewRecorder() + + r.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedStatus, rec.Code) + if tc.expectedBody != "" { + assert.JSONEq(t, tc.expectedBody, rec.Body.String()) + } else { + assert.Nil(t, rec.Body.Bytes()) + } + }) + } +} + +func listUsersHandler(_ http.ResponseWriter, _ *http.Request, input ListUsersInput) ([]User, error) { + return []User{{ + UserName: "Testowic", + Group: input.Group, + }}, nil +} + +func createUserHandler(_ http.ResponseWriter, _ *http.Request, _ CreateUserInput) error { + return nil +} + +func healthcheckHandler(_ http.ResponseWriter, _ *http.Request) error { + return nil +} + +func dependencyCheckHandler(_ http.ResponseWriter, _ *http.Request) (map[string]DependencyStatus, error) { + return map[string]DependencyStatus{ + "payment-provider": DependencyStatusReady, + "company-registry": DependencyStatusUnreachable, + }, nil +} + +type DependencyStatus int + +const ( + DependencyStatusReady DependencyStatus = iota + 1 + DependencyStatusUnreachable +) + +func (d DependencyStatus) MarshalText() ([]byte, error) { + var name string + switch d { + case DependencyStatusReady: + name = "ready" + case DependencyStatusUnreachable: + name = "unreachable" + default: + return nil, fmt.Errorf("invalid DependencyStatus value (%d) when marshaling", d) + } + return []byte(name), nil +} + +func TestWrapper_Error(t *testing.T) { + w := signature.DefaultWrapper() + + var interceptedError error + w = w.WithErrorHandler(func(w http.ResponseWriter, r *http.Request, err error) { + interceptedError = err + signature.InputGetErrorHandle(w, r, err) + }) + + testCases := []struct { + name string + inputBody string + expectedBody string + expectedStatus int + handler http.Handler + targetErr error + isABug bool + }{ + { + name: "parsing body returns 400", + inputBody: `{"incomplete_json":`, + expectedStatus: http.StatusBadRequest, + handler: signature.WrapHandler(w, buggyHandler), + targetErr: signature.ErrInputGet, + isABug: false, + }, + { + name: "internal handler error returns 500", + inputBody: `{"bug":true}`, + expectedStatus: http.StatusInternalServerError, + handler: signature.WrapHandler(w, buggyHandler), + targetErr: signature.ErrInnerHandler, + isABug: true, + }, + { + name: "marshaling error returns 500? Well actually 200", + inputBody: `{"bug":false}`, + // Header was already written at the time of marshal error, there is no way to change it. + // The only way would be to unmarshal the object into buffer to see if it returns error. + // + // This is behaviour of httptest.ResponseRecorder.WriteHeader(), but the http.response.WriteHeader() + // behaves the same way. I guess it's not good to have TextMarshalers that can error on valid ResponseWriter + expectedStatus: http.StatusOK, + handler: signature.WrapHandler(w, buggyHandler), + targetErr: signature.ErrResponseMarshal, + isABug: true, + }, + { + name: "parsing body returns 400 (only input)", + inputBody: `{"incomplete_json":`, + expectedStatus: http.StatusBadRequest, + handler: signature.WrapHandlerInput(w, buggyHandlerInput), + targetErr: signature.ErrInputGet, + isABug: false, + }, + { + name: "internal handler error returns 500 (only input)", + inputBody: `{"bug":true}`, + expectedStatus: http.StatusInternalServerError, + handler: signature.WrapHandlerInput(w, buggyHandlerInput), + targetErr: signature.ErrInnerHandler, + isABug: true, + }, + { + name: "internal handler error returns 500 (only response)", + inputBody: "", + expectedStatus: http.StatusInternalServerError, + handler: signature.WrapHandlerResponse(w, buggyHandlerResponse), + targetErr: signature.ErrInnerHandler, + isABug: true, + }, + { + name: "marshaling error returns 200 (only response)", + inputBody: "", + expectedStatus: http.StatusOK, // same as above problem + handler: signature.WrapHandlerResponse(w, buggyHandlerBuggyResponse), + targetErr: signature.ErrResponseMarshal, + isABug: true, + }, + { + name: "internal handler error returns 500 (only error)", + inputBody: "", + expectedStatus: http.StatusInternalServerError, + handler: signature.WrapHandlerError(w, buggyHandlerError), + targetErr: signature.ErrInnerHandler, + isABug: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var body io.Reader + if tc.inputBody != "" { + body = strings.NewReader(tc.inputBody) + } + req := httptest.NewRequest(http.MethodGet, "https://test.com/error", body) + rec := httptest.NewRecorder() + + tc.handler.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedStatus, rec.Code) + assert.ErrorIs(t, interceptedError, tc.targetErr) + if tc.isABug { + assert.ErrorIs(t, interceptedError, errBug) + } + assert.JSONEq(t, `{"errorCode":"ERR_UNKNOWN"}`, rec.Body.String()) + }) + } +} + +var errBug = errors.New("that's a bug, and should propagate properly") + +type willBug bool + +func (w willBug) MarshalText() (text []byte, err error) { + return nil, errBug +} + +type buggyInputOrOutput struct { + WillBug *willBug `json:"bug"` +} + +func buggyHandler(_ http.ResponseWriter, _ *http.Request, input buggyInputOrOutput) (*buggyInputOrOutput, error) { + if input.WillBug != nil && *input.WillBug { + return nil, errBug + } + x := willBug(true) + return &buggyInputOrOutput{WillBug: &x}, nil +} + +func buggyHandlerResponse(_ http.ResponseWriter, _ *http.Request) (*buggyInputOrOutput, error) { + return nil, errBug +} + +func buggyHandlerBuggyResponse(_ http.ResponseWriter, _ *http.Request) (*buggyInputOrOutput, error) { + x := willBug(true) + return &buggyInputOrOutput{WillBug: &x}, nil +} + +func buggyHandlerInput(_ http.ResponseWriter, _ *http.Request, input buggyInputOrOutput) error { + if input.WillBug != nil && *input.WillBug { + return errBug + } + return nil +} + +func buggyHandlerError(_ http.ResponseWriter, _ *http.Request) error { + return errBug +}