diff --git a/cluster/cluster_impl/base_cluster_invoker.go b/cluster/cluster_impl/base_cluster_invoker.go index d93e9a6a98..05e9a09208 100644 --- a/cluster/cluster_impl/base_cluster_invoker.go +++ b/cluster/cluster_impl/base_cluster_invoker.go @@ -35,6 +35,7 @@ type baseClusterInvoker struct { directory cluster.Directory availablecheck bool destroyed *atomic.Bool + stickyInvoker protocol.Invoker } func newBaseClusterInvoker(directory cluster.Directory) baseClusterInvoker { @@ -83,15 +84,44 @@ func (invoker *baseClusterInvoker) checkWhetherDestroyed() error { } func (invoker *baseClusterInvoker) doSelect(lb cluster.LoadBalance, invocation protocol.Invocation, invokers []protocol.Invoker, invoked []protocol.Invoker) protocol.Invoker { - //todo:sticky connect + + var selectedInvoker protocol.Invoker + url := invokers[0].GetUrl() + sticky := url.GetParam(constant.STICKY_KEY, "false") + //Get the service method sticky config if have + if v := url.GetMethodParam(invocation.MethodName(), constant.STICKY_KEY, sticky); len(v) != 0 { + sticky = v + } + + if invoker.stickyInvoker != nil && !isInvoked(invoker.stickyInvoker, invokers) { + invoker.stickyInvoker = nil + } + + if sticky == "true" && invoker.stickyInvoker != nil && (invoked == nil || !isInvoked(invoker.stickyInvoker, invoked)) { + if invoker.availablecheck && invoker.stickyInvoker.IsAvailable() { + return invoker.stickyInvoker + } + } + + selectedInvoker = invoker.doSelectInvoker(lb, invocation, invokers, invoked) + + if sticky == "true" { + invoker.stickyInvoker = selectedInvoker + } + return selectedInvoker + +} + +func (invoker *baseClusterInvoker) doSelectInvoker(lb cluster.LoadBalance, invocation protocol.Invocation, invokers []protocol.Invoker, invoked []protocol.Invoker) protocol.Invoker { if len(invokers) == 1 { return invokers[0] } + selectedInvoker := lb.Select(invokers, invocation) //judge to if the selectedInvoker is invoked - if !selectedInvoker.IsAvailable() || !invoker.availablecheck || isInvoked(selectedInvoker, invoked) { + if (!selectedInvoker.IsAvailable() && invoker.availablecheck) || isInvoked(selectedInvoker, invoked) { // do reselect var reslectInvokers []protocol.Invoker @@ -106,13 +136,12 @@ func (invoker *baseClusterInvoker) doSelect(lb cluster.LoadBalance, invocation p } if len(reslectInvokers) > 0 { - return lb.Select(reslectInvokers, invocation) + selectedInvoker = lb.Select(reslectInvokers, invocation) } else { return nil } } return selectedInvoker - } func isInvoked(selectedInvoker protocol.Invoker, invoked []protocol.Invoker) bool { diff --git a/cluster/cluster_impl/base_cluster_invoker_test.go b/cluster/cluster_impl/base_cluster_invoker_test.go new file mode 100644 index 0000000000..41074d0ae8 --- /dev/null +++ b/cluster/cluster_impl/base_cluster_invoker_test.go @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cluster_impl + +import ( + "context" + "fmt" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/apache/dubbo-go/cluster/loadbalance" + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/protocol" + "github.com/apache/dubbo-go/protocol/invocation" +) + +func Test_StickyNormal(t *testing.T) { + invokers := []protocol.Invoker{} + for i := 0; i < 10; i++ { + url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i)) + url.SetParam("sticky", "true") + invokers = append(invokers, NewMockInvoker(url, 1)) + } + base := &baseClusterInvoker{} + invoked := []protocol.Invoker{} + result := base.doSelect(loadbalance.NewRandomLoadBalance(), invocation.NewRPCInvocation("getUser", nil, nil), invokers, invoked) + result1 := base.doSelect(loadbalance.NewRandomLoadBalance(), invocation.NewRPCInvocation("getUser", nil, nil), invokers, invoked) + assert.Equal(t, result, result1) +} +func Test_StickyNormalWhenError(t *testing.T) { + invokers := []protocol.Invoker{} + for i := 0; i < 10; i++ { + url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i)) + url.SetParam("sticky", "true") + invokers = append(invokers, NewMockInvoker(url, 1)) + } + base := &baseClusterInvoker{} + invoked := []protocol.Invoker{} + result := base.doSelect(loadbalance.NewRandomLoadBalance(), invocation.NewRPCInvocation("getUser", nil, nil), invokers, invoked) + invoked = append(invoked, result) + result1 := base.doSelect(loadbalance.NewRandomLoadBalance(), invocation.NewRPCInvocation("getUser", nil, nil), invokers, invoked) + assert.NotEqual(t, result, result1) +} diff --git a/cluster/cluster_impl/failback_cluster_test.go b/cluster/cluster_impl/failback_cluster_test.go index c94347a125..1d2266cabe 100644 --- a/cluster/cluster_impl/failback_cluster_test.go +++ b/cluster/cluster_impl/failback_cluster_test.go @@ -67,7 +67,7 @@ func Test_FailbackSuceess(t *testing.T) { invoker := mock.NewMockInvoker(ctrl) clusterInvoker := registerFailback(t, invoker).(*failbackClusterInvoker) - invoker.EXPECT().GetUrl().Return(failbackUrl).Times(1) + invoker.EXPECT().GetUrl().Return(failbackUrl).AnyTimes() mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}} invoker.EXPECT().Invoke(gomock.Any()).Return(mockResult) diff --git a/cluster/cluster_impl/failfast_cluster_test.go b/cluster/cluster_impl/failfast_cluster_test.go index 7a19e80ccd..1a4342e6c2 100644 --- a/cluster/cluster_impl/failfast_cluster_test.go +++ b/cluster/cluster_impl/failfast_cluster_test.go @@ -64,7 +64,7 @@ func Test_FailfastInvokeSuccess(t *testing.T) { invoker := mock.NewMockInvoker(ctrl) clusterInvoker := registerFailfast(t, invoker) - invoker.EXPECT().GetUrl().Return(failfastUrl) + invoker.EXPECT().GetUrl().Return(failfastUrl).AnyTimes() mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}} @@ -84,7 +84,7 @@ func Test_FailfastInvokeFail(t *testing.T) { invoker := mock.NewMockInvoker(ctrl) clusterInvoker := registerFailfast(t, invoker) - invoker.EXPECT().GetUrl().Return(failfastUrl) + invoker.EXPECT().GetUrl().Return(failfastUrl).AnyTimes() mockResult := &protocol.RPCResult{Err: perrors.New("error")} diff --git a/cluster/cluster_impl/failsafe_cluster_test.go b/cluster/cluster_impl/failsafe_cluster_test.go index 9ee9d9fee3..7888b97c3a 100644 --- a/cluster/cluster_impl/failsafe_cluster_test.go +++ b/cluster/cluster_impl/failsafe_cluster_test.go @@ -64,7 +64,7 @@ func Test_FailSafeInvokeSuccess(t *testing.T) { invoker := mock.NewMockInvoker(ctrl) clusterInvoker := register_failsafe(t, invoker) - invoker.EXPECT().GetUrl().Return(failsafeUrl) + invoker.EXPECT().GetUrl().Return(failsafeUrl).AnyTimes() mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}} @@ -83,7 +83,7 @@ func Test_FailSafeInvokeFail(t *testing.T) { invoker := mock.NewMockInvoker(ctrl) clusterInvoker := register_failsafe(t, invoker) - invoker.EXPECT().GetUrl().Return(failsafeUrl) + invoker.EXPECT().GetUrl().Return(failsafeUrl).AnyTimes() mockResult := &protocol.RPCResult{Err: perrors.New("error")} diff --git a/common/constant/key.go b/common/constant/key.go index 17368b45ae..7538a2995a 100644 --- a/common/constant/key.go +++ b/common/constant/key.go @@ -55,6 +55,7 @@ const ( WEIGHT_KEY = "weight" WARMUP_KEY = "warmup" RETRIES_KEY = "retries" + STICKY_KEY = "sticky" BEAN_NAME = "bean.name" FAIL_BACK_TASKS_KEY = "failbacktasks" FORKS_KEY = "forks" diff --git a/config/method_config.go b/config/method_config.go index e3f0b1b01b..876abeeae0 100644 --- a/config/method_config.go +++ b/config/method_config.go @@ -36,6 +36,7 @@ type MethodConfig struct { TpsLimitStrategy string `yaml:"tps.limit.strategy" json:"tps.limit.strategy,omitempty" property:"tps.limit.strategy"` ExecuteLimit string `yaml:"execute.limit" json:"execute.limit,omitempty" property:"execute.limit"` ExecuteLimitRejectedHandler string `yaml:"execute.limit.rejected.handler" json:"execute.limit.rejected.handler,omitempty" property:"execute.limit.rejected.handler"` + Sticky bool `yaml:"sticky" json:"sticky,omitempty" property:"sticky"` } func (c *MethodConfig) Prefix() string { diff --git a/config/reference_config.go b/config/reference_config.go index 8703c459ba..a49dccc7e0 100644 --- a/config/reference_config.go +++ b/config/reference_config.go @@ -60,6 +60,7 @@ type ReferenceConfig struct { invoker protocol.Invoker urls []*common.URL Generic bool `yaml:"generic" json:"generic,omitempty" property:"generic"` + Sticky bool `yaml:"sticky" json:"sticky,omitempty" property:"sticky"` } func (c *ReferenceConfig) Prefix() string { @@ -170,6 +171,7 @@ func (refconfig *ReferenceConfig) getUrlMap() url.Values { urlMap.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER)) //getty invoke async or sync urlMap.Set(constant.ASYNC_KEY, strconv.FormatBool(refconfig.async)) + urlMap.Set(constant.STICKY_KEY, strconv.FormatBool(refconfig.Sticky)) //application info urlMap.Set(constant.APPLICATION_KEY, consumerConfig.ApplicationConfig.Name) @@ -190,6 +192,7 @@ func (refconfig *ReferenceConfig) getUrlMap() url.Values { for _, v := range refconfig.Methods { urlMap.Set("methods."+v.Name+"."+constant.LOADBALANCE_KEY, v.Loadbalance) urlMap.Set("methods."+v.Name+"."+constant.RETRIES_KEY, v.Retries) + urlMap.Set("methods."+v.Name+"."+constant.STICKY_KEY, strconv.FormatBool(v.Sticky)) } return urlMap diff --git a/config/reference_config_test.go b/config/reference_config_test.go index a81dbf06ce..85f15e7646 100644 --- a/config/reference_config_test.go +++ b/config/reference_config_test.go @@ -85,6 +85,7 @@ func doInitConsumer() { "serviceid": "soa.mock", "forks": "5", }, + Sticky: false, Registry: "shanghai_reg1,shanghai_reg2,hangzhou_reg1,hangzhou_reg2", InterfaceName: "com.MockService", Protocol: "mock", @@ -103,6 +104,7 @@ func doInitConsumer() { Name: "GetUser1", Retries: "2", Loadbalance: "random", + Sticky: true, }, }, }, @@ -254,6 +256,24 @@ func Test_Forking(t *testing.T) { consumerConfig = nil } +func Test_Sticky(t *testing.T) { + doInitConsumer() + extension.SetProtocol("dubbo", GetProtocol) + extension.SetProtocol("registry", GetProtocol) + m := consumerConfig.References["MockService"] + m.Url = "dubbo://127.0.0.1:20000;registry://127.0.0.2:20000" + + reference := consumerConfig.References["MockService"] + reference.Refer() + referenceSticky := reference.invoker.GetUrl().GetParam(constant.STICKY_KEY, "false") + assert.Equal(t, "false", referenceSticky) + + method0StickKey := reference.invoker.GetUrl().GetMethodParam(reference.Methods[0].Name, constant.STICKY_KEY, "false") + assert.Equal(t, "false", method0StickKey) + method1StickKey := reference.invoker.GetUrl().GetMethodParam(reference.Methods[1].Name, constant.STICKY_KEY, "false") + assert.Equal(t, "true", method1StickKey) +} + func GetProtocol() protocol.Protocol { if regProtocol != nil { return regProtocol diff --git a/go.mod b/go.mod index 6a9128af0c..c2a61f2db1 100644 --- a/go.mod +++ b/go.mod @@ -52,3 +52,5 @@ require ( google.golang.org/grpc v1.22.1 gopkg.in/yaml.v2 v2.2.2 ) + +go 1.13