Skip to content

Commit

Permalink
fix(tests): added unit test for Binder
Browse files Browse the repository at this point in the history
  • Loading branch information
cnlangzi committed Dec 25, 2024
1 parent 9c78e4b commit a069a0e
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 15 deletions.
19 changes: 4 additions & 15 deletions binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ import (
"net/http"

"github.com/go-playground/form/v4"
"github.com/go-playground/locales/en"
"github.com/go-playground/locales/zh"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
jsoniter "github.com/json-iterator/go"
)
Expand All @@ -21,11 +18,6 @@ var (
formDecoder = form.NewDecoder()
)

var (
uni = ut.New(en.New(), en.New(), zh.New())
validate = validator.New()
)

func BindQuery[T any](req *http.Request) (*TEntity[T], error) {

data := new(T)
Expand Down Expand Up @@ -63,7 +55,7 @@ func BindForm[T any](req *http.Request) (*TEntity[T], error) {

}

func BindJSON[T any](req *http.Request) (*TEntity[T], error) {
func BindJson[T any](req *http.Request) (*TEntity[T], error) {
data := new(T)

buf, err := io.ReadAll(req.Body)
Expand All @@ -90,25 +82,22 @@ type TEntity[T any] struct {
}

func (t *TEntity[T]) Validate(languages ...string) bool {
validate := findValidator(languages...)

err := validate.Struct(t.Data)
if err == nil {
return true
}

errs := err.(validator.ValidationErrors)

ut, ok := uni.FindTranslator(languages...)
if !ok {
ut = uni.GetFallback()
}

for _, err := range errs {
n := err.Field()
if n == "" {
n = err.StructField()
}

t.Errors[n] = err.Translate(ut)
t.Errors[n] = err.Translate(validate.Translator)
}

return false
Expand Down
156 changes: 156 additions & 0 deletions binder_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
package htmx

import (
"bytes"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/go-playground/locales/zh"
ut "github.com/go-playground/universal-translator"
trans "github.com/go-playground/validator/v10/translations/zh"
"github.com/stretchr/testify/require"
)

func TestBinder(t *testing.T) {
Expand All @@ -13,7 +21,155 @@ func TestBinder(t *testing.T) {

app := New(WithMux(mux))

type Login struct {
Email string `form:"email" json:"email" validate:"required,email"`
Passwd string `json:"passwd" validate:"required"`
}

AddValidator(ut.New(zh.New()).GetFallback(), trans.RegisterDefaultTranslations)

app.Get("/login", func(c *Context) error {
it, err := BindQuery[Login](c.Request())
if err != nil {
c.WriteStatus(http.StatusBadRequest)
return ErrCancelled
}

if it.Validate(c.AcceptLanguage()...) && it.Data.Email == "[email protected]" && it.Data.Passwd == "123" {
return c.View(it)
}
c.WriteStatus(http.StatusBadRequest)
return c.View(it)
})

app.Post("/login", func(c *Context) error {
it, err := BindForm[Login](c.Request())
if err != nil {
c.WriteStatus(http.StatusBadRequest)
return ErrCancelled
}

if it.Validate(c.AcceptLanguage()...) && it.Data.Email == "[email protected]" && it.Data.Passwd == "123" {
return c.View(it)
}
c.WriteStatus(http.StatusBadRequest)
return c.View(it)
})

app.Put("/login", func(c *Context) error {
it, err := BindJson[Login](c.Request())
if err != nil {
c.WriteStatus(http.StatusBadRequest)
return ErrCancelled
}

if it.Validate(c.AcceptLanguage()...) && it.Data.Email == "[email protected]" && it.Data.Passwd == "123" {
return c.View(it)
}
c.WriteStatus(http.StatusBadRequest)
return c.View(it)
})

app.Start()
defer app.Close()

var tests = []struct {
Name string
NewRequest func(it Login) *http.Request
}{
{
"BindQuery",
func(it Login) *http.Request {
req, _ := http.NewRequest("GET", srv.URL+"/login?email="+url.QueryEscape(it.Email)+"&Passwd="+url.QueryEscape(it.Passwd), nil)
return req
},
},

{
"BindForm",
func(it Login) *http.Request {
form := url.Values{}
form.Add("email", it.Email)
form.Add("Passwd", it.Passwd)

req, _ := http.NewRequest("POST", srv.URL+"/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req
},
},

{
"BindJson",
func(it Login) *http.Request {
buf, _ := json.Marshal(it)

req, _ := http.NewRequest("PUT", srv.URL+"/login", bytes.NewReader(buf))
// req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req
},
},
}

for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
var result TEntity[Login]

req := test.NewRequest(Login{Email: "[email protected]", Passwd: "123"})
resp, err := client.Do(req)
require.NoError(t, err)

require.Equal(t, http.StatusOK, resp.StatusCode)

err = json.NewDecoder(resp.Body).Decode(&result)
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, "[email protected]", result.Data.Email)
require.Equal(t, "123", result.Data.Passwd)
require.Len(t, result.Errors, 0)

req = test.NewRequest(Login{Email: "[email protected]", Passwd: "abc"})
resp, err = client.Do(req)
require.NoError(t, err)

require.Equal(t, http.StatusBadRequest, resp.StatusCode)

err = json.NewDecoder(resp.Body).Decode(&result)
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, "[email protected]", result.Data.Email)
require.Equal(t, "abc", result.Data.Passwd)
require.Len(t, result.Errors, 0)

req = test.NewRequest(Login{Email: "htmx"})
req.Header.Set("accept-language", "en-US,en;q=0.9,zh;q=0.8,zh-CN;q=0.7,zh-TW;q=0.6")
// req.Header.Set("accept-language", "zh-CN,zh;q=0.9,en;q=0.8,en-US;q=0.7,zh-TW;q=0.6")
resp, err = client.Do(req)
require.NoError(t, err)

require.Equal(t, http.StatusBadRequest, resp.StatusCode)

err = json.NewDecoder(resp.Body).Decode(&result)
require.NoError(t, err)
resp.Body.Close()
require.Len(t, result.Errors, 2)
require.Equal(t, "Email must be a valid email address", result.Errors["Email"])
require.Equal(t, "Passwd is a required field", result.Errors["Passwd"])

req = test.NewRequest(Login{Email: "htmx"})
// req.Header.Set("accept-language", "en-US,en;q=0.9,zh;q=0.8,zh-CN;q=0.7,zh-TW;q=0.6")
req.Header.Set("accept-language", "zh-CN,zh;q=0.9,en;q=0.8,en-US;q=0.7,zh-TW;q=0.6")
resp, err = client.Do(req)
require.NoError(t, err)

require.Equal(t, http.StatusBadRequest, resp.StatusCode)

err = json.NewDecoder(resp.Body).Decode(&result)
require.NoError(t, err)
resp.Body.Close()
require.Len(t, result.Errors, 2)
require.Equal(t, "Email必须是一个有效的邮箱", result.Errors["Email"])
require.Equal(t, "Passwd为必填字段", result.Errors["Passwd"])
})
}

}
46 changes: 46 additions & 0 deletions validate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package htmx

import (
"github.com/go-playground/locales/en"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
trans "github.com/go-playground/validator/v10/translations/en"
)

type Validator struct {
*validator.Validate
Translator ut.Translator
}

var (
validators = make(map[string]*Validator)
defaultValidator *Validator
)

func init() {
uni := ut.New(en.New())
defaultValidator = AddValidator(uni.GetFallback(), trans.RegisterDefaultTranslations)

trans.RegisterDefaultTranslations(defaultValidator.Validate, defaultValidator.Translator) // nolint: errcheck

validators[defaultValidator.Translator.Locale()] = defaultValidator
}

func AddValidator(trans ut.Translator, register func(v *validator.Validate, trans ut.Translator) (err error)) *Validator {
v := &Validator{
Validate: validator.New(),
Translator: trans,
}
register(v.Validate, v.Translator) //nolint: errcheck
validators[trans.Locale()] = v
return v
}

func findValidator(locales ...string) *Validator {
for _, locale := range locales {
if v, ok := validators[locale]; ok {
return v
}
}
return defaultValidator
}

0 comments on commit a069a0e

Please sign in to comment.