From 8bc14077cf3f382c4d3b6b47543cb72c24610903 Mon Sep 17 00:00:00 2001 From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com> Date: Tue, 26 Sep 2023 15:44:46 +0800 Subject: [PATCH] refactor: validate config (#955) --- pkg/app/context_test.go | 6 +- pkg/app/server/binding/binder_test.go | 55 +++++++++++++++ pkg/app/server/binding/config.go | 29 ++++---- pkg/app/server/binding/default.go | 85 +++++++++++++++++------- pkg/app/server/binding/validator.go | 1 + pkg/app/server/binding/validator_test.go | 29 ++++++++ pkg/app/server/hertz_test.go | 79 +++++++++++++++++++++- pkg/app/server/option.go | 7 ++ pkg/common/config/option.go | 1 + pkg/common/config/option_test.go | 1 + pkg/route/engine.go | 14 +++- pkg/route/engine_test.go | 52 ++++++++++++++- 12 files changed, 314 insertions(+), 45 deletions(-) diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index 22e7f8608..c065d482c 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -884,11 +884,15 @@ func (m *mockValidator) Engine() interface{} { return nil } +func (m *mockValidator) ValidateTag() string { + return "vt" +} + func TestSetValidator(t *testing.T) { m := &mockValidator{} c := NewContext(0) c.SetValidator(m) - c.SetBinder(binding.NewDefaultBinder(&binding.BindConfig{ValidateTag: "vt"})) + c.SetBinder(binding.NewDefaultBinder(&binding.BindConfig{Validator: m})) type User struct { Age int `vt:"$>=0&&$<=130"` } diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index d106ed7ad..c971776e3 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -1436,6 +1436,61 @@ func Test_BindHeaderNormalize(t *testing.T) { assert.DeepEqual(t, "", result3.Header) } +type ValidateError struct { + ErrType, FailField, Msg string +} + +// Error implements error interface. +func (e *ValidateError) Error() string { + if e.Msg != "" { + return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg + } + return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid" +} + +func Test_ValidatorErrorFactory(t *testing.T) { + type TestBind struct { + A string `query:"a,required"` + } + + r := protocol.NewRequest("GET", "/foo", nil) + r.SetRequestURI("/foo/bar?b=20") + CustomValidateErrFunc := func(failField, msg string) error { + err := ValidateError{ + ErrType: "validateErr", + FailField: "[validateFailField]: " + failField, + Msg: "[validateErrMsg]: " + msg, + } + + return &err + } + + validateConfig := NewValidateConfig() + validateConfig.SetValidatorErrorFactory(CustomValidateErrFunc) + validator := NewValidator(validateConfig) + + var req TestBind + err := Bind(r, &req, nil) + if err == nil { + t.Fatalf("unexpected nil, expected an error") + } + + type TestValidate struct { + B int `query:"b" vd:"$>100"` + } + + var reqValidate TestValidate + err = Bind(r, &reqValidate, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + err = validator.ValidateStruct(&reqValidate) + if err == nil { + t.Fatalf("unexpected nil, expected an error") + } + assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error()) +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index c122c54c6..81cf30e56 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -22,7 +22,7 @@ import ( "reflect" "time" - "github.com/bytedance/go-tagexpr/v2/validator" + exprValidator "github.com/bytedance/go-tagexpr/v2/validator" inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" hJson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" @@ -63,10 +63,6 @@ type BindConfig struct { // The default is false. // It is used for BindJSON(). EnableDecoderDisallowUnknownFields bool - // ValidateTag is used to determine if a filed needs to be validated. - // NOTE: - // The default is "vd". - ValidateTag string // TypeUnmarshalFuncs registers customized type unmarshaler. // NOTE: // time.Time is registered by default @@ -82,7 +78,6 @@ func NewBindConfig() *BindConfig { DisableStructFieldResolve: false, EnableDecoderUseNumber: false, EnableDecoderDisallowUnknownFields: false, - ValidateTag: "vd", TypeUnmarshalFuncs: make(map[reflect.Type]inDecoder.CustomizeDecodeFunc), Validator: defaultValidate, } @@ -145,7 +140,12 @@ func (config *BindConfig) UseStdJSONUnmarshaler() { config.UseThirdPartyJSONUnmarshaler(stdJson.Unmarshal) } -type ValidateConfig struct{} +type ValidateErrFactory func(fieldSelector, msg string) error + +type ValidateConfig struct { + ValidateTag string + ErrFactory ValidateErrFactory +} func NewValidateConfig() *ValidateConfig { return &ValidateConfig{} @@ -157,14 +157,15 @@ func NewValidateConfig() *ValidateConfig { // If force=true, allow to cover the existed same funcName. // MustRegValidateFunc will remain in effect once it has been called. func (config *ValidateConfig) MustRegValidateFunc(funcName string, fn func(args ...interface{}) error, force ...bool) { - validator.MustRegFunc(funcName, fn, force...) + exprValidator.MustRegFunc(funcName, fn, force...) } // SetValidatorErrorFactory customizes the factory of validation error. -func (config *ValidateConfig) SetValidatorErrorFactory(validatingErrFactory func(failField, msg string) error) { - if val, ok := DefaultValidator().(*defaultValidator); ok { - val.validate.SetErrorFactory(validatingErrFactory) - } else { - panic("customized validator can not use 'SetValidatorErrorFactory'") - } +func (config *ValidateConfig) SetValidatorErrorFactory(errFactory ValidateErrFactory) { + config.ErrFactory = errFactory +} + +// SetValidatorTag customizes the factory of validation error. +func (config *ValidateConfig) SetValidatorTag(tag string) { + config.ValidateTag = tag } diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 28bbc5311..0634f26cf 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -69,7 +69,7 @@ import ( "reflect" "sync" - "github.com/bytedance/go-tagexpr/v2/validator" + exprValidator "github.com/bytedance/go-tagexpr/v2/validator" "github.com/cloudwego/hertz/internal/bytesconv" inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" hJson "github.com/cloudwego/hertz/pkg/common/json" @@ -81,10 +81,11 @@ import ( ) const ( - queryTag = "query" - headerTag = "header" - formTag = "form" - pathTag = "path" + queryTag = "query" + headerTag = "header" + formTag = "form" + pathTag = "path" + defaultValidateTag = "vd" ) type decoderInfo struct { @@ -185,14 +186,17 @@ func (b *defaultBinder) bindTag(req *protocol.Request, v interface{}, params par decoder := cached.(decoderInfo) return decoder.decoder(req, params, rv.Elem()) } - + validateTag := defaultValidateTag + if len(b.config.Validator.ValidateTag()) != 0 { + validateTag = b.config.Validator.ValidateTag() + } decodeConfig := &inDecoder.DecodeConfig{ LooseZeroMode: b.config.LooseZeroMode, DisableDefaultTag: b.config.DisableDefaultTag, DisableStructFieldResolve: b.config.DisableStructFieldResolve, EnableDecoderUseNumber: b.config.EnableDecoderUseNumber, EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields, - ValidateTag: b.config.ValidateTag, + ValidateTag: validateTag, TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs, } decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig) @@ -232,13 +236,17 @@ func (b *defaultBinder) bindTagWithValidate(req *protocol.Request, v interface{} } return err } + validateTag := defaultValidateTag + if len(b.config.Validator.ValidateTag()) != 0 { + validateTag = b.config.Validator.ValidateTag() + } decodeConfig := &inDecoder.DecodeConfig{ LooseZeroMode: b.config.LooseZeroMode, DisableDefaultTag: b.config.DisableDefaultTag, DisableStructFieldResolve: b.config.DisableStructFieldResolve, EnableDecoderUseNumber: b.config.EnableDecoderUseNumber, EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields, - ValidateTag: b.config.ValidateTag, + ValidateTag: validateTag, TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs, } decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig) @@ -371,39 +379,66 @@ func (b *defaultBinder) bindNonStruct(req *protocol.Request, v interface{}) (err return } -var _ StructValidator = (*defaultValidator)(nil) +var _ StructValidator = (*validator)(nil) + +type validator struct { + validateTag string + validate *exprValidator.Validator +} + +func NewValidator(config *ValidateConfig) StructValidator { + validateTag := defaultValidateTag + if config != nil && len(config.ValidateTag) != 0 { + validateTag = config.ValidateTag + } + vd := exprValidator.New(validateTag).SetErrorFactory(defaultValidateErrorFactory) + if config != nil && config.ErrFactory != nil { + vd.SetErrorFactory(config.ErrFactory) + } + return &validator{ + validateTag: validateTag, + validate: vd, + } +} + +// Error validate error +type validateError struct { + FailPath, Msg string +} -type defaultValidator struct { - once sync.Once - validate *validator.Validator +// Error implements error interface. +func (e *validateError) Error() string { + if e.Msg != "" { + return e.Msg + } + return "invalid parameter: " + e.FailPath } -func NewDefaultValidator(config *ValidateConfig) StructValidator { - return &defaultValidator{} +func defaultValidateErrorFactory(failPath, msg string) error { + return &validateError{ + FailPath: failPath, + Msg: msg, + } } // ValidateStruct receives any kind of type, but only performed struct or pointer to struct type. -func (v *defaultValidator) ValidateStruct(obj interface{}) error { +func (v *validator) ValidateStruct(obj interface{}) error { if obj == nil { return nil } - v.lazyinit() return v.validate.Validate(obj) } -func (v *defaultValidator) lazyinit() { - v.once.Do(func() { - v.validate = validator.Default() - }) -} - // Engine returns the underlying validator -func (v *defaultValidator) Engine() interface{} { - v.lazyinit() +func (v *validator) Engine() interface{} { return v.validate } -var defaultValidate = NewDefaultValidator(nil) +func (v *validator) ValidateTag() string { + return v.validateTag +} + +var defaultValidate = NewValidator(NewValidateConfig()) func DefaultValidator() StructValidator { return defaultValidate diff --git a/pkg/app/server/binding/validator.go b/pkg/app/server/binding/validator.go index 0939b7aef..14d618364 100644 --- a/pkg/app/server/binding/validator.go +++ b/pkg/app/server/binding/validator.go @@ -43,4 +43,5 @@ package binding type StructValidator interface { ValidateStruct(interface{}) error Engine() interface{} + ValidateTag() string } diff --git a/pkg/app/server/binding/validator_test.go b/pkg/app/server/binding/validator_test.go index 2f85716b5..5564282ef 100644 --- a/pkg/app/server/binding/validator_test.go +++ b/pkg/app/server/binding/validator_test.go @@ -33,3 +33,32 @@ func Test_ValidateStruct(t *testing.T) { t.Fatalf("expected an error, but got nil") } } + +func Test_ValidateTag(t *testing.T) { + type User struct { + Age int `query:"age" vt:"$>=0&&$<=130"` + } + + user := &User{ + Age: 135, + } + validateConfig := NewValidateConfig() + validateConfig.ValidateTag = "vt" + vd := NewValidator(validateConfig) + err := vd.ValidateStruct(user) + if err == nil { + t.Fatalf("expected an error, but got nil") + } + + bindConfig := NewBindConfig() + bindConfig.Validator = vd + binder := NewDefaultBinder(bindConfig) + user = &User{} + req := newMockRequest(). + SetRequestURI("http://foobar.com?age=135"). + SetHeaders("h", "header") + err = binder.BindAndValidate(req.Req, user, nil) + if err == nil { + t.Fatalf("expected an error, but got nil") + } +} diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index c70ff7340..5baa3ecf5 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -929,7 +929,7 @@ func TestCustomBinder(t *testing.T) { time.Sleep(100 * time.Millisecond) } -func TestValidateConfig(t *testing.T) { +func TestValidateConfigRegValidateFunc(t *testing.T) { type Req struct { A int `query:"a" vd:"f($)"` } @@ -966,6 +966,10 @@ func (m *mockValidator) Engine() interface{} { return nil } +func (m *mockValidator) ValidateTag() string { + return "vd" +} + func TestCustomValidator(t *testing.T) { type Req struct { A int `query:"a" vd:"f($)"` @@ -989,3 +993,76 @@ func TestCustomValidator(t *testing.T) { assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } + +type ValidateError struct { + ErrType, FailField, Msg string +} + +// Error implements error interface. +func (e *ValidateError) Error() string { + if e.Msg != "" { + return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg + } + return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid" +} + +func TestValidateConfigSetSetErrorFactory(t *testing.T) { + type TestValidate struct { + B int `query:"b" vd:"$>100"` + } + CustomValidateErrFunc := func(failField, msg string) error { + err := ValidateError{ + ErrType: "validateErr", + FailField: "[validateFailField]: " + failField, + Msg: "[validateErrMsg]: " + msg, + } + + return &err + } + validateConfig := binding.NewValidateConfig() + validateConfig.SetValidatorErrorFactory(CustomValidateErrFunc) + h := New( + WithHostPorts("localhost:9666"), + WithValidateConfig(validateConfig)) + h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req TestValidate + err := ctx.BindAndValidate(&req) + if err == nil { + t.Fatal("expect an error") + } + assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error()) + }) + + go h.Spin() + time.Sleep(100 * time.Millisecond) + hc := http.Client{Timeout: time.Second} + _, err := hc.Get("http://127.0.0.1:9666/bind?b=1") + assert.Nil(t, err) + time.Sleep(100 * time.Millisecond) +} + +func TestValidateConfigAndBindConfig(t *testing.T) { + type Req struct { + A int `query:"a" vt:"$>=0&&$<=130"` + } + validateConfig := binding.NewValidateConfig() + validateConfig.ValidateTag = "vt" + h := New( + WithHostPorts("localhost:9876"), + WithValidateConfig(validateConfig)) + h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err == nil { + t.Fatal("expect an error") + } + t.Log(err) + }) + + go h.Spin() + time.Sleep(100 * time.Millisecond) + hc := http.Client{Timeout: time.Second} + _, err := hc.Get("http://127.0.0.1:9876/bind?a=135") + assert.Nil(t, err) + time.Sleep(100 * time.Millisecond) +} diff --git a/pkg/app/server/option.go b/pkg/app/server/option.go index d94a7b0cc..fcf380485 100644 --- a/pkg/app/server/option.go +++ b/pkg/app/server/option.go @@ -355,6 +355,13 @@ func WithBindConfig(bc *binding.BindConfig) config.Option { }} } +// WithValidateConfig sets validate config. +func WithValidateConfig(vc *binding.ValidateConfig) config.Option { + return config.Option{F: func(o *config.Options) { + o.ValidateConfig = vc + }} +} + // WithCustomBinder sets customized Binder. func WithCustomBinder(b binding.Binder) config.Option { return config.Option{F: func(o *config.Options) { diff --git a/pkg/common/config/option.go b/pkg/common/config/option.go index d8e6de2d0..048fb366f 100644 --- a/pkg/common/config/option.go +++ b/pkg/common/config/option.go @@ -73,6 +73,7 @@ type Options struct { TraceLevel interface{} ListenConfig *net.ListenConfig BindConfig interface{} + ValidateConfig interface{} CustomBinder interface{} CustomValidator interface{} diff --git a/pkg/common/config/option_test.go b/pkg/common/config/option_test.go index 6ee0fee95..67fcab796 100644 --- a/pkg/common/config/option_test.go +++ b/pkg/common/config/option_test.go @@ -54,6 +54,7 @@ func TestDefaultOptions(t *testing.T) { assert.DeepEqual(t, new(interface{}), options.TraceLevel) assert.DeepEqual(t, registry.NoopRegistry, options.Registry) assert.Nil(t, options.BindConfig) + assert.Nil(t, options.ValidateConfig) assert.Nil(t, options.CustomBinder) assert.Nil(t, options.CustomValidator) assert.DeepEqual(t, false, options.DisableHeaderNamesNormalizing) diff --git a/pkg/route/engine.go b/pkg/route/engine.go index ff8cf5a7e..168c79b48 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -558,13 +558,21 @@ func (engine *Engine) ServeStream(ctx context.Context, conn network.StreamConn) func (engine *Engine) initBinderAndValidator(opt *config.Options) { // init validator - engine.validator = binding.DefaultValidator() if opt.CustomValidator != nil { customValidator, ok := opt.CustomValidator.(binding.StructValidator) if !ok { - panic("customized validator can not implement binding.StructValidator") + panic("customized validator does not implement binding.StructValidator") } engine.validator = customValidator + } else { + engine.validator = binding.NewValidator(binding.NewValidateConfig()) + if opt.ValidateConfig != nil { + vConf, ok := opt.ValidateConfig.(*binding.ValidateConfig) + if !ok { + panic("opt.ValidateConfig is not the '*binding.ValidateConfig' type") + } + engine.validator = binding.NewValidator(vConf) + } } if opt.CustomBinder != nil { @@ -582,7 +590,7 @@ func (engine *Engine) initBinderAndValidator(opt *config.Options) { if opt.BindConfig != nil { bConf, ok := opt.BindConfig.(*binding.BindConfig) if !ok { - panic("bind config error") + panic("opt.BindConfig is not the '*binding.BindConfig' type") } if bConf.Validator == nil { bConf.Validator = engine.validator diff --git a/pkg/route/engine_test.go b/pkg/route/engine_test.go index 65f4fb16a..37a154bc2 100644 --- a/pkg/route/engine_test.go +++ b/pkg/route/engine_test.go @@ -675,6 +675,10 @@ func (m *mockValidator) Engine() interface{} { return nil } +func (m *mockValidator) ValidateTag() string { + return "vd" +} + type mockNonValidator struct{} func (m *mockNonValidator) ValidateStruct(interface{}) error { @@ -696,6 +700,10 @@ func TestInitBinderAndValidator(t *testing.T) { validator := &mockValidator{} opt.CustomValidator = validator NewEngine(opt) + validateConfig := binding.NewValidateConfig() + opt.ValidateConfig = validateConfig + opt.CustomValidator = nil + NewEngine(opt) } func TestInitBinderAndValidatorForPanic(t *testing.T) { @@ -748,6 +756,48 @@ func TestBindConfig(t *testing.T) { performRequest(e, "GET", "/bind?a=") } +type ValidateError struct { + ErrType, FailField, Msg string +} + +// Error implements error interface. +func (e *ValidateError) Error() string { + if e.Msg != "" { + return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg + } + return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid" +} + +func TestValidateConfigSetErrorFactory(t *testing.T) { + type TestValidate struct { + B int `query:"b" vd:"$>100"` + } + opt := config.NewOptions([]config.Option{}) + CustomValidateErrFunc := func(failField, msg string) error { + err := ValidateError{ + ErrType: "validateErr", + FailField: "[validateFailField]: " + failField, + Msg: "[validateErrMsg]: " + msg, + } + + return &err + } + + validateConfig := binding.NewValidateConfig() + validateConfig.SetValidatorErrorFactory(CustomValidateErrFunc) + opt.ValidateConfig = validateConfig + e := NewEngine(opt) + e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req TestValidate + err := ctx.BindAndValidate(&req) + if err == nil { + t.Fatal("expect an error") + } + assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error()) + }) + performRequest(e, "GET", "/bind?b=1") +} + func TestCustomBinder(t *testing.T) { type Req struct { A int `query:"a"` @@ -766,7 +816,7 @@ func TestCustomBinder(t *testing.T) { performRequest(e, "GET", "/bind?a=2") } -func TestValidateConfig(t *testing.T) { +func TestValidateRegValidateFunc(t *testing.T) { type Req struct { A int `query:"a" vd:"f($)"` }