Skip to content

Commit

Permalink
Fix CollectionNotExists when search and retrieve vector (#26524)
Browse files Browse the repository at this point in the history
Signed-off-by: xige-16 <xi.ge@zilliz.com>
  • Loading branch information
xige-16 authored Aug 22, 2023
1 parent 9131a0a commit 1e58362
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 17 deletions.
4 changes: 2 additions & 2 deletions internal/proxy/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -1653,7 +1653,7 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get
}
} else {
if loadProgress, refreshProgress, err = getPartitionProgress(ctx, node.queryCoord, request.GetBase(),
request.GetPartitionNames(), request.GetCollectionName(), collectionID); err != nil {
request.GetPartitionNames(), request.GetCollectionName(), collectionID, request.GetDbName()); err != nil {
return getErrResponse(err), nil
}
}
Expand Down Expand Up @@ -1755,7 +1755,7 @@ func (node *Proxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoadSt
}
} else {
if progress, _, err = getPartitionProgress(ctx, node.queryCoord, request.GetBase(),
request.GetPartitionNames(), request.GetCollectionName(), collectionID); err != nil {
request.GetPartitionNames(), request.GetCollectionName(), collectionID, request.GetDbName()); err != nil {
if errors.Is(err, ErrInsufficientMemory) {
return &milvuspb.GetLoadStateResponse{
Status: InSufficientMemoryStatus(request.GetCollectionName()),
Expand Down
4 changes: 2 additions & 2 deletions internal/proxy/task_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,14 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
return err
}
partitionKeys := ParsePartitionKeys(expr)
hashedPartitionNames, err := assignPartitionKeys(ctx, "", t.request.CollectionName, partitionKeys)
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.request.CollectionName, partitionKeys)
if err != nil {
return err
}

partitionNames = append(partitionNames, hashedPartitionNames...)
}
t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.CollectionName, partitionNames)
t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.request.CollectionName, partitionNames)
if err != nil {
return err
}
Expand Down
7 changes: 4 additions & 3 deletions internal/proxy/task_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ type searchTask struct {
lb LBPolicy
}

func getPartitionIDs(ctx context.Context, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
func getPartitionIDs(ctx context.Context, dbName string, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
for _, tag := range partitionNames {
if err := validatePartitionTag(tag, false); err != nil {
return nil, err
}
}

partitionsMap, err := globalMetaCache.GetPartitions(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
partitionsMap, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -351,7 +351,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
}

// translate partition name to partition ids. Use regex-pattern to match partition name.
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, collectionName, partitionNames)
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, partitionNames)
if err != nil {
log.Warn("failed to get partition ids", zap.Error(err))
return err
Expand Down Expand Up @@ -579,6 +579,7 @@ func (t *searchTask) Requery() error {
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
},
DbName: t.request.GetDbName(),
CollectionName: t.request.GetCollectionName(),
Expr: expr,
OutputFields: t.request.GetOutputFields(),
Expand Down
10 changes: 5 additions & 5 deletions internal/proxy/task_statistic.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (g *getStatisticsTask) PreExecute(ctx context.Context) error {
if err != nil { // err is not nil if collection not exists
return err
}
partIDs, err := getPartitionIDs(ctx, g.collectionName, g.partitionNames)
partIDs, err := getPartitionIDs(ctx, g.request.GetDbName(), g.collectionName, g.partitionNames)
if err != nil { // err is not nil if partition not exists
return err
}
Expand All @@ -131,7 +131,7 @@ func (g *getStatisticsTask) PreExecute(ctx context.Context) error {
}

// check if collection/partitions are loaded into query node
loaded, unloaded, err := checkFullLoaded(ctx, g.qc, g.collectionName, g.GetStatisticsRequest.CollectionID, partIDs)
loaded, unloaded, err := checkFullLoaded(ctx, g.qc, g.request.GetDbName(), g.collectionName, g.GetStatisticsRequest.CollectionID, partIDs)
log := log.Ctx(ctx).With(
zap.String("collectionName", g.collectionName),
zap.Int64("collectionID", g.CollectionID),
Expand Down Expand Up @@ -312,14 +312,14 @@ func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64

// checkFullLoaded check if collection / partition was fully loaded into QueryNode
// return loaded partitions, unloaded partitions and error
func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName string, collectionID int64, searchPartitionIDs []UniqueID) ([]UniqueID, []UniqueID, error) {
func checkFullLoaded(ctx context.Context, qc types.QueryCoord, dbName string, collectionName string, collectionID int64, searchPartitionIDs []UniqueID) ([]UniqueID, []UniqueID, error) {
var loadedPartitionIDs []UniqueID
var unloadPartitionIDs []UniqueID

// TODO: Consider to check if partition loaded from cache to save rpc.
info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName, collectionID)
info, err := globalMetaCache.GetCollectionInfo(ctx, dbName, collectionName, collectionID)
if err != nil {
return nil, nil, fmt.Errorf("GetCollectionInfo failed, collectionName = %s,collectionID = %d, err = %s", collectionName, collectionID, err)
return nil, nil, fmt.Errorf("GetCollectionInfo failed, dbName = %s, collectionName = %s,collectionID = %d, err = %s", dbName, collectionName, collectionID, err)
}

// If request to search partitions
Expand Down
3 changes: 2 additions & 1 deletion internal/proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -1219,12 +1219,13 @@ func getPartitionProgress(
partitionNames []string,
collectionName string,
collectionID int64,
dbName string,
) (loadProgress int64, refreshProgress int64, err error) {
IDs2Names := make(map[int64]string)
partitionIDs := make([]int64, 0)
for _, partitionName := range partitionNames {
var partitionID int64
partitionID, err = globalMetaCache.GetPartitionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName, partitionName)
partitionID, err = globalMetaCache.GetPartitionID(ctx, dbName, collectionName, partitionName)
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1808,6 +1808,6 @@ func Test_GetPartitionProgressFailed(t *testing.T) {
Reason: "Unexpected error",
},
}, nil)
_, _, err := getPartitionProgress(context.TODO(), qc, &commonpb.MsgBase{}, []string{}, "", 1)
_, _, err := getPartitionProgress(context.TODO(), qc, &commonpb.MsgBase{}, []string{}, "", 1, "")
assert.Error(t, err)
}
34 changes: 31 additions & 3 deletions tests/integration/getvector/get_vector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ import (
type TestGetVectorSuite struct {
integration.MiniClusterSuite

dbName string

// test params
nq int
topK int
Expand All @@ -62,6 +64,14 @@ func (s *TestGetVectorSuite) run() {
dim = 128
)

if len(s.dbName) > 0 {
createDataBaseStatus, err := s.Cluster.Proxy.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{
DbName: s.dbName,
})
s.Require().NoError(err)
s.Require().Equal(createDataBaseStatus.GetErrorCode(), commonpb.ErrorCode_Success)
}

pkFieldName := "pkField"
vecFieldName := "vecField"
pk := &schemapb.FieldSchema{
Expand Down Expand Up @@ -98,6 +108,7 @@ func (s *TestGetVectorSuite) run() {
s.Require().NoError(err)

createCollectionStatus, err := s.Cluster.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: s.dbName,
CollectionName: collection,
Schema: marshaledSchema,
ShardsNum: 2,
Expand All @@ -120,6 +131,7 @@ func (s *TestGetVectorSuite) run() {
fieldsData = append(fieldsData, vecFieldData)
hashKeys := integration.GenerateHashKeys(NB)
_, err = s.Cluster.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: s.dbName,
CollectionName: collection,
FieldsData: fieldsData,
HashKeys: hashKeys,
Expand All @@ -130,6 +142,7 @@ func (s *TestGetVectorSuite) run() {

// flush
flushResp, err := s.Cluster.Proxy.Flush(ctx, &milvuspb.FlushRequest{
DbName: s.dbName,
CollectionNames: []string{collection},
})
s.Require().NoError(err)
Expand All @@ -146,6 +159,7 @@ func (s *TestGetVectorSuite) run() {

// create index
_, err = s.Cluster.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
DbName: s.dbName,
CollectionName: collection,
FieldName: vecFieldName,
IndexName: "_default",
Expand All @@ -154,23 +168,24 @@ func (s *TestGetVectorSuite) run() {
s.Require().NoError(err)
s.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)

s.WaitForIndexBuilt(ctx, collection, vecFieldName)
s.WaitForIndexBuiltWithDB(ctx, s.dbName, collection, vecFieldName)

// load
_, err = s.Cluster.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: s.dbName,
CollectionName: collection,
})
s.Require().NoError(err)
s.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
s.WaitForLoad(ctx, collection)
s.WaitForLoadWithDB(ctx, s.dbName, collection)

// search
nq := s.nq
topk := s.topK

outputFields := []string{vecFieldName}
params := integration.GetSearchParams(s.indexType, s.metricType)
searchReq := integration.ConstructSearchRequest("", collection, "",
searchReq := integration.ConstructSearchRequest(s.dbName, collection, "",
vecFieldName, s.vecType, outputFields, s.metricType, params, nq, dim, topk, -1)

searchResp, err := s.Cluster.Proxy.Search(ctx, searchReq)
Expand Down Expand Up @@ -248,6 +263,7 @@ func (s *TestGetVectorSuite) run() {
}

status, err := s.Cluster.Proxy.DropCollection(ctx, &milvuspb.DropCollectionRequest{
DbName: s.dbName,
CollectionName: collection,
})
s.Require().NoError(err)
Expand Down Expand Up @@ -365,6 +381,18 @@ func (s *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() {
s.run()
}

func (s *TestGetVectorSuite) TestGetVector_With_DB_Name() {
s.dbName = "test_db"
s.nq = 10
s.topK = 10
s.indexType = integration.IndexHNSW
s.metricType = metric.L2
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
s.run()
}

//func (s *TestGetVectorSuite) TestGetVector_DISKANN() {
// s.nq = 10
// s.topK = 10
Expand Down
9 changes: 9 additions & 0 deletions tests/integration/util_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,18 @@ const (
IndexDISKANN = indexparamcheck.IndexDISKANN
)

func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) {
s.waitForIndexBuiltInternal(ctx, dbName, collection, field)
}

func (s *MiniClusterSuite) WaitForIndexBuilt(ctx context.Context, collection, field string) {
s.waitForIndexBuiltInternal(ctx, "", collection, field)
}

func (s *MiniClusterSuite) waitForIndexBuiltInternal(ctx context.Context, dbName, collection, field string) {
getIndexBuilt := func() bool {
resp, err := s.Cluster.Proxy.DescribeIndex(ctx, &milvuspb.DescribeIndexRequest{
DbName: dbName,
CollectionName: collection,
FieldName: field,
})
Expand Down
9 changes: 9 additions & 0 deletions tests/integration/util_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,19 @@ const (
LimitKey = "limit"
)

func (s *MiniClusterSuite) WaitForLoadWithDB(ctx context.Context, dbName, collection string) {
s.waitForLoadInternal(ctx, dbName, collection)
}

func (s *MiniClusterSuite) WaitForLoad(ctx context.Context, collection string) {
s.waitForLoadInternal(ctx, "", collection)
}

func (s *MiniClusterSuite) waitForLoadInternal(ctx context.Context, dbName, collection string) {
cluster := s.Cluster
getLoadingProgress := func() *milvuspb.GetLoadingProgressResponse {
loadProgress, err := cluster.Proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
DbName: dbName,
CollectionName: collection,
})
if err != nil {
Expand Down

0 comments on commit 1e58362

Please sign in to comment.