diff --git a/internal/metastore/kv/rootcoord/kv_catalog.go b/internal/metastore/kv/rootcoord/kv_catalog.go index 97dddaa9b8903..ff10535df60f3 100644 --- a/internal/metastore/kv/rootcoord/kv_catalog.go +++ b/internal/metastore/kv/rootcoord/kv_catalog.go @@ -175,17 +175,17 @@ func (kc *Catalog) CreateCollection(ctx context.Context, coll *model.Collection, } // Though batchSave is not atomic enough, we can promise the atomicity outside. - // Recovering from failure, if we found collection is creating, we should removing all these related meta. + // Recovering from failure, if we found collection is creating, we should remove all these related meta. return etcd.SaveByBatchWithLimit(kvs, maxTxnNum/2, func(partialKvs map[string]string) error { return kc.Snapshot.MultiSave(partialKvs, ts) }) } -func (kc *Catalog) loadCollection(ctx context.Context, dbName string, collectionID typeutil.UniqueID, ts typeutil.Timestamp) (*pb.CollectionInfo, error) { +func (kc *Catalog) loadCollectionFromDb(ctx context.Context, dbName string, collectionID typeutil.UniqueID, ts typeutil.Timestamp) (*pb.CollectionInfo, error) { collKey := BuildCollectionKey(dbName, collectionID) collVal, err := kc.Snapshot.Load(collKey, ts) if err != nil { - return nil, common.NewCollectionNotExistError(fmt.Sprintf("collection not found: %d", collectionID)) + return nil, common.NewCollectionNotExistError(fmt.Sprintf("collection not found: %d, error: %s", collectionID, err.Error())) } collMeta := &pb.CollectionInfo{} @@ -193,6 +193,21 @@ func (kc *Catalog) loadCollection(ctx context.Context, dbName string, collection return collMeta, err } +func (kc *Catalog) loadCollectionFromDefaultDb(ctx context.Context, collectionID typeutil.UniqueID, ts typeutil.Timestamp) (*pb.CollectionInfo, error) { + if info, err := kc.loadCollectionFromDb(ctx, "default", collectionID, ts); err == nil { + return info, nil + } + // get collection from older version. + return kc.loadCollectionFromDb(ctx, "", collectionID, ts) +} + +func (kc *Catalog) loadCollection(ctx context.Context, dbName string, collectionID typeutil.UniqueID, ts typeutil.Timestamp) (*pb.CollectionInfo, error) { + if dbName == "default" { + return kc.loadCollectionFromDefaultDb(ctx, collectionID, ts) + } + return kc.loadCollectionFromDb(ctx, dbName, collectionID, ts) +} + func partitionVersionAfter210(collMeta *pb.CollectionInfo) bool { return len(collMeta.GetPartitionIDs()) <= 0 && len(collMeta.GetPartitionNames()) <= 0 && @@ -237,7 +252,8 @@ func (kc *Catalog) CreatePartition(ctx context.Context, dbName string, partition collMeta.PartitionNames = append(collMeta.PartitionNames, partition.PartitionName) collMeta.PartitionCreatedTimestamps = append(collMeta.PartitionCreatedTimestamps, partition.PartitionCreatedTimestamp) - k := BuildCollectionKey(dbName, partition.CollectionID) + // this partition exists in older version, should be also changed in place. + k := BuildCollectionKey("", partition.CollectionID) v, err := proto.Marshal(collMeta) if err != nil { return err @@ -379,13 +395,21 @@ func (kc *Catalog) AlterAlias(ctx context.Context, alias *model.Alias, ts typeut } func (kc *Catalog) DropCollection(ctx context.Context, collectionInfo *model.Collection, ts typeutil.Timestamp) error { - collectionKey := BuildCollectionKey(collectionInfo.DBName, collectionInfo.CollectionID) + collectionKeys := []string{BuildCollectionKey(collectionInfo.DBName, collectionInfo.CollectionID)} + if collectionInfo.DBName == "default" { + collectionKeys = append(collectionKeys, BuildCollectionKey("", collectionInfo.CollectionID)) + } var delMetakeysSnap []string for _, alias := range collectionInfo.Aliases { delMetakeysSnap = append(delMetakeysSnap, BuildAliasKey210(alias), + BuildAliasKey(alias), + BuildAliasKeyWithDb(collectionInfo.DBName, alias), ) + if collectionInfo.DBName == "default" { + delMetakeysSnap = append(delMetakeysSnap, BuildAliasKeyWithDb("", alias)) + } } // Snapshot will list all (k, v) pairs and then use Txn.MultiSave to save tombstone for these keys when it prepares // to remove a prefix, so though we have very few prefixes, the final operations may exceed the max txn number. @@ -407,7 +431,7 @@ func (kc *Catalog) DropCollection(ctx context.Context, collectionInfo *model.Col } // if we found collection dropping, we should try removing related resources. - return kc.Snapshot.MultiSaveAndRemoveWithPrefix(nil, []string{collectionKey}, ts) + return kc.Snapshot.MultiSaveAndRemoveWithPrefix(nil, collectionKeys, ts) } func (kc *Catalog) alterModifyCollection(oldColl *model.Collection, newColl *model.Collection, ts typeutil.Timestamp) error { @@ -431,6 +455,10 @@ func (kc *Catalog) alterModifyCollection(oldColl *model.Collection, newColl *mod if err != nil { return err } + if newColl.DBName == "default" { + removal := BuildCollectionKey("", newColl.CollectionID) + return kc.Snapshot.MultiSaveAndRemoveWithPrefix(map[string]string{key: string(value)}, []string{removal}, ts) + } return kc.Snapshot.Save(key, string(value), ts) } @@ -634,21 +662,29 @@ func (kc *Catalog) listAliasesAfter210WithDb(ctx context.Context, dbName string, return aliases, nil } -func (kc *Catalog) listAliasesWithDb(ctx context.Context, dbName string, ts typeutil.Timestamp) ([]*model.Alias, error) { +func (kc *Catalog) listAliasesInDefaultDb(ctx context.Context, ts typeutil.Timestamp) ([]*model.Alias, error) { aliases1, err := kc.listAliasesBefore210(ctx, ts) if err != nil { return nil, err } - aliases2, err := kc.listAliasesAfter210WithDb(ctx, dbName, ts) + aliases2, err := kc.listAliasesAfter210WithDb(ctx, "default", ts) + if err != nil { + return nil, err + } + aliases3, err := kc.listAliasesAfter210WithDb(ctx, "", ts) if err != nil { return nil, err } aliases := append(aliases1, aliases2...) + aliases = append(aliases, aliases3...) return aliases, nil } func (kc *Catalog) ListAliases(ctx context.Context, dbName string, ts typeutil.Timestamp) ([]*model.Alias, error) { - return kc.listAliasesWithDb(ctx, dbName, ts) + if dbName != "default" { + return kc.listAliasesAfter210WithDb(ctx, dbName, ts) + } + return kc.listAliasesInDefaultDb(ctx, ts) } func (kc *Catalog) ListCredentials(ctx context.Context) ([]string, error) { diff --git a/internal/proxy/database_interceptor.go b/internal/proxy/database_interceptor.go index e488a026f66f3..d3fef6fa47ee8 100644 --- a/internal/proxy/database_interceptor.go +++ b/internal/proxy/database_interceptor.go @@ -19,170 +19,162 @@ func DatabaseInterceptor() grpc.UnaryServerInterceptor { func fillDatabase(ctx context.Context, req interface{}) (context.Context, interface{}) { switch r := req.(type) { case *milvuspb.CreateCollectionRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.DropCollectionRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.HasCollectionRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.LoadCollectionRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.ReleaseCollectionRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.DescribeCollectionRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.GetStatisticsRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.GetCollectionStatisticsRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.ShowCollectionsRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.AlterCollectionRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.CreatePartitionRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.DropPartitionRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.HasPartitionRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.LoadPartitionsRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.ReleasePartitionsRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.GetPartitionStatisticsRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.ShowPartitionsRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.GetLoadingProgressRequest: - // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.GetLoadStateRequest: - // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.CreateIndexRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.DescribeIndexRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.DropIndexRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.GetIndexBuildProgressRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.GetIndexStateRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.InsertRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.DeleteRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.SearchRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.FlushRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.QueryRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.CreateAliasRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.DropAliasRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.AlterAliasRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.CalcDistanceRequest: // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + // r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.FlushAllRequest: // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + // r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.GetPersistentSegmentInfoRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.GetQuerySegmentInfoRequest: - r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.DummyRequest: // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + // r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.GetMetricsRequest: return ctx, r case *milvuspb.LoadBalanceRequest: - // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.GetReplicasRequest: - // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.GetCompactionStateRequest: // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + // r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.ManualCompactionRequest: // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + // r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.GetCompactionPlansRequest: // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + // r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.GetFlushStateRequest: // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + // r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.GetFlushAllStateRequest: // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + // r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.ImportRequest: - // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.GetImportStateRequest: // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + // r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.ListImportTasksRequest: - // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.RenameCollectionRequest: - // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r case *milvuspb.TransferReplicaRequest: - // TODO - // r.DbName = GetCurDatabaseFromContextOrEmpty(ctx) + r.DbName = GetCurDatabaseFromContextOrDefault(ctx) return ctx, r default: return ctx, req diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 1bad6286197f8..9b8fe2695d1dc 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -1981,7 +1981,7 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get if err := validateCollectionName(request.CollectionName); err != nil { return getErrResponse(err), nil } - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), request.CollectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, request.GetDbName(), request.CollectionName) if err != nil { return getErrResponse(err), nil } @@ -2080,7 +2080,7 @@ func (node *Proxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoadSt metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds())) }() - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), request.CollectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, request.GetDbName(), request.CollectionName) if err != nil { successResponse.State = commonpb.LoadState_LoadStateNotExist return successResponse, nil @@ -3771,7 +3771,7 @@ func (node *Proxy) GetPersistentSegmentInfo(ctx context.Context, req *milvuspb.G metrics.TotalLabel).Inc() // list segments - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), req.GetCollectionName()) + collectionID, err := globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName()) if err != nil { metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), method, metrics.FailLabel).Inc() resp.Status.Reason = fmt.Errorf("getCollectionID failed, err:%w", err).Error() @@ -3854,7 +3854,7 @@ func (node *Proxy) GetQuerySegmentInfo(ctx context.Context, req *milvuspb.GetQue metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), method, metrics.TotalLabel).Inc() - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), req.CollectionName) + collID, err := globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.CollectionName) if err != nil { metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), method, metrics.FailLabel).Inc() resp.Status.Reason = err.Error() @@ -4140,7 +4140,7 @@ func (node *Proxy) LoadBalance(ctx context.Context, req *milvuspb.LoadBalanceReq ErrorCode: commonpb.ErrorCode_UnexpectedError, } - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), req.GetCollectionName()) + collectionID, err := globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName()) if err != nil { log.Warn("failed to get collection id", zap.String("collection name", req.GetCollectionName()), @@ -4194,7 +4194,7 @@ func (node *Proxy) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasReq ) if req.GetCollectionName() != "" { - req.CollectionID, _ = globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), req.GetCollectionName()) + req.CollectionID, _ = globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName()) } r, err := node.queryCoord.GetReplicas(ctx, req) diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index 80ea0d554d585..d9d7e72fe6ecc 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -243,13 +243,13 @@ func TestMetaCache_GetCollection(t *testing.T) { err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.Nil(t, err) - id, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + id, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.Nil(t, err) assert.Equal(t, id, typeutil.UniqueID(1)) assert.Equal(t, rootCoord.GetAccessCount(), 1) // should'nt be accessed to remote root coord. - schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.Nil(t, err) assert.Equal(t, schema, &schemapb.CollectionSchema{ @@ -257,11 +257,11 @@ func TestMetaCache_GetCollection(t *testing.T) { Fields: []*schemapb.FieldSchema{}, Name: "collection1", }) - id, err = globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection2") + id, err = globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection2") assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.Nil(t, err) assert.Equal(t, id, typeutil.UniqueID(2)) - schema, err = globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection2") + schema, err = globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection2") assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.Nil(t, err) assert.Equal(t, schema, &schemapb.CollectionSchema{ @@ -271,11 +271,11 @@ func TestMetaCache_GetCollection(t *testing.T) { }) // test to get from cache, this should trigger root request - id, err = globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + id, err = globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.Nil(t, err) assert.Equal(t, id, typeutil.UniqueID(1)) - schema, err = globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + schema, err = globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.Nil(t, err) assert.Equal(t, schema, &schemapb.CollectionSchema{ @@ -300,7 +300,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) { assert.Equal(t, rootCoord.GetAccessCount(), 1) // should'nt be accessed to remote root coord. - schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.Nil(t, err) assert.Equal(t, schema, &schemapb.CollectionSchema{ @@ -312,7 +312,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) { assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.Nil(t, err) assert.Equal(t, collection, "collection1") - schema, err = globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection2") + schema, err = globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection2") assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.Nil(t, err) assert.Equal(t, schema, &schemapb.CollectionSchema{ @@ -326,7 +326,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) { assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.Nil(t, err) assert.Equal(t, collection, "collection1") - schema, err = globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + schema, err = globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.Nil(t, err) assert.Equal(t, schema, &schemapb.CollectionSchema{ @@ -345,13 +345,13 @@ func TestMetaCache_GetCollectionFailure(t *testing.T) { assert.Nil(t, err) rootCoord.Error = true - schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.NotNil(t, err) assert.Nil(t, schema) rootCoord.Error = false - schema, err = globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + schema, err = globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.Nil(t, err) assert.Equal(t, schema, &schemapb.CollectionSchema{ AutoID: true, @@ -377,10 +377,10 @@ func TestMetaCache_GetNonExistCollection(t *testing.T) { err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.Nil(t, err) - id, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection3") + id, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection3") assert.NotNil(t, err) assert.Equal(t, id, int64(0)) - schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection3") + schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection3") assert.NotNil(t, err) assert.Nil(t, schema) } @@ -393,16 +393,16 @@ func TestMetaCache_GetPartitionID(t *testing.T) { err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.Nil(t, err) - id, err := globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1", "par1") + id, err := globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1", "par1") assert.Nil(t, err) assert.Equal(t, id, typeutil.UniqueID(1)) - id, err = globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1", "par2") + id, err = globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1", "par2") assert.Nil(t, err) assert.Equal(t, id, typeutil.UniqueID(2)) - id, err = globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection2", "par1") + id, err = globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection2", "par1") assert.Nil(t, err) assert.Equal(t, id, typeutil.UniqueID(3)) - id, err = globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection2", "par2") + id, err = globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection2", "par2") assert.Nil(t, err) assert.Equal(t, id, typeutil.UniqueID(4)) } @@ -421,7 +421,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) { defer wg.Done() for i := 0; i < cnt; i++ { //GetCollectionSchema will never fail - schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.Nil(t, err) assert.Equal(t, schema, &schemapb.CollectionSchema{ AutoID: true, @@ -436,7 +436,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) { defer wg.Done() for i := 0; i < cnt; i++ { //GetPartitions may fail - globalMetaCache.GetPartitions(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + globalMetaCache.GetPartitions(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") time.Sleep(10 * time.Millisecond) } } @@ -445,7 +445,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) { defer wg.Done() for i := 0; i < cnt; i++ { //periodically invalid collection cache - globalMetaCache.RemoveCollection(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + globalMetaCache.RemoveCollection(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") time.Sleep(10 * time.Millisecond) } } @@ -470,24 +470,24 @@ func TestMetaCache_GetPartitionError(t *testing.T) { assert.Nil(t, err) // Test the case where ShowPartitionsResponse is not aligned - id, err := globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "errorCollection", "par1") + id, err := globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), "errorCollection", "par1") assert.NotNil(t, err) log.Debug(err.Error()) assert.Equal(t, id, typeutil.UniqueID(0)) - partitions, err2 := globalMetaCache.GetPartitions(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "errorCollection") + partitions, err2 := globalMetaCache.GetPartitions(ctx, GetCurDatabaseFromContextOrDefault(ctx), "errorCollection") assert.NotNil(t, err2) log.Debug(err.Error()) assert.Equal(t, len(partitions), 0) // Test non existed tables - id, err = globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "nonExisted", "par1") + id, err = globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), "nonExisted", "par1") assert.NotNil(t, err) log.Debug(err.Error()) assert.Equal(t, id, typeutil.UniqueID(0)) // Test non existed partition - id, err = globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1", "par3") + id, err = globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1", "par3") assert.NotNil(t, err) log.Debug(err.Error()) assert.Equal(t, id, typeutil.UniqueID(0)) @@ -510,21 +510,21 @@ func TestMetaCache_GetShards(t *testing.T) { defer qc.Stop() t.Run("No collection in meta cache", func(t *testing.T) { - shards, err := globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrEmpty(ctx), "non-exists") + shards, err := globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrDefault(ctx), "non-exists") assert.Error(t, err) assert.Empty(t, shards) }) t.Run("without shardLeaders in collection info invalid shardLeaders", func(t *testing.T) { qc.validShardLeaders = false - shards, err := globalMetaCache.GetShards(ctx, false, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + shards, err := globalMetaCache.GetShards(ctx, false, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.Error(t, err) assert.Empty(t, shards) }) t.Run("without shardLeaders in collection info", func(t *testing.T) { qc.validShardLeaders = true - shards, err := globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + shards, err := globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) assert.NotEmpty(t, shards) assert.Equal(t, 1, len(shards)) @@ -532,7 +532,7 @@ func TestMetaCache_GetShards(t *testing.T) { // get from cache qc.validShardLeaders = false - shards, err = globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + shards, err = globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) assert.NotEmpty(t, shards) @@ -558,26 +558,26 @@ func TestMetaCache_ClearShards(t *testing.T) { defer qc.Stop() t.Run("Clear with no collection info", func(t *testing.T) { - globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrEmpty(ctx), "collection_not_exist") + globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrDefault(ctx), "collection_not_exist") }) t.Run("Clear valid collection empty cache", func(t *testing.T) { - globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrDefault(ctx), collectionName) }) t.Run("Clear valid collection valid cache", func(t *testing.T) { qc.validShardLeaders = true - shards, err := globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + shards, err := globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrDefault(ctx), collectionName) require.NoError(t, err) require.NotEmpty(t, shards) require.Equal(t, 1, len(shards)) require.Equal(t, 3, len(shards["channel-1"])) - globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrDefault(ctx), collectionName) qc.validShardLeaders = false - shards, err = globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + shards, err = globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.Error(t, err) assert.Empty(t, shards) }) @@ -674,7 +674,7 @@ func TestMetaCache_LoadCache(t *testing.T) { assert.Nil(t, err) t.Run("test IsCollectionLoaded", func(t *testing.T) { - info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.NoError(t, err) assert.True(t, info.isLoaded) // no collectionInfo of collection1, should access RootCoord @@ -682,7 +682,7 @@ func TestMetaCache_LoadCache(t *testing.T) { // not loaded, should access QueryCoord assert.Equal(t, queryCoord.GetAccessCount(), 1) - info, err = globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + info, err = globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.NoError(t, err) assert.True(t, info.isLoaded) // shouldn't access QueryCoord or RootCoord again @@ -690,7 +690,7 @@ func TestMetaCache_LoadCache(t *testing.T) { assert.Equal(t, queryCoord.GetAccessCount(), 1) // test collection2 not fully loaded - info, err = globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection2") + info, err = globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection2") assert.NoError(t, err) assert.False(t, info.isLoaded) // no collectionInfo of collection2, should access RootCoord @@ -700,8 +700,8 @@ func TestMetaCache_LoadCache(t *testing.T) { }) t.Run("test RemoveCollectionLoadCache", func(t *testing.T) { - globalMetaCache.RemoveCollection(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") - info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + globalMetaCache.RemoveCollection(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") + info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.NoError(t, err) assert.True(t, info.isLoaded) // should access QueryCoord @@ -717,21 +717,21 @@ func TestMetaCache_RemoveCollection(t *testing.T) { err := InitMetaCache(ctx, rootCoord, queryCoord, shardMgr) assert.Nil(t, err) - info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.NoError(t, err) assert.True(t, info.isLoaded) // no collectionInfo of collection1, should access RootCoord assert.Equal(t, rootCoord.GetAccessCount(), 1) - info, err = globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + info, err = globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.NoError(t, err) assert.True(t, info.isLoaded) // shouldn't access RootCoord again assert.Equal(t, rootCoord.GetAccessCount(), 1) - globalMetaCache.RemoveCollection(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + globalMetaCache.RemoveCollection(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") // no collectionInfo of collection2, should access RootCoord - info, err = globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + info, err = globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.NoError(t, err) assert.True(t, info.isLoaded) // shouldn't access RootCoord again @@ -739,7 +739,7 @@ func TestMetaCache_RemoveCollection(t *testing.T) { globalMetaCache.RemoveCollectionsByID(ctx, UniqueID(1)) // no collectionInfo of collection2, should access RootCoord - info, err = globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + info, err = globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.NoError(t, err) assert.True(t, info.isLoaded) // shouldn't access RootCoord again @@ -777,7 +777,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { }, }, }, nil).Times(1) - nodeInfos, err := globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + nodeInfos, err := globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.NoError(t, err) assert.Len(t, nodeInfos["channel-1"], 3) @@ -795,7 +795,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { }, nil).Times(1) assert.Eventually(t, func() bool { - nodeInfos, err := globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + nodeInfos, err := globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.NoError(t, err) return len(nodeInfos["channel-1"]) == 2 }, 3*time.Second, 1*time.Second) @@ -814,7 +814,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { }, nil).Times(1) assert.Eventually(t, func() bool { - nodeInfos, err := globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + nodeInfos, err := globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.NoError(t, err) return len(nodeInfos["channel-1"]) == 3 }, 3*time.Second, 1*time.Second) @@ -838,7 +838,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { }, nil).Times(1) assert.Eventually(t, func() bool { - nodeInfos, err := globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + nodeInfos, err := globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrDefault(ctx), "collection1") assert.NoError(t, err) return len(nodeInfos["channel-1"]) == 3 && len(nodeInfos["channel-2"]) == 3 }, 3*time.Second, 1*time.Second) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index ad9383acf5422..d7a0c23585078 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -875,7 +875,7 @@ func TestProxy(t *testing.T) { wg.Add(1) t.Run("describe collection", func(t *testing.T) { defer wg.Done() - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) resp, err := proxy.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{ @@ -1052,7 +1052,7 @@ func TestProxy(t *testing.T) { wg.Add(1) t.Run("show partitions", func(t *testing.T) { defer wg.Done() - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) resp, err := proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{ @@ -1359,7 +1359,7 @@ func TestProxy(t *testing.T) { t.Run("get replicas", func(t *testing.T) { defer wg.Done() - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) resp, err := proxy.GetReplicas(ctx, &milvuspb.GetReplicasRequest{ @@ -1608,7 +1608,7 @@ func TestProxy(t *testing.T) { wg.Add(1) t.Run("release collection", func(t *testing.T) { defer wg.Done() - _, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + _, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) resp, err := proxy.ReleaseCollection(ctx, &milvuspb.ReleaseCollectionRequest{ @@ -1649,7 +1649,7 @@ func TestProxy(t *testing.T) { wg.Add(1) t.Run("load partitions", func(t *testing.T) { defer wg.Done() - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) resp, err := proxy.LoadPartitions(ctx, &milvuspb.LoadPartitionsRequest{ @@ -1722,7 +1722,7 @@ func TestProxy(t *testing.T) { wg.Add(1) t.Run("show in-memory partitions", func(t *testing.T) { defer wg.Done() - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) resp, err := proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{ @@ -1879,7 +1879,7 @@ func TestProxy(t *testing.T) { wg.Add(1) t.Run("show in-memory partitions after release partition", func(t *testing.T) { defer wg.Done() - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) resp, err := proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{ @@ -1957,7 +1957,7 @@ func TestProxy(t *testing.T) { wg.Add(1) t.Run("show partitions after drop partition", func(t *testing.T) { defer wg.Done() - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) resp, err := proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{ @@ -2004,7 +2004,7 @@ func TestProxy(t *testing.T) { wg.Add(1) t.Run("drop collection", func(t *testing.T) { defer wg.Done() - _, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + _, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) resp, err := proxy.DropCollection(ctx, &milvuspb.DropCollectionRequest{ diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 2452aa26b1d6d..c673c77f44da6 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -596,7 +596,7 @@ func (sct *showCollectionsTask) Execute(ctx context.Context) error { } collectionIDs := make([]UniqueID, 0) for _, collectionName := range sct.CollectionNames { - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, sct.GetDbName(), collectionName) if err != nil { log.Warn("Failed to get collection id.", zap.Any("collectionName", collectionName), zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections")) @@ -649,7 +649,7 @@ func (sct *showCollectionsTask) Execute(ctx context.Context) error { zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections")) return errors.New("failed to show collections") } - collectionInfo, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collectionInfo, err := globalMetaCache.GetCollectionInfo(ctx, sct.GetDbName(), collectionName) if err != nil { log.Warn("Failed to get collection info.", zap.Any("collectionName", collectionName), zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections")) @@ -872,11 +872,11 @@ func (dpt *dropPartitionTask) PreExecute(ctx context.Context) error { return err } - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), dpt.GetCollectionName()) + collID, err := globalMetaCache.GetCollectionID(ctx, dpt.GetDbName(), dpt.GetCollectionName()) if err != nil { return err } - partID, err := globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), dpt.GetCollectionName(), dpt.GetPartitionName()) + partID, err := globalMetaCache.GetPartitionID(ctx, dpt.GetDbName(), dpt.GetCollectionName(), dpt.GetPartitionName()) if err != nil { if err.Error() == ErrPartitionNotExist(dpt.GetPartitionName()).Error() { return nil @@ -1073,7 +1073,7 @@ func (spt *showPartitionsTask) Execute(ctx context.Context) error { if spt.GetType() == milvuspb.ShowType_InMemory { collectionName := spt.CollectionName - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, spt.GetDbName(), collectionName) if err != nil { log.Warn("Failed to get collection id.", zap.Any("collectionName", collectionName), zap.Any("requestID", spt.Base.MsgID), zap.Any("requestType", "showPartitions")) @@ -1086,7 +1086,7 @@ func (spt *showPartitionsTask) Execute(ctx context.Context) error { } partitionIDs := make([]UniqueID, 0) for _, partitionName := range spt.PartitionNames { - partitionID, err := globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName, partitionName) + partitionID, err := globalMetaCache.GetPartitionID(ctx, spt.GetDbName(), collectionName, partitionName) if err != nil { log.Warn("Failed to get partition id.", zap.Any("partitionName", partitionName), zap.Any("requestID", spt.Base.MsgID), zap.Any("requestType", "showPartitions")) @@ -1132,7 +1132,7 @@ func (spt *showPartitionsTask) Execute(ctx context.Context) error { zap.Any("requestID", spt.Base.MsgID), zap.Any("requestType", "showPartitions")) return errors.New("failed to show partitions") } - partitionInfo, err := globalMetaCache.GetPartitionInfo(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName, partitionName) + partitionInfo, err := globalMetaCache.GetPartitionInfo(ctx, spt.GetDbName(), collectionName, partitionName) if err != nil { log.Warn("Failed to get partition id.", zap.Any("partitionName", partitionName), zap.Any("requestID", spt.Base.MsgID), zap.Any("requestType", "showPartitions")) @@ -1211,7 +1211,7 @@ func (ft *flushTask) Execute(ctx context.Context) error { flushColl2Segments := make(map[string]*schemapb.LongArray) coll2SealTimes := make(map[string]int64) for _, collName := range ft.CollectionNames { - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collName) + collID, err := globalMetaCache.GetCollectionID(ctx, ft.GetDbName(), collName) if err != nil { return err } @@ -1321,13 +1321,13 @@ func (lct *loadCollectionTask) PreExecute(ctx context.Context) error { func (lct *loadCollectionTask) Execute(ctx context.Context) (err error) { log.Info("loadCollectionTask Execute", zap.String("role", typeutil.ProxyRole), zap.Int64("msgID", lct.Base.MsgID)) - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), lct.CollectionName) + collID, err := globalMetaCache.GetCollectionID(ctx, lct.GetDbName(), lct.CollectionName) if err != nil { return err } lct.collectionID = collID - collSchema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), lct.CollectionName) + collSchema, err := globalMetaCache.GetCollectionSchema(ctx, lct.GetDbName(), lct.CollectionName) if err != nil { return err } @@ -1449,7 +1449,7 @@ func (rct *releaseCollectionTask) PreExecute(ctx context.Context) error { } func (rct *releaseCollectionTask) Execute(ctx context.Context) (err error) { - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), rct.CollectionName) + collID, err := globalMetaCache.GetCollectionID(ctx, rct.GetDbName(), rct.CollectionName) if err != nil { return err } @@ -1465,13 +1465,13 @@ func (rct *releaseCollectionTask) Execute(ctx context.Context) (err error) { rct.result, err = rct.queryCoord.ReleaseCollection(ctx, request) - globalMetaCache.RemoveCollection(ctx, GetCurDatabaseFromContextOrEmpty(ctx), rct.CollectionName) + globalMetaCache.RemoveCollection(ctx, rct.GetDbName(), rct.CollectionName) return err } func (rct *releaseCollectionTask) PostExecute(ctx context.Context) error { - globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrEmpty(ctx), rct.CollectionName) + globalMetaCache.DeprecateShardCache(rct.GetDbName(), rct.CollectionName) return nil } @@ -1538,12 +1538,12 @@ func (lpt *loadPartitionsTask) PreExecute(ctx context.Context) error { func (lpt *loadPartitionsTask) Execute(ctx context.Context) error { var partitionIDs []int64 - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), lpt.CollectionName) + collID, err := globalMetaCache.GetCollectionID(ctx, lpt.GetDbName(), lpt.CollectionName) if err != nil { return err } lpt.collectionID = collID - collSchema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), lpt.CollectionName) + collSchema, err := globalMetaCache.GetCollectionSchema(ctx, lpt.GetDbName(), lpt.CollectionName) if err != nil { return err } @@ -1575,7 +1575,7 @@ func (lpt *loadPartitionsTask) Execute(ctx context.Context) error { return errors.New(errMsg) } for _, partitionName := range lpt.PartitionNames { - partitionID, err := globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), lpt.CollectionName, partitionName) + partitionID, err := globalMetaCache.GetPartitionID(ctx, lpt.GetDbName(), lpt.CollectionName, partitionName) if err != nil { return err } @@ -1665,13 +1665,13 @@ func (rpt *releasePartitionsTask) PreExecute(ctx context.Context) error { func (rpt *releasePartitionsTask) Execute(ctx context.Context) (err error) { var partitionIDs []int64 - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), rpt.CollectionName) + collID, err := globalMetaCache.GetCollectionID(ctx, rpt.GetDbName(), rpt.CollectionName) if err != nil { return err } rpt.collectionID = collID for _, partitionName := range rpt.PartitionNames { - partitionID, err := globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), rpt.CollectionName, partitionName) + partitionID, err := globalMetaCache.GetPartitionID(ctx, rpt.GetDbName(), rpt.CollectionName, partitionName) if err != nil { return err } @@ -1691,7 +1691,7 @@ func (rpt *releasePartitionsTask) Execute(ctx context.Context) (err error) { } func (rpt *releasePartitionsTask) PostExecute(ctx context.Context) error { - globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrEmpty(ctx), rpt.CollectionName) + globalMetaCache.DeprecateShardCache(rpt.GetDbName(), rpt.CollectionName) return nil } @@ -2265,7 +2265,7 @@ func (t *TransferReplicaTask) PreExecute(ctx context.Context) error { func (t *TransferReplicaTask) Execute(ctx context.Context) error { var err error - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), t.CollectionName) + collID, err := globalMetaCache.GetCollectionID(ctx, t.GetDbName(), t.CollectionName) if err != nil { return err } diff --git a/internal/proxy/task_delete.go b/internal/proxy/task_delete.go index 894d0afff79a8..3d6f976c4c186 100644 --- a/internal/proxy/task_delete.go +++ b/internal/proxy/task_delete.go @@ -81,7 +81,7 @@ func (dt *deleteTask) getChannels() ([]pChan, error) { if len(dt.pChannels) != 0 { return dt.pChannels, nil } - collID, err := globalMetaCache.GetCollectionID(dt.ctx, GetCurDatabaseFromContextOrEmpty(dt.ctx), dt.CollectionName) + collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.GetDbName(), dt.CollectionName) if err != nil { return nil, err } @@ -163,7 +163,7 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error { log.Info("Invalid collection name", zap.String("collectionName", collName), zap.Error(err)) return err } - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collName) + collID, err := globalMetaCache.GetCollectionID(ctx, dt.GetDbName(), collName) if err != nil { log.Info("Failed to get collection id", zap.String("collectionName", collName), zap.Error(err)) return err @@ -178,7 +178,7 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error { log.Info("Invalid partition name", zap.String("partitionName", partName), zap.Error(err)) return err } - partID, err := globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collName, partName) + partID, err := globalMetaCache.GetPartitionID(ctx, dt.GetDbName(), collName, partName) if err != nil { log.Info("Failed to get partition id", zap.String("collectionName", collName), zap.String("partitionName", partName), zap.Error(err)) return err @@ -188,7 +188,7 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error { dt.DeleteRequest.PartitionID = common.InvalidPartitionID } - schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collName) + schema, err := globalMetaCache.GetCollectionSchema(ctx, dt.GetDbName(), collName) if err != nil { log.Info("Failed to get collection schema", zap.String("collectionName", collName), zap.Error(err)) return err diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index 3196f8537912e..d170b11feb36f 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -190,7 +190,7 @@ func (cit *createIndexTask) parseIndexParams() error { } func (cit *createIndexTask) getIndexedField(ctx context.Context) (*schemapb.FieldSchema, error) { - schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), cit.req.GetCollectionName()) + schema, err := globalMetaCache.GetCollectionSchema(ctx, cit.req.GetDbName(), cit.req.GetCollectionName()) if err != nil { log.Error("failed to get collection schema", zap.Error(err)) return nil, fmt.Errorf("failed to get collection schema: %s", err) @@ -276,7 +276,7 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error { collName := cit.req.GetCollectionName() - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collName) + collID, err := globalMetaCache.GetCollectionID(ctx, cit.req.GetDbName(), collName) if err != nil { return err } @@ -402,7 +402,7 @@ func (dit *describeIndexTask) PreExecute(ctx context.Context) error { return err } - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), dit.CollectionName) + collID, err := globalMetaCache.GetCollectionID(ctx, dit.GetDbName(), dit.CollectionName) if err != nil { return err } @@ -411,7 +411,7 @@ func (dit *describeIndexTask) PreExecute(ctx context.Context) error { } func (dit *describeIndexTask) Execute(ctx context.Context) error { - schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), dit.GetCollectionName()) + schema, err := globalMetaCache.GetCollectionSchema(ctx, dit.GetDbName(), dit.GetCollectionName()) if err != nil { log.Error("failed to get collection schema", zap.Error(err)) return fmt.Errorf("failed to get collection schema: %s", err) @@ -524,7 +524,7 @@ func (dit *dropIndexTask) PreExecute(ctx context.Context) error { } } - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), dit.CollectionName) + collID, err := globalMetaCache.GetCollectionID(ctx, dit.GetDbName(), dit.CollectionName) if err != nil { return err } @@ -626,7 +626,7 @@ func (gibpt *getIndexBuildProgressTask) PreExecute(ctx context.Context) error { func (gibpt *getIndexBuildProgressTask) Execute(ctx context.Context) error { collectionName := gibpt.CollectionName - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, gibpt.GetDbName(), collectionName) if err != nil { // err is not nil if collection not exists return err } @@ -714,7 +714,7 @@ func (gist *getIndexStateTask) PreExecute(ctx context.Context) error { } func (gist *getIndexStateTask) Execute(ctx context.Context) error { - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), gist.CollectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, gist.GetDbName(), gist.CollectionName) if err != nil { return err } diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index 96e49932ac5b9..910d52eddad21 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -75,7 +75,7 @@ func (it *insertTask) getChannels() ([]pChan, error) { if len(it.pChannels) != 0 { return it.pChannels, nil } - collID, err := globalMetaCache.GetCollectionID(it.ctx, GetCurDatabaseFromContextOrEmpty(it.ctx), it.CollectionName) + collID, err := globalMetaCache.GetCollectionID(it.ctx, it.GetDbName(), it.CollectionName) if err != nil { return nil, err } @@ -182,7 +182,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error { return err } - collSchema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collSchema, err := globalMetaCache.GetCollectionSchema(ctx, it.GetDbName(), collectionName) if err != nil { log.Error("get collection schema from global meta cache failed", zap.String("collection name", collectionName), zap.Error(err)) return err @@ -400,19 +400,19 @@ func (it *insertTask) Execute(ctx context.Context) error { tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute insert %d", it.ID())) collectionName := it.CollectionName - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collID, err := globalMetaCache.GetCollectionID(ctx, it.GetDbName(), collectionName) if err != nil { return err } it.CollectionID = collID var partitionID UniqueID if len(it.PartitionName) > 0 { - partitionID, err = globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName, it.PartitionName) + partitionID, err = globalMetaCache.GetPartitionID(ctx, it.GetDbName(), collectionName, it.PartitionName) if err != nil { return err } } else { - partitionID, err = globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName, Params.CommonCfg.DefaultPartitionName) + partitionID, err = globalMetaCache.GetPartitionID(ctx, it.GetDbName(), collectionName, Params.CommonCfg.DefaultPartitionName) if err != nil { return err } diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 8bed3700a9af1..52021df87906f 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -185,7 +185,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error { log.Ctx(ctx).Debug("Validate collection name.", zap.Any("collectionName", collectionName), zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collID, err := globalMetaCache.GetCollectionID(ctx, t.request.GetDbName(), collectionName) if err != nil { log.Ctx(ctx).Warn("Failed to get collection id.", zap.Any("collectionName", collectionName), zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) @@ -246,7 +246,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error { return fmt.Errorf("collection:%v or partition:%v not loaded into memory when query", collectionName, t.request.GetPartitionNames()) } - schema, _ := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + schema, _ := globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName) t.schema = schema if t.ids != nil { @@ -322,7 +322,7 @@ func (t *queryTask) Execute(ctx context.Context) error { log := log.Ctx(ctx) executeQuery := func() error { - shards, err := globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrEmpty(ctx), t.collectionName) + shards, err := globalMetaCache.GetShards(ctx, true, t.request.GetDbName(), t.collectionName) if err != nil { return err } @@ -344,7 +344,7 @@ func (t *queryTask) Execute(ctx context.Context) error { if queryError != nil { log.Warn("invalid shard leaders cache, updating shardleader caches and retry query", zap.Error(queryError)) - globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrEmpty(ctx), t.collectionName) + globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName) } return queryError }, retry.Attempts(Params.CommonCfg.GrpcRetryTimes)) @@ -400,7 +400,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error { return nil } - schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), t.request.CollectionName) + schema, err := globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), t.request.CollectionName) if err != nil { return err } @@ -442,13 +442,13 @@ func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.Query if err != nil { log.Ctx(ctx).Warn("QueryNode query return error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs), zap.Error(err)) - globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrEmpty(ctx), t.collectionName) + globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName) return common.NewCodeError(commonpb.ErrorCode_NotReadyServe, err) } errCode := result.GetStatus().GetErrorCode() if errCode == commonpb.ErrorCode_NotShardLeader { log.Ctx(ctx).Warn("QueryNode is not shardLeader", zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs)) - globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrEmpty(ctx), t.collectionName) + globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName) return common.NewCodeError(errCode, errInvalidShardLeaders) } if errCode != commonpb.ErrorCode_Success { diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 423402f0b99ac..9caed76c97b1c 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -94,7 +94,7 @@ func TestQueryTask_all(t *testing.T) { require.NoError(t, createColT.Execute(ctx)) require.NoError(t, createColT.PostExecute(ctx)) - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 88c0ebdd9d2cc..11e2c4e3e00f0 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -64,7 +64,7 @@ func getPartitionIDs(ctx context.Context, collectionName string, partitionNames } } - partitionsMap, err := globalMetaCache.GetPartitions(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + partitionsMap, err := globalMetaCache.GetPartitions(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) if err != nil { return nil, err } @@ -265,14 +265,14 @@ func (t *searchTask) PreExecute(ctx context.Context) error { collectionName := t.request.CollectionName t.collectionName = collectionName - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collID, err := globalMetaCache.GetCollectionID(ctx, t.request.GetDbName(), collectionName) if err != nil { // err is not nil if collection not exists return err } t.SearchRequest.DbID = 0 // todo t.SearchRequest.CollectionID = collID - t.schema, _ = globalMetaCache.GetCollectionSchema(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + t.schema, _ = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName) // translate partition name to partition ids. Use regex-pattern to match partition name. t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, collectionName, t.request.GetPartitionNames()) @@ -402,7 +402,7 @@ func (t *searchTask) Execute(ctx context.Context) error { log := log.Ctx(ctx) executeSearch := func() error { - shard2Leaders, err := globalMetaCache.GetShards(ctx, true, GetCurDatabaseFromContextOrEmpty(ctx), t.collectionName) + shard2Leaders, err := globalMetaCache.GetShards(ctx, true, t.request.GetDbName(), t.collectionName) if err != nil { return err } @@ -423,7 +423,7 @@ func (t *searchTask) Execute(ctx context.Context) error { } if searchErr != nil { log.Warn("first search failed, updating shardleader caches and retry search", zap.Int64("msgId", t.ID()), zap.Error(searchErr)) - globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrEmpty(ctx), t.collectionName) + globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName) } return searchErr }, retry.Attempts(Params.CommonCfg.GrpcRetryTimes)) @@ -514,14 +514,14 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que if err != nil { log.Ctx(ctx).Warn("QueryNode search return error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs), zap.Error(err)) - globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrEmpty(ctx), t.collectionName) + globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName) return common.NewCodeError(commonpb.ErrorCode_NotReadyServe, err) } errCode := result.GetStatus().GetErrorCode() if errCode == commonpb.ErrorCode_NotShardLeader { log.Ctx(ctx).Warn("QueryNode is not shardLeader", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs)) - globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrEmpty(ctx), t.collectionName) + globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName) return common.NewCodeError(errCode, errInvalidShardLeaders) } if errCode != commonpb.ErrorCode_Success { @@ -580,7 +580,7 @@ func (t *searchTask) collectSearchResults(ctx context.Context) error { // checkIfLoaded check if collection was loaded into QueryNode func checkIfLoaded(ctx context.Context, qc types.QueryCoord, collectionName string, searchPartitionIDs []UniqueID) (bool, error) { - info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) if err != nil { return false, fmt.Errorf("GetCollectionInfo failed, collection = %s, err = %s", collectionName, err) } diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 01823be995ef3..08ebe4fb0f453 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -205,7 +205,7 @@ func TestSearchTask_PreExecute(t *testing.T) { t.Run("invalid IgnoreGrowing param", func(t *testing.T) { collName := "test_invalid_param" + funcutil.GenRandomStr() createColl(t, collName, rc) - collID, err := globalMetaCache.GetCollectionID(context.TODO(), GetCurDatabaseFromContextOrEmpty(ctx), collName) + collID, err := globalMetaCache.GetCollectionID(context.TODO(), GetCurDatabaseFromContextOrDefault(ctx), collName) require.NoError(t, err) status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ Base: &commonpb.MsgBase{ @@ -266,7 +266,7 @@ func TestSearchTask_PreExecute(t *testing.T) { t.Run("test checkIfLoaded error", func(t *testing.T) { collName := "test_checkIfLoaded_error" + funcutil.GenRandomStr() createColl(t, collName, rc) - _, err := globalMetaCache.GetCollectionID(context.TODO(), GetCurDatabaseFromContextOrEmpty(ctx), collName) + _, err := globalMetaCache.GetCollectionID(context.TODO(), GetCurDatabaseFromContextOrDefault(ctx), collName) require.NoError(t, err) task := getSearchTask(t, collName) task.collectionName = collName @@ -292,7 +292,7 @@ func TestSearchTask_PreExecute(t *testing.T) { t.Run("search with timeout", func(t *testing.T) { collName := "search_with_timeout" + funcutil.GenRandomStr() createColl(t, collName, rc) - collID, err := globalMetaCache.GetCollectionID(context.TODO(), GetCurDatabaseFromContextOrEmpty(ctx), collName) + collID, err := globalMetaCache.GetCollectionID(context.TODO(), GetCurDatabaseFromContextOrDefault(ctx), collName) require.NoError(t, err) status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ Base: &commonpb.MsgBase{ @@ -1827,7 +1827,7 @@ func TestSearchTask_ErrExecute(t *testing.T) { require.NoError(t, createColT.Execute(ctx)) require.NoError(t, createColT.PostExecute(ctx)) - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ diff --git a/internal/proxy/task_statistic.go b/internal/proxy/task_statistic.go index 8a31cd34b9ac4..6ae294f14b107 100644 --- a/internal/proxy/task_statistic.go +++ b/internal/proxy/task_statistic.go @@ -112,7 +112,7 @@ func (g *getStatisticsTask) PreExecute(ctx context.Context) error { g.Base.MsgType = commonpb.MsgType_GetPartitionStatistics g.Base.SourceID = Params.ProxyCfg.GetNodeID() - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), g.collectionName) + collID, err := globalMetaCache.GetCollectionID(ctx, g.request.GetDbName(), g.collectionName) if err != nil { // err is not nil if collection not exists return err } @@ -258,7 +258,7 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro g.GetStatisticsRequest.PartitionIDs = g.loadedPartitionIDs executeGetStatistics := func(withCache bool) error { - shard2Leaders, err := globalMetaCache.GetShards(ctx, withCache, GetCurDatabaseFromContextOrEmpty(ctx), g.collectionName) + shard2Leaders, err := globalMetaCache.GetShards(ctx, withCache, g.request.GetDbName(), g.collectionName) if err != nil { return err } @@ -277,7 +277,7 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro zap.Error(err), ) // invalidate cache first, since ctx may be canceled or timeout here - globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrEmpty(ctx), g.collectionName) + globalMetaCache.DeprecateShardCache(g.request.GetDbName(), g.collectionName) err = executeGetStatistics(WithoutCache) } if err != nil { @@ -297,19 +297,19 @@ func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64 if err != nil { log.Warn("QueryNode statistic return error", zap.Int64("msgID", g.ID()), zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs), zap.Error(err)) - globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrEmpty(ctx), g.collectionName) + globalMetaCache.DeprecateShardCache(g.request.GetDbName(), g.collectionName) return err } if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader { log.Warn("QueryNode is not shardLeader", zap.Int64("msgID", g.ID()), zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs)) - globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrEmpty(ctx), g.collectionName) + globalMetaCache.DeprecateShardCache(g.request.GetDbName(), g.collectionName) return errInvalidShardLeaders } if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { log.Warn("QueryNode statistic result error", zap.Int64("msgID", g.ID()), zap.Int64("nodeID", nodeID), zap.String("reason", result.GetStatus().GetReason())) - globalMetaCache.DeprecateShardCache(GetCurDatabaseFromContextOrEmpty(ctx), g.collectionName) + globalMetaCache.DeprecateShardCache(g.request.GetDbName(), g.collectionName) return fmt.Errorf("fail to get statistic, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason()) } g.resultBuf <- result @@ -324,7 +324,7 @@ func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName st var unloadPartitionIDs []UniqueID // TODO: Consider to check if partition loaded from cache to save rpc. - info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) if err != nil { return nil, nil, fmt.Errorf("GetCollectionInfo failed, collection = %s, err = %s", collectionName, err) } @@ -646,7 +646,7 @@ func (g *getCollectionStatisticsTask) PreExecute(ctx context.Context) error { } func (g *getCollectionStatisticsTask) Execute(ctx context.Context) error { - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), g.CollectionName) + collID, err := globalMetaCache.GetCollectionID(ctx, g.GetDbName(), g.CollectionName) if err != nil { return err } @@ -734,12 +734,12 @@ func (g *getPartitionStatisticsTask) PreExecute(ctx context.Context) error { } func (g *getPartitionStatisticsTask) Execute(ctx context.Context) error { - collID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), g.CollectionName) + collID, err := globalMetaCache.GetCollectionID(ctx, g.GetDbName(), g.CollectionName) if err != nil { return err } g.collectionID = collID - partitionID, err := globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), g.CollectionName, g.PartitionName) + partitionID, err := globalMetaCache.GetPartitionID(ctx, g.GetDbName(), g.CollectionName, g.PartitionName) if err != nil { return err } diff --git a/internal/proxy/task_statistic_test.go b/internal/proxy/task_statistic_test.go index 1c319b2148fa6..be7faeb508d6d 100644 --- a/internal/proxy/task_statistic_test.go +++ b/internal/proxy/task_statistic_test.go @@ -95,7 +95,7 @@ func TestStatisticTask_all(t *testing.T) { require.NoError(t, createColT.Execute(ctx)) require.NoError(t, createColT.PostExecute(ctx)) - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 1993d82df4958..abe34bce5bbd5 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -818,7 +818,7 @@ func TestHasCollectionTask(t *testing.T) { assert.Equal(t, false, task.result.Value) // createCollection in RootCood and fill GlobalMetaCache rc.CreateCollection(ctx, createColReq) - globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) // success to drop collection err = task.Execute(ctx) @@ -939,7 +939,7 @@ func TestDescribeCollectionTask_ShardsNum1(t *testing.T) { } rc.CreateCollection(ctx, createColReq) - globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) //CreateCollection task := &describeCollectionTask{ @@ -1002,7 +1002,7 @@ func TestDescribeCollectionTask_ShardsNum2(t *testing.T) { } rc.CreateCollection(ctx, createColReq) - globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) //CreateCollection task := &describeCollectionTask{ @@ -1386,7 +1386,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) { }) }) - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) @@ -1641,7 +1641,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { }) }) - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) @@ -2639,7 +2639,7 @@ func TestTransferReplicaTask(t *testing.T) { mgr := newShardClientMgr() InitMetaCache(ctx, rc, qc, mgr) // make it avoid remote call on rc - globalMetaCache.GetCollectionSchema(context.Background(), GetCurDatabaseFromContextOrEmpty(ctx), "collection1") + globalMetaCache.GetCollectionSchema(context.Background(), GetCurDatabaseFromContextOrDefault(ctx), "collection1") req := &milvuspb.TransferReplicaRequest{ Base: &commonpb.MsgBase{ @@ -2720,8 +2720,8 @@ func TestDescribeResourceGroupTask(t *testing.T) { mgr := newShardClientMgr() InitMetaCache(ctx, rc, qc, mgr) // make it avoid remote call on rc - globalMetaCache.GetCollectionSchema(context.Background(), GetCurDatabaseFromContextOrEmpty(ctx), "collection1") - globalMetaCache.GetCollectionSchema(context.Background(), GetCurDatabaseFromContextOrEmpty(ctx), "collection2") + globalMetaCache.GetCollectionSchema(context.Background(), GetCurDatabaseFromContextOrDefault(ctx), "collection1") + globalMetaCache.GetCollectionSchema(context.Background(), GetCurDatabaseFromContextOrDefault(ctx), "collection2") req := &milvuspb.DescribeResourceGroupRequest{ Base: &commonpb.MsgBase{ @@ -2765,8 +2765,8 @@ func TestDescribeResourceGroupTaskFailed(t *testing.T) { mgr := newShardClientMgr() InitMetaCache(ctx, rc, qc, mgr) // make it avoid remote call on rc - globalMetaCache.GetCollectionSchema(context.Background(), GetCurDatabaseFromContextOrEmpty(ctx), "collection1") - globalMetaCache.GetCollectionSchema(context.Background(), GetCurDatabaseFromContextOrEmpty(ctx), "collection2") + globalMetaCache.GetCollectionSchema(context.Background(), GetCurDatabaseFromContextOrDefault(ctx), "collection1") + globalMetaCache.GetCollectionSchema(context.Background(), GetCurDatabaseFromContextOrDefault(ctx), "collection2") req := &milvuspb.DescribeResourceGroupRequest{ Base: &commonpb.MsgBase{ diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 24a40c0a6d7a0..5afb7a0a461e7 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -766,7 +766,7 @@ func GetCurDatabaseFromContext(ctx context.Context) (string, error) { return header[0], nil } -func GetCurDatabaseFromContextOrEmpty(ctx context.Context) string { +func GetCurDatabaseFromContextOrDefault(ctx context.Context) string { if db, err := GetCurDatabaseFromContext(ctx); err == nil { return db } @@ -983,7 +983,7 @@ func getPartitionProgress(ctx context.Context, queryCoord types.QueryCoord, IDs2Names := make(map[int64]string) partitionIDs := make([]int64, 0) for _, partitionName := range partitionNames { - partitionID, err := globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrEmpty(ctx), collectionName, partitionName) + partitionID, err := globalMetaCache.GetPartitionID(ctx, GetCurDatabaseFromContextOrDefault(ctx), collectionName, partitionName) if err != nil { return 0, err } diff --git a/internal/rootcoord/meta_table.go b/internal/rootcoord/meta_table.go index 573a8bd3cee20..c43b4575ebb42 100644 --- a/internal/rootcoord/meta_table.go +++ b/internal/rootcoord/meta_table.go @@ -32,6 +32,7 @@ import ( "github.com/milvus-io/milvus/internal/metrics" pb "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/tso" "github.com/milvus-io/milvus/internal/util/contextutil" "github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/typeutil" @@ -92,6 +93,8 @@ type MetaTable struct { ctx context.Context catalog metastore.RootCoordCatalog + tsoAllocator tso.Allocator + collID2Meta map[typeutil.UniqueID]*model.Collection // collection id -> collection meta // collections *collectionDb @@ -102,10 +105,11 @@ type MetaTable struct { permissionLock sync.RWMutex } -func NewMetaTable(ctx context.Context, catalog metastore.RootCoordCatalog) (*MetaTable, error) { +func NewMetaTable(ctx context.Context, catalog metastore.RootCoordCatalog, tsoAllocator tso.Allocator) (*MetaTable, error) { mt := &MetaTable{ - ctx: contextutil.WithTenantID(ctx, Params.CommonCfg.ClusterName), - catalog: catalog, + ctx: contextutil.WithTenantID(ctx, Params.CommonCfg.ClusterName), + catalog: catalog, + tsoAllocator: tsoAllocator, } if err := mt.reload(); err != nil { return nil, err @@ -129,6 +133,14 @@ func (mt *MetaTable) reload() error { if err != nil { return err } + + // create default database. + if !funcutil.SliceContain(dbs, "default") { + if err := mt.createDefaultDb(); err != nil { + return err + } + } + dbs = append(dbs, "") // recover collections. @@ -139,7 +151,12 @@ func (mt *MetaTable) reload() error { } for name, collection := range collections { mt.collID2Meta[collection.CollectionID] = collection - mt.names.insert(db, name, collection.CollectionID) + if db == "" { + // insert into default database. + mt.names.insert("default", name, collection.CollectionID) + } else { + mt.names.insert(db, name, collection.CollectionID) + } if collection.Available() { collectionNum++ @@ -155,7 +172,11 @@ func (mt *MetaTable) reload() error { return err } for _, alias := range aliases { - mt.aliases.insert(db, alias.Name, alias.CollectionID) + if db == "" { + mt.aliases.insert("default", alias.Name, alias.CollectionID) + } else { + mt.aliases.insert(db, alias.Name, alias.CollectionID) + } } } @@ -164,6 +185,14 @@ func (mt *MetaTable) reload() error { return nil } +func (mt *MetaTable) createDefaultDb() error { + ts, err := mt.tsoAllocator.GenerateTSO(1) + if err != nil { + return err + } + return mt.CreateDatabase(mt.ctx, "default", ts) +} + func (mt *MetaTable) CreateDatabase(ctx context.Context, dbName string, ts typeutil.Timestamp) error { mt.ddLock.Lock() defer mt.ddLock.Unlock() diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index 7d3a8aed804f9..a4896605fbd2d 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -371,7 +371,7 @@ func (c *Core) initMetaTable() error { return retry.Unrecoverable(fmt.Errorf("not supported meta store: %s", Params.MetaStoreCfg.MetaStoreType)) } - if c.meta, err = NewMetaTable(c.ctx, catalog); err != nil { + if c.meta, err = NewMetaTable(c.ctx, catalog, c.tsoAllocator); err != nil { return err }