diff --git a/cluster/gate/gate.go b/cluster/gate/gate.go index e1f30bb..a2174df 100644 --- a/cluster/gate/gate.go +++ b/cluster/gate/gate.go @@ -210,6 +210,7 @@ func (g *Gate) registerServiceInstance() { Kind: cluster.Gate.String(), Alias: g.opts.name, State: g.getState().String(), + Weight: g.opts.weight, Endpoint: g.linker.Endpoint().String(), } diff --git a/cluster/gate/options.go b/cluster/gate/options.go index d52dc3c..603b345 100644 --- a/cluster/gate/options.go +++ b/cluster/gate/options.go @@ -22,6 +22,7 @@ const ( defaultName = "gate" // 默认名称 defaultAddr = ":0" // 连接器监听地址 defaultTimeout = 3 * time.Second // 默认超时时间 + defaultWeight = 1 // 默认权重 ) const ( @@ -29,6 +30,7 @@ const ( defaultNameKey = "etc.cluster.gate.name" defaultAddrKey = "etc.cluster.gate.addr" defaultTimeoutKey = "etc.cluster.gate.timeout" + defaultWeightKey = "etc.cluster.gate.weight" ) type Option func(o *options) @@ -39,6 +41,7 @@ type options struct { name string // 实例名称 addr string // 监听地址 timeout time.Duration // RPC调用超时时间 + weight int // 权重 server network.Server // 网关服务器 locator locate.Locator // 用户定位器 registry registry.Registry // 服务注册器 @@ -50,6 +53,7 @@ func defaultOptions() *options { name: defaultName, addr: defaultAddr, timeout: defaultTimeout, + weight: defaultWeight, } if id := etc.Get(defaultIDKey).String(); id != "" { @@ -70,6 +74,10 @@ func defaultOptions() *options { opts.timeout = timeout } + if weight := etc.Get(defaultWeightKey).Int(); weight > 0 { + opts.weight = weight + } + return opts } @@ -107,3 +115,8 @@ func WithLocator(locator locate.Locator) Option { func WithRegistry(r registry.Registry) Option { return func(o *options) { o.registry = r } } + +// WithWeight 设置权重 +func WithWeight(weight int) Option { + return func(o *options) { o.weight = weight } +} diff --git a/cluster/mesh/mesh.go b/cluster/mesh/mesh.go index 9c12cd2..2dbfd8d 100644 --- a/cluster/mesh/mesh.go +++ b/cluster/mesh/mesh.go @@ -153,6 +153,7 @@ func (m *Mesh) registerServiceInstances() { Kind: cluster.Mesh.String(), Alias: m.opts.name, State: m.getState().String(), + Weight: m.opts.weight, Endpoint: m.transporter.Endpoint().String(), Services: make([]string, 0, len(m.services)), } diff --git a/cluster/mesh/options.go b/cluster/mesh/options.go index 3bcf3c6..e25b1bf 100644 --- a/cluster/mesh/options.go +++ b/cluster/mesh/options.go @@ -16,6 +16,7 @@ const ( defaultName = "mesh" // 默认节点名称 defaultCodec = "proto" // 默认编解码器名称 defaultTimeout = 3 * time.Second // 默认超时时间 + defaultWeight = 1 // 默认权重 ) const ( @@ -23,6 +24,7 @@ const ( defaultNameKey = "etc.cluster.mesh.name" defaultCodecKey = "etc.cluster.mesh.codec" defaultTimeoutKey = "etc.cluster.mesh.timeout" + defaultWeightKey = "etc.cluster.mesh.weight" ) type Option func(o *options) @@ -37,6 +39,7 @@ type options struct { registry registry.Registry // 服务注册器 encryptor crypto.Encryptor // 消息加密器 transporter transport.Transporter // 消息传输器 + weight int // 权重 } func defaultOptions() *options { @@ -45,6 +48,7 @@ func defaultOptions() *options { name: defaultName, codec: encoding.Invoke(defaultCodec), timeout: defaultTimeout, + weight: defaultWeight, } if id := etc.Get(defaultIDKey).String(); id != "" { @@ -65,6 +69,10 @@ func defaultOptions() *options { opts.timeout = timeout } + if weight := etc.Get(defaultWeightKey).Int(); weight > 0 { + opts.weight = weight + } + return opts } @@ -107,3 +115,8 @@ func WithEncryptor(encryptor crypto.Encryptor) Option { func WithTransporter(transporter transport.Transporter) Option { return func(o *options) { o.transporter = transporter } } + +// WithWeight 设置权重 +func WithWeight(weight int) Option { + return func(o *options) { o.weight = weight } +} diff --git a/cluster/node/node.go b/cluster/node/node.go index 98089c2..9bb00b5 100644 --- a/cluster/node/node.go +++ b/cluster/node/node.go @@ -300,6 +300,7 @@ func (n *Node) registerServiceInstances() { Routes: routes, Events: events, Endpoint: n.linker.Endpoint().String(), + Weight: n.opts.weight, }) if n.transporter != nil { @@ -316,6 +317,7 @@ func (n *Node) registerServiceInstances() { State: n.getState().String(), Services: services, Endpoint: n.transporter.Endpoint().String(), + Weight: n.opts.weight, }) } diff --git a/cluster/node/options.go b/cluster/node/options.go index 6035dca..1bce2ff 100644 --- a/cluster/node/options.go +++ b/cluster/node/options.go @@ -17,6 +17,7 @@ const ( defaultAddr = ":0" // 连接器监听地址 defaultCodec = "proto" // 默认编解码器名称 defaultTimeout = 3 * time.Second // 默认超时时间 + defaultWeight = 1 // 默认权重 ) const ( @@ -25,6 +26,7 @@ const ( defaultAddrKey = "etc.cluster.node.addr" defaultCodecKey = "etc.cluster.node.codec" defaultTimeoutKey = "etc.cluster.node.timeout" + defaultWeightKey = "etc.cluster.node.weight" ) // SchedulingModel 调度模型 @@ -43,6 +45,7 @@ type options struct { registry registry.Registry // 服务注册器 encryptor crypto.Encryptor // 消息加密器 transporter transport.Transporter // 消息传输器 + weight int // 权重 } func defaultOptions() *options { @@ -52,6 +55,7 @@ func defaultOptions() *options { addr: defaultAddr, codec: encoding.Invoke(defaultCodec), timeout: defaultTimeout, + weight: defaultWeight, } if id := etc.Get(defaultIDKey).String(); id != "" { @@ -76,6 +80,10 @@ func defaultOptions() *options { opts.timeout = timeout } + if weight := etc.Get(defaultWeightKey).Int(); weight > 0 { + opts.weight = weight + } + return opts } @@ -128,3 +136,8 @@ func WithEncryptor(encryptor crypto.Encryptor) Option { func WithTransporter(transporter transport.Transporter) Option { return func(o *options) { o.transporter = transporter } } + +// WithWeight 设置权重 +func WithWeight(weight int) Option { + return func(o *options) { o.weight = weight } +} diff --git a/internal/dispatcher/abstract.go b/internal/dispatcher/abstract.go index 29c4102..f480f0d 100644 --- a/internal/dispatcher/abstract.go +++ b/internal/dispatcher/abstract.go @@ -5,6 +5,7 @@ import ( "github.com/dobyte/due/v2/core/endpoint" "github.com/dobyte/due/v2/errors" "sync/atomic" + "sync" ) type serviceEndpoint struct { @@ -20,6 +21,25 @@ type abstract struct { endpoints2 map[string]*serviceEndpoint // 所有端口(包含work、busy、hang、shut状态的实例) endpoints3 []*serviceEndpoint // 所有端口(包含work、busy状态的实例) endpoints4 map[string]*serviceEndpoint // 所有端口(包含work、busy状态的实例) + // 加权轮询相关字段 + currentQueue *wrrQueue // 当前队列 + nextQueue *wrrQueue // 下一个队列 + step int // GCD步长 + wrrMu sync.Mutex // 加权轮询锁 +} + +// 加权轮询队列节点 +type wrrEntry struct { + weight int // 当前权重 + orgWeight int // 原始权重 + endpoint *serviceEndpoint + next *wrrEntry +} + +// 加权轮询队列 +type wrrQueue struct { + head *wrrEntry + tail *wrrEntry } // FindEndpoint 查询路由服务端点 @@ -29,7 +49,7 @@ func (a *abstract) FindEndpoint(insID ...string) (*endpoint.Endpoint, error) { case RoundRobin: return a.roundRobinDispatch() case WeightRoundRobin: - return a.randomDispatch() + return a.weightRoundRobinDispatch() default: return a.randomDispatch() } @@ -111,3 +131,106 @@ func (a *abstract) roundRobinDispatch() (*endpoint.Endpoint, error) { return a.endpoints3[index].endpoint, nil } + +// 加权轮询分配 +func (a *abstract) weightRoundRobinDispatch() (*endpoint.Endpoint, error) { + a.wrrMu.Lock() + defer a.wrrMu.Unlock() + + // 如果当前队列为空,交换当前队列和下一个队列 + if a.currentQueue.isEmpty() { + a.currentQueue, a.nextQueue = a.nextQueue, a.currentQueue + } + + // 从当前队列中取出一个节点 + entry := a.currentQueue.pop() + if entry == nil { + return nil, errors.ErrNotFoundEndpoint + } + + // 减少当前权重 + entry.weight -= a.step + + // 如果权重大于0,放回当前队列 + if entry.weight > 0 { + a.currentQueue.push(entry) + } else { + // 重置权重并放入下一个队列 + entry.weight = entry.orgWeight + a.nextQueue.push(entry) + } + + return entry.endpoint.endpoint, nil +} + +// 初始化 WRR 队列 +func (a *abstract) initWRRQueue() { + a.currentQueue = &wrrQueue{} + a.nextQueue = &wrrQueue{} + + // 计算最大公约数作为步长 + a.step = 0 + for _, sep := range a.endpoints4 { + weight := a.dispatcher.instances[sep.insID].Weight + if a.step == 0 { + a.step = weight + } else { + a.step = gcd(a.step, weight) + } + + // 创建队列节点 + entry := &wrrEntry{ + weight: weight, + orgWeight: weight, + endpoint: sep, + } + a.currentQueue.push(entry) + } +} + +// 判断队列是否为空 +func (q *wrrQueue) isEmpty() bool { + return q.head == nil +} + +// 将节点加入队列尾部 +func (q *wrrQueue) push(entry *wrrEntry) { + entry.next = nil + + if q.tail == nil { + // 空队列 + q.head = entry + q.tail = entry + return + } + + // 添加到队列尾部 + q.tail.next = entry + q.tail = entry +} + +// 从队列头部取出节点 +func (q *wrrQueue) pop() *wrrEntry { + if q.head == nil { + return nil + } + + entry := q.head + q.head = entry.next + + if q.head == nil { + // 队列已空 + q.tail = nil + } + + entry.next = nil + return entry +} + +// 计算最大公约数 +func gcd(a, b int) int { + for b != 0 { + a, b = b, a%b + } + return a +} \ No newline at end of file diff --git a/internal/dispatcher/dispatcher.go b/internal/dispatcher/dispatcher.go index e924e80..4a51038 100644 --- a/internal/dispatcher/dispatcher.go +++ b/internal/dispatcher/dispatcher.go @@ -23,6 +23,7 @@ type Dispatcher struct { routes map[int32]*Route events map[int]*Event endpoints map[string]*endpoint.Endpoint + instances map[string]*registry.ServiceInstance } func NewDispatcher(strategy BalanceStrategy) *Dispatcher { @@ -85,6 +86,7 @@ func (d *Dispatcher) ReplaceServices(services ...*registry.ServiceInstance) { routes := make(map[int32]*Route, len(services)) events := make(map[int]*Event, len(services)) endpoints := make(map[string]*endpoint.Endpoint) + instances := make(map[string]*registry.ServiceInstance, len(services)) log.Debugf("services change: %v", xconv.Json(services)) @@ -97,6 +99,7 @@ func (d *Dispatcher) ReplaceServices(services ...*registry.ServiceInstance) { } endpoints[service.ID] = ep + instances[service.ID] = service for _, item := range service.Routes { route, ok := routes[item.ID] @@ -121,5 +124,15 @@ func (d *Dispatcher) ReplaceServices(services ...*registry.ServiceInstance) { d.routes = routes d.events = events d.endpoints = endpoints + d.instances = instances + + if d.strategy == WeightRoundRobin { + for _, route := range routes { + route.initWRRQueue() + } + for _, event := range events { + event.initWRRQueue() + } + } d.rw.Unlock() } diff --git a/internal/dispatcher/dispatcher_test.go b/internal/dispatcher/dispatcher_test.go index ed2a336..7b4effb 100644 --- a/internal/dispatcher/dispatcher_test.go +++ b/internal/dispatcher/dispatcher_test.go @@ -6,6 +6,8 @@ import ( "github.com/dobyte/due/v2/internal/dispatcher" "github.com/dobyte/due/v2/registry" "testing" + "math" + "fmt" ) func TestDispatcher_ReplaceServices(t *testing.T) { @@ -85,3 +87,220 @@ func TestDispatcher_ReplaceServices(t *testing.T) { // t.Log(event.FindEndpoint()) //} } + +func TestDispatcher_WeightRoundRobin(t *testing.T) { + var ( + // 创建三个服务实例,权重分别为4、2、1 + instance1 = ®istry.ServiceInstance{ + ID: "xa", + Name: "gate-1", + Kind: cluster.Node.String(), + Alias: "gate-1", + State: cluster.Work.String(), + Endpoint: endpoint.NewEndpoint("grpc", "127.0.0.1:8001", false).String(), + Weight: 4, // 权重4 + Routes: []registry.Route{{ + ID: 1, + Stateful: false, + }}, + } + instance2 = ®istry.ServiceInstance{ + ID: "xb", + Name: "gate-2", + Kind: cluster.Node.String(), + Alias: "gate-2", + State: cluster.Work.String(), + Endpoint: endpoint.NewEndpoint("grpc", "127.0.0.1:8002", false).String(), + Weight: 2, // 权重2 + Routes: []registry.Route{{ + ID: 1, + Stateful: false, + }}, + } + instance3 = ®istry.ServiceInstance{ + ID: "xc", + Name: "gate-3", + Kind: cluster.Node.String(), + Alias: "gate-3", + State: cluster.Work.String(), + Endpoint: endpoint.NewEndpoint("grpc", "127.0.0.1:8003", false).String(), + Weight: 1, // 权重1 + Routes: []registry.Route{{ + ID: 1, + Stateful: false, + }}, + } + ) + + // 创建加权轮询调度器 + d := dispatcher.NewDispatcher(dispatcher.WeightRoundRobin) + d.ReplaceServices(instance1, instance2, instance3) + + // 统计每个实例被选中的次数 + counts := make(map[string]int) + totalRounds := 70 // 选择一个能被所有权重和(7)整除的数 + + // 执行多轮测试 + for i := 0; i < totalRounds; i++ { + route, err := d.FindRoute(1) + if err != nil { + t.Errorf("find route failed: %v", err) + return + } + + ep, err := route.FindEndpoint() + if err != nil { + t.Errorf("find endpoint failed: %v", err) + return + } + + // 从endpoint中解析实例ID并计数 + parsedEp, err := endpoint.ParseEndpoint(ep.String()) + if err != nil { + t.Errorf("parse endpoint failed: %v", err) + return + } + addr := parsedEp.Address() + counts[addr]++ + } + + // 验证分配结果 + expectedRatios := map[string]float64{ + "127.0.0.1:8001": 4.0 / 7.0, // 权重4 + "127.0.0.1:8002": 2.0 / 7.0, // 权重2 + "127.0.0.1:8003": 1.0 / 7.0, // 权重1 + } + + t.Log("Distribution results:") + for addr, count := range counts { + ratio := float64(count) / float64(totalRounds) + expected := expectedRatios[addr] + t.Logf("Server %s: selected %d times, ratio=%.3f, expected=%.3f", + addr, count, ratio, expected) + + // 验证分配比例是否符合权重比例(允许5%的误差) + if delta := math.Abs(ratio - expected); delta > 0.05 { + t.Errorf("distribution ratio for %s is %.3f, want %.3f (±0.05)", + addr, ratio, expected) + } + } + + // 验证总次数 + total := 0 + for _, count := range counts { + total += count + } + if total != totalRounds { + t.Errorf("total rounds = %d, want %d", total, totalRounds) + } +} + +func BenchmarkDispatcher_WeightRoundRobin(b *testing.B) { + var ( + // 创建测试服务实例 + instances = []*registry.ServiceInstance{ + { + ID: "xa", + Name: "gate-1", + Kind: cluster.Node.String(), + Alias: "gate-1", + State: cluster.Work.String(), + Weight: 4, + Endpoint: endpoint.NewEndpoint("grpc", "127.0.0.1:8001", false).String(), + Routes: []registry.Route{{ + ID: 1, + Stateful: false, + }}, + }, + { + ID: "xb", + Name: "gate-2", + Kind: cluster.Node.String(), + Alias: "gate-2", + State: cluster.Work.String(), + Weight: 2, + Endpoint: endpoint.NewEndpoint("grpc", "127.0.0.1:8002", false).String(), + Routes: []registry.Route{{ + ID: 1, + Stateful: false, + }}, + }, + { + ID: "xc", + Name: "gate-3", + Kind: cluster.Node.String(), + Alias: "gate-3", + State: cluster.Work.String(), + Weight: 1, + Endpoint: endpoint.NewEndpoint("grpc", "127.0.0.1:8003", false).String(), + Routes: []registry.Route{{ + ID: 1, + Stateful: false, + }}, + }, + } + ) + + // 运行不同规模的基准测试 + benchmarks := []struct { + name string + concurrency int // 并发数 + instanceCount int // 服务实例数量 + }{ + {"Concurrency1_Instances3", 1, 3}, + {"Concurrency10_Instances3", 10, 3}, + {"Concurrency100_Instances3", 100, 3}, + {"Concurrency1_Instances10", 1, 10}, + {"Concurrency10_Instances10", 10, 10}, + {"Concurrency100_Instances10", 100, 10}, + } + + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + // 准备足够数量的实例 + testInstances := make([]*registry.ServiceInstance, bm.instanceCount) + for i := 0; i < bm.instanceCount; i++ { + if i < len(instances) { + testInstances[i] = instances[i] + } else { + // 复制最后一个实例并修改ID和端口 + last := instances[len(instances)-1] + testInstances[i] = ®istry.ServiceInstance{ + ID: fmt.Sprintf("x%d", i), + Name: fmt.Sprintf("gate-%d", i+1), + Kind: last.Kind, + Alias: fmt.Sprintf("gate-%d", i+1), + State: last.State, + Weight: 1, + Endpoint: endpoint.NewEndpoint("grpc", fmt.Sprintf("127.0.0.1:%d", 8000+i), false).String(), + Routes: last.Routes, + } + } + } + + // 创建调度器 + d := dispatcher.NewDispatcher(dispatcher.WeightRoundRobin) + d.ReplaceServices(testInstances...) + + // 重置计时器 + b.ResetTimer() + + // 并发执行基准测试 + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + route, err := d.FindRoute(1) + if err != nil { + b.Fatal(err) + } + _, err = route.FindEndpoint() + if err != nil { + b.Fatal(err) + } + } + }) + + // 报告内存分配统计 + b.ReportAllocs() + }) + } +} \ No newline at end of file diff --git a/registry/consul/registrar.go b/registry/consul/registrar.go index 23774c2..236b592 100644 --- a/registry/consul/registrar.go +++ b/registry/consul/registrar.go @@ -21,6 +21,7 @@ const ( metaFieldState = "state" metaFieldRoutes = "routes" metaFieldEvents = "events" + metaFieldWeight = "weight" metaFieldServices = "services" metaFieldEndpoint = "endpoint" ) @@ -79,6 +80,7 @@ func (r *registrar) register(ctx context.Context, ins *registry.ServiceInstance) registration.Meta[metaFieldState] = ins.State registration.Meta[metaFieldEndpoint] = ins.Endpoint registration.Meta[metaFieldEvents] = xconv.Json(ins.Events) + registration.Meta[metaFieldWeight] = xconv.String(ins.Weight) registration.Meta[metaFieldServices] = xconv.Json(ins.Services) for field, value := range marshalMetaRoutes(ins.Routes) { diff --git a/registry/consul/registry.go b/registry/consul/registry.go index 26671a2..26b7d1a 100644 --- a/registry/consul/registry.go +++ b/registry/consul/registry.go @@ -4,6 +4,7 @@ import ( "context" "github.com/dobyte/due/v2/encoding/json" "github.com/dobyte/due/v2/registry" + "github.com/dobyte/due/v2/utils/xconv" "github.com/hashicorp/consul/api" "sync" "time" @@ -147,6 +148,8 @@ func (r *Registry) services(ctx context.Context, serviceName string, waitIndex u ins.Alias = v case metaFieldState: ins.State = v + case metaFieldWeight: + ins.Weight = xconv.Int(v) case metaFieldEvents: if err = json.Unmarshal([]byte(v), &ins.Events); err != nil { continue diff --git a/registry/nacos/registrar.go b/registry/nacos/registrar.go index bd9dd80..a21f064 100644 --- a/registry/nacos/registrar.go +++ b/registry/nacos/registrar.go @@ -5,6 +5,7 @@ import ( "github.com/dobyte/due/v2/encoding/json" "github.com/dobyte/due/v2/errors" "github.com/dobyte/due/v2/registry" + "github.com/dobyte/due/v2/utils/xconv" "github.com/nacos-group/nacos-sdk-go/v2/vo" "net" "net/url" @@ -19,6 +20,7 @@ const ( metaFieldState = "state" metaFieldRoutes = "routes" metaFieldEvents = "events" + metaFieldWeight = "weight" metaFieldServices = "services" metaFieldEndpoint = "endpoint" ) @@ -73,6 +75,7 @@ func (r *registrar) register(ctx context.Context, ins *registry.ServiceInstance) metaFieldEvents: string(events), metaFieldServices: string(services), metaFieldEndpoint: ins.Endpoint, + metaFieldWeight: xconv.String(ins.Weight), }, } diff --git a/registry/nacos/registry.go b/registry/nacos/registry.go index 71a305a..2488ae8 100644 --- a/registry/nacos/registry.go +++ b/registry/nacos/registry.go @@ -10,6 +10,7 @@ import ( "github.com/nacos-group/nacos-sdk-go/v2/common/constant" "github.com/nacos-group/nacos-sdk-go/v2/model" "github.com/nacos-group/nacos-sdk-go/v2/vo" + "github.com/dobyte/due/v2/utils/xconv" "net" "net/url" "strconv" @@ -219,6 +220,7 @@ func parseInstances(instances []model.Instance) ([]*registry.ServiceInstance, er ins.Routes = make([]registry.Route, 0) ins.Events = make([]int, 0) ins.Services = make([]string, 0) + ins.Weight = xconv.Int(instance.Metadata[metaFieldWeight]) if v := instance.Metadata[metaFieldRoutes]; v != "" { if err := json.Unmarshal([]byte(v), &ins.Routes); err != nil { diff --git a/registry/registry.go b/registry/registry.go index 18c9762..ef7ac01 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -50,6 +50,8 @@ type ServiceInstance struct { Services []string `json:"services,omitempty"` // 微服务实体暴露端口 Endpoint string `json:"endpoint,omitempty"` + // 微服务路由加权轮询权重 + Weight int `json:"weight,omitempty"` } type Route struct {