Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix:fix url params unsafe #201

Merged
merged 4 commits into from
Sep 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 52 additions & 27 deletions common/url.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ type baseUrl struct {
Port string
//url.Values is not safe map, add to avoid concurrent map read and map write error
paramsLock sync.RWMutex
Params url.Values
params url.Values
PrimitiveURL string
ctx context.Context
}
Expand Down Expand Up @@ -108,13 +108,13 @@ func WithMethods(methods []string) option {

func WithParams(params url.Values) option {
return func(url *URL) {
url.Params = params
url.params = params
}
}

func WithParamsValue(key, val string) option {
return func(url *URL) {
url.Params.Set(key, val)
url.SetParam(key, val)
}
}

Expand Down Expand Up @@ -189,7 +189,7 @@ func NewURL(ctx context.Context, urlString string, opts ...option) (URL, error)
return s, perrors.Errorf("url.Parse(url string{%s}), error{%v}", rawUrlString, err)
}

s.Params, err = url.ParseQuery(serviceUrl.RawQuery)
s.params, err = url.ParseQuery(serviceUrl.RawQuery)
if err != nil {
return s, perrors.Errorf("url.ParseQuery(raw url string{%s}), error{%v}", serviceUrl.RawQuery, err)
}
Expand Down Expand Up @@ -237,7 +237,9 @@ func (c URL) String() string {
buildString := fmt.Sprintf(
"%s://%s:%s@%s:%s%s?",
c.Protocol, c.Username, c.Password, c.Ip, c.Port, c.Path)
buildString += c.Params.Encode()
c.paramsLock.RLock()
buildString += c.params.Encode()
c.paramsLock.RUnlock()
return buildString
}

Expand Down Expand Up @@ -291,20 +293,38 @@ func (c URL) Service() string {

func (c *URL) AddParam(key string, value string) {
c.paramsLock.Lock()
c.Params.Add(key, value)
c.params.Add(key, value)
c.paramsLock.Unlock()
}

func (c *URL) SetParam(key string, value string) {
c.paramsLock.Lock()
c.params.Set(key, value)
c.paramsLock.Unlock()
}

func (c *URL) RangeParams(f func(key, value string) bool) {
c.paramsLock.RLock()
hxmhlt marked this conversation as resolved.
Show resolved Hide resolved
defer c.paramsLock.RUnlock()
for k, v := range c.params {
if !f(k, v[0]) {
break
}
}
}

func (c URL) GetParam(s string, d string) string {
var r string
c.paramsLock.RLock()
if r = c.Params.Get(s); len(r) == 0 {
if r = c.params.Get(s); len(r) == 0 {
r = d
}
c.paramsLock.RUnlock()
return r
}
func (c URL) GetParamAndDecoded(key string) (string, error) {
c.paramsLock.RLock()
defer c.paramsLock.RUnlock()
ruleDec, err := base64.URLEncoding.DecodeString(c.GetParam(key, ""))
value := string(ruleDec)
return value, err
Expand All @@ -325,7 +345,7 @@ func (c URL) GetRawParam(key string) string {
case "path":
return c.Path
default:
return c.Params.Get(key)
return c.GetParam(key, "")
}
}

Expand All @@ -334,7 +354,7 @@ func (c URL) GetParamBool(s string, d bool) bool {

var r bool
var err error
if r, err = strconv.ParseBool(c.Params.Get(s)); err != nil {
if r, err = strconv.ParseBool(c.GetParam(s, "")); err != nil {
return d
}
return r
Expand All @@ -343,7 +363,8 @@ func (c URL) GetParamBool(s string, d bool) bool {
func (c URL) GetParamInt(s string, d int64) int64 {
var r int
var err error
if r, err = strconv.Atoi(c.Params.Get(s)); r == 0 || err != nil {

if r, err = strconv.Atoi(c.GetParam(s, "")); r == 0 || err != nil {
return d
}
return int64(r)
Expand All @@ -352,7 +373,9 @@ func (c URL) GetParamInt(s string, d int64) int64 {
func (c URL) GetMethodParamInt(method string, key string, d int64) int64 {
var r int
var err error
if r, err = strconv.Atoi(c.Params.Get("methods." + method + "." + key)); r == 0 || err != nil {
c.paramsLock.RLock()
defer c.paramsLock.RUnlock()
if r, err = strconv.Atoi(c.GetParam("methods."+method+"."+key, "")); r == 0 || err != nil {
return d
}
return int64(r)
Expand All @@ -369,7 +392,7 @@ func (c URL) GetMethodParamInt64(method string, key string, d int64) int64 {

func (c URL) GetMethodParam(method string, key string, d string) string {
var r string
if r = c.Params.Get("methods." + method + "." + key); r == "" {
if r = c.GetParam("methods."+method+"."+key, ""); r == "" {
r = d
}
return r
Expand All @@ -380,9 +403,11 @@ func (c URL) ToMap() map[string]string {

paramsMap := make(map[string]string)

for k, v := range c.Params {
paramsMap[k] = v[0]
}
c.RangeParams(func(key, value string) bool {
paramsMap[key] = value
return true
})

if c.Protocol != "" {
paramsMap["protocol"] = c.Protocol
}
Expand Down Expand Up @@ -421,19 +446,19 @@ func MergeUrl(serviceUrl URL, referenceUrl *URL) URL {
mergedUrl := serviceUrl

//iterator the referenceUrl if serviceUrl not have the key ,merge in

for k, v := range referenceUrl.Params {
if _, ok := mergedUrl.Params[k]; !ok {
mergedUrl.Params.Set(k, v[0])
referenceUrl.RangeParams(func(key, value string) bool {
if v := mergedUrl.GetParam(key, ""); len(v) == 0 {
mergedUrl.SetParam(key, value)
}
}
return true
})
//loadBalance,cluster,retries strategy config
methodConfigMergeFcn := mergeNormalParam(mergedUrl, referenceUrl, []string{constant.LOADBALANCE_KEY, constant.CLUSTER_KEY, constant.RETRIES_KEY})

//remote timestamp
if v := serviceUrl.Params.Get(constant.TIMESTAMP_KEY); len(v) > 0 {
mergedUrl.Params.Set(constant.REMOTE_TIMESTAMP_KEY, v)
mergedUrl.Params.Set(constant.TIMESTAMP_KEY, referenceUrl.Params.Get(constant.TIMESTAMP_KEY))
if v := serviceUrl.GetParam(constant.TIMESTAMP_KEY, ""); len(v) > 0 {
mergedUrl.SetParam(constant.REMOTE_TIMESTAMP_KEY, v)
mergedUrl.SetParam(constant.TIMESTAMP_KEY, referenceUrl.GetParam(constant.TIMESTAMP_KEY, ""))
}

//finally execute methodConfigMergeFcn
Expand All @@ -449,12 +474,12 @@ func MergeUrl(serviceUrl URL, referenceUrl *URL) URL {
func mergeNormalParam(mergedUrl URL, referenceUrl *URL, paramKeys []string) []func(method string) {
var methodConfigMergeFcn = []func(method string){}
for _, paramKey := range paramKeys {
if v := referenceUrl.Params.Get(paramKey); len(v) > 0 {
mergedUrl.Params.Set(paramKey, v)
if v := referenceUrl.GetParam(paramKey, ""); len(v) > 0 {
mergedUrl.SetParam(paramKey, v)
}
methodConfigMergeFcn = append(methodConfigMergeFcn, func(method string) {
if v := referenceUrl.Params.Get(method + "." + paramKey); len(v) > 0 {
mergedUrl.Params.Set(method+"."+paramKey, v)
if v := referenceUrl.GetParam(method+"."+paramKey, ""); len(v) > 0 {
mergedUrl.SetParam(method+"."+paramKey, v)
}
})
}
Expand Down
18 changes: 9 additions & 9 deletions common/url_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestNewURLWithOptions(t *testing.T) {
assert.Equal(t, "127.0.0.1", u.Ip)
assert.Equal(t, "8080", u.Port)
assert.Equal(t, methods, u.Methods)
assert.Equal(t, params, u.Params)
assert.Equal(t, params, u.params)
}

func TestURL(t *testing.T) {
Expand All @@ -74,7 +74,7 @@ func TestURL(t *testing.T) {
assert.Equal(t, "anyhost=true&application=BDTService&category=providers&default.timeout=10000&dubbo=dubbo-"+
"provider-golang-1.0.0&environment=dev&interface=com.ikurento.user.UserProvider&ip=192.168.56.1&methods=GetUser%"+
"2C&module=dubbogo+user-info+server&org=ikurento.com&owner=ZX&pid=1447&revision=0.0.1&side=provider&timeout=3000&t"+
"imestamp=1556509797245", u.Params.Encode())
"imestamp=1556509797245", u.params.Encode())

assert.Equal(t, "dubbo://:@127.0.0.1:20000/com.ikurento.user.UserProvider?anyhost=true&application=BDTServi"+
"ce&category=providers&default.timeout=10000&dubbo=dubbo-provider-golang-1.0.0&environment=dev&interface=com.ikure"+
Expand All @@ -101,7 +101,7 @@ func TestURLWithoutSchema(t *testing.T) {
assert.Equal(t, "anyhost=true&application=BDTService&category=providers&default.timeout=10000&dubbo=dubbo-"+
"provider-golang-1.0.0&environment=dev&interface=com.ikurento.user.UserProvider&ip=192.168.56.1&methods=GetUser%"+
"2C&module=dubbogo+user-info+server&org=ikurento.com&owner=ZX&pid=1447&revision=0.0.1&side=provider&timeout=3000&t"+
"imestamp=1556509797245", u.Params.Encode())
"imestamp=1556509797245", u.params.Encode())

assert.Equal(t, "dubbo://:@127.0.0.1:20000/com.ikurento.user.UserProvider?anyhost=true&application=BDTServi"+
"ce&category=providers&default.timeout=10000&dubbo=dubbo-provider-golang-1.0.0&environment=dev&interface=com.ikure"+
Expand All @@ -124,7 +124,7 @@ func TestURL_URLEqual(t *testing.T) {
func TestURL_GetParam(t *testing.T) {
params := url.Values{}
params.Set("key", "value")
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v := u.GetParam("key", "default")
assert.Equal(t, "value", v)

Expand All @@ -136,7 +136,7 @@ func TestURL_GetParam(t *testing.T) {
func TestURL_GetParamInt(t *testing.T) {
params := url.Values{}
params.Set("key", "3")
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v := u.GetParamInt("key", 1)
assert.Equal(t, int64(3), v)

Expand All @@ -148,7 +148,7 @@ func TestURL_GetParamInt(t *testing.T) {
func TestURL_GetParamBool(t *testing.T) {
params := url.Values{}
params.Set("force", "true")
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v := u.GetParamBool("force", false)
assert.Equal(t, true, v)

Expand All @@ -161,7 +161,7 @@ func TestURL_GetParamAndDecoded(t *testing.T) {
rule := "host = 2.2.2.2,1.1.1.1,3.3.3.3 & host !=1.1.1.1 => host = 1.2.3.4"
params := url.Values{}
params.Set("rule", base64.URLEncoding.EncodeToString([]byte(rule)))
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v, _ := u.GetParamAndDecoded("rule")
assert.Equal(t, rule, v)
}
Expand Down Expand Up @@ -196,7 +196,7 @@ func TestURL_ToMap(t *testing.T) {
func TestURL_GetMethodParamInt(t *testing.T) {
params := url.Values{}
params.Set("methods.GetValue.timeout", "3")
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v := u.GetMethodParamInt("GetValue", "timeout", 1)
assert.Equal(t, int64(3), v)

Expand All @@ -208,7 +208,7 @@ func TestURL_GetMethodParamInt(t *testing.T) {
func TestURL_GetMethodParam(t *testing.T) {
params := url.Values{}
params.Set("methods.GetValue.timeout", "3s")
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v := u.GetMethodParam("GetValue", "timeout", "1s")
assert.Equal(t, "3s", v)

Expand Down
2 changes: 1 addition & 1 deletion protocol/protocolwrapper/protocol_filter_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (pfw *ProtocolFilterWrapper) Destroy() {
}

func buildInvokerChain(invoker protocol.Invoker, key string) protocol.Invoker {
filtName := invoker.GetUrl().Params.Get(key)
filtName := invoker.GetUrl().GetParam(key, "")
if filtName == "" {
return invoker
}
Expand Down
9 changes: 6 additions & 3 deletions registry/consul/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ func buildService(url common.URL) (*consul.AgentServiceRegistration, error) {

// tags
tags := make([]string, 0, 8)
for k := range url.Params {
tags = append(tags, k+"="+url.Params.Get(k))
}

url.RangeParams(func(key, value string) bool {
tags = append(tags, key+"="+value)
return true
})

tags = append(tags, "dubbo")

// meta
Expand Down
2 changes: 1 addition & 1 deletion registry/directory/directory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func TestSubscribe_Group(t *testing.T) {

regurl, _ := common.NewURL(context.TODO(), "mock://127.0.0.1:1111")
suburl, _ := common.NewURL(context.TODO(), "dubbo://127.0.0.1:20000")
suburl.Params.Set(constant.CLUSTER_KEY, "mock")
suburl.SetParam(constant.CLUSTER_KEY, "mock")
regurl.SubURL = &suburl
mockRegistry, _ := registry.NewMockRegistry(&common.URL{})
registryDirectory, _ := NewRegistryDirectory(&regurl, mockRegistry)
Expand Down
7 changes: 4 additions & 3 deletions registry/etcdv3/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,11 @@ func (r *etcdV3Registry) registerProvider(svc common.URL) error {
}

params := url.Values{}
for k, v := range svc.Params {
params[k] = v
}

svc.RangeParams(func(key, value string) bool {
params[key] = []string{value}
return true
})
params.Add("pid", processID)
params.Add("ip", localIP)
params.Add("anyhost", "true")
Expand Down
2 changes: 1 addition & 1 deletion registry/etcdv3/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (suite *RegistryTestSuite) TestSubscribe() {
}

//consumer register
regurl.Params.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
regurl.SetParam(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
reg2 := initRegistry(t)

reg2.Register(url)
Expand Down
11 changes: 7 additions & 4 deletions registry/nacos/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,13 @@ func appendParam(target *bytes.Buffer, url common.URL, key string) {

func createRegisterParam(url common.URL, serviceName string) vo.RegisterInstanceParam {
category := getCategory(url)
params := make(map[string]string, len(url.Params)+3)
for k := range url.Params {
params[k] = url.Params.Get(k)
}
params := make(map[string]string)

url.RangeParams(func(key, value string) bool {
params[key] = value
return true
})

params[constant.NACOS_CATEGORY_KEY] = category
params[constant.NACOS_PROTOCOL_KEY] = url.Protocol
params[constant.NACOS_PATH_KEY] = url.Path
Expand Down
4 changes: 2 additions & 2 deletions registry/nacos/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func TestNacosRegistry_Subscribe(t *testing.T) {
return
}

regurl.Params.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
regurl.SetParam(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
reg2, _ := newNacosRegistry(&regurl)
listener, err := reg2.Subscribe(url)
assert.Nil(t, err)
Expand Down Expand Up @@ -111,7 +111,7 @@ func TestNacosRegistry_Subscribe_del(t *testing.T) {
return
}

regurl.Params.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
regurl.SetParam(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
reg2, _ := newNacosRegistry(&regurl)
listener, err := reg2.Subscribe(url1)
assert.Nil(t, err)
Expand Down
9 changes: 5 additions & 4 deletions registry/zookeeper/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ type zkRegistry struct {
configListener *RegistryConfigurationListener
//for provider
zkPath map[string]int // key = protocol://ip:port/interface

}

func newZkRegistry(url *common.URL) (registry.Registry, error) {
Expand Down Expand Up @@ -271,9 +270,11 @@ func (r *zkRegistry) register(c common.URL) error {
return perrors.WithStack(err)
}
params = url.Values{}
for k, v := range c.Params {
params[k] = v
}

c.RangeParams(func(key, value string) bool {
params.Add(key, value)
return true
})

params.Add("pid", processID)
params.Add("ip", localIP)
Expand Down
Loading