Skip to content

Commit

Permalink
Make proxy use roundrobin to choose replica (milvus-io#17063)
Browse files Browse the repository at this point in the history
Fixes: milvus-io#17055

Signed-off-by: yangxuan <xuan.yang@zilliz.com>
  • Loading branch information
XuanYang-cn authored May 17, 2022
1 parent b37b87e commit 127dd34
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 55 deletions.
4 changes: 2 additions & 2 deletions internal/datacoord/channel_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ func (c *ChannelManager) Release(nodeID UniqueID, channelName string) error {

toReleaseChannel := c.getChannelByNodeAndName(nodeID, channelName)
if toReleaseChannel == nil {
return fmt.Errorf("fail to find matching nodID: %d with channelName: %s", nodeID, channelName)
return fmt.Errorf("fail to find matching nodeID: %d with channelName: %s", nodeID, channelName)
}

toReleaseUpdates := getReleaseOp(nodeID, toReleaseChannel)
Expand All @@ -731,7 +731,7 @@ func (c *ChannelManager) toDelete(nodeID UniqueID, channelName string) error {

ch := c.getChannelByNodeAndName(nodeID, channelName)
if ch == nil {
return fmt.Errorf("fail to find matching nodID: %d with channelName: %s", nodeID, channelName)
return fmt.Errorf("fail to find matching nodeID: %d with channelName: %s", nodeID, channelName)
}

if !c.isMarkedDrop(channelName) {
Expand Down
33 changes: 28 additions & 5 deletions internal/proxy/meta_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ type Cache interface {
GetPartitionInfo(ctx context.Context, collectionName string, partitionName string) (*partitionInfo, error)
// GetCollectionSchema get collection's schema.
GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error)
GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) ([]*querypb.ShardLeadersList, error)
GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) (map[string][]queryNode, error)
ClearShards(collectionName string)
RemoveCollection(ctx context.Context, collectionName string)
RemovePartition(ctx context.Context, collectionName string, partitionName string)
Expand All @@ -70,7 +70,7 @@ type collectionInfo struct {
collID typeutil.UniqueID
schema *schemapb.CollectionSchema
partInfo map[string]*partitionInfo
shardLeaders []*querypb.ShardLeadersList
shardLeaders map[string][]queryNode
createdTimestamp uint64
createdUtcTimestamp uint64
}
Expand Down Expand Up @@ -528,15 +528,20 @@ func (m *MetaCache) GetCredUsernames(ctx context.Context) ([]string, error) {
}

// GetShards update cache if withCache == false
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) ([]*querypb.ShardLeadersList, error) {
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) (map[string][]queryNode, error) {
info, err := m.GetCollectionInfo(ctx, collectionName)
if err != nil {
return nil, err
}

if withCache {
if len(info.shardLeaders) > 0 {
return info.shardLeaders, nil
shards := updateShardsWithRoundRobin(info.shardLeaders)

m.mu.Lock()
m.collInfo[collectionName].shardLeaders = shards
m.mu.Unlock()
return shards, nil
}
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord",
zap.String("collectionName", collectionName))
Expand All @@ -557,7 +562,9 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
return nil, fmt.Errorf("fail to get shard leaders from QueryCoord: %s", resp.Status.Reason)
}

shards := resp.GetShards()
shards := parseShardLeaderList2QueryNode(resp.GetShards())

shards = updateShardsWithRoundRobin(shards)

m.mu.Lock()
m.collInfo[collectionName].shardLeaders = shards
Expand All @@ -566,6 +573,22 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
return shards, nil
}

func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]queryNode {
shard2QueryNodes := make(map[string][]queryNode)

for _, leaders := range shardsLeaders {
qns := make([]queryNode, len(leaders.GetNodeIds()))

for j := range qns {
qns[j] = queryNode{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j]}
}

shard2QueryNodes[leaders.GetChannelName()] = qns
}

return shard2QueryNodes
}

// ClearShards clear the shard leader cache of a collection
func (m *MetaCache) ClearShards(collectionName string) {
log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName))
Expand Down
10 changes: 4 additions & 6 deletions internal/proxy/meta_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,17 +344,16 @@ func TestMetaCache_GetShards(t *testing.T) {
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards))
assert.Equal(t, 3, len(shards[0].GetNodeAddrs()))
assert.Equal(t, 3, len(shards[0].GetNodeIds()))

assert.Equal(t, 3, len(shards["channel-1"]))

// get from cache
qc.validShardLeaders = false
shards, err = globalMetaCache.GetShards(ctx, true, collectionName, qc)
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards))
assert.Equal(t, 3, len(shards[0].GetNodeAddrs()))
assert.Equal(t, 3, len(shards[0].GetNodeIds()))
assert.Equal(t, 3, len(shards["channel-1"]))
})
}

Expand Down Expand Up @@ -387,8 +386,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, shards)
require.Equal(t, 1, len(shards))
require.Equal(t, 3, len(shards[0].GetNodeAddrs()))
require.Equal(t, 3, len(shards[0].GetNodeIds()))
require.Equal(t, 3, len(shards["channel-1"]))

globalMetaCache.ClearShards(collectionName)

Expand Down
42 changes: 31 additions & 11 deletions internal/proxy/task_policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ package proxy
import (
"context"
"errors"
"fmt"

qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client"

"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"

"go.uber.org/zap"
)

type getQueryNodePolicy func(context.Context, string) (types.QueryNode, error)

type pickShardPolicy func(ctx context.Context, policy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders *querypb.ShardLeadersList) error
type pickShardPolicy func(ctx context.Context, policy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders []queryNode) error

// TODO add another policy to enbale the use of cache
// defaultGetQueryNodePolicy creates QueryNode client for every address everytime
Expand All @@ -40,23 +40,45 @@ var (
errInvalidShardLeaders = errors.New("Invalid shard leader")
)

func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders *querypb.ShardLeadersList) error {
type queryNode struct {
nodeID UniqueID
address string
}

func (q queryNode) String() string {
return fmt.Sprintf("<NodeID: %d>", q.nodeID)
}

func updateShardsWithRoundRobin(shardsLeaders map[string][]queryNode) map[string][]queryNode {

for channelID, leaders := range shardsLeaders {
if len(leaders) <= 1 {
continue
}

shardsLeaders[channelID] = append(leaders[1:], leaders[0])
}

return shardsLeaders
}

func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders []queryNode) error {
var (
err = errBegin
current = 0
qn types.QueryNode
)
replicaNum := len(leaders.GetNodeIds())
replicaNum := len(leaders)

for err != nil && current < replicaNum {
currentID := leaders.GetNodeIds()[current]
currentID := leaders[current].nodeID
if err != errBegin {
log.Warn("retry with another QueryNode",
zap.Int("retries numbers", current),
zap.String("leader", leaders.GetChannelName()), zap.Int64("nodeID", currentID))
zap.Int64("nodeID", currentID))
}

qn, err = getQueryNodePolicy(ctx, leaders.GetNodeAddrs()[current])
qn, err = getQueryNodePolicy(ctx, leaders[current].address)
if err != nil {
log.Warn("fail to get valid QueryNode", zap.Int64("nodeID", currentID),
zap.Error(err))
Expand All @@ -68,17 +90,15 @@ func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy
err = query(currentID, qn)
if err != nil {
log.Warn("fail to Query with shard leader",
zap.String("leader", leaders.GetChannelName()),
zap.Int64("nodeID", currentID),
zap.Error(err))
}
current++
}

if current == replicaNum && err != nil {
log.Warn("no shard leaders available for channel",
zap.String("channel name", leaders.GetChannelName()),
zap.Int64s("leaders", leaders.GetNodeIds()), zap.Error(err))
log.Warn("no shard leaders available",
zap.String("leaders", fmt.Sprintf("%v", leaders)), zap.Error(err))
// needs to return the error from query
return err
}
Expand Down
69 changes: 56 additions & 13 deletions internal/proxy/task_policies_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,53 @@ import (
"fmt"
"testing"

"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/types"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"go.uber.org/zap"
)

func TestUpdateShardsWithRoundRobin(t *testing.T) {
in := map[string][]queryNode{
"channel-1": {
{1, "addr1"},
{2, "addr2"},
},
"channel-2": {
{20, "addr20"},
{21, "addr21"},
},
}

out := updateShardsWithRoundRobin(in)

assert.Equal(t, int64(2), out["channel-1"][0].nodeID)
assert.Equal(t, "addr2", out["channel-1"][0].address)
assert.Equal(t, int64(21), out["channel-2"][0].nodeID)
assert.Equal(t, "addr21", out["channel-2"][0].address)

t.Run("check print", func(t *testing.T) {
qns := []queryNode{
{1, "addr1"},
{2, "addr2"},
{20, "addr20"},
{21, "addr21"},
}

res := fmt.Sprintf("list: %v", qns)

log.Debug("Check String func",
zap.Any("Any", qns),
zap.Any("ok", qns[0]),
zap.String("ok2", res),
)

})
}

func TestRoundRobinPolicy(t *testing.T) {
var (
getQueryNodePolicy = mockGetQueryNodePolicy
Expand All @@ -31,11 +73,12 @@ func TestRoundRobinPolicy(t *testing.T) {
t.Run(test.description, func(t *testing.T) {
query := (&mockQuery{isvalid: false}).query

leaders := &querypb.ShardLeadersList{
ChannelName: t.Name(),
NodeIds: test.leaderIDs,
NodeAddrs: make([]string, len(test.leaderIDs)),
leaders := make([]queryNode, 0, len(test.leaderIDs))
for _, ID := range test.leaderIDs {
leaders = append(leaders, queryNode{ID, "random-addr"})

}

err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
require.Error(t, err)
})
Expand All @@ -55,10 +98,10 @@ func TestRoundRobinPolicy(t *testing.T) {

for _, test := range allPassTests {
query := (&mockQuery{isvalid: true}).query
leaders := &querypb.ShardLeadersList{
ChannelName: t.Name(),
NodeIds: test.leaderIDs,
NodeAddrs: make([]string, len(test.leaderIDs)),
leaders := make([]queryNode, 0, len(test.leaderIDs))
for _, ID := range test.leaderIDs {
leaders = append(leaders, queryNode{ID, "random-addr"})

}
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
require.NoError(t, err)
Expand All @@ -77,10 +120,10 @@ func TestRoundRobinPolicy(t *testing.T) {

for _, test := range passAtLast {
query := (&mockQuery{isvalid: true}).query
leaders := &querypb.ShardLeadersList{
ChannelName: t.Name(),
NodeIds: test.leaderIDs,
NodeAddrs: make([]string, len(test.leaderIDs)),
leaders := make([]queryNode, 0, len(test.leaderIDs))
for _, ID := range test.leaderIDs {
leaders = append(leaders, queryNode{ID, "random-addr"})

}
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
require.NoError(t, err)
Expand Down
17 changes: 9 additions & 8 deletions internal/proxy/task_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,17 @@ func (t *queryTask) Execute(ctx context.Context) error {
t.resultBuf = make(chan *internalpb.RetrieveResults, len(shards))
t.toReduceResults = make([]*internalpb.RetrieveResults, 0, len(shards))
t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx)
for _, shard := range shards {
s := shard
for channelID, leaders := range shards {
channelID := channelID
leaders := leaders
t.runningGroup.Go(func() error {
log.Debug("proxy starting to query one shard",
zap.Int64("collectionID", t.CollectionID),
zap.String("collection name", t.collectionName),
zap.String("shard channel", s.GetChannelName()),
zap.String("shard channel", channelID),
zap.Uint64("timeoutTs", t.TimeoutTimestamp))

err := t.queryShard(t.runningGroupCtx, s)
err := t.queryShard(t.runningGroupCtx, leaders, channelID)
if err != nil {
return err
}
Expand Down Expand Up @@ -344,12 +345,12 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
return nil
}

func (t *queryTask) queryShard(ctx context.Context, leaders *querypb.ShardLeadersList) error {
func (t *queryTask) queryShard(ctx context.Context, leaders []queryNode, channelID string) error {
query := func(nodeID UniqueID, qn types.QueryNode) error {
req := &querypb.QueryRequest{
Req: t.RetrieveRequest,
IsShardLeader: true,
DmlChannel: leaders.GetChannelName(),
DmlChannel: channelID,
}

result, err := qn.Query(ctx, req)
Expand All @@ -364,14 +365,14 @@ func (t *queryTask) queryShard(ctx context.Context, leaders *querypb.ShardLeader
return fmt.Errorf("fail to Query, QueryNode ID = %d, reason=%s", nodeID, result.GetStatus().GetReason())
}

log.Debug("get query result", zap.Int64("nodeID", nodeID), zap.String("channelID", leaders.GetChannelName()))
log.Debug("get query result", zap.Int64("nodeID", nodeID), zap.String("channelID", channelID))
t.resultBuf <- result
return nil
}

err := t.queryShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, query, leaders)
if err != nil {
log.Warn("fail to Query to all shard leaders", zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders.GetNodeIds()))
log.Warn("fail to Query to all shard leaders", zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders))
return err
}

Expand Down
Loading

0 comments on commit 127dd34

Please sign in to comment.