Skip to content

Commit

Permalink
refactor: validate config (#955)
Browse files Browse the repository at this point in the history
  • Loading branch information
FGYFFFF authored Sep 26, 2023
1 parent a924461 commit 8bc1407
Show file tree
Hide file tree
Showing 12 changed files with 314 additions and 45 deletions.
6 changes: 5 additions & 1 deletion pkg/app/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -884,11 +884,15 @@ func (m *mockValidator) Engine() interface{} {
return nil
}

func (m *mockValidator) ValidateTag() string {
return "vt"
}

func TestSetValidator(t *testing.T) {
m := &mockValidator{}
c := NewContext(0)
c.SetValidator(m)
c.SetBinder(binding.NewDefaultBinder(&binding.BindConfig{ValidateTag: "vt"}))
c.SetBinder(binding.NewDefaultBinder(&binding.BindConfig{Validator: m}))
type User struct {
Age int `vt:"$>=0&&$<=130"`
}
Expand Down
55 changes: 55 additions & 0 deletions pkg/app/server/binding/binder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,61 @@ func Test_BindHeaderNormalize(t *testing.T) {
assert.DeepEqual(t, "", result3.Header)
}

type ValidateError struct {
ErrType, FailField, Msg string
}

// Error implements error interface.
func (e *ValidateError) Error() string {
if e.Msg != "" {
return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg
}
return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid"
}

func Test_ValidatorErrorFactory(t *testing.T) {
type TestBind struct {
A string `query:"a,required"`
}

r := protocol.NewRequest("GET", "/foo", nil)
r.SetRequestURI("/foo/bar?b=20")
CustomValidateErrFunc := func(failField, msg string) error {
err := ValidateError{
ErrType: "validateErr",
FailField: "[validateFailField]: " + failField,
Msg: "[validateErrMsg]: " + msg,
}

return &err
}

validateConfig := NewValidateConfig()
validateConfig.SetValidatorErrorFactory(CustomValidateErrFunc)
validator := NewValidator(validateConfig)

var req TestBind
err := Bind(r, &req, nil)
if err == nil {
t.Fatalf("unexpected nil, expected an error")
}

type TestValidate struct {
B int `query:"b" vd:"$>100"`
}

var reqValidate TestValidate
err = Bind(r, &reqValidate, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
err = validator.ValidateStruct(&reqValidate)
if err == nil {
t.Fatalf("unexpected nil, expected an error")
}
assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error())
}

func Benchmark_Binding(b *testing.B) {
type Req struct {
Version string `path:"v"`
Expand Down
29 changes: 15 additions & 14 deletions pkg/app/server/binding/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"reflect"
"time"

"github.com/bytedance/go-tagexpr/v2/validator"
exprValidator "github.com/bytedance/go-tagexpr/v2/validator"
inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder"
hJson "github.com/cloudwego/hertz/pkg/common/json"
"github.com/cloudwego/hertz/pkg/protocol"
Expand Down Expand Up @@ -63,10 +63,6 @@ type BindConfig struct {
// The default is false.
// It is used for BindJSON().
EnableDecoderDisallowUnknownFields bool
// ValidateTag is used to determine if a filed needs to be validated.
// NOTE:
// The default is "vd".
ValidateTag string
// TypeUnmarshalFuncs registers customized type unmarshaler.
// NOTE:
// time.Time is registered by default
Expand All @@ -82,7 +78,6 @@ func NewBindConfig() *BindConfig {
DisableStructFieldResolve: false,
EnableDecoderUseNumber: false,
EnableDecoderDisallowUnknownFields: false,
ValidateTag: "vd",
TypeUnmarshalFuncs: make(map[reflect.Type]inDecoder.CustomizeDecodeFunc),
Validator: defaultValidate,
}
Expand Down Expand Up @@ -145,7 +140,12 @@ func (config *BindConfig) UseStdJSONUnmarshaler() {
config.UseThirdPartyJSONUnmarshaler(stdJson.Unmarshal)
}

type ValidateConfig struct{}
type ValidateErrFactory func(fieldSelector, msg string) error

type ValidateConfig struct {
ValidateTag string
ErrFactory ValidateErrFactory
}

func NewValidateConfig() *ValidateConfig {
return &ValidateConfig{}
Expand All @@ -157,14 +157,15 @@ func NewValidateConfig() *ValidateConfig {
// If force=true, allow to cover the existed same funcName.
// MustRegValidateFunc will remain in effect once it has been called.
func (config *ValidateConfig) MustRegValidateFunc(funcName string, fn func(args ...interface{}) error, force ...bool) {
validator.MustRegFunc(funcName, fn, force...)
exprValidator.MustRegFunc(funcName, fn, force...)
}

// SetValidatorErrorFactory customizes the factory of validation error.
func (config *ValidateConfig) SetValidatorErrorFactory(validatingErrFactory func(failField, msg string) error) {
if val, ok := DefaultValidator().(*defaultValidator); ok {
val.validate.SetErrorFactory(validatingErrFactory)
} else {
panic("customized validator can not use 'SetValidatorErrorFactory'")
}
func (config *ValidateConfig) SetValidatorErrorFactory(errFactory ValidateErrFactory) {
config.ErrFactory = errFactory
}

// SetValidatorTag customizes the factory of validation error.
func (config *ValidateConfig) SetValidatorTag(tag string) {
config.ValidateTag = tag
}
85 changes: 60 additions & 25 deletions pkg/app/server/binding/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ import (
"reflect"
"sync"

"github.com/bytedance/go-tagexpr/v2/validator"
exprValidator "github.com/bytedance/go-tagexpr/v2/validator"
"github.com/cloudwego/hertz/internal/bytesconv"
inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder"
hJson "github.com/cloudwego/hertz/pkg/common/json"
Expand All @@ -81,10 +81,11 @@ import (
)

const (
queryTag = "query"
headerTag = "header"
formTag = "form"
pathTag = "path"
queryTag = "query"
headerTag = "header"
formTag = "form"
pathTag = "path"
defaultValidateTag = "vd"
)

type decoderInfo struct {
Expand Down Expand Up @@ -185,14 +186,17 @@ func (b *defaultBinder) bindTag(req *protocol.Request, v interface{}, params par
decoder := cached.(decoderInfo)
return decoder.decoder(req, params, rv.Elem())
}

validateTag := defaultValidateTag
if len(b.config.Validator.ValidateTag()) != 0 {
validateTag = b.config.Validator.ValidateTag()
}
decodeConfig := &inDecoder.DecodeConfig{
LooseZeroMode: b.config.LooseZeroMode,
DisableDefaultTag: b.config.DisableDefaultTag,
DisableStructFieldResolve: b.config.DisableStructFieldResolve,
EnableDecoderUseNumber: b.config.EnableDecoderUseNumber,
EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields,
ValidateTag: b.config.ValidateTag,
ValidateTag: validateTag,
TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs,
}
decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig)
Expand Down Expand Up @@ -232,13 +236,17 @@ func (b *defaultBinder) bindTagWithValidate(req *protocol.Request, v interface{}
}
return err
}
validateTag := defaultValidateTag
if len(b.config.Validator.ValidateTag()) != 0 {
validateTag = b.config.Validator.ValidateTag()
}
decodeConfig := &inDecoder.DecodeConfig{
LooseZeroMode: b.config.LooseZeroMode,
DisableDefaultTag: b.config.DisableDefaultTag,
DisableStructFieldResolve: b.config.DisableStructFieldResolve,
EnableDecoderUseNumber: b.config.EnableDecoderUseNumber,
EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields,
ValidateTag: b.config.ValidateTag,
ValidateTag: validateTag,
TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs,
}
decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig)
Expand Down Expand Up @@ -371,39 +379,66 @@ func (b *defaultBinder) bindNonStruct(req *protocol.Request, v interface{}) (err
return
}

var _ StructValidator = (*defaultValidator)(nil)
var _ StructValidator = (*validator)(nil)

type validator struct {
validateTag string
validate *exprValidator.Validator
}

func NewValidator(config *ValidateConfig) StructValidator {
validateTag := defaultValidateTag
if config != nil && len(config.ValidateTag) != 0 {
validateTag = config.ValidateTag
}
vd := exprValidator.New(validateTag).SetErrorFactory(defaultValidateErrorFactory)
if config != nil && config.ErrFactory != nil {
vd.SetErrorFactory(config.ErrFactory)
}
return &validator{
validateTag: validateTag,
validate: vd,
}
}

// Error validate error
type validateError struct {
FailPath, Msg string
}

type defaultValidator struct {
once sync.Once
validate *validator.Validator
// Error implements error interface.
func (e *validateError) Error() string {
if e.Msg != "" {
return e.Msg
}
return "invalid parameter: " + e.FailPath
}

func NewDefaultValidator(config *ValidateConfig) StructValidator {
return &defaultValidator{}
func defaultValidateErrorFactory(failPath, msg string) error {
return &validateError{
FailPath: failPath,
Msg: msg,
}
}

// ValidateStruct receives any kind of type, but only performed struct or pointer to struct type.
func (v *defaultValidator) ValidateStruct(obj interface{}) error {
func (v *validator) ValidateStruct(obj interface{}) error {
if obj == nil {
return nil
}
v.lazyinit()
return v.validate.Validate(obj)
}

func (v *defaultValidator) lazyinit() {
v.once.Do(func() {
v.validate = validator.Default()
})
}

// Engine returns the underlying validator
func (v *defaultValidator) Engine() interface{} {
v.lazyinit()
func (v *validator) Engine() interface{} {
return v.validate
}

var defaultValidate = NewDefaultValidator(nil)
func (v *validator) ValidateTag() string {
return v.validateTag
}

var defaultValidate = NewValidator(NewValidateConfig())

func DefaultValidator() StructValidator {
return defaultValidate
Expand Down
1 change: 1 addition & 0 deletions pkg/app/server/binding/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ package binding
type StructValidator interface {
ValidateStruct(interface{}) error
Engine() interface{}
ValidateTag() string
}
29 changes: 29 additions & 0 deletions pkg/app/server/binding/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,32 @@ func Test_ValidateStruct(t *testing.T) {
t.Fatalf("expected an error, but got nil")
}
}

func Test_ValidateTag(t *testing.T) {
type User struct {
Age int `query:"age" vt:"$>=0&&$<=130"`
}

user := &User{
Age: 135,
}
validateConfig := NewValidateConfig()
validateConfig.ValidateTag = "vt"
vd := NewValidator(validateConfig)
err := vd.ValidateStruct(user)
if err == nil {
t.Fatalf("expected an error, but got nil")
}

bindConfig := NewBindConfig()
bindConfig.Validator = vd
binder := NewDefaultBinder(bindConfig)
user = &User{}
req := newMockRequest().
SetRequestURI("http://foobar.com?age=135").
SetHeaders("h", "header")
err = binder.BindAndValidate(req.Req, user, nil)
if err == nil {
t.Fatalf("expected an error, but got nil")
}
}
Loading

0 comments on commit 8bc1407

Please sign in to comment.