From 4f8c540c7786a10c0d91baae6b2759a098758b62 Mon Sep 17 00:00:00 2001 From: congqixia Date: Thu, 4 Jan 2024 17:28:46 +0800 Subject: [PATCH] enhance: cache collection schema attributes to reduce proxy cpu (#29668) See also #29113 The collection schema is crucial when performing search/query but some of the information is calculated for every request. This PR change schema field of cached collection info into a utility `schemaInfo` type to store some stable result, say pk field, partitionKeyEnabled, etc. And provided field name to id map for search/query services. --------- Signed-off-by: Congqi Xia --- .../proxy/httpserver/handler_v1.go | 2 +- internal/proxy/meta_cache.go | 56 +++++++++++++++++-- internal/proxy/meta_cache_test.go | 18 +++--- internal/proxy/mock_cache.go | 16 +++--- internal/proxy/task.go | 4 +- internal/proxy/task_delete.go | 12 ++-- internal/proxy/task_delete_test.go | 21 ++++--- internal/proxy/task_index.go | 6 +- internal/proxy/task_index_test.go | 2 +- internal/proxy/task_insert.go | 6 +- internal/proxy/task_query.go | 12 ++-- internal/proxy/task_query_test.go | 19 ++++++- internal/proxy/task_search.go | 24 +++----- internal/proxy/task_search_test.go | 17 ++++-- internal/proxy/task_test.go | 30 +++++----- internal/proxy/task_upsert.go | 14 ++--- internal/proxy/task_upsert_test.go | 36 ++++++------ internal/proxy/util.go | 8 +-- 18 files changed, 185 insertions(+), 118 deletions(-) diff --git a/internal/distributed/proxy/httpserver/handler_v1.go b/internal/distributed/proxy/httpserver/handler_v1.go index 42312df67d34a..6afffb23a3ff8 100644 --- a/internal/distributed/proxy/httpserver/handler_v1.go +++ b/internal/distributed/proxy/httpserver/handler_v1.go @@ -120,7 +120,7 @@ func (h *HandlersV1) checkDatabase(ctx context.Context, c *gin.Context, dbName s func (h *HandlersV1) describeCollection(ctx context.Context, c *gin.Context, dbName string, collectionName string) (*schemapb.CollectionSchema, error) { collSchema, err := proxy.GetCachedCollectionSchema(ctx, dbName, collectionName) if err == nil { - return collSchema, nil + return collSchema.CollectionSchema, nil } req := milvuspb.DescribeCollectionRequest{ DbName: dbName, diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index fac890f8905af..c372b2ee7a5c7 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -69,7 +69,7 @@ type Cache interface { // GetPartitionsIndex returns a partition names in partition key indexed order. GetPartitionsIndex(ctx context.Context, database, collectionName string) ([]string, error) // GetCollectionSchema get collection's schema. - GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemapb.CollectionSchema, error) + GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemaInfo, error) GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error) DeprecateShardCache(database, collectionName string) expireShardLeaderCache(ctx context.Context) @@ -99,9 +99,8 @@ type collectionBasicInfo struct { } type collectionInfo struct { - collID typeutil.UniqueID - schema *schemapb.CollectionSchema - // partInfo map[string]*partitionInfo + collID typeutil.UniqueID + schema *schemaInfo partInfo *partitionInfos leaderMutex sync.RWMutex shardLeaders *shardLeaders @@ -110,6 +109,51 @@ type collectionInfo struct { consistencyLevel commonpb.ConsistencyLevel } +// schemaInfo is a helper function wraps *schemapb.CollectionSchema +// with extra fields mapping and methods +type schemaInfo struct { + *schemapb.CollectionSchema + fieldMap *typeutil.ConcurrentMap[string, int64] // field name to id mapping + hasPartitionKeyField bool + pkField *schemapb.FieldSchema +} + +func newSchemaInfo(schema *schemapb.CollectionSchema) *schemaInfo { + fieldMap := typeutil.NewConcurrentMap[string, int64]() + hasPartitionkey := false + var pkField *schemapb.FieldSchema + for _, field := range schema.GetFields() { + fieldMap.Insert(field.GetName(), field.GetFieldID()) + if field.GetIsPartitionKey() { + hasPartitionkey = true + } + if field.GetIsPrimaryKey() { + pkField = field + } + } + return &schemaInfo{ + CollectionSchema: schema, + fieldMap: fieldMap, + hasPartitionKeyField: hasPartitionkey, + pkField: pkField, + } +} + +func (s *schemaInfo) MapFieldID(name string) (int64, bool) { + return s.fieldMap.Get(name) +} + +func (s *schemaInfo) IsPartitionKeyCollection() bool { + return s.hasPartitionKeyField +} + +func (s *schemaInfo) GetPkField() (*schemapb.FieldSchema, error) { + if s.pkField == nil { + return nil, merr.WrapErrServiceInternal("pk field not found") + } + return s.pkField, nil +} + // partitionInfos contains the cached collection partition informations. type partitionInfos struct { partitionInfos []*partitionInfo @@ -396,7 +440,7 @@ func (m *MetaCache) getFullCollectionInfo(ctx context.Context, database, collect return collInfo, nil } -func (m *MetaCache) GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemapb.CollectionSchema, error) { +func (m *MetaCache) GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemaInfo, error) { m.mu.RLock() var collInfo *collectionInfo var ok bool @@ -445,7 +489,7 @@ func (m *MetaCache) updateCollection(coll *milvuspb.DescribeCollectionResponse, if !ok { m.collInfo[database][collectionName] = &collectionInfo{} } - m.collInfo[database][collectionName].schema = coll.Schema + m.collInfo[database][collectionName].schema = newSchemaInfo(coll.Schema) m.collInfo[database][collectionName].collID = coll.CollectionID m.collInfo[database][collectionName].createdTimestamp = coll.CreatedTimestamp m.collInfo[database][collectionName].createdUtcTimestamp = coll.CreatedUtcTimestamp diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index c662935617392..f08743184162a 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -208,7 +208,7 @@ func TestMetaCache_GetCollection(t *testing.T) { schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1") assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.NoError(t, err) - assert.Equal(t, schema, &schemapb.CollectionSchema{ + assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ AutoID: true, Fields: []*schemapb.FieldSchema{}, Name: "collection1", @@ -220,7 +220,7 @@ func TestMetaCache_GetCollection(t *testing.T) { schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection2") assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.NoError(t, err) - assert.Equal(t, schema, &schemapb.CollectionSchema{ + assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ AutoID: true, Fields: []*schemapb.FieldSchema{}, Name: "collection2", @@ -234,7 +234,7 @@ func TestMetaCache_GetCollection(t *testing.T) { schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1") assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.NoError(t, err) - assert.Equal(t, schema, &schemapb.CollectionSchema{ + assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ AutoID: true, Fields: []*schemapb.FieldSchema{}, Name: "collection1", @@ -290,7 +290,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) { schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1") assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.NoError(t, err) - assert.Equal(t, schema, &schemapb.CollectionSchema{ + assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ AutoID: true, Fields: []*schemapb.FieldSchema{}, Name: "collection1", @@ -302,7 +302,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) { schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection2") assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.NoError(t, err) - assert.Equal(t, schema, &schemapb.CollectionSchema{ + assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ AutoID: true, Fields: []*schemapb.FieldSchema{}, Name: "collection2", @@ -316,7 +316,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) { schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1") assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.NoError(t, err) - assert.Equal(t, schema, &schemapb.CollectionSchema{ + assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ AutoID: true, Fields: []*schemapb.FieldSchema{}, Name: "collection1", @@ -340,7 +340,7 @@ func TestMetaCache_GetCollectionFailure(t *testing.T) { schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1") assert.NoError(t, err) - assert.Equal(t, schema, &schemapb.CollectionSchema{ + assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ AutoID: true, Fields: []*schemapb.FieldSchema{}, Name: "collection1", @@ -349,7 +349,7 @@ func TestMetaCache_GetCollectionFailure(t *testing.T) { rootCoord.Error = true // should be cached with no error assert.NoError(t, err) - assert.Equal(t, schema, &schemapb.CollectionSchema{ + assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ AutoID: true, Fields: []*schemapb.FieldSchema{}, Name: "collection1", @@ -410,7 +410,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) { // GetCollectionSchema will never fail schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1") assert.NoError(t, err) - assert.Equal(t, schema, &schemapb.CollectionSchema{ + assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ AutoID: true, Fields: []*schemapb.FieldSchema{}, Name: "collection1", diff --git a/internal/proxy/mock_cache.go b/internal/proxy/mock_cache.go index b2bbaddd013ea..bd5c707b56efd 100644 --- a/internal/proxy/mock_cache.go +++ b/internal/proxy/mock_cache.go @@ -8,8 +8,6 @@ import ( internalpb "github.com/milvus-io/milvus/internal/proto/internalpb" mock "github.com/stretchr/testify/mock" - schemapb "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - typeutil "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -226,19 +224,19 @@ func (_c *MockCache_GetCollectionName_Call) RunAndReturn(run func(context.Contex } // GetCollectionSchema provides a mock function with given fields: ctx, database, collectionName -func (_m *MockCache) GetCollectionSchema(ctx context.Context, database string, collectionName string) (*schemapb.CollectionSchema, error) { +func (_m *MockCache) GetCollectionSchema(ctx context.Context, database string, collectionName string) (*schemaInfo, error) { ret := _m.Called(ctx, database, collectionName) - var r0 *schemapb.CollectionSchema + var r0 *schemaInfo var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) (*schemapb.CollectionSchema, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, string) (*schemaInfo, error)); ok { return rf(ctx, database, collectionName) } - if rf, ok := ret.Get(0).(func(context.Context, string, string) *schemapb.CollectionSchema); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, string) *schemaInfo); ok { r0 = rf(ctx, database, collectionName) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*schemapb.CollectionSchema) + r0 = ret.Get(0).(*schemaInfo) } } @@ -271,12 +269,12 @@ func (_c *MockCache_GetCollectionSchema_Call) Run(run func(ctx context.Context, return _c } -func (_c *MockCache_GetCollectionSchema_Call) Return(_a0 *schemapb.CollectionSchema, _a1 error) *MockCache_GetCollectionSchema_Call { +func (_c *MockCache_GetCollectionSchema_Call) Return(_a0 *schemaInfo, _a1 error) *MockCache_GetCollectionSchema_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockCache_GetCollectionSchema_Call) RunAndReturn(run func(context.Context, string, string) (*schemapb.CollectionSchema, error)) *MockCache_GetCollectionSchema_Call { +func (_c *MockCache_GetCollectionSchema_Call) RunAndReturn(run func(context.Context, string, string) (*schemaInfo, error)) *MockCache_GetCollectionSchema_Call { _c.Call.Return(run) return _c } diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 8de4be8ccbb1c..76ee6f45d90e6 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1487,7 +1487,7 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) { ), DbID: 0, CollectionID: collID, - Schema: collSchema, + Schema: collSchema.CollectionSchema, ReplicaNumber: t.ReplicaNumber, FieldIndexID: fieldIndexIDs, Refresh: t.Refresh, @@ -1738,7 +1738,7 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error { DbID: 0, CollectionID: collID, PartitionIDs: partitionIDs, - Schema: collSchema, + Schema: collSchema.CollectionSchema, ReplicaNumber: t.ReplicaNumber, FieldIndexID: fieldIndexIDs, Refresh: t.Refresh, diff --git a/internal/proxy/task_delete.go b/internal/proxy/task_delete.go index 52782dea3a17b..2a0df673d8a38 100644 --- a/internal/proxy/task_delete.go +++ b/internal/proxy/task_delete.go @@ -230,7 +230,7 @@ type deleteRunner struct { tsoAllocatorIns tsoAllocator // delete info - schema *schemapb.CollectionSchema + schema *schemaInfo collectionID UniqueID partitionID UniqueID partitionKeyMode bool @@ -264,8 +264,8 @@ func (dr *deleteRunner) Init(ctx context.Context) error { return ErrWithLog(log, "Failed to get collection schema", err) } - dr.partitionKeyMode = hasParitionKeyModeField(dr.schema) - // get prititionIDs of delete + dr.partitionKeyMode = dr.schema.IsPartitionKeyCollection() + // get partitionIDs of delete dr.partitionID = common.InvalidPartitionID if len(dr.req.PartitionName) > 0 { if dr.partitionKeyMode { @@ -300,12 +300,12 @@ func (dr *deleteRunner) Init(ctx context.Context) error { } func (dr *deleteRunner) Run(ctx context.Context) error { - plan, err := planparserv2.CreateRetrievePlan(dr.schema, dr.req.Expr) + plan, err := planparserv2.CreateRetrievePlan(dr.schema.CollectionSchema, dr.req.Expr) if err != nil { return fmt.Errorf("failed to create expr plan, expr = %s", dr.req.GetExpr()) } - isSimple, pk, numRow := getPrimaryKeysFromPlan(dr.schema, plan) + isSimple, pk, numRow := getPrimaryKeysFromPlan(dr.schema.CollectionSchema, plan) if isSimple { // if could get delete.primaryKeys from delete expr err := dr.simpleDelete(ctx, pk, numRow) @@ -379,7 +379,7 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe zap.Int64("nodeID", nodeID)) // set plan - _, outputFieldIDs := translatePkOutputFields(dr.schema) + _, outputFieldIDs := translatePkOutputFields(dr.schema.CollectionSchema) outputFieldIDs = append(outputFieldIDs, common.TimeStampField) plan.OutputFieldIds = outputFieldIDs diff --git a/internal/proxy/task_delete_test.go b/internal/proxy/task_delete_test.go index cbd44cb8e83c5..816b6aaa903c5 100644 --- a/internal/proxy/task_delete_test.go +++ b/internal/proxy/task_delete_test.go @@ -234,7 +234,7 @@ func TestDeleteRunner_Init(t *testing.T) { // channels := []string{"test_channel"} dbName := "test_1" - schema := &schemapb.CollectionSchema{ + collSchema := &schemapb.CollectionSchema{ Name: collectionName, Description: "", AutoID: false, @@ -253,6 +253,7 @@ func TestDeleteRunner_Init(t *testing.T) { }, }, } + schema := newSchemaInfo(collSchema) t.Run("empty collection name", func(t *testing.T) { dr := deleteRunner{} @@ -312,7 +313,7 @@ func TestDeleteRunner_Init(t *testing.T) { mock.Anything, // context.Context mock.AnythingOfType("string"), mock.AnythingOfType("string"), - ).Return(&schemapb.CollectionSchema{ + ).Return(newSchemaInfo(&schemapb.CollectionSchema{ Name: collectionName, Description: "", AutoID: false, @@ -325,7 +326,7 @@ func TestDeleteRunner_Init(t *testing.T) { IsPartitionKey: true, }, }, - }, nil) + }), nil) globalMetaCache = cache assert.Error(t, dr.Init(context.Background())) @@ -440,7 +441,7 @@ func TestDeleteRunner_Run(t *testing.T) { queue.Start() defer queue.Close() - schema := &schemapb.CollectionSchema{ + collSchema := &schemapb.CollectionSchema{ Name: collectionName, Description: "", AutoID: false, @@ -459,6 +460,7 @@ func TestDeleteRunner_Run(t *testing.T) { }, }, } + schema := newSchemaInfo(collSchema) metaCache := NewMockCache(t) metaCache.EXPECT().GetCollectionID(mock.Anything, dbName, collectionName).Return(collectionID, nil).Maybe() @@ -474,6 +476,7 @@ func TestDeleteRunner_Run(t *testing.T) { req: &milvuspb.DeleteRequest{ Expr: "????", }, + schema: schema, } assert.Error(t, dr.Run(context.Background())) }) @@ -838,7 +841,7 @@ func TestDeleteRunner_StreamingQueryAndDelteFunc(t *testing.T) { queue.Start() defer queue.Close() - schema := &schemapb.CollectionSchema{ + collSchema := &schemapb.CollectionSchema{ Name: "test_delete", Description: "", AutoID: false, @@ -859,7 +862,9 @@ func TestDeleteRunner_StreamingQueryAndDelteFunc(t *testing.T) { } // test partitionKey mode - schema.Fields[1].IsPartitionKey = true + collSchema.Fields[1].IsPartitionKey = true + + schema := newSchemaInfo(collSchema) partitionMaps := make(map[string]int64) partitionMaps["test_0"] = 1 partitionMaps["test_1"] = 2 @@ -930,7 +935,7 @@ func TestDeleteRunner_StreamingQueryAndDelteFunc(t *testing.T) { globalMetaCache = mockCache defer func() { globalMetaCache = nil }() - plan, err := planparserv2.CreateRetrievePlan(dr.schema, dr.req.Expr) + plan, err := planparserv2.CreateRetrievePlan(dr.schema.CollectionSchema, dr.req.Expr) assert.NoError(t, err) queryFunc := dr.getStreamingQueryAndDelteFunc(plan) assert.Error(t, queryFunc(ctx, 1, qn)) @@ -973,7 +978,7 @@ func TestDeleteRunner_StreamingQueryAndDelteFunc(t *testing.T) { globalMetaCache = mockCache defer func() { globalMetaCache = nil }() - plan, err := planparserv2.CreateRetrievePlan(dr.schema, dr.req.Expr) + plan, err := planparserv2.CreateRetrievePlan(dr.schema.CollectionSchema, dr.req.Expr) assert.NoError(t, err) queryFunc := dr.getStreamingQueryAndDelteFunc(plan) assert.Error(t, queryFunc(ctx, 1, qn)) diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index 3704d8d0c1d5b..6981b385a0683 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -294,7 +294,7 @@ func (cit *createIndexTask) getIndexedField(ctx context.Context) (*schemapb.Fiel log.Error("failed to get collection schema", zap.Error(err)) return nil, fmt.Errorf("failed to get collection schema: %s", err) } - schemaHelper, err := typeutil.CreateSchemaHelper(schema) + schemaHelper, err := typeutil.CreateSchemaHelper(schema.CollectionSchema) if err != nil { log.Error("failed to parse collection schema", zap.Error(err)) return nil, fmt.Errorf("failed to parse collection schema: %s", err) @@ -616,7 +616,7 @@ func (dit *describeIndexTask) Execute(ctx context.Context) error { log.Error("failed to get collection schema", zap.Error(err)) return fmt.Errorf("failed to get collection schema: %s", err) } - schemaHelper, err := typeutil.CreateSchemaHelper(schema) + schemaHelper, err := typeutil.CreateSchemaHelper(schema.CollectionSchema) if err != nil { log.Error("failed to parse collection schema", zap.Error(err)) return fmt.Errorf("failed to parse collection schema: %s", err) @@ -740,7 +740,7 @@ func (dit *getIndexStatisticsTask) Execute(ctx context.Context) error { log.Error("failed to get collection schema", zap.String("collection_name", dit.GetCollectionName()), zap.Error(err)) return fmt.Errorf("failed to get collection schema: %s", dit.GetCollectionName()) } - schemaHelper, err := typeutil.CreateSchemaHelper(schema) + schemaHelper, err := typeutil.CreateSchemaHelper(schema.CollectionSchema) if err != nil { log.Error("failed to parse collection schema", zap.String("collection_name", schema.GetName()), zap.Error(err)) return fmt.Errorf("failed to parse collection schema: %s", dit.GetCollectionName()) diff --git a/internal/proxy/task_index_test.go b/internal/proxy/task_index_test.go index 38c8e507940db..d0626786c0cdb 100644 --- a/internal/proxy/task_index_test.go +++ b/internal/proxy/task_index_test.go @@ -245,7 +245,7 @@ func TestCreateIndexTask_PreExecute(t *testing.T) { mock.Anything, // context.Context mock.AnythingOfType("string"), mock.AnythingOfType("string"), - ).Return(newTestSchema(), nil) + ).Return(newSchemaInfo(newTestSchema()), nil) globalMetaCache = mockCache diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index aa710e3d6575c..36dbe61174981 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -116,7 +116,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error { log.Warn("get collection schema from global meta cache failed", zap.String("collectionName", collectionName), zap.Error(err)) return err } - it.schema = schema + it.schema = schema.CollectionSchema rowNums := uint32(it.insertMsg.NRows()) // set insertTask.rowIDs @@ -164,7 +164,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error { } // set field ID to insert field data - err = fillFieldIDBySchema(it.insertMsg.GetFieldsData(), schema) + err = fillFieldIDBySchema(it.insertMsg.GetFieldsData(), schema.CollectionSchema) if err != nil { log.Info("set fieldID to fieldData failed", zap.Error(err)) @@ -199,7 +199,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error { } if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck(), withMaxCapCheck()). - Validate(it.insertMsg.GetFieldsData(), schema, it.insertMsg.NRows()); err != nil { + Validate(it.insertMsg.GetFieldsData(), schema.CollectionSchema, it.insertMsg.NRows()); err != nil { return err } diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 80569240d07f2..47d59d6174125 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -52,7 +52,7 @@ type queryTask struct { ids *schemapb.IDs collectionName string queryParams *queryParams - schema *schemapb.CollectionSchema + schema *schemaInfo userOutputFields []string @@ -206,25 +206,25 @@ func (t *queryTask) createPlan(ctx context.Context) error { cntMatch := matchCountRule(t.request.GetOutputFields()) if cntMatch { var err error - t.plan, err = createCntPlan(t.request.GetExpr(), schema) + t.plan, err = createCntPlan(t.request.GetExpr(), schema.CollectionSchema) t.userOutputFields = []string{"count(*)"} return err } var err error if t.plan == nil { - t.plan, err = planparserv2.CreateRetrievePlan(schema, t.request.Expr) + t.plan, err = planparserv2.CreateRetrievePlan(schema.CollectionSchema, t.request.Expr) if err != nil { return err } } - t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, schema, true) + t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, true) if err != nil { return err } - outputFieldIDs, err := translateToOutputFieldIDs(t.request.GetOutputFields(), schema) + outputFieldIDs, err := translateToOutputFieldIDs(t.request.GetOutputFields(), schema.CollectionSchema) if err != nil { return err } @@ -453,7 +453,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error { metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Observe(0.0) tr.CtxRecord(ctx, "reduceResultStart") - reducer := createMilvusReducer(ctx, t.queryParams, t.RetrieveRequest, t.schema, t.plan, t.collectionName) + reducer := createMilvusReducer(ctx, t.queryParams, t.RetrieveRequest, t.schema.CollectionSchema, t.plan, t.collectionName) t.result, err = reducer.Reduce(toReduceResults) if err != nil { diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 266e2830d7f8f..6d4a9665b1e75 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -849,10 +849,22 @@ func Test_createCntPlan(t *testing.T) { func Test_queryTask_createPlan(t *testing.T) { t.Run("match count rule", func(t *testing.T) { + collSchema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "a", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, + }, + } + schema := newSchemaInfo(collSchema) tsk := &queryTask{ request: &milvuspb.QueryRequest{ OutputFields: []string{"count(*)"}, }, + schema: schema, } err := tsk.createPlan(context.TODO()) assert.NoError(t, err) @@ -866,13 +878,14 @@ func Test_queryTask_createPlan(t *testing.T) { request: &milvuspb.QueryRequest{ OutputFields: []string{"a"}, }, + schema: &schemaInfo{}, } err := tsk.createPlan(context.TODO()) assert.Error(t, err) }) t.Run("invalid expression", func(t *testing.T) { - schema := &schemapb.CollectionSchema{ + collSchema := &schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ { FieldID: 100, @@ -882,6 +895,7 @@ func Test_queryTask_createPlan(t *testing.T) { }, }, } + schema := newSchemaInfo(collSchema) tsk := &queryTask{ schema: schema, @@ -895,7 +909,7 @@ func Test_queryTask_createPlan(t *testing.T) { }) t.Run("invalid output fields", func(t *testing.T) { - schema := &schemapb.CollectionSchema{ + collSchema := &schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ { FieldID: 100, @@ -905,6 +919,7 @@ func Test_queryTask_createPlan(t *testing.T) { }, }, } + schema := newSchemaInfo(collSchema) tsk := &queryTask{ schema: schema, diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 9f96439610c47..f5aed98aca70a 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -55,7 +55,7 @@ type searchTask struct { tr *timerecord.TimeRecorder collectionName string - schema *schemapb.CollectionSchema + schema *schemaInfo requery bool userOutputFields []string @@ -179,20 +179,14 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryIn }, offset, nil } -func getOutputFieldIDs(schema *schemapb.CollectionSchema, outputFields []string) (outputFieldIDs []UniqueID, err error) { +func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) { outputFieldIDs = make([]UniqueID, 0, len(outputFields)) for _, name := range outputFields { - hitField := false - for _, field := range schema.GetFields() { - if field.Name == name { - outputFieldIDs = append(outputFieldIDs, field.GetFieldID()) - hitField = true - break - } - } - if !hitField { + id, ok := schema.MapFieldID(name) + if !ok { return nil, fmt.Errorf("Field %s not exist", name) } + outputFieldIDs = append(outputFieldIDs, id) } return outputFieldIDs, nil } @@ -294,7 +288,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error { if t.request.GetDslType() == commonpb.DslType_BoolExprV1 { annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams()) if err != nil || len(annsField) == 0 { - vecFields := typeutil.GetVectorFieldSchemas(t.schema) + vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema) if len(vecFields) == 0 { return errors.New(AnnsFieldKey + " not found in schema") } @@ -311,7 +305,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error { } t.offset = offset - plan, err := planparserv2.CreateSearchPlan(t.schema, t.request.Dsl, annsField, queryInfo) + plan, err := planparserv2.CreateSearchPlan(t.schema.CollectionSchema, t.request.Dsl, annsField, queryInfo) if err != nil { log.Warn("failed to create query plan", zap.Error(err), zap.String("dsl", t.request.Dsl), // may be very large if large term passed. @@ -489,7 +483,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error { zap.Int64s("partitionIDs", t.GetPartitionIDs()), zap.Int("number of valid search results", len(validSearchResults))) tr.CtxRecord(ctx, "reduceResultStart") - primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(t.schema) + primaryFieldSchema, err := t.schema.GetPkField() if err != nil { log.Warn("failed to get primary field schema", zap.Error(err)) return err @@ -582,7 +576,7 @@ func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) { } func (t *searchTask) Requery() error { - pkField, err := typeutil.GetPrimaryFieldSchema(t.schema) + pkField, err := t.schema.GetPkField() if err != nil { return err } diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 8427c0c68cc7f..0031bb82406b5 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -1933,8 +1933,10 @@ func TestSearchTask_Requery(t *testing.T) { collectionName := "col" collectionID := UniqueID(0) cache := NewMockCache(t) + collSchema := constructCollectionSchema(pkField, vecField, dim, collection) + schema := newSchemaInfo(collSchema) cache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(collectionID, nil).Maybe() - cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(constructCollectionSchema(pkField, vecField, dim, collection), nil).Maybe() + cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(schema, nil).Maybe() cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe() cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionBasicInfo{}, nil).Maybe() cache.EXPECT().GetShards(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[string][]nodeInfo{}, nil).Maybe() @@ -1942,7 +1944,8 @@ func TestSearchTask_Requery(t *testing.T) { globalMetaCache = cache t.Run("Test normal", func(t *testing.T) { - schema := constructCollectionSchema(pkField, vecField, dim, collection) + collSchema := constructCollectionSchema(pkField, vecField, dim, collection) + schema := newSchemaInfo(collSchema) qn := mocks.NewMockQueryNodeClient(t) qn.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn( func(ctx context.Context, request *querypb.QueryRequest, option ...grpc.CallOption) (*internalpb.RetrieveResults, error) { @@ -2033,7 +2036,9 @@ func TestSearchTask_Requery(t *testing.T) { }) t.Run("Test no primary key", func(t *testing.T) { - schema := &schemapb.CollectionSchema{} + collSchema := &schemapb.CollectionSchema{} + schema := newSchemaInfo(collSchema) + node := mocks.NewMockProxy(t) qt := &searchTask{ @@ -2056,7 +2061,8 @@ func TestSearchTask_Requery(t *testing.T) { }) t.Run("Test requery failed", func(t *testing.T) { - schema := constructCollectionSchema(pkField, vecField, dim, collection) + collSchema := constructCollectionSchema(pkField, vecField, dim, collection) + schema := newSchemaInfo(collSchema) qn := mocks.NewMockQueryNodeClient(t) qn.EXPECT().Query(mock.Anything, mock.Anything). Return(nil, fmt.Errorf("mock err 1")) @@ -2089,7 +2095,8 @@ func TestSearchTask_Requery(t *testing.T) { }) t.Run("Test postExecute with requery failed", func(t *testing.T) { - schema := constructCollectionSchema(pkField, vecField, dim, collection) + collSchema := constructCollectionSchema(pkField, vecField, dim, collection) + schema := newSchemaInfo(collSchema) qn := mocks.NewMockQueryNodeClient(t) qn.EXPECT().Query(mock.Anything, mock.Anything). Return(nil, fmt.Errorf("mock err 1")) diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 391bfce7491ba..4bb083ee5fd89 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -434,7 +434,7 @@ func TestTranslateOutputFields(t *testing.T) { var userOutputFields []string var err error - schema := &schemapb.CollectionSchema{ + collSchema := &schemapb.CollectionSchema{ Name: "TestTranslateOutputFields", Description: "TestTranslateOutputFields", AutoID: false, @@ -446,6 +446,7 @@ func TestTranslateOutputFields(t *testing.T) { {Name: float16VectorFieldName, FieldID: 102, DataType: schemapb.DataType_Float16Vector}, }, } + schema := newSchemaInfo(collSchema) outputFields, userOutputFields, err = translateOutputFields([]string{}, schema, false) assert.Equal(t, nil, err) @@ -527,7 +528,7 @@ func TestTranslateOutputFields(t *testing.T) { assert.Error(t, err) t.Run("enable dynamic schema", func(t *testing.T) { - schema := &schemapb.CollectionSchema{ + collSchema := &schemapb.CollectionSchema{ Name: "TestTranslateOutputFields", Description: "TestTranslateOutputFields", AutoID: false, @@ -540,6 +541,7 @@ func TestTranslateOutputFields(t *testing.T) { {Name: common.MetaFieldName, FieldID: 102, DataType: schemapb.DataType_JSON, IsDynamic: true}, }, } + schema := newSchemaInfo(collSchema) outputFields, userOutputFields, err = translateOutputFields([]string{"A", idFieldName}, schema, true) assert.Equal(t, nil, err) @@ -1322,7 +1324,7 @@ func TestDropPartitionTask(t *testing.T) { mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), - ).Return(&schemapb.CollectionSchema{}, nil) + ).Return(newSchemaInfo(&schemapb.CollectionSchema{}), nil) globalMetaCache = mockCache task := &dropPartitionTask{ @@ -1373,7 +1375,7 @@ func TestDropPartitionTask(t *testing.T) { mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), - ).Return(&schemapb.CollectionSchema{}, nil) + ).Return(newSchemaInfo(&schemapb.CollectionSchema{}), nil) globalMetaCache = mockCache task.PartitionName = "partition1" err = task.PreExecute(ctx) @@ -1400,7 +1402,7 @@ func TestDropPartitionTask(t *testing.T) { mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), - ).Return(&schemapb.CollectionSchema{}, nil) + ).Return(newSchemaInfo(&schemapb.CollectionSchema{}), nil) globalMetaCache = mockCache err = task.PreExecute(ctx) assert.NoError(t, err) @@ -1426,7 +1428,7 @@ func TestDropPartitionTask(t *testing.T) { mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), - ).Return(&schemapb.CollectionSchema{}, nil) + ).Return(newSchemaInfo(&schemapb.CollectionSchema{}), nil) globalMetaCache = mockCache err = task.PreExecute(ctx) assert.Error(t, err) @@ -2136,7 +2138,7 @@ func Test_createIndexTask_getIndexedField(t *testing.T) { mock.Anything, // context.Context mock.AnythingOfType("string"), mock.AnythingOfType("string"), - ).Return(&schemapb.CollectionSchema{ + ).Return(newSchemaInfo(&schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ { FieldID: 100, @@ -2153,7 +2155,7 @@ func Test_createIndexTask_getIndexedField(t *testing.T) { AutoID: false, }, }, - }, nil) + }), nil) globalMetaCache = cache field, err := cit.getIndexedField(context.Background()) @@ -2179,7 +2181,7 @@ func Test_createIndexTask_getIndexedField(t *testing.T) { mock.Anything, // context.Context mock.AnythingOfType("string"), mock.AnythingOfType("string"), - ).Return(&schemapb.CollectionSchema{ + ).Return(newSchemaInfo(&schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ { Name: fieldName, @@ -2188,7 +2190,7 @@ func Test_createIndexTask_getIndexedField(t *testing.T) { Name: fieldName, // duplicate }, }, - }, nil) + }), nil) globalMetaCache = cache _, err := cit.getIndexedField(context.Background()) assert.Error(t, err) @@ -2200,13 +2202,13 @@ func Test_createIndexTask_getIndexedField(t *testing.T) { mock.Anything, // context.Context mock.AnythingOfType("string"), mock.AnythingOfType("string"), - ).Return(&schemapb.CollectionSchema{ + ).Return(newSchemaInfo(&schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ { Name: fieldName + fieldName, }, }, - }, nil) + }), nil) globalMetaCache = cache _, err := cit.getIndexedField(context.Background()) assert.Error(t, err) @@ -2348,7 +2350,7 @@ func Test_createIndexTask_PreExecute(t *testing.T) { mock.Anything, // context.Context mock.AnythingOfType("string"), mock.AnythingOfType("string"), - ).Return(&schemapb.CollectionSchema{ + ).Return(newSchemaInfo(&schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ { FieldID: 100, @@ -2365,7 +2367,7 @@ func Test_createIndexTask_PreExecute(t *testing.T) { AutoID: false, }, }, - }, nil) + }), nil) globalMetaCache = cache cit.req.ExtraParams = []*commonpb.KeyValuePair{ { diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index f08188116e960..f7568b6bfcb0d 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -59,7 +59,7 @@ type upsertTask struct { chTicker channelsTimeTicker vChannels []vChan pChannels []pChan - schema *schemapb.CollectionSchema + schema *schemaInfo partitionKeyMode bool partitionKeys *schemapb.FieldData } @@ -172,7 +172,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error { it.result.SuccIndex = sliceIndex if it.schema.EnableDynamicField { - err := checkDynamicFieldData(it.schema, it.upsertMsg.InsertMsg) + err := checkDynamicFieldData(it.schema.CollectionSchema, it.upsertMsg.InsertMsg) if err != nil { return err } @@ -181,7 +181,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error { // check primaryFieldData whether autoID is true or not // only allow support autoID == false var err error - it.result.IDs, err = checkPrimaryFieldData(it.schema, it.result, it.upsertMsg.InsertMsg, false) + it.result.IDs, err = checkPrimaryFieldData(it.schema.CollectionSchema, it.result, it.upsertMsg.InsertMsg, false) log := log.Ctx(ctx).With(zap.String("collectionName", it.upsertMsg.InsertMsg.CollectionName)) if err != nil { log.Warn("check primary field data and hash primary key failed when upsert", @@ -189,7 +189,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error { return err } // set field ID to insert field data - err = fillFieldIDBySchema(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema) + err = fillFieldIDBySchema(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema.CollectionSchema) if err != nil { log.Warn("insert set fieldID to fieldData failed when upsert", zap.Error(err)) @@ -197,8 +197,8 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error { } if it.partitionKeyMode { - fieldSchema, _ := typeutil.GetPartitionKeyFieldSchema(it.schema) - it.partitionKeys, err = getPartitionKeyFieldData(fieldSchema, it.upsertMsg.InsertMsg) + pkFieldSchema, _ := it.schema.GetPkField() + it.partitionKeys, err = getPartitionKeyFieldData(pkFieldSchema, it.upsertMsg.InsertMsg) if err != nil { log.Warn("get partition keys from insert request failed", zap.String("collectionName", collectionName), @@ -214,7 +214,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error { } if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck()). - Validate(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema, it.upsertMsg.InsertMsg.NRows()); err != nil { + Validate(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema.CollectionSchema, it.upsertMsg.InsertMsg.NRows()); err != nil { return err } diff --git a/internal/proxy/task_upsert_test.go b/internal/proxy/task_upsert_test.go index dd6cfda6915e6..26e1946168be6 100644 --- a/internal/proxy/task_upsert_test.go +++ b/internal/proxy/task_upsert_test.go @@ -73,6 +73,24 @@ func TestUpsertTask_CheckAligned(t *testing.T) { numRows := 20 dim := 128 + collSchema := &schemapb.CollectionSchema{ + Name: "TestUpsertTask_checkRowNums", + Description: "TestUpsertTask_checkRowNums", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + boolFieldSchema, + int8FieldSchema, + int16FieldSchema, + int32FieldSchema, + int64FieldSchema, + floatFieldSchema, + doubleFieldSchema, + floatVectorFieldSchema, + binaryVectorFieldSchema, + varCharFieldSchema, + }, + } + schema := newSchemaInfo(collSchema) case2 := upsertTask{ req: &milvuspb.UpsertRequest{ NumRows: uint32(numRows), @@ -80,23 +98,7 @@ func TestUpsertTask_CheckAligned(t *testing.T) { }, rowIDs: generateInt64Array(numRows), timestamps: generateUint64Array(numRows), - schema: &schemapb.CollectionSchema{ - Name: "TestUpsertTask_checkRowNums", - Description: "TestUpsertTask_checkRowNums", - AutoID: false, - Fields: []*schemapb.FieldSchema{ - boolFieldSchema, - int8FieldSchema, - int16FieldSchema, - int32FieldSchema, - int64FieldSchema, - floatFieldSchema, - doubleFieldSchema, - floatVectorFieldSchema, - binaryVectorFieldSchema, - varCharFieldSchema, - }, - }, + schema: schema, upsertMsg: &msgstream.UpsertMsg{ InsertMsg: &msgstream.InsertMsg{ InsertRequest: msgpb.InsertRequest{}, diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 059b4acf8bbfb..df726c2d21bc1 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -978,7 +978,7 @@ func translatePkOutputFields(schema *schemapb.CollectionSchema) ([]string, []int // output_fields=["*"] ==> [A,B,C,D] // output_fields=["*",A] ==> [A,B,C,D] // output_fields=["*",C] ==> [A,B,C,D] -func translateOutputFields(outputFields []string, schema *schemapb.CollectionSchema, addPrimary bool) ([]string, []string, error) { +func translateOutputFields(outputFields []string, schema *schemaInfo, addPrimary bool) ([]string, []string, error) { var primaryFieldName string allFieldNameMap := make(map[string]bool) resultFieldNameMap := make(map[string]bool) @@ -1006,7 +1006,7 @@ func translateOutputFields(outputFields []string, schema *schemapb.CollectionSch userOutputFieldsMap[outputFieldName] = true } else { if schema.EnableDynamicField { - schemaH, err := typeutil.CreateSchemaHelper(schema) + schemaH, err := typeutil.CreateSchemaHelper(schema.CollectionSchema) if err != nil { return nil, nil, err } @@ -1447,7 +1447,7 @@ func assignPartitionKeys(ctx context.Context, dbName string, collName string, ke return nil, err } - partitionKeyFieldSchema, err := typeutil.GetPartitionKeyFieldSchema(schema) + partitionKeyFieldSchema, err := typeutil.GetPartitionKeyFieldSchema(schema.CollectionSchema) if err != nil { return nil, err } @@ -1600,7 +1600,7 @@ func SendReplicateMessagePack(ctx context.Context, replicateMsgStream msgstream. } } -func GetCachedCollectionSchema(ctx context.Context, dbName string, colName string) (*schemapb.CollectionSchema, error) { +func GetCachedCollectionSchema(ctx context.Context, dbName string, colName string) (*schemaInfo, error) { if globalMetaCache != nil { return globalMetaCache.GetCollectionSchema(ctx, dbName, colName) }