Skip to content

Commit

Permalink
Fix:fix url params unsafe
Browse files Browse the repository at this point in the history
  • Loading branch information
hxmhlt committed Sep 11, 2019
1 parent cf198d0 commit 1213578
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 56 deletions.
84 changes: 57 additions & 27 deletions common/url.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ type baseUrl struct {
Port string
//url.Values is not safe map, add to avoid concurrent map read and map write error
paramsLock sync.RWMutex
Params url.Values
params url.Values
PrimitiveURL string
ctx context.Context
}
Expand Down Expand Up @@ -108,13 +108,13 @@ func WithMethods(methods []string) option {

func WithParams(params url.Values) option {
return func(url *URL) {
url.Params = params
url.params = params
}
}

func WithParamsValue(key, val string) option {
return func(url *URL) {
url.Params.Set(key, val)
url.SetParam(key, val)
}
}

Expand Down Expand Up @@ -189,7 +189,7 @@ func NewURL(ctx context.Context, urlString string, opts ...option) (URL, error)
return s, perrors.Errorf("url.Parse(url string{%s}), error{%v}", rawUrlString, err)
}

s.Params, err = url.ParseQuery(serviceUrl.RawQuery)
s.params, err = url.ParseQuery(serviceUrl.RawQuery)
if err != nil {
return s, perrors.Errorf("url.ParseQuery(raw url string{%s}), error{%v}", serviceUrl.RawQuery, err)
}
Expand Down Expand Up @@ -237,7 +237,9 @@ func (c URL) String() string {
buildString := fmt.Sprintf(
"%s://%s:%s@%s:%s%s?",
c.Protocol, c.Username, c.Password, c.Ip, c.Port, c.Path)
buildString += c.Params.Encode()
c.paramsLock.RLock()
buildString += c.params.Encode()
c.paramsLock.RUnlock()
return buildString
}

Expand Down Expand Up @@ -291,26 +293,46 @@ func (c URL) Service() string {

func (c *URL) AddParam(key string, value string) {
c.paramsLock.Lock()
c.Params.Add(key, value)
c.params.Add(key, value)
c.paramsLock.Unlock()
}

func (c *URL) SetParam(key string, value string) {
c.paramsLock.Lock()
c.params.Set(key, value)
c.paramsLock.Unlock()
}

func (c *URL) RangeParams(f func(key, value string) bool) {
c.paramsLock.RLock()
for k, v := range c.params {
if !f(k, v[0]) {
break
}
}
c.paramsLock.RUnlock()
}

func (c URL) GetParam(s string, d string) string {
var r string
c.paramsLock.RLock()
if r = c.Params.Get(s); len(r) == 0 {
if r = c.params.Get(s); len(r) == 0 {
r = d
}
c.paramsLock.RUnlock()
return r
}
func (c URL) GetParamAndDecoded(key string) (string, error) {
c.paramsLock.RLock()
defer c.paramsLock.RUnlock()
ruleDec, err := base64.URLEncoding.DecodeString(c.GetParam(key, ""))
value := string(ruleDec)
return value, err
}

func (c URL) GetRawParam(key string) string {
c.paramsLock.RLock()
defer c.paramsLock.RUnlock()
switch key {
case "protocol":
return c.Protocol
Expand All @@ -325,7 +347,7 @@ func (c URL) GetRawParam(key string) string {
case "path":
return c.Path
default:
return c.Params.Get(key)
return c.params.Get(key)
}
}

Expand All @@ -334,7 +356,9 @@ func (c URL) GetParamBool(s string, d bool) bool {

var r bool
var err error
if r, err = strconv.ParseBool(c.Params.Get(s)); err != nil {
c.paramsLock.RLock()
defer c.paramsLock.RUnlock()
if r, err = strconv.ParseBool(c.params.Get(s)); err != nil {
return d
}
return r
Expand All @@ -343,7 +367,9 @@ func (c URL) GetParamBool(s string, d bool) bool {
func (c URL) GetParamInt(s string, d int64) int64 {
var r int
var err error
if r, err = strconv.Atoi(c.Params.Get(s)); r == 0 || err != nil {
c.paramsLock.RLock()
defer c.paramsLock.RUnlock()
if r, err = strconv.Atoi(c.params.Get(s)); r == 0 || err != nil {
return d
}
return int64(r)
Expand All @@ -352,7 +378,9 @@ func (c URL) GetParamInt(s string, d int64) int64 {
func (c URL) GetMethodParamInt(method string, key string, d int64) int64 {
var r int
var err error
if r, err = strconv.Atoi(c.Params.Get("methods." + method + "." + key)); r == 0 || err != nil {
c.paramsLock.RLock()
defer c.paramsLock.RUnlock()
if r, err = strconv.Atoi(c.params.Get("methods." + method + "." + key)); r == 0 || err != nil {
return d
}
return int64(r)
Expand All @@ -369,7 +397,7 @@ func (c URL) GetMethodParamInt64(method string, key string, d int64) int64 {

func (c URL) GetMethodParam(method string, key string, d string) string {
var r string
if r = c.Params.Get("methods." + method + "." + key); r == "" {
if r = c.params.Get("methods." + method + "." + key); r == "" {
r = d
}
return r
Expand All @@ -380,9 +408,11 @@ func (c URL) ToMap() map[string]string {

paramsMap := make(map[string]string)

for k, v := range c.Params {
paramsMap[k] = v[0]
}
c.RangeParams(func(key, value string) bool {
paramsMap[key] = value
return true
})

if c.Protocol != "" {
paramsMap["protocol"] = c.Protocol
}
Expand Down Expand Up @@ -421,19 +451,19 @@ func MergeUrl(serviceUrl URL, referenceUrl *URL) URL {
mergedUrl := serviceUrl

//iterator the referenceUrl if serviceUrl not have the key ,merge in

for k, v := range referenceUrl.Params {
if _, ok := mergedUrl.Params[k]; !ok {
mergedUrl.Params.Set(k, v[0])
referenceUrl.RangeParams(func(key, value string) bool {
if v := mergedUrl.GetParam(key, ""); len(v) == 0 {
mergedUrl.SetParam(key, value)
}
}
return true
})
//loadBalance,cluster,retries strategy config
methodConfigMergeFcn := mergeNormalParam(mergedUrl, referenceUrl, []string{constant.LOADBALANCE_KEY, constant.CLUSTER_KEY, constant.RETRIES_KEY})

//remote timestamp
if v := serviceUrl.Params.Get(constant.TIMESTAMP_KEY); len(v) > 0 {
mergedUrl.Params.Set(constant.REMOTE_TIMESTAMP_KEY, v)
mergedUrl.Params.Set(constant.TIMESTAMP_KEY, referenceUrl.Params.Get(constant.TIMESTAMP_KEY))
if v := serviceUrl.GetParam(constant.TIMESTAMP_KEY, ""); len(v) > 0 {
mergedUrl.SetParam(constant.REMOTE_TIMESTAMP_KEY, v)
mergedUrl.SetParam(constant.TIMESTAMP_KEY, referenceUrl.GetParam(constant.TIMESTAMP_KEY, ""))
}

//finally execute methodConfigMergeFcn
Expand All @@ -449,12 +479,12 @@ func MergeUrl(serviceUrl URL, referenceUrl *URL) URL {
func mergeNormalParam(mergedUrl URL, referenceUrl *URL, paramKeys []string) []func(method string) {
var methodConfigMergeFcn = []func(method string){}
for _, paramKey := range paramKeys {
if v := referenceUrl.Params.Get(paramKey); len(v) > 0 {
mergedUrl.Params.Set(paramKey, v)
if v := referenceUrl.GetParam(paramKey, ""); len(v) > 0 {
mergedUrl.SetParam(paramKey, v)
}
methodConfigMergeFcn = append(methodConfigMergeFcn, func(method string) {
if v := referenceUrl.Params.Get(method + "." + paramKey); len(v) > 0 {
mergedUrl.Params.Set(method+"."+paramKey, v)
if v := referenceUrl.GetParam(method+"."+paramKey, ""); len(v) > 0 {
mergedUrl.SetParam(method+"."+paramKey, v)
}
})
}
Expand Down
18 changes: 9 additions & 9 deletions common/url_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestNewURLWithOptions(t *testing.T) {
assert.Equal(t, "127.0.0.1", u.Ip)
assert.Equal(t, "8080", u.Port)
assert.Equal(t, methods, u.Methods)
assert.Equal(t, params, u.Params)
assert.Equal(t, params, u.params)
}

func TestURL(t *testing.T) {
Expand All @@ -74,7 +74,7 @@ func TestURL(t *testing.T) {
assert.Equal(t, "anyhost=true&application=BDTService&category=providers&default.timeout=10000&dubbo=dubbo-"+
"provider-golang-1.0.0&environment=dev&interface=com.ikurento.user.UserProvider&ip=192.168.56.1&methods=GetUser%"+
"2C&module=dubbogo+user-info+server&org=ikurento.com&owner=ZX&pid=1447&revision=0.0.1&side=provider&timeout=3000&t"+
"imestamp=1556509797245", u.Params.Encode())
"imestamp=1556509797245", u.params.Encode())

assert.Equal(t, "dubbo://:@127.0.0.1:20000/com.ikurento.user.UserProvider?anyhost=true&application=BDTServi"+
"ce&category=providers&default.timeout=10000&dubbo=dubbo-provider-golang-1.0.0&environment=dev&interface=com.ikure"+
Expand All @@ -101,7 +101,7 @@ func TestURLWithoutSchema(t *testing.T) {
assert.Equal(t, "anyhost=true&application=BDTService&category=providers&default.timeout=10000&dubbo=dubbo-"+
"provider-golang-1.0.0&environment=dev&interface=com.ikurento.user.UserProvider&ip=192.168.56.1&methods=GetUser%"+
"2C&module=dubbogo+user-info+server&org=ikurento.com&owner=ZX&pid=1447&revision=0.0.1&side=provider&timeout=3000&t"+
"imestamp=1556509797245", u.Params.Encode())
"imestamp=1556509797245", u.params.Encode())

assert.Equal(t, "dubbo://:@127.0.0.1:20000/com.ikurento.user.UserProvider?anyhost=true&application=BDTServi"+
"ce&category=providers&default.timeout=10000&dubbo=dubbo-provider-golang-1.0.0&environment=dev&interface=com.ikure"+
Expand All @@ -124,7 +124,7 @@ func TestURL_URLEqual(t *testing.T) {
func TestURL_GetParam(t *testing.T) {
params := url.Values{}
params.Set("key", "value")
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v := u.GetParam("key", "default")
assert.Equal(t, "value", v)

Expand All @@ -136,7 +136,7 @@ func TestURL_GetParam(t *testing.T) {
func TestURL_GetParamInt(t *testing.T) {
params := url.Values{}
params.Set("key", "3")
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v := u.GetParamInt("key", 1)
assert.Equal(t, int64(3), v)

Expand All @@ -148,7 +148,7 @@ func TestURL_GetParamInt(t *testing.T) {
func TestURL_GetParamBool(t *testing.T) {
params := url.Values{}
params.Set("force", "true")
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v := u.GetParamBool("force", false)
assert.Equal(t, true, v)

Expand All @@ -161,7 +161,7 @@ func TestURL_GetParamAndDecoded(t *testing.T) {
rule := "host = 2.2.2.2,1.1.1.1,3.3.3.3 & host !=1.1.1.1 => host = 1.2.3.4"
params := url.Values{}
params.Set("rule", base64.URLEncoding.EncodeToString([]byte(rule)))
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v, _ := u.GetParamAndDecoded("rule")
assert.Equal(t, rule, v)
}
Expand Down Expand Up @@ -196,7 +196,7 @@ func TestURL_ToMap(t *testing.T) {
func TestURL_GetMethodParamInt(t *testing.T) {
params := url.Values{}
params.Set("methods.GetValue.timeout", "3")
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v := u.GetMethodParamInt("GetValue", "timeout", 1)
assert.Equal(t, int64(3), v)

Expand All @@ -208,7 +208,7 @@ func TestURL_GetMethodParamInt(t *testing.T) {
func TestURL_GetMethodParam(t *testing.T) {
params := url.Values{}
params.Set("methods.GetValue.timeout", "3s")
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v := u.GetMethodParam("GetValue", "timeout", "1s")
assert.Equal(t, "3s", v)

Expand Down
2 changes: 1 addition & 1 deletion protocol/protocolwrapper/protocol_filter_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (pfw *ProtocolFilterWrapper) Destroy() {
}

func buildInvokerChain(invoker protocol.Invoker, key string) protocol.Invoker {
filtName := invoker.GetUrl().Params.Get(key)
filtName := invoker.GetUrl().GetParam(key, "")
if filtName == "" {
return invoker
}
Expand Down
9 changes: 6 additions & 3 deletions registry/consul/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ func buildService(url common.URL) (*consul.AgentServiceRegistration, error) {

// tags
tags := make([]string, 0, 8)
for k := range url.Params {
tags = append(tags, k+"="+url.Params.Get(k))
}

url.RangeParams(func(key, value string) bool {
tags = append(tags, key+"="+value)
return true
})

tags = append(tags, "dubbo")

// meta
Expand Down
2 changes: 1 addition & 1 deletion registry/directory/directory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func TestSubscribe_Group(t *testing.T) {

regurl, _ := common.NewURL(context.TODO(), "mock://127.0.0.1:1111")
suburl, _ := common.NewURL(context.TODO(), "dubbo://127.0.0.1:20000")
suburl.Params.Set(constant.CLUSTER_KEY, "mock")
suburl.SetParam(constant.CLUSTER_KEY, "mock")
regurl.SubURL = &suburl
mockRegistry, _ := registry.NewMockRegistry(&common.URL{})
registryDirectory, _ := NewRegistryDirectory(&regurl, mockRegistry)
Expand Down
7 changes: 4 additions & 3 deletions registry/etcdv3/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,11 @@ func (r *etcdV3Registry) registerProvider(svc common.URL) error {
}

params := url.Values{}
for k, v := range svc.Params {
params[k] = v
}

svc.RangeParams(func(key, value string) bool {
params[key] = []string{value}
return true
})
params.Add("pid", processID)
params.Add("ip", localIP)
params.Add("anyhost", "true")
Expand Down
2 changes: 1 addition & 1 deletion registry/etcdv3/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (suite *RegistryTestSuite) TestSubscribe() {
}

//consumer register
regurl.Params.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
regurl.SetParam(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
reg2 := initRegistry(t)

reg2.Register(url)
Expand Down
11 changes: 7 additions & 4 deletions registry/nacos/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,13 @@ func appendParam(target *bytes.Buffer, url common.URL, key string) {

func createRegisterParam(url common.URL, serviceName string) vo.RegisterInstanceParam {
category := getCategory(url)
params := make(map[string]string, len(url.Params)+3)
for k := range url.Params {
params[k] = url.Params.Get(k)
}
params := make(map[string]string)

url.RangeParams(func(key, value string) bool {
params[key] = value
return true
})

params[constant.NACOS_CATEGORY_KEY] = category
params[constant.NACOS_PROTOCOL_KEY] = url.Protocol
params[constant.NACOS_PATH_KEY] = url.Path
Expand Down
4 changes: 2 additions & 2 deletions registry/nacos/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func TestNacosRegistry_Subscribe(t *testing.T) {
return
}

regurl.Params.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
regurl.SetParam(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
reg2, _ := newNacosRegistry(&regurl)
listener, err := reg2.Subscribe(url)
assert.Nil(t, err)
Expand Down Expand Up @@ -111,7 +111,7 @@ func TestNacosRegistry_Subscribe_del(t *testing.T) {
return
}

regurl.Params.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
regurl.SetParam(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
reg2, _ := newNacosRegistry(&regurl)
listener, err := reg2.Subscribe(url1)
assert.Nil(t, err)
Expand Down
Loading

0 comments on commit 1213578

Please sign in to comment.