Skip to content

Commit

Permalink
Support object data types for header params
Browse files Browse the repository at this point in the history
Add initial struct test for header names and validation.
  • Loading branch information
tanenbaum committed Jan 21, 2024
1 parent 76695ca commit 37c4347
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 23 deletions.
10 changes: 9 additions & 1 deletion field_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,21 @@ func (ps *tagBaseFieldParser) FieldName() (string, error) {
}
}

func (ps *tagBaseFieldParser) FormName() string {
func (ps *tagBaseFieldParser) firstTagValue(tag string) string {
if ps.field.Tag != nil {
return strings.TrimRight(strings.TrimSpace(strings.Split(ps.tag.Get(formTag), ",")[0]), "[]")
}
return ""
}

func (ps *tagBaseFieldParser) FormName() string {
return ps.firstTagValue(formTag)
}

func (ps *tagBaseFieldParser) HeaderName() string {
return ps.firstTagValue(headerTag)
}

func toSnakeCase(in string) string {
var (
runes = []rune(in)
Expand Down
25 changes: 8 additions & 17 deletions operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,16 +286,7 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F
param := createParameter(paramType, description, name, objectType, refType, required, enums, operation.parser.collectionFormatInQuery)

switch paramType {
case "path", "header":
switch objectType {
case ARRAY:
if !IsPrimitiveType(refType) {
return fmt.Errorf("%s is not supported array type for %s", refType, paramType)
}
case OBJECT:
return fmt.Errorf("%s is not supported type for %s", refType, paramType)
}
case "query", "formData":
case "path", "header", "query", "formData":
switch objectType {
case ARRAY:
if !IsPrimitiveType(refType) && !(refType == "file" && paramType == "formData") {
Expand Down Expand Up @@ -324,11 +315,9 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F
}
}

var formName = name
if item.Schema.Extensions != nil {
if nameVal, ok := item.Schema.Extensions[formTag]; ok {
formName = nameVal.(string)
}
// load overridden type specific name from extensions if exists
if nameVal, ok := item.Schema.Extensions[paramType]; ok {
name = nameVal.(string)
}

switch {
Expand All @@ -346,10 +335,10 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F
if !IsSimplePrimitiveType(itemSchema.Type[0]) {
continue
}
param = createParameter(paramType, prop.Description, formName, prop.Type[0], itemSchema.Type[0], findInSlice(schema.Required, name), itemSchema.Enum, operation.parser.collectionFormatInQuery)
param = createParameter(paramType, prop.Description, name, prop.Type[0], itemSchema.Type[0], findInSlice(schema.Required, item.Name), itemSchema.Enum, operation.parser.collectionFormatInQuery)

case IsSimplePrimitiveType(prop.Type[0]):
param = createParameter(paramType, prop.Description, formName, PRIMITIVE, prop.Type[0], findInSlice(schema.Required, name), nil, operation.parser.collectionFormatInQuery)
param = createParameter(paramType, prop.Description, name, PRIMITIVE, prop.Type[0], findInSlice(schema.Required, item.Name), nil, operation.parser.collectionFormatInQuery)
default:
operation.parser.debug.Printf("skip field [%s] in %s is not supported type for %s", name, refType, paramType)
continue
Expand Down Expand Up @@ -406,6 +395,8 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F
const (
formTag = "form"
jsonTag = "json"
uriTag = "uri"
headerTag = "header"
bindingTag = "binding"
defaultTag = "default"
enumsTag = "enums"
Expand Down
73 changes: 72 additions & 1 deletion operation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1177,11 +1177,17 @@ func TestOperation_ParseParamComment(t *testing.T) {
t.Parallel()
for _, paramType := range []string{"header", "path", "query", "formData"} {
t.Run(paramType, func(t *testing.T) {
// unknown object returns error
assert.Error(t, NewOperation(nil).ParseComment(`@Param some_object `+paramType+` main.Object true "Some Object"`, nil))

// verify objects are supported here
o := NewOperation(nil)
o.parser.addTestType("main.TestObject")
err := o.ParseComment(`@Param some_object `+paramType+` main.TestObject true "Some Object"`, nil)
assert.NoError(t, err)
})
}
})

}

// Test ParseParamComment Query Params
Expand Down Expand Up @@ -2067,6 +2073,71 @@ func TestParseParamCommentByExtensions(t *testing.T) {
assert.Equal(t, expected, string(b))
}

func TestParseParamStructCodeExample(t *testing.T) {
t.Parallel()

fset := token.NewFileSet()
ast, err := goparser.ParseFile(fset, "operation_test.go", `package swag
import structs "github.com/swaggo/swag/testdata/param_structs"
`, goparser.ParseComments)
assert.NoError(t, err)

parser := New()
err = parser.parseFile("github.com/swaggo/swag/testdata/param_structs", "testdata/param_structs/structs.go", nil, ParseModels)
assert.NoError(t, err)
_, err = parser.packages.ParseTypes()
assert.NoError(t, err)

validateParameters := func(operation *Operation, params ...spec.Parameter) {
assert.Equal(t, len(params), len(operation.Parameters))

for _, param := range params {
found := false
for _, p := range operation.Parameters {
if p.Name == param.Name {
assert.Equal(t, param.ParamProps, p.ParamProps)
assert.Equal(t, param.CommonValidations, p.CommonValidations)
found = true
break
}
}
assert.True(t, found, "found parameter %s", param.Name)
}
}

// values used in validation checks
max := float64(10)
min := float64(0)

t.Run("Header struct", func(t *testing.T) {
operation := NewOperation(parser)
comment := `@Param auth header structs.AuthHeader true "auth header"`
err = operation.ParseComment(comment, ast)
assert.NoError(t, err)

validateParameters(operation,
spec.Parameter{
ParamProps: spec.ParamProps{
Name: "X-Auth-Token",
Description: "Token is the auth token",
In: "header",
Required: true,
},
}, spec.Parameter{
ParamProps: spec.ParamProps{
Name: "anotherHeader",
Description: "AnotherHeader is another header",
In: "header",
Required: false,
},
CommonValidations: spec.CommonValidations{
Maximum: &max,
Minimum: &min,
},
})
})
}

func TestParseIdComment(t *testing.T) {
t.Parallel()

Expand Down
12 changes: 8 additions & 4 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ type FieldParser interface {
ShouldSkip() bool
FieldName() (string, error)
FormName() string
HeaderName() string
CustomSchema() (*spec.Schema, error)
ComplementSchema(schema *spec.Schema) error
IsRequired() (bool, error)
Expand Down Expand Up @@ -1506,11 +1507,14 @@ func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[st
tagRequired = append(tagRequired, fieldName)
}

if schema.Extensions == nil {
schema.Extensions = make(spec.Extensions)
}
if formName := ps.FormName(); len(formName) > 0 {
if schema.Extensions == nil {
schema.Extensions = make(spec.Extensions)
}
schema.Extensions[formTag] = formName
schema.Extensions["formData"] = formName
}
if headerName := ps.HeaderName(); len(headerName) > 0 {
schema.Extensions["header"] = headerName
}

return map[string]spec.Schema{fieldName: *schema}, tagRequired, nil
Expand Down
8 changes: 8 additions & 0 deletions testdata/param_structs/structs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package structs

type AuthHeader struct {
// Token is the auth token
Token string `header:"X-Auth-Token" binding:"required"`
// AnotherHeader is another header
AnotherHeader int `validate:"gte=0,lte=10"`
}

0 comments on commit 37c4347

Please sign in to comment.