diff --git a/rest/httpx/requests.go b/rest/httpx/requests.go index 87c576a9fe47..4648a46dd1ce 100644 --- a/rest/httpx/requests.go +++ b/rest/httpx/requests.go @@ -5,7 +5,7 @@ import ( "net/http" "reflect" "strings" - "sync/atomic" + "sync" "github.com/zeromicro/go-zero/core/mapping" "github.com/zeromicro/go-zero/core/validation" @@ -33,7 +33,11 @@ var ( pathKey, mapping.WithStringValues(), mapping.WithOpaqueKeys()) - validator atomic.Value + + // panic: sync/atomic: store of inconsistently typed value into Value + // don't use atomic.Value to store the validator, different concrete types still panic + validator Validator + validatorLock sync.RWMutex ) // Validator defines the interface for validating the request. @@ -65,8 +69,8 @@ func Parse(r *http.Request, v any) error { if valid, ok := v.(validation.Validator); ok { return valid.Validate() - } else if val := validator.Load(); val != nil { - return val.(Validator).Validate(r, v) + } else if val := getValidator(); val != nil { + return val.Validate(r, v) } return nil @@ -135,7 +139,15 @@ func ParsePath(r *http.Request, v any) error { // The validator is used to validate the request, only called in Parse, // not in ParseHeaders, ParseForm, ParseHeader, ParseJsonBody, ParsePath. func SetValidator(val Validator) { - validator.Store(val) + validatorLock.Lock() + defer validatorLock.Unlock() + validator = val +} + +func getValidator() Validator { + validatorLock.RLock() + defer validatorLock.RUnlock() + return validator } func withJsonBody(r *http.Request) bool { diff --git a/rest/httpx/requests_test.go b/rest/httpx/requests_test.go index a544c2d94024..c81a5707f575 100644 --- a/rest/httpx/requests_test.go +++ b/rest/httpx/requests_test.go @@ -734,6 +734,22 @@ func TestParseJsonStringRequest(t *testing.T) { }) } +type valid1 struct{} + +func (v valid1) Validate(*http.Request, any) error { return nil } + +type valid2 struct{} + +func (v valid2) Validate(*http.Request, any) error { return nil } + +func TestSetValidatorTwice(t *testing.T) { + // panic: sync/atomic: store of inconsistently typed value into Value + assert.NotPanics(t, func() { + SetValidator(valid1{}) + SetValidator(valid2{}) + }) +} + func BenchmarkParseRaw(b *testing.B) { r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", http.NoBody) if err != nil {