diff --git a/binder.go b/binder.go index 33ffd57..0b537b6 100644 --- a/binder.go +++ b/binder.go @@ -7,9 +7,6 @@ import ( "net/http" "github.com/go-playground/form/v4" - "github.com/go-playground/locales/en" - "github.com/go-playground/locales/zh" - ut "github.com/go-playground/universal-translator" "github.com/go-playground/validator/v10" jsoniter "github.com/json-iterator/go" ) @@ -21,11 +18,6 @@ var ( formDecoder = form.NewDecoder() ) -var ( - uni = ut.New(en.New(), en.New(), zh.New()) - validate = validator.New() -) - func BindQuery[T any](req *http.Request) (*TEntity[T], error) { data := new(T) @@ -63,7 +55,7 @@ func BindForm[T any](req *http.Request) (*TEntity[T], error) { } -func BindJSON[T any](req *http.Request) (*TEntity[T], error) { +func BindJson[T any](req *http.Request) (*TEntity[T], error) { data := new(T) buf, err := io.ReadAll(req.Body) @@ -90,6 +82,8 @@ type TEntity[T any] struct { } func (t *TEntity[T]) Validate(languages ...string) bool { + validate := findValidator(languages...) + err := validate.Struct(t.Data) if err == nil { return true @@ -97,18 +91,13 @@ func (t *TEntity[T]) Validate(languages ...string) bool { errs := err.(validator.ValidationErrors) - ut, ok := uni.FindTranslator(languages...) - if !ok { - ut = uni.GetFallback() - } - for _, err := range errs { n := err.Field() if n == "" { n = err.StructField() } - t.Errors[n] = err.Translate(ut) + t.Errors[n] = err.Translate(validate.Translator) } return false diff --git a/binder_test.go b/binder_test.go index 89060b7..3de44f4 100644 --- a/binder_test.go +++ b/binder_test.go @@ -1,9 +1,17 @@ package htmx import ( + "bytes" "net/http" "net/http/httptest" + "net/url" + "strings" "testing" + + "github.com/go-playground/locales/zh" + ut "github.com/go-playground/universal-translator" + trans "github.com/go-playground/validator/v10/translations/zh" + "github.com/stretchr/testify/require" ) func TestBinder(t *testing.T) { @@ -13,7 +21,155 @@ func TestBinder(t *testing.T) { app := New(WithMux(mux)) + type Login struct { + Email string `form:"email" json:"email" validate:"required,email"` + Passwd string `json:"passwd" validate:"required"` + } + + AddValidator(ut.New(zh.New()).GetFallback(), trans.RegisterDefaultTranslations) + + app.Get("/login", func(c *Context) error { + it, err := BindQuery[Login](c.Request()) + if err != nil { + c.WriteStatus(http.StatusBadRequest) + return ErrCancelled + } + + if it.Validate(c.AcceptLanguage()...) && it.Data.Email == "htmx@yaitoo.cn" && it.Data.Passwd == "123" { + return c.View(it) + } + c.WriteStatus(http.StatusBadRequest) + return c.View(it) + }) + + app.Post("/login", func(c *Context) error { + it, err := BindForm[Login](c.Request()) + if err != nil { + c.WriteStatus(http.StatusBadRequest) + return ErrCancelled + } + + if it.Validate(c.AcceptLanguage()...) && it.Data.Email == "htmx@yaitoo.cn" && it.Data.Passwd == "123" { + return c.View(it) + } + c.WriteStatus(http.StatusBadRequest) + return c.View(it) + }) + + app.Put("/login", func(c *Context) error { + it, err := BindJson[Login](c.Request()) + if err != nil { + c.WriteStatus(http.StatusBadRequest) + return ErrCancelled + } + + if it.Validate(c.AcceptLanguage()...) && it.Data.Email == "htmx@yaitoo.cn" && it.Data.Passwd == "123" { + return c.View(it) + } + c.WriteStatus(http.StatusBadRequest) + return c.View(it) + }) + app.Start() defer app.Close() + var tests = []struct { + Name string + NewRequest func(it Login) *http.Request + }{ + { + "BindQuery", + func(it Login) *http.Request { + req, _ := http.NewRequest("GET", srv.URL+"/login?email="+url.QueryEscape(it.Email)+"&Passwd="+url.QueryEscape(it.Passwd), nil) + return req + }, + }, + + { + "BindForm", + func(it Login) *http.Request { + form := url.Values{} + form.Add("email", it.Email) + form.Add("Passwd", it.Passwd) + + req, _ := http.NewRequest("POST", srv.URL+"/login", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return req + }, + }, + + { + "BindJson", + func(it Login) *http.Request { + buf, _ := json.Marshal(it) + + req, _ := http.NewRequest("PUT", srv.URL+"/login", bytes.NewReader(buf)) + // req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return req + }, + }, + } + + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + var result TEntity[Login] + + req := test.NewRequest(Login{Email: "htmx@yaitoo.cn", Passwd: "123"}) + resp, err := client.Do(req) + require.NoError(t, err) + + require.Equal(t, http.StatusOK, resp.StatusCode) + + err = json.NewDecoder(resp.Body).Decode(&result) + require.NoError(t, err) + resp.Body.Close() + require.Equal(t, "htmx@yaitoo.cn", result.Data.Email) + require.Equal(t, "123", result.Data.Passwd) + require.Len(t, result.Errors, 0) + + req = test.NewRequest(Login{Email: "htmx@yaitoo.cn", Passwd: "abc"}) + resp, err = client.Do(req) + require.NoError(t, err) + + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + + err = json.NewDecoder(resp.Body).Decode(&result) + require.NoError(t, err) + resp.Body.Close() + require.Equal(t, "htmx@yaitoo.cn", result.Data.Email) + require.Equal(t, "abc", result.Data.Passwd) + require.Len(t, result.Errors, 0) + + req = test.NewRequest(Login{Email: "htmx"}) + req.Header.Set("accept-language", "en-US,en;q=0.9,zh;q=0.8,zh-CN;q=0.7,zh-TW;q=0.6") + // req.Header.Set("accept-language", "zh-CN,zh;q=0.9,en;q=0.8,en-US;q=0.7,zh-TW;q=0.6") + resp, err = client.Do(req) + require.NoError(t, err) + + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + + err = json.NewDecoder(resp.Body).Decode(&result) + require.NoError(t, err) + resp.Body.Close() + require.Len(t, result.Errors, 2) + require.Equal(t, "Email must be a valid email address", result.Errors["Email"]) + require.Equal(t, "Passwd is a required field", result.Errors["Passwd"]) + + req = test.NewRequest(Login{Email: "htmx"}) + // req.Header.Set("accept-language", "en-US,en;q=0.9,zh;q=0.8,zh-CN;q=0.7,zh-TW;q=0.6") + req.Header.Set("accept-language", "zh-CN,zh;q=0.9,en;q=0.8,en-US;q=0.7,zh-TW;q=0.6") + resp, err = client.Do(req) + require.NoError(t, err) + + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + + err = json.NewDecoder(resp.Body).Decode(&result) + require.NoError(t, err) + resp.Body.Close() + require.Len(t, result.Errors, 2) + require.Equal(t, "Email必须是一个有效的邮箱", result.Errors["Email"]) + require.Equal(t, "Passwd为必填字段", result.Errors["Passwd"]) + }) + } + } diff --git a/validate.go b/validate.go new file mode 100644 index 0000000..5b19d30 --- /dev/null +++ b/validate.go @@ -0,0 +1,46 @@ +package htmx + +import ( + "github.com/go-playground/locales/en" + ut "github.com/go-playground/universal-translator" + "github.com/go-playground/validator/v10" + trans "github.com/go-playground/validator/v10/translations/en" +) + +type Validator struct { + *validator.Validate + Translator ut.Translator +} + +var ( + validators = make(map[string]*Validator) + defaultValidator *Validator +) + +func init() { + uni := ut.New(en.New()) + defaultValidator = AddValidator(uni.GetFallback(), trans.RegisterDefaultTranslations) + + trans.RegisterDefaultTranslations(defaultValidator.Validate, defaultValidator.Translator) // nolint: errcheck + + validators[defaultValidator.Translator.Locale()] = defaultValidator +} + +func AddValidator(trans ut.Translator, register func(v *validator.Validate, trans ut.Translator) (err error)) *Validator { + v := &Validator{ + Validate: validator.New(), + Translator: trans, + } + register(v.Validate, v.Translator) //nolint: errcheck + validators[trans.Locale()] = v + return v +} + +func findValidator(locales ...string) *Validator { + for _, locale := range locales { + if v, ok := validators[locale]; ok { + return v + } + } + return defaultValidator +}