diff --git a/cluster/cluster_impl/base_cluster_invoker.go b/cluster/cluster_impl/base_cluster_invoker.go index d93e9a6a98..644f67c524 100644 --- a/cluster/cluster_impl/base_cluster_invoker.go +++ b/cluster/cluster_impl/base_cluster_invoker.go @@ -35,6 +35,7 @@ type baseClusterInvoker struct { directory cluster.Directory availablecheck bool destroyed *atomic.Bool + stickyInvoker protocol.Invoker } func newBaseClusterInvoker(directory cluster.Directory) baseClusterInvoker { @@ -56,7 +57,9 @@ func (invoker *baseClusterInvoker) Destroy() { } func (invoker *baseClusterInvoker) IsAvailable() bool { - //TODO:sticky connection + if invoker.stickyInvoker != nil { + return invoker.stickyInvoker.IsAvailable() + } return invoker.directory.IsAvailable() } @@ -83,15 +86,42 @@ func (invoker *baseClusterInvoker) checkWhetherDestroyed() error { } func (invoker *baseClusterInvoker) doSelect(lb cluster.LoadBalance, invocation protocol.Invocation, invokers []protocol.Invoker, invoked []protocol.Invoker) protocol.Invoker { - //todo:sticky connect + + var selectedInvoker protocol.Invoker + url := invokers[0].GetUrl() + sticky := url.GetParamBool(constant.STICKY_KEY, false) + //Get the service method sticky config if have + sticky = url.GetMethodParamBool(invocation.MethodName(), constant.STICKY_KEY, sticky) + + if invoker.stickyInvoker != nil && !isInvoked(invoker.stickyInvoker, invokers) { + invoker.stickyInvoker = nil + } + + if sticky && invoker.stickyInvoker != nil && (invoked == nil || !isInvoked(invoker.stickyInvoker, invoked)) { + if invoker.availablecheck && invoker.stickyInvoker.IsAvailable() { + return invoker.stickyInvoker + } + } + + selectedInvoker = invoker.doSelectInvoker(lb, invocation, invokers, invoked) + + if sticky { + invoker.stickyInvoker = selectedInvoker + } + return selectedInvoker + +} + +func (invoker *baseClusterInvoker) doSelectInvoker(lb cluster.LoadBalance, invocation protocol.Invocation, invokers []protocol.Invoker, invoked []protocol.Invoker) protocol.Invoker { if len(invokers) == 1 { return invokers[0] } + selectedInvoker := lb.Select(invokers, invocation) //judge to if the selectedInvoker is invoked - if !selectedInvoker.IsAvailable() || !invoker.availablecheck || isInvoked(selectedInvoker, invoked) { + if (!selectedInvoker.IsAvailable() && invoker.availablecheck) || isInvoked(selectedInvoker, invoked) { // do reselect var reslectInvokers []protocol.Invoker @@ -106,13 +136,12 @@ func (invoker *baseClusterInvoker) doSelect(lb cluster.LoadBalance, invocation p } if len(reslectInvokers) > 0 { - return lb.Select(reslectInvokers, invocation) + selectedInvoker = lb.Select(reslectInvokers, invocation) } else { return nil } } return selectedInvoker - } func isInvoked(selectedInvoker protocol.Invoker, invoked []protocol.Invoker) bool { diff --git a/cluster/cluster_impl/base_cluster_invoker_test.go b/cluster/cluster_impl/base_cluster_invoker_test.go new file mode 100644 index 0000000000..d06d3cc23e --- /dev/null +++ b/cluster/cluster_impl/base_cluster_invoker_test.go @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cluster_impl + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/apache/dubbo-go/cluster/loadbalance" + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/protocol" + "github.com/apache/dubbo-go/protocol/invocation" +) + +func Test_StickyNormal(t *testing.T) { + invokers := []protocol.Invoker{} + for i := 0; i < 10; i++ { + url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i)) + url.SetParam("sticky", "true") + invokers = append(invokers, NewMockInvoker(url, 1)) + } + base := &baseClusterInvoker{} + base.availablecheck = true + invoked := []protocol.Invoker{} + result := base.doSelect(loadbalance.NewRandomLoadBalance(), invocation.NewRPCInvocation("getUser", nil, nil), invokers, invoked) + result1 := base.doSelect(loadbalance.NewRandomLoadBalance(), invocation.NewRPCInvocation("getUser", nil, nil), invokers, invoked) + assert.Equal(t, result, result1) +} +func Test_StickyNormalWhenError(t *testing.T) { + invokers := []protocol.Invoker{} + for i := 0; i < 10; i++ { + url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i)) + url.SetParam("sticky", "true") + invokers = append(invokers, NewMockInvoker(url, 1)) + } + base := &baseClusterInvoker{} + base.availablecheck = true + + invoked := []protocol.Invoker{} + result := base.doSelect(loadbalance.NewRandomLoadBalance(), invocation.NewRPCInvocation("getUser", nil, nil), invokers, invoked) + invoked = append(invoked, result) + result1 := base.doSelect(loadbalance.NewRandomLoadBalance(), invocation.NewRPCInvocation("getUser", nil, nil), invokers, invoked) + assert.NotEqual(t, result, result1) +} diff --git a/cluster/cluster_impl/failback_cluster_test.go b/cluster/cluster_impl/failback_cluster_test.go index c94347a125..1d2266cabe 100644 --- a/cluster/cluster_impl/failback_cluster_test.go +++ b/cluster/cluster_impl/failback_cluster_test.go @@ -67,7 +67,7 @@ func Test_FailbackSuceess(t *testing.T) { invoker := mock.NewMockInvoker(ctrl) clusterInvoker := registerFailback(t, invoker).(*failbackClusterInvoker) - invoker.EXPECT().GetUrl().Return(failbackUrl).Times(1) + invoker.EXPECT().GetUrl().Return(failbackUrl).AnyTimes() mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}} invoker.EXPECT().Invoke(gomock.Any()).Return(mockResult) diff --git a/cluster/cluster_impl/failfast_cluster_test.go b/cluster/cluster_impl/failfast_cluster_test.go index 7a19e80ccd..1a4342e6c2 100644 --- a/cluster/cluster_impl/failfast_cluster_test.go +++ b/cluster/cluster_impl/failfast_cluster_test.go @@ -64,7 +64,7 @@ func Test_FailfastInvokeSuccess(t *testing.T) { invoker := mock.NewMockInvoker(ctrl) clusterInvoker := registerFailfast(t, invoker) - invoker.EXPECT().GetUrl().Return(failfastUrl) + invoker.EXPECT().GetUrl().Return(failfastUrl).AnyTimes() mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}} @@ -84,7 +84,7 @@ func Test_FailfastInvokeFail(t *testing.T) { invoker := mock.NewMockInvoker(ctrl) clusterInvoker := registerFailfast(t, invoker) - invoker.EXPECT().GetUrl().Return(failfastUrl) + invoker.EXPECT().GetUrl().Return(failfastUrl).AnyTimes() mockResult := &protocol.RPCResult{Err: perrors.New("error")} diff --git a/cluster/cluster_impl/failsafe_cluster_test.go b/cluster/cluster_impl/failsafe_cluster_test.go index 9ee9d9fee3..7888b97c3a 100644 --- a/cluster/cluster_impl/failsafe_cluster_test.go +++ b/cluster/cluster_impl/failsafe_cluster_test.go @@ -64,7 +64,7 @@ func Test_FailSafeInvokeSuccess(t *testing.T) { invoker := mock.NewMockInvoker(ctrl) clusterInvoker := register_failsafe(t, invoker) - invoker.EXPECT().GetUrl().Return(failsafeUrl) + invoker.EXPECT().GetUrl().Return(failsafeUrl).AnyTimes() mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}} @@ -83,7 +83,7 @@ func Test_FailSafeInvokeFail(t *testing.T) { invoker := mock.NewMockInvoker(ctrl) clusterInvoker := register_failsafe(t, invoker) - invoker.EXPECT().GetUrl().Return(failsafeUrl) + invoker.EXPECT().GetUrl().Return(failsafeUrl).AnyTimes() mockResult := &protocol.RPCResult{Err: perrors.New("error")} diff --git a/common/constant/key.go b/common/constant/key.go index 17368b45ae..7538a2995a 100644 --- a/common/constant/key.go +++ b/common/constant/key.go @@ -55,6 +55,7 @@ const ( WEIGHT_KEY = "weight" WARMUP_KEY = "warmup" RETRIES_KEY = "retries" + STICKY_KEY = "sticky" BEAN_NAME = "bean.name" FAIL_BACK_TASKS_KEY = "failbacktasks" FORKS_KEY = "forks" diff --git a/common/url.go b/common/url.go index 6e7a843c8f..c010298bf5 100644 --- a/common/url.go +++ b/common/url.go @@ -447,6 +447,11 @@ func (c URL) GetMethodParam(method string, key string, d string) string { return r } +func (c URL) GetMethodParamBool(method string, key string, d bool) bool { + r := c.GetParamBool("methods."+method+"."+key, d) + return r +} + func (c *URL) RemoveParams(set *gxset.HashSet) { c.paramsLock.Lock() defer c.paramsLock.Unlock() diff --git a/common/url_test.go b/common/url_test.go index 41fd374a4d..4d60d7f13f 100644 --- a/common/url_test.go +++ b/common/url_test.go @@ -217,6 +217,18 @@ func TestURL_GetMethodParam(t *testing.T) { assert.Equal(t, "1s", v) } +func TestURL_GetMethodParamBool(t *testing.T) { + params := url.Values{} + params.Set("methods.GetValue.async", "true") + u := URL{baseUrl: baseUrl{params: params}} + v := u.GetMethodParamBool("GetValue", "async", false) + assert.Equal(t, true, v) + + u = URL{} + v = u.GetMethodParamBool("GetValue2", "async", false) + assert.Equal(t, false, v) +} + func TestMergeUrl(t *testing.T) { referenceUrlParams := url.Values{} referenceUrlParams.Set(constant.CLUSTER_KEY, "random") diff --git a/config/method_config.go b/config/method_config.go index e3f0b1b01b..876abeeae0 100644 --- a/config/method_config.go +++ b/config/method_config.go @@ -36,6 +36,7 @@ type MethodConfig struct { TpsLimitStrategy string `yaml:"tps.limit.strategy" json:"tps.limit.strategy,omitempty" property:"tps.limit.strategy"` ExecuteLimit string `yaml:"execute.limit" json:"execute.limit,omitempty" property:"execute.limit"` ExecuteLimitRejectedHandler string `yaml:"execute.limit.rejected.handler" json:"execute.limit.rejected.handler,omitempty" property:"execute.limit.rejected.handler"` + Sticky bool `yaml:"sticky" json:"sticky,omitempty" property:"sticky"` } func (c *MethodConfig) Prefix() string { diff --git a/config/reference_config.go b/config/reference_config.go index 6b34f55359..4e0c56c0bc 100644 --- a/config/reference_config.go +++ b/config/reference_config.go @@ -60,6 +60,7 @@ type ReferenceConfig struct { invoker protocol.Invoker urls []*common.URL Generic bool `yaml:"generic" json:"generic,omitempty" property:"generic"` + Sticky bool `yaml:"sticky" json:"sticky,omitempty" property:"sticky"` } func (c *ReferenceConfig) Prefix() string { @@ -175,6 +176,7 @@ func (refconfig *ReferenceConfig) getUrlMap() url.Values { urlMap.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER)) //getty invoke async or sync urlMap.Set(constant.ASYNC_KEY, strconv.FormatBool(refconfig.Async)) + urlMap.Set(constant.STICKY_KEY, strconv.FormatBool(refconfig.Sticky)) //application info urlMap.Set(constant.APPLICATION_KEY, consumerConfig.ApplicationConfig.Name) @@ -195,6 +197,7 @@ func (refconfig *ReferenceConfig) getUrlMap() url.Values { for _, v := range refconfig.Methods { urlMap.Set("methods."+v.Name+"."+constant.LOADBALANCE_KEY, v.Loadbalance) urlMap.Set("methods."+v.Name+"."+constant.RETRIES_KEY, v.Retries) + urlMap.Set("methods."+v.Name+"."+constant.STICKY_KEY, strconv.FormatBool(v.Sticky)) } return urlMap diff --git a/config/reference_config_test.go b/config/reference_config_test.go index a7af925cab..e689c471ed 100644 --- a/config/reference_config_test.go +++ b/config/reference_config_test.go @@ -86,6 +86,7 @@ func doInitConsumer() { "serviceid": "soa.mock", "forks": "5", }, + Sticky: false, Registry: "shanghai_reg1,shanghai_reg2,hangzhou_reg1,hangzhou_reg2", InterfaceName: "com.MockService", Protocol: "mock", @@ -104,6 +105,7 @@ func doInitConsumer() { Name: "GetUser1", Retries: "2", Loadbalance: "random", + Sticky: true, }, }, }, @@ -291,6 +293,24 @@ func Test_Forking(t *testing.T) { consumerConfig = nil } +func Test_Sticky(t *testing.T) { + doInitConsumer() + extension.SetProtocol("dubbo", GetProtocol) + extension.SetProtocol("registry", GetProtocol) + m := consumerConfig.References["MockService"] + m.Url = "dubbo://127.0.0.1:20000;registry://127.0.0.2:20000" + + reference := consumerConfig.References["MockService"] + reference.Refer() + referenceSticky := reference.invoker.GetUrl().GetParam(constant.STICKY_KEY, "false") + assert.Equal(t, "false", referenceSticky) + + method0StickKey := reference.invoker.GetUrl().GetMethodParam(reference.Methods[0].Name, constant.STICKY_KEY, "false") + assert.Equal(t, "false", method0StickKey) + method1StickKey := reference.invoker.GetUrl().GetMethodParam(reference.Methods[1].Name, constant.STICKY_KEY, "false") + assert.Equal(t, "true", method1StickKey) +} + func GetProtocol() protocol.Protocol { if regProtocol != nil { return regProtocol