diff --git a/configs/embedded-milvus.yaml b/configs/embedded-milvus.yaml index dfea43ce9f136..a15d007d69958 100644 --- a/configs/embedded-milvus.yaml +++ b/configs/embedded-milvus.yaml @@ -96,14 +96,6 @@ rootCoord: # seconds (24 hours). # Note: If default value is to be changed, change also the default in: internal/util/paramtable/component_param.go importTaskRetention: 86400 - # (in seconds) During index building phase of an import task, Milvus will check the building status of a task's - # segments' indices every `importIndexCheckInterval` seconds. Default 300 seconds (5 minutes). - # Note: If default value is to be changed, change also the default in: internal/util/paramtable/component_param.go - importIndexCheckInterval: 300 - # (in seconds) Maximum time to wait before pushing flushed segments online (make them searchable) during importing. - # Default 1200 seconds (20 minutes). - # Note: If default value is to be changed, change also the default in: internal/util/paramtable/component_param.go - importIndexWaitLimit: 1200 # Related configuration of proxy, used to validate client requests and reduce the returned results. proxy: diff --git a/configs/milvus.yaml b/configs/milvus.yaml index ef3052a896cab..bbed908edb89d 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -132,14 +132,6 @@ rootCoord: # seconds (24 hours). # Note: If default value is to be changed, change also the default in: internal/util/paramtable/component_param.go importTaskRetention: 86400 - # (in seconds) Check an import task's segment loading state in queryNodes every `importSegmentStateCheckInterval` - # seconds. Default 10 seconds. - # Note: If default value is to be changed, change also the default in: internal/util/paramtable/component_param.go - importSegmentStateCheckInterval: 10 - # (in seconds) Maximum time to wait for segments in a single import task to be loaded in queryNodes. - # Default 60 seconds (1 minute). - # Note: If default value is to be changed, change also the default in: internal/util/paramtable/component_param.go - importSegmentStateWaitLimit: 60 # (in seconds) Check the building status of a task's segments' indices every `importIndexCheckInterval` seconds. # Default 10 seconds. # Note: If default value is to be changed, change also the default in: internal/util/paramtable/component_param.go diff --git a/internal/datacoord/cluster.go b/internal/datacoord/cluster.go index 1cc7296749464..81140c7469f76 100644 --- a/internal/datacoord/cluster.go +++ b/internal/datacoord/cluster.go @@ -149,11 +149,6 @@ func (c *Cluster) ReCollectSegmentStats(ctx context.Context, nodeID int64) { c.sessionManager.ReCollectSegmentStats(ctx, nodeID) } -// AddSegment triggers a AddSegment call from session manager. -func (c *Cluster) AddSegment(ctx context.Context, nodeID int64, req *datapb.AddSegmentRequest) { - c.sessionManager.AddSegment(ctx, nodeID, req) -} - // GetSessions returns all sessions func (c *Cluster) GetSessions() []*Session { return c.sessionManager.GetSessions() diff --git a/internal/datacoord/cluster_test.go b/internal/datacoord/cluster_test.go index 14cfaa4b6d364..b6c2a5c0ddc92 100644 --- a/internal/datacoord/cluster_test.go +++ b/internal/datacoord/cluster_test.go @@ -641,74 +641,3 @@ func TestCluster_ReCollectSegmentStats(t *testing.T) { time.Sleep(500 * time.Millisecond) }) } - -func TestCluster_AddSegment(t *testing.T) { - kv := getMetaKv(t) - defer func() { - kv.RemoveWithPrefix("") - kv.Close() - }() - - t.Run("add segment succeed", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - var mockSessionCreator = func(ctx context.Context, addr string) (types.DataNode, error) { - return newMockDataNodeClient(1, nil) - } - sessionManager := NewSessionManager(withSessionCreator(mockSessionCreator)) - channelManager, err := NewChannelManager(kv, newMockHandler()) - assert.Nil(t, err) - cluster := NewCluster(sessionManager, channelManager) - defer cluster.Close() - addr := "localhost:8080" - info := &NodeInfo{ - Address: addr, - NodeID: 1, - } - nodes := []*NodeInfo{info} - err = cluster.Startup(ctx, nodes) - assert.Nil(t, err) - - err = cluster.Watch("chan-1", 1) - assert.NoError(t, err) - - assert.NotPanics(t, func() { - cluster.AddSegment(ctx, 1, &datapb.AddSegmentRequest{ - Base: &commonpb.MsgBase{ - SourceID: 0, - }, - }) - }) - time.Sleep(500 * time.Millisecond) - }) - - t.Run("add segment failed", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - sessionManager := NewSessionManager() - channelManager, err := NewChannelManager(kv, newMockHandler()) - assert.Nil(t, err) - cluster := NewCluster(sessionManager, channelManager) - defer cluster.Close() - addr := "localhost:8080" - info := &NodeInfo{ - Address: addr, - NodeID: 1, - } - nodes := []*NodeInfo{info} - err = cluster.Startup(ctx, nodes) - assert.Nil(t, err) - - err = cluster.Watch("chan-1", 1) - assert.NoError(t, err) - - assert.NotPanics(t, func() { - cluster.AddSegment(ctx, 1, &datapb.AddSegmentRequest{ - Base: &commonpb.MsgBase{ - SourceID: 0, - }, - }) - }) - time.Sleep(500 * time.Millisecond) - }) -} diff --git a/internal/datacoord/compaction_trigger.go b/internal/datacoord/compaction_trigger.go index 8bdeff719e520..0a053b73cac9f 100644 --- a/internal/datacoord/compaction_trigger.go +++ b/internal/datacoord/compaction_trigger.go @@ -236,7 +236,8 @@ func (t *compactionTrigger) handleGlobalSignal(signal *compactionSignal) { return (signal.collectionID == 0 || segment.CollectionID == signal.collectionID) && isSegmentHealthy(segment) && isFlush(segment) && - !segment.isCompacting // not compacting now + !segment.isCompacting && // not compacting now + !segment.isImporting // not importing now }) // m is list of chanPartSegments, which is channel-partition organized segments for _, group := range m { @@ -474,7 +475,8 @@ func (t *compactionTrigger) getCandidateSegments(channel string, partitionID Uni !isFlush(s) || s.GetInsertChannel() != channel || s.GetPartitionID() != partitionID || - s.isCompacting { + s.isCompacting || + s.isImporting { continue } res = append(res, s) diff --git a/internal/datacoord/garbage_collector.go b/internal/datacoord/garbage_collector.go index eafdf905cefb0..f4e701ecaf319 100644 --- a/internal/datacoord/garbage_collector.go +++ b/internal/datacoord/garbage_collector.go @@ -17,6 +17,7 @@ package datacoord import ( + "context" "path" "sync" "time" @@ -29,7 +30,9 @@ import ( "github.com/milvus-io/milvus/api/commonpb" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/types" "github.com/minio/minio-go/v7" "go.uber.org/zap" ) @@ -59,6 +62,8 @@ type garbageCollector struct { segRefer *SegmentReferenceManager indexCoord types.IndexCoord + rcc types.RootCoord + startOnce sync.Once stopOnce sync.Once wg sync.WaitGroup @@ -66,18 +71,16 @@ type garbageCollector struct { } // newGarbageCollector create garbage collector with meta and option -func newGarbageCollector(meta *meta, - segRefer *SegmentReferenceManager, - indexCoord types.IndexCoord, - opt GcOption) *garbageCollector { +func newGarbageCollector(meta *meta, segRefer *SegmentReferenceManager, + indexCoord types.IndexCoord, opt GcOption) *garbageCollector { log.Info("GC with option", zap.Bool("enabled", opt.enabled), zap.Duration("interval", opt.checkInterval), zap.Duration("missingTolerance", opt.missingTolerance), zap.Duration("dropTolerance", opt.dropTolerance)) return &garbageCollector{ - meta: meta, - segRefer: segRefer, + meta: meta, + segRefer: segRefer, indexCoord: indexCoord, - option: opt, - closeCh: make(chan struct{}), + option: opt, + closeCh: make(chan struct{}), } } @@ -221,7 +224,7 @@ func (gc *garbageCollector) clearEtcd() { func (gc *garbageCollector) isExpire(dropts Timestamp) bool { droptime := time.Unix(0, int64(dropts)) - return time.Since(droptime) > gc.option.dropTolerance + return time.Since(droptime) >= gc.option.dropTolerance } func getLogs(sinfo *SegmentInfo) []*datapb.Binlog { diff --git a/internal/datacoord/garbage_collector_test.go b/internal/datacoord/garbage_collector_test.go index 9936ec1029e30..d9e4fcf2b01b1 100644 --- a/internal/datacoord/garbage_collector_test.go +++ b/internal/datacoord/garbage_collector_test.go @@ -107,7 +107,7 @@ func Test_garbageCollector_scan(t *testing.T) { bucketName := `datacoord-ut` + strings.ToLower(funcutil.RandomString(8)) rootPath := `gc` + funcutil.RandomString(8) //TODO change to Params - cli, inserts, stats, delta, others, err := initUtOSSEnv(bucketName, rootPath, 4) + cli, inserts, stats, delta, others, err := initUtOSSEnv(bucketName, rootPath, 5) require.NoError(t, err) mockAllocator := newMockAllocator() @@ -120,6 +120,7 @@ func Test_garbageCollector_scan(t *testing.T) { segRefer, err := NewSegmentReferenceManager(etcdKV, nil) assert.NoError(t, err) assert.NotNil(t, segRefer) + mockRootCoord := newMockRootCoordService() indexCoord := mocks.NewMockIndexCoord(t) @@ -232,6 +233,42 @@ func Test_garbageCollector_scan(t *testing.T) { gc.close() }) + t.Run("clear import failed segments", func(t *testing.T) { + segment := buildSegment(1, 10, ImportFailedSegmentID, "ch") + segment.State = commonpb.SegmentState_Importing + segment.Binlogs = []*datapb.FieldBinlog{getFieldBinlogPaths(0, inserts[0])} + segment.Statslogs = []*datapb.FieldBinlog{getFieldBinlogPaths(0, stats[0])} + segment.Deltalogs = []*datapb.FieldBinlog{getFieldBinlogPaths(0, delta[0])} + err = meta.AddSegment(segment) + require.NoError(t, err) + + gc := newGarbageCollector(meta, segRefer, mockRootCoord, GcOption{ + cli: cli, + enabled: true, + checkInterval: time.Minute * 30, + missingTolerance: time.Hour * 24, + dropTolerance: 0, + rootPath: rootPath, + }) + gc.clearEtcd() + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, insertLogPrefix), inserts[1:]) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, statsLogPrefix), stats[1:]) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, deltaLogPrefix), delta[1:]) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, `indexes`), others) + + gc.close() + + gc2 := newGarbageCollector(meta, segRefer, nil, GcOption{ + cli: cli, + enabled: true, + checkInterval: time.Minute * 30, + missingTolerance: time.Hour * 24, + dropTolerance: 0, + rootPath: rootPath, + }) + gc2.clearEtcd() + gc2.close() + }) t.Run("missing gc all", func(t *testing.T) { gc := newGarbageCollector(meta, segRefer, indexCoord, GcOption{ cli: cli, @@ -244,6 +281,28 @@ func Test_garbageCollector_scan(t *testing.T) { gc.start() gc.scan() gc.clearEtcd() + + // bad path shall remains since datacoord cannot determine file is garbage or not if path is not valid + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, insertLogPrefix), inserts[1:2]) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, statsLogPrefix), stats[1:2]) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, deltaLogPrefix), delta[1:2]) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, `indexes`), others) + + gc.close() + }) + + t.Run("list object with error", func(t *testing.T) { + gc := newGarbageCollector(meta, segRefer, mockRootCoord, GcOption{ + cli: cli, + enabled: true, + checkInterval: time.Minute * 30, + missingTolerance: 0, + dropTolerance: 0, + rootPath: rootPath, + }) + gc.start() + gc.scan() + // bad path shall remains since datacoord cannot determine file is garbage or not if path is not valid validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, insertLogPrefix), inserts[1:2]) validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, statsLogPrefix), stats[1:2]) diff --git a/internal/datacoord/meta.go b/internal/datacoord/meta.go index cf768241fe3ba..fd591b603b6c6 100644 --- a/internal/datacoord/meta.go +++ b/internal/datacoord/meta.go @@ -259,12 +259,15 @@ func (m *meta) UpdateFlushSegmentsInfo( m.Lock() defer m.Unlock() - log.Info("update flush segments info", zap.Int64("segmentId", segmentID), + log.Info("update flush segments info", + zap.Int64("segmentId", segmentID), zap.Int("binlog", len(binlogs)), - zap.Int("statslog", len(statslogs)), - zap.Int("deltalogs", len(deltalogs)), + zap.Int("stats log", len(statslogs)), + zap.Int("delta logs", len(deltalogs)), zap.Bool("flushed", flushed), zap.Bool("dropped", dropped), + zap.Any("check points", checkpoints), + zap.Any("start position", startPositions), zap.Bool("importing", importing)) segment := m.segments.GetSegment(segmentID) if importing { @@ -747,6 +750,14 @@ func (m *meta) SetSegmentCompacting(segmentID UniqueID, compacting bool) { m.segments.SetIsCompacting(segmentID, compacting) } +// SetSegmentIsImporting sets the importing state for a segment. +func (m *meta) SetSegmentIsImporting(segmentID UniqueID, importing bool) { + m.Lock() + defer m.Unlock() + + m.segments.SetIsImporting(segmentID, importing) +} + func (m *meta) CompleteMergeCompaction(compactionLogs []*datapb.CompactionSegmentBinlogs, result *datapb.CompactionResult) error { m.Lock() defer m.Unlock() diff --git a/internal/datacoord/meta_test.go b/internal/datacoord/meta_test.go index 4e3003c3269a6..ba05a6e2a2bdb 100644 --- a/internal/datacoord/meta_test.go +++ b/internal/datacoord/meta_test.go @@ -717,6 +717,55 @@ func Test_meta_SetSegmentCompacting(t *testing.T) { } } +func Test_meta_SetSegmentIsImporting(t *testing.T) { + type fields struct { + client kv.TxnKV + segments *SegmentsInfo + } + type args struct { + segmentID UniqueID + isImporting bool + } + tests := []struct { + name string + fields fields + args args + }{ + { + "test set segment importing", + fields{ + memkv.NewMemoryKV(), + &SegmentsInfo{ + map[int64]*SegmentInfo{ + 1: { + SegmentInfo: &datapb.SegmentInfo{ + ID: 1, + State: commonpb.SegmentState_Flushed, + }, + isImporting: false, + }, + }, + }, + }, + args{ + segmentID: 1, + isImporting: true, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &meta{ + client: tt.fields.client, + segments: tt.fields.segments, + } + m.SetSegmentIsImporting(tt.args.segmentID, tt.args.isImporting) + segment := m.GetSegment(tt.args.segmentID) + assert.Equal(t, tt.args.isImporting, segment.isImporting) + }) + } +} + func Test_meta_GetSegmentsOfCollection(t *testing.T) { type fields struct { segments *SegmentsInfo diff --git a/internal/datacoord/mock_test.go b/internal/datacoord/mock_test.go index 82b3df5f201a2..d519d8d748763 100644 --- a/internal/datacoord/mock_test.go +++ b/internal/datacoord/mock_test.go @@ -36,6 +36,8 @@ import ( "github.com/milvus-io/milvus/internal/util/typeutil" ) +const ImportFailedSegmentID = 102 + func newMemoryMeta(allocator allocator) (*meta, error) { memoryKV := memkv.NewMemoryKV() return newMeta(context.TODO(), memoryKV) @@ -240,7 +242,7 @@ func (c *mockDataNodeClient) Import(ctx context.Context, in *datapb.ImportTaskRe return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil } -func (c *mockDataNodeClient) AddSegment(ctx context.Context, req *datapb.AddSegmentRequest) (*commonpb.Status, error) { +func (c *mockDataNodeClient) AddImportSegment(ctx context.Context, req *datapb.AddImportSegmentRequest) (*commonpb.Status, error) { return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil } @@ -254,6 +256,22 @@ type mockRootCoordService struct { cnt int64 } +func (m *mockRootCoordService) GetImportFailedSegmentIDs(ctx context.Context, req *internalpb.GetImportFailedSegmentIDsRequest) (*internalpb.GetImportFailedSegmentIDsResponse, error) { + segIDs := make([]int64, 0) + segIDs = append(segIDs, ImportFailedSegmentID) + return &internalpb.GetImportFailedSegmentIDsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + Reason: "", + }, + SegmentIDs: segIDs, + }, nil +} + +func (m *mockRootCoordService) CheckSegmentIndexReady(context.Context, *internalpb.CheckSegmentIndexReadyRequest) (*commonpb.Status, error) { + panic("implement me") +} + func (m *mockRootCoordService) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { panic("implement me") } diff --git a/internal/datacoord/segment_info.go b/internal/datacoord/segment_info.go index 8e1104592a9b0..6c6f952d46308 100644 --- a/internal/datacoord/segment_info.go +++ b/internal/datacoord/segment_info.go @@ -37,6 +37,7 @@ type SegmentInfo struct { allocations []*Allocation lastFlushTime time.Time isCompacting bool + isImporting bool // a cache to avoid calculate twice size int64 lastWrittenTime time.Time @@ -57,6 +58,17 @@ func NewSegmentInfo(info *datapb.SegmentInfo) *SegmentInfo { } } +// NewImportSegmentInfo works the same as NewSegmentInfo except that isImport is explicitly set to true. +func NewImportSegmentInfo(info *datapb.SegmentInfo) *SegmentInfo { + return &SegmentInfo{ + SegmentInfo: info, + currRows: 0, + allocations: make([]*Allocation, 0, 16), + lastFlushTime: time.Now().Add(-1 * flushInterval), + isImporting: true, + } +} + // NewSegmentsInfo creates a `SegmentsInfo` instance, which makes sure internal map is initialized // note that no mutex is wrapped so external concurrent control is needed func NewSegmentsInfo() *SegmentsInfo { @@ -109,6 +121,14 @@ func (s *SegmentsInfo) SetState(segmentID UniqueID, state commonpb.SegmentState) } } +// SetDroppedAt sets Segment DroppedAt time for SegmentInfo with provided segmentID +// if SegmentInfo not found, do nothing +func (s *SegmentsInfo) SetDroppedAt(segmentID UniqueID, time uint64) { + if segment, ok := s.segments[segmentID]; ok { + s.segments[segmentID] = segment.Clone(SetDroppedAt(time)) + } +} + // SetDmlPosition sets DmlPosition info (checkpoint for recovery) for SegmentInfo with provided segmentID // if SegmentInfo not found, do nothing func (s *SegmentsInfo) SetDmlPosition(segmentID UniqueID, pos *internalpb.MsgPosition) { @@ -179,13 +199,20 @@ func (s *SegmentsInfo) AddSegmentBinlogs(segmentID UniqueID, field2Binlogs map[U } } -// SetIsCompacting sets compactino status for segment +// SetIsCompacting sets compaction status for segment func (s *SegmentsInfo) SetIsCompacting(segmentID UniqueID, isCompacting bool) { if segment, ok := s.segments[segmentID]; ok { s.segments[segmentID] = segment.ShadowClone(SetIsCompacting(isCompacting)) } } +// SetIsImporting sets the import status for a segment. +func (s *SegmentsInfo) SetIsImporting(segmentID UniqueID, isImporting bool) { + if segment, ok := s.segments[segmentID]; ok { + s.segments[segmentID] = segment.ShadowClone(SetIsImporting(isImporting)) + } +} + // Clone deep clone the segment info and return a new instance func (s *SegmentInfo) Clone(opts ...SegmentInfoOption) *SegmentInfo { info := proto.Clone(s.SegmentInfo).(*datapb.SegmentInfo) @@ -195,6 +222,7 @@ func (s *SegmentInfo) Clone(opts ...SegmentInfoOption) *SegmentInfo { allocations: s.allocations, lastFlushTime: s.lastFlushTime, isCompacting: s.isCompacting, + isImporting: s.isImporting, //cannot copy size, since binlog may be changed lastWrittenTime: s.lastWrittenTime, } @@ -207,12 +235,13 @@ func (s *SegmentInfo) Clone(opts ...SegmentInfoOption) *SegmentInfo { // ShadowClone shadow clone the segment and return a new instance func (s *SegmentInfo) ShadowClone(opts ...SegmentInfoOption) *SegmentInfo { cloned := &SegmentInfo{ - SegmentInfo: s.SegmentInfo, - currRows: s.currRows, - allocations: s.allocations, - lastFlushTime: s.lastFlushTime, - isCompacting: s.isCompacting, - size: s.size, + SegmentInfo: s.SegmentInfo, + currRows: s.currRows, + allocations: s.allocations, + lastFlushTime: s.lastFlushTime, + isCompacting: s.isCompacting, + isImporting: s.isImporting, + size: s.size, lastWrittenTime: s.lastWrittenTime, } @@ -246,6 +275,13 @@ func SetState(state commonpb.SegmentState) SegmentInfoOption { } } +// SetDroppedAt is the option to set droppedAt time for segment info +func SetDroppedAt(time uint64) SegmentInfoOption { + return func(segment *SegmentInfo) { + segment.DroppedAt = time + } +} + // SetDmlPosition is the option to set dml position for segment info func SetDmlPosition(pos *internalpb.MsgPosition) SegmentInfoOption { return func(segment *SegmentInfo) { @@ -304,6 +340,13 @@ func SetIsCompacting(isCompacting bool) SegmentInfoOption { } } +// SetIsImporting is the option to set import state for segment info. +func SetIsImporting(isImporting bool) SegmentInfoOption { + return func(segment *SegmentInfo) { + segment.isImporting = isImporting + } +} + func addSegmentBinlogs(field2Binlogs map[UniqueID][]*datapb.Binlog) SegmentInfoOption { return func(segment *SegmentInfo) { for fieldID, binlogPaths := range field2Binlogs { diff --git a/internal/datacoord/segment_manager.go b/internal/datacoord/segment_manager.go index e3d50c5ffe257..18587597541d7 100644 --- a/internal/datacoord/segment_manager.go +++ b/internal/datacoord/segment_manager.go @@ -375,7 +375,12 @@ func (s *SegmentManager) openNewSegment(ctx context.Context, collectionID Unique MaxRowNum: int64(maxNumOfRows), LastExpireTime: 0, } - segment := NewSegmentInfo(segmentInfo) + var segment *SegmentInfo + if segmentState == commonpb.SegmentState_Importing { + segment = NewImportSegmentInfo(segmentInfo) + } else { + segment = NewSegmentInfo(segmentInfo) + } if err := s.meta.AddSegment(segment); err != nil { log.Error("failed to add segment to DataCoord", zap.Error(err)) return nil, err diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index c444abd60281d..7ab74f22fb6e7 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -133,7 +133,6 @@ type Server struct { dnEventCh <-chan *sessionutil.SessionEvent //icEventCh <-chan *sessionutil.SessionEvent qcEventCh <-chan *sessionutil.SessionEvent - rcEventCh <-chan *sessionutil.SessionEvent dataNodeCreator dataNodeCreatorFunc rootCoordClientCreator rootCoordCreatorFunc @@ -463,16 +462,6 @@ func (s *Server) initServiceDiscovery() error { } s.qcEventCh = s.session.WatchServices(typeutil.QueryCoordRole, qcRevision+1, nil) - rcSessions, rcRevision, err := s.session.GetSessions(typeutil.RootCoordRole) - if err != nil { - log.Error("DataCoord get RootCoord session failed", zap.Error(err)) - return err - } - for _, session := range rcSessions { - serverIDs = append(serverIDs, session.ServerID) - } - s.rcEventCh = s.session.WatchServices(typeutil.RootCoordRole, rcRevision+1, nil) - s.segReferManager, err = NewSegmentReferenceManager(s.kvClient, serverIDs) return err } @@ -756,12 +745,6 @@ func (s *Server) watchService(ctx context.Context) { return } s.processSessionEvent(ctx, "QueryCoord", event) - case event, ok := <-s.rcEventCh: - if !ok { - s.stopServiceWatch() - return - } - s.processSessionEvent(ctx, "RootCoord", event) } } } diff --git a/internal/datacoord/server_test.go b/internal/datacoord/server_test.go index 72d468f177104..6b41795b566dc 100644 --- a/internal/datacoord/server_test.go +++ b/internal/datacoord/server_test.go @@ -52,11 +52,19 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/typeutil" + "github.com/minio/minio-go/v7" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" ) func TestMain(m *testing.M) { @@ -144,6 +152,19 @@ func TestAssignSegmentID(t *testing.T) { assert.EqualValues(t, 1000, assign.Count) }) + t.Run("assign segment for bulk load", func(t *testing.T) { + svr := newTestServer(t, nil) + defer closeTestServer(t, svr) + reportImportAttempts = 2 + svr.rootCoordClient = &mockRootCoord{ + RootCoord: svr.rootCoordClient, + collID: collID, + } + svr.CompleteBulkLoad(context.TODO(), &datapb.CompleteBulkLoadRequest{ + SegmentIds: []int64{1001, 1002, 1003}, + }) + }) + t.Run("with closed server", func(t *testing.T) { req := &datapb.SegmentIDRequest{ Count: 100, @@ -166,7 +187,7 @@ func TestAssignSegmentID(t *testing.T) { t.Run("assign segment with invalid collection", func(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - svr.rootCoordClient = &mockDescribeCollRoot{ + svr.rootCoordClient = &mockRootCoord{ RootCoord: svr.rootCoordClient, collID: collID, } @@ -193,12 +214,12 @@ func TestAssignSegmentID(t *testing.T) { }) } -type mockDescribeCollRoot struct { +type mockRootCoord struct { types.RootCoord collID UniqueID } -func (r *mockDescribeCollRoot) DescribeCollection(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { +func (r *mockRootCoord) DescribeCollection(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { if req.CollectionID != r.collID { return &milvuspb.DescribeCollectionResponse{ Status: &commonpb.Status{ @@ -210,6 +231,19 @@ func (r *mockDescribeCollRoot) DescribeCollection(ctx context.Context, req *milv return r.RootCoord.DescribeCollection(ctx, req) } +func (r *mockRootCoord) CheckSegmentIndexReady(context.Context, *internalpb.CheckSegmentIndexReadyRequest) (*commonpb.Status, error) { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, nil +} + +func (r *mockRootCoord) ReportImport(context.Context, *rootcoordpb.ImportResult) (*commonpb.Status, error) { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "something bad", + }, nil +} + func TestFlush(t *testing.T) { req := &datapb.FlushRequest{ Base: &commonpb.MsgBase{ @@ -815,12 +849,10 @@ func TestServer_watchQueryCoord(t *testing.T) { dnCh := make(chan *sessionutil.SessionEvent) //icCh := make(chan *sessionutil.SessionEvent) qcCh := make(chan *sessionutil.SessionEvent) - rcCh := make(chan *sessionutil.SessionEvent) svr.dnEventCh = dnCh //svr.icEventCh = icCh svr.qcEventCh = qcCh - svr.rcEventCh = rcCh segRefer, err := NewSegmentReferenceManager(etcdKV, nil) assert.NoError(t, err) @@ -862,69 +894,6 @@ func TestServer_watchQueryCoord(t *testing.T) { assert.True(t, closed) } -func TestServer_watchRootCoord(t *testing.T) { - Params.Init() - etcdCli, err := etcd.GetEtcdClient(&Params.EtcdCfg) - assert.Nil(t, err) - etcdKV := etcdkv.NewEtcdKV(etcdCli, Params.EtcdCfg.MetaRootPath) - assert.NotNil(t, etcdKV) - factory := dependency.NewDefaultFactory(true) - svr := CreateServer(context.TODO(), factory) - svr.session = &sessionutil.Session{ - TriggerKill: true, - } - svr.kvClient = etcdKV - - dnCh := make(chan *sessionutil.SessionEvent) - //icCh := make(chan *sessionutil.SessionEvent) - qcCh := make(chan *sessionutil.SessionEvent) - rcCh := make(chan *sessionutil.SessionEvent) - - svr.dnEventCh = dnCh - //svr.icEventCh = icCh - svr.qcEventCh = qcCh - svr.rcEventCh = rcCh - - segRefer, err := NewSegmentReferenceManager(etcdKV, nil) - assert.NoError(t, err) - assert.NotNil(t, segRefer) - svr.segReferManager = segRefer - - sc := make(chan os.Signal, 1) - signal.Notify(sc, syscall.SIGINT) - defer signal.Reset(syscall.SIGINT) - closed := false - sigQuit := make(chan struct{}, 1) - - svr.serverLoopWg.Add(1) - go func() { - svr.watchService(context.Background()) - }() - - go func() { - <-sc - closed = true - sigQuit <- struct{}{} - }() - - rcCh <- &sessionutil.SessionEvent{ - EventType: sessionutil.SessionAddEvent, - Session: &sessionutil.Session{ - ServerID: 3, - }, - } - rcCh <- &sessionutil.SessionEvent{ - EventType: sessionutil.SessionDelEvent, - Session: &sessionutil.Session{ - ServerID: 3, - }, - } - close(rcCh) - <-sigQuit - svr.serverLoopWg.Wait() - assert.True(t, closed) -} - func TestServer_ShowConfigurations(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) @@ -2735,7 +2704,10 @@ func TestDataCoord_Import(t *testing.T) { t.Run("normal case", func(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - + svr.sessionManager.AddSession(&NodeInfo{ + NodeID: 0, + Address: "localhost:8080", + }) err := svr.channelManager.AddNode(0) assert.Nil(t, err) err = svr.channelManager.Watch(&channel{Name: "ch1", CollectionID: 0}) @@ -2860,18 +2832,31 @@ func TestDataCoord_AddSegment(t *testing.T) { t.Run("test add segment", func(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - + seg := buildSegment(100, 100, 100, "ch1") + svr.meta.AddSegment(seg) + svr.sessionManager.AddSession(&NodeInfo{ + NodeID: 110, + Address: "localhost:8080", + }) err := svr.channelManager.AddNode(110) assert.Nil(t, err) err = svr.channelManager.Watch(&channel{Name: "ch1", CollectionID: 100}) assert.Nil(t, err) - status, err := svr.AddSegment(context.TODO(), &datapb.AddSegmentRequest{ + status, err := svr.SaveImportSegment(context.TODO(), &datapb.SaveImportSegmentRequest{ SegmentId: 100, ChannelName: "ch1", CollectionId: 100, PartitionId: 100, RowNum: int64(1), + SaveBinlogPathReq: &datapb.SaveBinlogPathsRequest{ + Base: &commonpb.MsgBase{ + SourceID: Params.DataNodeCfg.GetNodeID(), + }, + SegmentID: 100, + CollectionID: 100, + Importing: true, + }, }) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, status.GetErrorCode()) @@ -2886,7 +2871,7 @@ func TestDataCoord_AddSegment(t *testing.T) { err = svr.channelManager.Watch(&channel{Name: "ch1", CollectionID: 100}) assert.Nil(t, err) - status, err := svr.AddSegment(context.TODO(), &datapb.AddSegmentRequest{ + status, err := svr.SaveImportSegment(context.TODO(), &datapb.SaveImportSegmentRequest{ SegmentId: 100, ChannelName: "non-channel", CollectionId: 100, @@ -2901,9 +2886,9 @@ func TestDataCoord_AddSegment(t *testing.T) { svr := newTestServer(t, nil) closeTestServer(t, svr) - status, err := svr.AddSegment(context.TODO(), &datapb.AddSegmentRequest{}) + status, err := svr.SaveImportSegment(context.TODO(), &datapb.SaveImportSegmentRequest{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.GetErrorCode()) + assert.Equal(t, commonpb.ErrorCode_DataCoordNA, status.GetErrorCode()) }) } @@ -3087,13 +3072,6 @@ func Test_initGarbageCollection(t *testing.T) { server := newTestServer2(t, nil) Params.DataCoordCfg.EnableGarbageCollection = true - t.Run("err_minio_bad_address", func(t *testing.T) { - Params.MinioCfg.Address = "host:9000:bad" - err := server.initGarbageCollection() - assert.Error(t, err) - assert.Contains(t, err.Error(), "too many colons in address") - }) - // mock CheckBucketFn getCheckBucketFnBak := getCheckBucketFn getCheckBucketFn = func(cli *minio.Client) func() error { @@ -3102,20 +3080,28 @@ func Test_initGarbageCollection(t *testing.T) { defer func() { getCheckBucketFn = getCheckBucketFnBak }() - Params.MinioCfg.Address = "minio:9000" + storage.CheckBucketRetryAttempts = 1 t.Run("ok", func(t *testing.T) { + Params.CommonCfg.StorageType = "minio" err := server.initGarbageCollection() assert.NoError(t, err) }) t.Run("iam_ok", func(t *testing.T) { + Params.CommonCfg.StorageType = "minio" Params.MinioCfg.UseIAM = true err := server.initGarbageCollection() assert.Error(t, err) - assert.Contains(t, err.Error(), "404 Not Found") }) t.Run("local storage init", func(t *testing.T) { Params.CommonCfg.StorageType = "local" err := server.initGarbageCollection() assert.NoError(t, err) }) + t.Run("err_minio_bad_address", func(t *testing.T) { + Params.CommonCfg.StorageType = "minio" + Params.MinioCfg.Address = "host:9000:bad" + err := server.initGarbageCollection() + assert.Error(t, err) + assert.Contains(t, err.Error(), "too many colons in address") + }) } diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index 9e5f037fabb06..48b3bf8988ce3 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -30,6 +30,7 @@ import ( "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/util/logutil" "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/retry" @@ -38,8 +39,13 @@ import ( "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/samber/lo" "go.uber.org/zap" + "go.uber.org/zap" ) +var ImportFlushedCheckInterval = 5 * time.Second +var ImportFlushedWaitLimit = 2 * time.Minute +var reportImportAttempts uint = 20 + // checks whether server in Healthy State func (s *Server) isClosed() bool { return atomic.LoadInt64(&s.isServing) != ServerStateHealthy @@ -439,7 +445,7 @@ func (s *Server) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPath s.segmentManager.DropSegment(ctx, req.SegmentID) s.flushCh <- req.SegmentID - if Params.DataCoordCfg.EnableCompaction { + if !req.Importing && Params.DataCoordCfg.EnableCompaction { cctx, cancel := context.WithTimeout(s.ctx, 5*time.Second) defer cancel() @@ -1094,11 +1100,12 @@ func (s *Server) Import(ctx context.Context, itr *datapb.ImportTaskRequest) (*da return resp, nil } - nodes := s.channelManager.store.GetNodes() + nodes := s.sessionManager.getLiveNodeIDs() if len(nodes) == 0 { log.Error("import failed as all DataNodes are offline") return resp, nil } + log.Info("available DataNodes are", zap.Int64s("node ID", nodes)) avaNodes := getDiff(nodes, itr.GetWorkingNodes()) if len(avaNodes) > 0 { @@ -1215,8 +1222,9 @@ func (s *Server) ReleaseSegmentLock(ctx context.Context, req *datapb.ReleaseSegm return resp, nil } -func (s *Server) AddSegment(ctx context.Context, req *datapb.AddSegmentRequest) (*commonpb.Status, error) { - log.Info("DataCoord putting segment to the right DataNode", +// SaveImportSegment saves the segment binlog paths and puts this segment to its belonging DataNode as a flushed segment. +func (s *Server) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) (*commonpb.Status, error) { + log.Info("DataCoord putting segment to the right DataNode and saving binlog path", zap.Int64("segment ID", req.GetSegmentId()), zap.Int64("collection ID", req.GetCollectionId()), zap.Int64("partition ID", req.GetPartitionId()), @@ -1228,16 +1236,106 @@ func (s *Server) AddSegment(ctx context.Context, req *datapb.AddSegmentRequest) } if s.isClosed() { log.Warn("failed to add segment for closed server") + errResp.ErrorCode = commonpb.ErrorCode_DataCoordNA errResp.Reason = msgDataCoordIsUnhealthy(Params.DataCoordCfg.GetNodeID()) return errResp, nil } + // Look for the DataNode that watches the channel. ok, nodeID := s.channelManager.getNodeIDByChannelName(req.GetChannelName()) if !ok { log.Error("no DataNode found for channel", zap.String("channel name", req.GetChannelName())) errResp.Reason = fmt.Sprint("no DataNode found for channel ", req.GetChannelName()) return errResp, nil } - s.cluster.AddSegment(s.ctx, nodeID, req) + // Start saving bin log paths. + rsp, err := s.SaveBinlogPaths(context.Background(), req.GetSaveBinlogPathReq()) + if err := VerifyResponse(rsp, err); err != nil { + log.Error("failed to SaveBinlogPaths", zap.Error(err)) + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + }, nil + } + // Call DataNode to add the new segment to its own flow graph. + cli, err := s.sessionManager.getClient(ctx, nodeID) + if err != nil { + log.Error("failed to get DataNode client for SaveImportSegment", + zap.Int64("DataNode ID", nodeID), + zap.Error(err)) + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + }, nil + } + resp, err := cli.AddImportSegment(ctx, + &datapb.AddImportSegmentRequest{ + Base: &commonpb.MsgBase{ + SourceID: Params.DataNodeCfg.GetNodeID(), + Timestamp: req.GetBase().GetTimestamp(), + }, + SegmentId: req.GetSegmentId(), + ChannelName: req.GetChannelName(), + CollectionId: req.GetCollectionId(), + PartitionId: req.GetPartitionId(), + RowNum: req.GetRowNum(), + StatsLog: req.GetSaveBinlogPathReq().GetField2StatslogPaths(), + DmlPositionId: req.GetDmlPositionId(), + }) + if err := VerifyResponse(resp, err); err != nil { + log.Error("failed to add segment", zap.Int64("DataNode ID", nodeID), zap.Error(err)) + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + }, nil + } + log.Info("succeed to add segment", zap.Int64("DataNode ID", nodeID), zap.Any("add segment req", req)) + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, nil +} + +func (s *Server) CompleteBulkLoad(ctx context.Context, req *datapb.CompleteBulkLoadRequest) (*commonpb.Status, error) { + log.Info("received CompleteBulkLoad request", zap.Int64("task ID", req.GetTaskId())) + // Check index status. + checkIndexStatus, err := s.rootCoordClient.CheckSegmentIndexReady(ctx, &internalpb.CheckSegmentIndexReadyRequest{ + TaskID: req.GetTaskId(), + ColID: req.GetCollectionId(), + SegIDs: req.GetSegmentIds(), + }) + if err != nil { + log.Warn(fmt.Sprintf("failed to wait for all index build to complete %s, but continue anyway", err.Error())) + } + if checkIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success { + log.Warn(fmt.Sprintf("failed to wait for all index build to complete %s, but continue anyway", checkIndexStatus.Reason)) + } + + // Update import task state to `ImportState_ImportCompleted`. + // Retry on errors. + err = retry.Do(s.ctx, func() error { + status, err := s.rootCoordClient.ReportImport(ctx, &rootcoordpb.ImportResult{ + TaskId: req.GetTaskId(), + State: commonpb.ImportState_ImportCompleted, + }) + return VerifyResponse(status, err) + }, retry.Attempts(reportImportAttempts)) + if err != nil { + log.Error("failed to report import, we are not able to update the import task state", + zap.Int64("task ID", req.GetTaskId()), + zap.Error(err)) + } + // Remove the `isImport` states of these segments, no matter index building check succeeded, timed up or failed. + s.UnsetIsImportingState(ctx, &datapb.UnsetIsImportingStateRequest{ + SegmentIds: req.GetSegmentIds(), + }) + + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, nil +} + +// UnsetIsImportingState unsets the isImporting states of the given segments. +func (s *Server) UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { + log.Info("unsetting isImport state of segments", zap.Int64s("segments", req.GetSegmentIds())) + for _, segID := range req.GetSegmentIds() { + s.meta.SetSegmentIsImporting(segID, false) + } return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, nil diff --git a/internal/datacoord/session_manager.go b/internal/datacoord/session_manager.go index 9cadc30f2d2bc..3a8768bbdfa05 100644 --- a/internal/datacoord/session_manager.go +++ b/internal/datacoord/session_manager.go @@ -264,29 +264,6 @@ func (c *SessionManager) GetCompactionState() map[int64]*datapb.CompactionStateR return rst } -// AddSegment calls DataNode with ID == `nodeID` to put the segment into this node. -func (c *SessionManager) AddSegment(ctx context.Context, nodeID int64, req *datapb.AddSegmentRequest) { - go c.execAddSegment(ctx, nodeID, req) -} - -func (c *SessionManager) execAddSegment(ctx context.Context, nodeID int64, req *datapb.AddSegmentRequest) { - cli, err := c.getClient(ctx, nodeID) - if err != nil { - log.Warn("failed to get client for AddSegment", zap.Int64("DataNode ID", nodeID), zap.Error(err)) - return - } - ctx, cancel := context.WithTimeout(ctx, addSegmentTimeout) - defer cancel() - req.Base.SourceID = Params.DataCoordCfg.GetNodeID() - resp, err := cli.AddSegment(ctx, req) - if err := VerifyResponse(resp, err); err != nil { - log.Warn("failed to add segment", zap.Int64("DataNode ID", nodeID), zap.Error(err)) - return - } - - log.Info("success to add segment", zap.Int64("DataNode ID", nodeID), zap.Any("add segment req", req)) -} - func (c *SessionManager) getClient(ctx context.Context, nodeID int64) (types.DataNode, error) { c.sessions.RLock() session, ok := c.sessions.data[nodeID] diff --git a/internal/datanode/compactor_test.go b/internal/datanode/compactor_test.go index 06c9a2c5114fb..3877f072d57eb 100644 --- a/internal/datanode/compactor_test.go +++ b/internal/datanode/compactor_test.go @@ -56,7 +56,15 @@ func TestCompactionTaskInnerMethods(t *testing.T) { _, _, _, err = task.getSegmentMeta(100) assert.Error(t, err) - err = replica.addNewSegment(100, 1, 10, "a", new(internalpb.MsgPosition), nil) + err = replica.addSegment(addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 100, + collID: 1, + partitionID: 10, + channelName: "a", + startPos: new(internalpb.MsgPosition), + endPos: nil, + }) require.NoError(t, err) collID, partID, meta, err := task.getSegmentMeta(100) diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index a40f09d6fe61c..74ea17f839b5e 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -79,6 +79,8 @@ const ( ConnectEtcdMaxRetryTime = 100 ) +var getFlowGraphServiceAttempts = uint(50) + // makes sure DataNode implements types.DataNode var _ types.DataNode = (*DataNode)(nil) @@ -900,9 +902,17 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) AutoIds: make([]int64, 0), RowCount: 0, } + // func to report import state to rootcoord reportFunc := func(res *rootcoordpb.ImportResult) error { - _, err := node.rootCoord.ReportImport(ctx, res) - return err + status, err := node.rootCoord.ReportImport(ctx, res) + if err != nil { + log.Error("fail to report import state to root coord", zap.Error(err)) + return err + } + if status != nil && status.ErrorCode != commonpb.ErrorCode_Success { + return errors.New(status.GetReason()) + } + return nil } if !node.isHealthy() { @@ -915,7 +925,10 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) msg := msgDataNodeIsUnhealthy(Params.DataNodeCfg.GetNodeID()) importResult.State = commonpb.ImportState_ImportFailed importResult.Infos = append(importResult.Infos, &commonpb.KeyValuePair{Key: "failed_reason", Value: msg}) - reportFunc(importResult) + reportErr := reportFunc(importResult) + if reportErr != nil { + log.Warn("fail to report import state to root coord", zap.Error(reportErr)) + } return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: msg, @@ -938,7 +951,9 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) log.Warn(msg) importResult.State = commonpb.ImportState_ImportFailed importResult.Infos = append(importResult.Infos, &commonpb.KeyValuePair{Key: "failed_reason", Value: msg}) - reportFunc(importResult) + if reportErr := reportFunc(importResult); reportErr != nil { + log.Warn("fail to report import state to root coord", zap.Error(reportErr)) + } if err != nil { return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -955,7 +970,10 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) if err != nil { importResult.State = commonpb.ImportState_ImportFailed importResult.Infos = append(importResult.Infos, &commonpb.KeyValuePair{Key: "failed_reason", Value: err.Error()}) - reportFunc(importResult) + reportErr := reportFunc(importResult) + if reportErr != nil { + log.Warn("fail to report import state to root coord", zap.Error(err)) + } return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: err.Error(), @@ -970,7 +988,10 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) if err != nil { importResult.State = commonpb.ImportState_ImportFailed importResult.Infos = append(importResult.Infos, &commonpb.KeyValuePair{Key: "failed_reason", Value: err.Error()}) - reportFunc(importResult) + reportErr := reportFunc(importResult) + if reportErr != nil { + log.Warn("fail to report import state to root coord", zap.Error(err)) + } return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: err.Error(), @@ -983,8 +1004,8 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) return resp, nil } -// AddSegment adds the segment to the current DataNode. -func (node *DataNode) AddSegment(ctx context.Context, req *datapb.AddSegmentRequest) (*commonpb.Status, error) { +// AddImportSegment adds the import segment to the current DataNode. +func (node *DataNode) AddImportSegment(ctx context.Context, req *datapb.AddImportSegmentRequest) (*commonpb.Status, error) { log.Info("adding segment to DataNode flow graph", zap.Int64("segment ID", req.GetSegmentId()), zap.Int64("collection ID", req.GetCollectionId()), @@ -992,30 +1013,54 @@ func (node *DataNode) AddSegment(ctx context.Context, req *datapb.AddSegmentRequ zap.String("channel name", req.GetChannelName()), zap.Int64("# of rows", req.GetRowNum())) // Fetch the flow graph on the given v-channel. - ds, ok := node.flowgraphManager.getFlowgraphService(req.GetChannelName()) - if !ok { + var ds *dataSyncService + // Retry in case the channel hasn't been watched yet. + err := retry.Do(ctx, func() error { + var ok bool + ds, ok = node.flowgraphManager.getFlowgraphService(req.GetChannelName()) + if !ok { + return errors.New("channel not found") + } + return nil + }, retry.Attempts(getFlowGraphServiceAttempts)) + if err != nil { log.Error("channel not found in current DataNode", zap.String("channel name", req.GetChannelName()), zap.Int64("node ID", Params.DataNodeCfg.GetNodeID())) return &commonpb.Status{ // TODO: Add specific error code. ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "channel not found in current DataNode", }, nil } // Add the new segment to the replica. if !ds.replica.hasSegment(req.GetSegmentId(), true) { - log.Info("add a new segment to replica") - err := ds.replica.addNewSegment(req.GetSegmentId(), - req.GetCollectionId(), - req.GetPartitionId(), - req.GetChannelName(), - &internalpb.MsgPosition{ - ChannelName: req.GetChannelName(), - }, - &internalpb.MsgPosition{ - ChannelName: req.GetChannelName(), - }) - if err != nil { + log.Info("adding a new segment to replica", + zap.Int64("segment ID", req.GetSegmentId())) + // Add segment as a flushed segment, but set `importing` to true to add extra information of the segment. + // By 'extra information' we mean segment info while adding a `SegmentType_New` typed segment. + if err := ds.replica.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_Flushed, + segID: req.GetSegmentId(), + collID: req.GetCollectionId(), + partitionID: req.GetPartitionId(), + channelName: req.GetChannelName(), + numOfRows: req.GetRowNum(), + statsBinLogs: req.GetStatsLog(), + startPos: &internalpb.MsgPosition{ + ChannelName: req.GetChannelName(), + MsgID: req.GetDmlPositionId(), + Timestamp: req.GetBase().GetTimestamp(), + }, + endPos: &internalpb.MsgPosition{ + ChannelName: req.GetChannelName(), + MsgID: req.GetDmlPositionId(), + Timestamp: req.GetBase().GetTimestamp(), + }, + recoverTs: req.GetBase().GetTimestamp(), + importing: true, + }); err != nil { log.Error("failed to add segment to flow graph", zap.Error(err)) return &commonpb.Status{ @@ -1025,22 +1070,21 @@ func (node *DataNode) AddSegment(ctx context.Context, req *datapb.AddSegmentRequ }, nil } } - // Update # of rows of the given segment. - ds.replica.updateStatistics(req.GetSegmentId(), req.GetRowNum()) + ds.flushingSegCache.Remove(req.GetSegmentId()) return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, nil } func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoordpb.ImportResult, schema *schemapb.CollectionSchema, ts Timestamp) importutil.ImportFlushFunc { - return func(fields map[storage.FieldID]storage.FieldData, shardNum int) error { - if shardNum >= len(req.GetImportTask().GetChannelNames()) { - log.Error("import task returns invalid shard number", - zap.Int("shard num", shardNum), + return func(fields map[storage.FieldID]storage.FieldData, shardID int) error { + if shardID >= len(req.GetImportTask().GetChannelNames()) { + log.Error("import task returns invalid shard ID", + zap.Int("shard ID", shardID), zap.Int("# of channels", len(req.GetImportTask().GetChannelNames())), zap.Strings("channel names", req.GetImportTask().GetChannelNames()), ) - return fmt.Errorf("syncSegmentID Failed: invalid shard number %d", shardNum) + return fmt.Errorf("syncSegmentID Failed: invalid shard ID %d", shardID) } tr := timerecord.NewTimeRecorder("import callback function") @@ -1055,10 +1099,12 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root } // ask DataCoord to alloc a new segment - log.Info("import task flush segment", zap.Any("ChannelNames", req.ImportTask.ChannelNames), zap.Int("shardNum", shardNum)) + log.Info("import task flush segment", + zap.Any("channel names", req.GetImportTask().GetChannelNames()), + zap.Int("shard ID", shardID)) segReqs := []*datapb.SegmentIDRequest{ { - ChannelName: req.ImportTask.ChannelNames[shardNum], + ChannelName: req.ImportTask.ChannelNames[shardID], Count: uint32(rowNum), CollectionID: req.GetImportTask().GetCollectionId(), PartitionID: req.GetImportTask().GetPartitionId(), @@ -1104,6 +1150,7 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root data := BufferData{buffer: &InsertData{ Data: fields, }} + data.updateSize(int64(rowNum)) meta := &etcdpb.CollectionMeta{ ID: req.GetImportTask().GetCollectionId(), Schema: schema, @@ -1140,8 +1187,8 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root kvs[key] = blob.Value[:] field2Insert[fieldID] = &datapb.Binlog{ EntriesNum: data.size, - TimestampFrom: 0, //TODO - TimestampTo: 0, //TODO, + TimestampFrom: ts, + TimestampTo: ts, LogPath: key, LogSize: int64(len(blob.Value)), } @@ -1165,9 +1212,9 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root key := path.Join(node.chunkManager.RootPath(), common.SegmentStatslogPath, k) kvs[key] = blob.Value field2Stats[fieldID] = &datapb.Binlog{ - EntriesNum: 0, - TimestampFrom: 0, //TODO - TimestampTo: 0, //TODO, + EntriesNum: data.size, + TimestampFrom: ts, + TimestampTo: ts, LogPath: key, LogSize: int64(len(blob.Value)), } @@ -1189,51 +1236,85 @@ func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, res *root fieldStats = append(fieldStats, &datapb.FieldBinlog{FieldID: k, Binlogs: []*datapb.Binlog{v}}) } - log.Info("now adding segment to the correct DataNode flow graph") - // Ask DataCoord to add segment to the corresponding DataNode flow graph. - node.dataCoord.AddSegment(context.Background(), &datapb.AddSegmentRequest{ - Base: &commonpb.MsgBase{ - SourceID: Params.DataNodeCfg.GetNodeID(), - }, - SegmentId: segmentID, - ChannelName: segReqs[0].GetChannelName(), - CollectionId: req.GetImportTask().GetCollectionId(), - PartitionId: req.GetImportTask().GetPartitionId(), - RowNum: int64(rowNum), - }) - - binlogReq := &datapb.SaveBinlogPathsRequest{ - Base: &commonpb.MsgBase{ - MsgType: 0, //TODO msg type - MsgID: 0, //TODO msg id - Timestamp: 0, //TODO time stamp - SourceID: Params.DataNodeCfg.GetNodeID(), - }, - SegmentID: segmentID, - CollectionID: req.GetImportTask().GetCollectionId(), - Field2BinlogPaths: fieldInsert, - Field2StatslogPaths: fieldStats, - Importing: true, + // Fetch the flow graph on the given v-channel. + var ds *dataSyncService + // Retry in case the channel hasn't been watched yet. + err = retry.Do(context.Background(), func() error { + var ok bool + ds, ok = node.flowgraphManager.getFlowgraphService(segReqs[0].GetChannelName()) + if !ok { + return errors.New("channel not found") + } + return nil + }, retry.Attempts(getFlowGraphServiceAttempts)) + if err != nil { + log.Error("channel not found in current DataNode", + zap.Error(err), + zap.String("channel name", segReqs[0].GetChannelName()), + zap.Int64("node ID", Params.DataNodeCfg.GetNodeID())) + return err } - + // Get the current dml channel position ID, that will be used in segments start positions and end positions. + posID, err := ds.getDmlChannelPositionByBroadcast(context.Background(), segReqs[0].GetChannelName(), req.GetBase().GetTimestamp()) + if err != nil { + return errors.New("failed to get channel position") + } + log.Info("adding segment to the correct DataNode flow graph and saving binlog paths") err = retry.Do(context.Background(), func() error { - rsp, err := node.dataCoord.SaveBinlogPaths(context.Background(), binlogReq) - // should be network issue, return error and retry + // Ask DataCoord to save binlog path and add segment to the corresponding DataNode flow graph. + resp, err := node.dataCoord.SaveImportSegment(context.Background(), &datapb.SaveImportSegmentRequest{ + Base: &commonpb.MsgBase{ + SourceID: Params.DataNodeCfg.GetNodeID(), + // Pass current timestamp downstream. + Timestamp: ts, + }, + SegmentId: segmentID, + ChannelName: segReqs[0].GetChannelName(), + CollectionId: req.GetImportTask().GetCollectionId(), + PartitionId: req.GetImportTask().GetPartitionId(), + RowNum: int64(rowNum), + // Pass the DML position ID downstream. + DmlPositionId: posID, + SaveBinlogPathReq: &datapb.SaveBinlogPathsRequest{ + Base: &commonpb.MsgBase{ + MsgType: 0, + MsgID: 0, + Timestamp: 0, + SourceID: Params.DataNodeCfg.GetNodeID(), + }, + SegmentID: segmentID, + CollectionID: req.GetImportTask().GetCollectionId(), + Field2BinlogPaths: fieldInsert, + Field2StatslogPaths: fieldStats, + // Set start positions of a SaveBinlogPathRequest explicitly. + StartPositions: []*datapb.SegmentStartPosition{ + { + StartPosition: &internalpb.MsgPosition{ + ChannelName: segReqs[0].GetChannelName(), + MsgID: posID, + Timestamp: ts, + }, + SegmentID: segmentID, + }, + }, + Importing: true, + }, + }) + // Only retrying when DataCoord is unhealthy or err != nil, otherwise return immediately. if err != nil { return fmt.Errorf(err.Error()) } - - // TODO should retry only when datacoord status is unhealthy - if rsp.ErrorCode != commonpb.ErrorCode_Success { - return fmt.Errorf("data service save bin log path failed, reason = %s", rsp.Reason) + if resp.ErrorCode != commonpb.ErrorCode_Success && resp.ErrorCode != commonpb.ErrorCode_DataCoordNA { + return retry.Unrecoverable(fmt.Errorf("failed to save import segment, reason = %s", resp.Reason)) + } else if resp.ErrorCode == commonpb.ErrorCode_DataCoordNA { + return fmt.Errorf("failed to save import segment: %s", resp.GetReason()) } return nil }) if err != nil { - log.Warn("failed to SaveBinlogPaths", zap.Error(err)) + log.Warn("failed to save import segment", zap.Error(err)) return err } - log.Info("segment imported and persisted", zap.Int64("segmentID", segmentID)) res.Segments = append(res.Segments, segmentID) res.RowCount += int64(rowNum) diff --git a/internal/datanode/data_node_test.go b/internal/datanode/data_node_test.go index 75d0ba1539428..206bc3a439be3 100644 --- a/internal/datanode/data_node_test.go +++ b/internal/datanode/data_node_test.go @@ -28,6 +28,7 @@ import ( "testing" "time" + "github.com/golang/protobuf/proto" "github.com/milvus-io/milvus/internal/common" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/log" @@ -36,9 +37,6 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" - "github.com/milvus-io/milvus/api/commonpb" - "github.com/milvus-io/milvus/api/milvuspb" - "github.com/milvus-io/milvus/api/schemapb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" @@ -46,8 +44,6 @@ import ( "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/sessionutil" - - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" @@ -217,7 +213,15 @@ func TestDataNode(t *testing.T) { fgservice, ok := node1.flowgraphManager.getFlowgraphService(dmChannelName) assert.True(t, ok) - err = fgservice.replica.addNewSegment(0, 1, 1, dmChannelName, &internalpb.MsgPosition{}, &internalpb.MsgPosition{}) + err = fgservice.replica.addSegment(addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 0, + collID: 1, + partitionID: 1, + channelName: dmChannelName, + startPos: &internalpb.MsgPosition{}, + endPos: &internalpb.MsgPosition{}, + }) assert.Nil(t, err) req := &datapb.FlushSegmentsRequest{ @@ -425,6 +429,28 @@ func TestDataNode(t *testing.T) { ] }`) + chName1 := "fake-by-dev-rootcoord-dml-testimport-1" + chName2 := "fake-by-dev-rootcoord-dml-testimport-2" + err := node.flowgraphManager.addAndStart(node, &datapb.VchannelInfo{ + CollectionID: 100, + ChannelName: chName1, + UnflushedSegmentIds: []int64{}, + FlushedSegmentIds: []int64{}, + }) + require.Nil(t, err) + err = node.flowgraphManager.addAndStart(node, &datapb.VchannelInfo{ + CollectionID: 100, + ChannelName: chName2, + UnflushedSegmentIds: []int64{}, + FlushedSegmentIds: []int64{}, + }) + require.Nil(t, err) + + _, ok := node.flowgraphManager.getFlowgraphService(chName1) + assert.True(t, ok) + _, ok = node.flowgraphManager.getFlowgraphService(chName2) + assert.True(t, ok) + filePath := "import/rows_1.json" err = node.chunkManager.Write(filePath, content) assert.NoError(t, err) @@ -432,11 +458,31 @@ func TestDataNode(t *testing.T) { ImportTask: &datapb.ImportTask{ CollectionId: 100, PartitionId: 100, - ChannelNames: []string{"ch1", "ch2"}, + ChannelNames: []string{chName1, chName2}, Files: []string{filePath}, RowBased: true, }, } + node.rootCoord.(*RootCoordFactory).ReportImportErr = true + _, err = node.Import(context.WithValue(ctx, ctxKey{}, ""), req) + assert.NoError(t, err) + node.rootCoord.(*RootCoordFactory).ReportImportErr = false + + node.rootCoord.(*RootCoordFactory).ReportImportNotSuccess = true + _, err = node.Import(context.WithValue(ctx, ctxKey{}, ""), req) + assert.NoError(t, err) + node.rootCoord.(*RootCoordFactory).ReportImportNotSuccess = false + + node.dataCoord.(*DataCoordFactory).AddSegmentError = true + _, err = node.Import(context.WithValue(ctx, ctxKey{}, ""), req) + assert.NoError(t, err) + node.dataCoord.(*DataCoordFactory).AddSegmentError = false + + node.dataCoord.(*DataCoordFactory).AddSegmentNotSuccess = true + _, err = node.Import(context.WithValue(ctx, ctxKey{}, ""), req) + assert.NoError(t, err) + node.dataCoord.(*DataCoordFactory).AddSegmentNotSuccess = false + stat, err := node.Import(context.WithValue(ctx, ctxKey{}, ""), req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, stat.GetErrorCode()) @@ -652,7 +698,7 @@ func TestDataNode_AddSegment(t *testing.T) { _, ok = node.flowgraphManager.getFlowgraphService(chName2) assert.True(t, ok) - stat, err := node.AddSegment(context.WithValue(ctx, ctxKey{}, ""), &datapb.AddSegmentRequest{ + stat, err := node.AddImportSegment(context.WithValue(ctx, ctxKey{}, ""), &datapb.AddImportSegmentRequest{ SegmentId: 100, CollectionId: 100, PartitionId: 100, @@ -663,7 +709,8 @@ func TestDataNode_AddSegment(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_Success, stat.GetErrorCode()) assert.Equal(t, "", stat.GetReason()) - stat, err = node.AddSegment(context.WithValue(ctx, ctxKey{}, ""), &datapb.AddSegmentRequest{ + getFlowGraphServiceAttempts = 3 + stat, err = node.AddImportSegment(context.WithValue(ctx, ctxKey{}, ""), &datapb.AddImportSegmentRequest{ SegmentId: 100, CollectionId: 100, PartitionId: 100, @@ -978,11 +1025,35 @@ func TestDataNode_ResendSegmentStats(t *testing.T) { fgService, ok := node.flowgraphManager.getFlowgraphService(dmChannelName) assert.True(t, ok) - err = fgService.replica.addNewSegment(0, 1, 1, dmChannelName, &internalpb.MsgPosition{}, &internalpb.MsgPosition{}) + err = fgService.replica.addSegment(addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 0, + collID: 1, + partitionID: 1, + channelName: dmChannelName, + startPos: &internalpb.MsgPosition{}, + endPos: &internalpb.MsgPosition{}, + }) assert.Nil(t, err) - err = fgService.replica.addNewSegment(1, 1, 2, dmChannelName, &internalpb.MsgPosition{}, &internalpb.MsgPosition{}) + err = fgService.replica.addSegment(addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 1, + collID: 1, + partitionID: 2, + channelName: dmChannelName, + startPos: &internalpb.MsgPosition{}, + endPos: &internalpb.MsgPosition{}, + }) assert.Nil(t, err) - err = fgService.replica.addNewSegment(2, 1, 3, dmChannelName, &internalpb.MsgPosition{}, &internalpb.MsgPosition{}) + err = fgService.replica.addSegment(addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 2, + collID: 1, + partitionID: 3, + channelName: dmChannelName, + startPos: &internalpb.MsgPosition{}, + endPos: &internalpb.MsgPosition{}, + }) assert.Nil(t, err) req := &datapb.ResendSegmentStatsRequest{ diff --git a/internal/datanode/data_sync_service.go b/internal/datanode/data_sync_service.go index 2d5a742b88a17..c863e73097092 100644 --- a/internal/datanode/data_sync_service.go +++ b/internal/datanode/data_sync_service.go @@ -28,10 +28,12 @@ import ( "github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/concurrency" "github.com/milvus-io/milvus/internal/util/flowgraph" + "github.com/milvus-io/milvus/internal/util/funcutil" ) // dataSyncService controls a flowgraph for a specific collection @@ -210,8 +212,16 @@ func (dsService *dataSyncService) initNodes(vchanInfo *datapb.VchannelInfo) erro // avoid closure capture iteration variable segment := us future := dsService.ioPool.Submit(func() (interface{}, error) { - if err := dsService.replica.addNormalSegment(segment.GetID(), segment.GetCollectionID(), segment.GetPartitionID(), segment.GetInsertChannel(), - segment.GetNumOfRows(), segment.GetStatslogs(), cp, vchanInfo.GetSeekPosition().GetTimestamp()); err != nil { + if err := dsService.replica.addSegment(addSegmentReq{ + segType: datapb.SegmentType_Normal, + segID: segment.GetID(), + collID: segment.CollectionID, + partitionID: segment.PartitionID, + channelName: segment.GetInsertChannel(), + numOfRows: segment.GetNumOfRows(), + statsBinLogs: segment.Statslogs, + cp: cp, + recoverTs: vchanInfo.GetSeekPosition().GetTimestamp()}); err != nil { return nil, err } return nil, nil @@ -238,8 +248,16 @@ func (dsService *dataSyncService) initNodes(vchanInfo *datapb.VchannelInfo) erro // avoid closure capture iteration variable segment := fs future := dsService.ioPool.Submit(func() (interface{}, error) { - if err := dsService.replica.addFlushedSegment(segment.GetID(), segment.GetCollectionID(), segment.GetPartitionID(), segment.GetInsertChannel(), - segment.GetNumOfRows(), segment.GetStatslogs(), vchanInfo.GetSeekPosition().GetTimestamp()); err != nil { + if err := dsService.replica.addSegment(addSegmentReq{ + segType: datapb.SegmentType_Flushed, + segID: segment.GetID(), + collID: segment.CollectionID, + partitionID: segment.PartitionID, + channelName: segment.GetInsertChannel(), + numOfRows: segment.GetNumOfRows(), + statsBinLogs: segment.Statslogs, + recoverTs: vchanInfo.GetSeekPosition().GetTimestamp(), + }); err != nil { return nil, err } return nil, nil @@ -372,3 +390,56 @@ func (dsService *dataSyncService) getSegmentInfos(segmentIDs []int64) ([]*datapb } return infoResp.Infos, nil } + +func (dsService *dataSyncService) getDmlChannelPositionByBroadcast(ctx context.Context, channelName string, ts uint64) ([]byte, error) { + msgPack := msgstream.MsgPack{} + baseMsg := msgstream.BaseMsg{ + Ctx: ctx, + BeginTimestamp: ts, + EndTimestamp: ts, + HashValues: []uint32{0}, + } + msg := &msgstream.InsertMsg{ + BaseMsg: baseMsg, + InsertRequest: internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_TimeTick, + MsgID: 0, + Timestamp: ts, + SourceID: Params.DataNodeCfg.GetNodeID(), + }, + }, + } + msgPack.Msgs = append(msgPack.Msgs, msg) + + pChannelName := funcutil.ToPhysicalChannel(channelName) + log.Info("ddNode convert vChannel to pChannel", + zap.String("vChannelName", channelName), + zap.String("pChannelName", pChannelName), + ) + + dmlStream, err := dsService.msFactory.NewMsgStream(ctx) + if err != nil { + return nil, err + } + dmlStream.SetRepackFunc(msgstream.DefaultRepackFunc) + dmlStream.AsProducer([]string{pChannelName}) + dmlStream.Start() + defer dmlStream.Close() + + result := make(map[string][]byte) + + ids, err := dmlStream.BroadcastMark(&msgPack) + if err != nil { + log.Error("BroadcastMark failed", zap.Error(err), zap.String("channelName", channelName)) + return nil, err + } + for cn, idList := range ids { + // idList should have length 1, just flat by iteration + for _, id := range idList { + result[cn] = id.Serialize() + } + } + + return result[pChannelName], nil +} diff --git a/internal/datanode/data_sync_service_test.go b/internal/datanode/data_sync_service_test.go index ec78dd940f65e..b0f1bad63fbf5 100644 --- a/internal/datanode/data_sync_service_test.go +++ b/internal/datanode/data_sync_service_test.go @@ -139,6 +139,11 @@ func TestDataSyncService_newDataSyncService(te *testing.T) { 0, 0, "", 0, 0, 0, "", 0, "replica nil"}, + {true, false, &mockMsgStreamFactory{true, true}, + 1, "by-dev-rootcoord-dml-test_v1", + 1, 1, "by-dev-rootcoord-dml-test_v1", 0, + 1, 2, "by-dev-rootcoord-dml-test_v1", 0, + "add normal segments"}, } cm := storage.NewLocalChunkManager(storage.RootPath(dataSyncServiceTestDir)) defer cm.RemoveWithPrefix("") @@ -419,13 +424,42 @@ func TestClearGlobalFlushingCache(t *testing.T) { flushingSegCache: cache, } - err = replica.addNewSegment(1, 1, 1, "", &internalpb.MsgPosition{}, &internalpb.MsgPosition{}) + err = replica.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 1, + collID: 1, + partitionID: 1, + channelName: "", + startPos: &internalpb.MsgPosition{}, + endPos: &internalpb.MsgPosition{}}) assert.NoError(t, err) - err = replica.addFlushedSegment(2, 1, 1, "", 0, nil, 0) + err = replica.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_Flushed, + segID: 2, + collID: 1, + partitionID: 1, + channelName: "", + numOfRows: 0, + statsBinLogs: nil, + recoverTs: 0, + }) assert.NoError(t, err) - err = replica.addNormalSegment(3, 1, 1, "", 0, nil, nil, 0) + err = replica.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_Normal, + segID: 3, + collID: 1, + partitionID: 1, + channelName: "", + numOfRows: 0, + statsBinLogs: nil, + cp: nil, + recoverTs: 0, + }) assert.NoError(t, err) cache.checkOrCache(1) @@ -439,3 +473,28 @@ func TestClearGlobalFlushingCache(t *testing.T) { assert.False(t, cache.checkIfCached(3)) assert.True(t, cache.checkIfCached(4)) } + +func TestGetDmlChannelPositionByBroadcast(t *testing.T) { + delay := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, cancel := context.WithDeadline(context.Background(), delay) + defer cancel() + factory := dependency.NewDefaultFactory(true) + + dataCoord := &DataCoordFactory{} + dsService := &dataSyncService{ + dataCoord: dataCoord, + msFactory: factory, + } + + dmlChannelName := "fake-by-dev-rootcoord-dml-channel_12345v0" + + insertStream, _ := factory.NewMsgStream(ctx) + insertStream.AsProducer([]string{dmlChannelName}) + + var insertMsgStream = insertStream + insertMsgStream.Start() + + id, err := dsService.getDmlChannelPositionByBroadcast(ctx, dmlChannelName, 0) + assert.NoError(t, err) + assert.NotNil(t, id) +} diff --git a/internal/datanode/flow_graph_insert_buffer_node.go b/internal/datanode/flow_graph_insert_buffer_node.go index 5fc652d599fb8..7c75404b25926 100644 --- a/internal/datanode/flow_graph_insert_buffer_node.go +++ b/internal/datanode/flow_graph_insert_buffer_node.go @@ -476,8 +476,16 @@ func (ibNode *insertBufferNode) updateSegStatesInReplica(insertMsgs []*msgstream partitionID := msg.GetPartitionID() if !ibNode.replica.hasSegment(currentSegID, true) { - err = ibNode.replica.addNewSegment(currentSegID, collID, partitionID, msg.GetShardName(), - startPos, endPos) + err = ibNode.replica.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: currentSegID, + collID: collID, + partitionID: partitionID, + channelName: msg.GetShardName(), + startPos: startPos, + endPos: endPos, + }) if err != nil { log.Error("add segment wrong", zap.Int64("segID", currentSegID), diff --git a/internal/datanode/flow_graph_insert_buffer_node_test.go b/internal/datanode/flow_graph_insert_buffer_node_test.go index 5c491e4fdf4c5..29f6cdbf4f29f 100644 --- a/internal/datanode/flow_graph_insert_buffer_node_test.go +++ b/internal/datanode/flow_graph_insert_buffer_node_test.go @@ -24,16 +24,6 @@ import ( "testing" "time" - "github.com/milvus-io/milvus/internal/util/retry" - - "github.com/milvus-io/milvus/internal/util/typeutil" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/internal/util/dependency" - "github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/flowgraph" @@ -82,8 +72,16 @@ func TestFlowGraphInsertBufferNodeCreate(t *testing.T) { replica, err := newReplica(ctx, mockRootCoord, cm, collMeta.ID) assert.Nil(t, err) - - err = replica.addNewSegment(1, collMeta.ID, 0, insertChannelName, &internalpb.MsgPosition{}, &internalpb.MsgPosition{}) + err = replica.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 1, + collID: collMeta.ID, + partitionID: 0, + channelName: insertChannelName, + startPos: &internalpb.MsgPosition{}, + endPos: &internalpb.MsgPosition{}, + }) require.NoError(t, err) factory := dependency.NewDefaultFactory(true) @@ -170,7 +168,16 @@ func TestFlowGraphInsertBufferNode_Operate(t *testing.T) { replica, err := newReplica(ctx, mockRootCoord, cm, collMeta.ID) assert.Nil(t, err) - err = replica.addNewSegment(1, collMeta.ID, 0, insertChannelName, &internalpb.MsgPosition{}, &internalpb.MsgPosition{}) + err = replica.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 1, + collID: collMeta.ID, + partitionID: 0, + channelName: insertChannelName, + startPos: &internalpb.MsgPosition{}, + endPos: &internalpb.MsgPosition{}, + }) require.NoError(t, err) factory := dependency.NewDefaultFactory(true) @@ -922,8 +929,16 @@ func TestInsertBufferNode_bufferInsertMsg(t *testing.T) { replica, err := newReplica(ctx, mockRootCoord, cm, collMeta.ID) assert.Nil(t, err) - - err = replica.addNewSegment(1, collMeta.ID, 0, insertChannelName, &internalpb.MsgPosition{}, &internalpb.MsgPosition{Timestamp: 101}) + err = replica.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 1, + collID: collMeta.ID, + partitionID: 0, + channelName: insertChannelName, + startPos: &internalpb.MsgPosition{}, + endPos: &internalpb.MsgPosition{Timestamp: 101}, + }) require.NoError(t, err) factory := dependency.NewDefaultFactory(true) diff --git a/internal/datanode/flow_graph_manager_test.go b/internal/datanode/flow_graph_manager_test.go index 99c773a20a006..bffa846a569cf 100644 --- a/internal/datanode/flow_graph_manager_test.go +++ b/internal/datanode/flow_graph_manager_test.go @@ -95,7 +95,15 @@ func TestFlowGraphManager(t *testing.T) { fg, ok := fm.getFlowgraphService(vchanName) require.True(t, ok) - err = fg.replica.addNewSegment(100, 1, 10, vchanName, &internalpb.MsgPosition{}, &internalpb.MsgPosition{}) + err = fg.replica.addSegment(addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 100, + collID: 1, + partitionID: 10, + channelName: vchanName, + startPos: &internalpb.MsgPosition{}, + endPos: &internalpb.MsgPosition{}, + }) require.NoError(t, err) tests := []struct { diff --git a/internal/datanode/flush_manager_test.go b/internal/datanode/flush_manager_test.go index ded516629a317..9eef954e0f20f 100644 --- a/internal/datanode/flush_manager_test.go +++ b/internal/datanode/flush_manager_test.go @@ -614,11 +614,17 @@ func TestDropVirtualChannelFunc(t *testing.T) { } dropFunc := dropVirtualChannelFunc(dsService, retry.Attempts(1)) t.Run("normal run", func(t *testing.T) { - replica.addNewSegment(2, 1, 10, "vchan_01", &internalpb.MsgPosition{ - ChannelName: "vchan_01", - MsgID: []byte{1, 2, 3}, - Timestamp: 10, - }, nil) + replica.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 2, + collID: 1, + partitionID: 10, + channelName: "vchan_01", startPos: &internalpb.MsgPosition{ + ChannelName: "vchan_01", + MsgID: []byte{1, 2, 3}, + Timestamp: 10, + }, endPos: nil}) assert.NotPanics(t, func() { dropFunc([]*segmentFlushPack{ { diff --git a/internal/datanode/mock_test.go b/internal/datanode/mock_test.go index 8e75772d37346..e4afe1751b55f 100644 --- a/internal/datanode/mock_test.go +++ b/internal/datanode/mock_test.go @@ -52,6 +52,26 @@ import ( const ctxTimeInMillisecond = 5000 const debug = false +// As used in data_sync_service_test.go +var segID2SegInfo = map[int64]*datapb.SegmentInfo{ + 1: { + ID: 1, + CollectionID: 1, + PartitionID: 1, + InsertChannel: "by-dev-rootcoord-dml-test_v1", + }, + 2: { + ID: 2, + CollectionID: 1, + InsertChannel: "by-dev-rootcoord-dml-test_v1", + }, + 3: { + ID: 3, + CollectionID: 1, + InsertChannel: "by-dev-rootcoord-dml-test_v1", + }, +} + var emptyFlushAndDropFunc flushAndDropFunc = func(_ []*segmentFlushPack) {} func newIDLEDataNodeMock(ctx context.Context, pkType schemapb.DataType) *DataNode { @@ -159,6 +179,9 @@ type RootCoordFactory struct { collectionName string collectionID UniqueID pkType schemapb.DataType + + ReportImportErr bool + ReportImportNotSuccess bool } type DataCoordFactory struct { @@ -175,6 +198,9 @@ type DataCoordFactory struct { GetSegmentInfosError bool GetSegmentInfosNotSuccess bool + + AddSegmentError bool + AddSegmentNotSuccess bool } func (ds *DataCoordFactory) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) { @@ -225,7 +251,19 @@ func (ds *DataCoordFactory) UpdateSegmentStatistics(ctx context.Context, req *da }, nil } -func (ds *DataCoordFactory) AddSegment(ctx context.Context, req *datapb.AddSegmentRequest) (*commonpb.Status, error) { +func (ds *DataCoordFactory) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) (*commonpb.Status, error) { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, nil +} + +func (ds *DataCoordFactory) CompleteBulkLoad(context.Context, *datapb.CompleteBulkLoadRequest) (*commonpb.Status, error) { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, nil +} + +func (ds *DataCoordFactory) UnsetIsImportingState(context.Context, *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, nil @@ -245,9 +283,13 @@ func (ds *DataCoordFactory) GetSegmentInfo(ctx context.Context, req *datapb.GetS } var segmentInfos []*datapb.SegmentInfo for _, segmentID := range req.SegmentIDs { - segmentInfos = append(segmentInfos, &datapb.SegmentInfo{ - ID: segmentID, - }) + if segInfo, ok := segID2SegInfo[segmentID]; ok { + segmentInfos = append(segmentInfos, segInfo) + } else { + segmentInfos = append(segmentInfos, &datapb.SegmentInfo{ + ID: segmentID, + }) + } } return &datapb.GetSegmentInfoResponse{ Status: &commonpb.Status{ @@ -974,6 +1016,16 @@ func (m *RootCoordFactory) ReportImport(ctx context.Context, req *rootcoordpb.Im return nil, fmt.Errorf("injected error") } } + if m.ReportImportErr { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, fmt.Errorf("mock error") + } + if m.ReportImportNotSuccess { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + }, nil + } return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, nil diff --git a/internal/datanode/segment_replica.go b/internal/datanode/segment_replica.go index 97343800cd716..41050bd4b2435 100644 --- a/internal/datanode/segment_replica.go +++ b/internal/datanode/segment_replica.go @@ -60,11 +60,9 @@ type Replica interface { listAllSegmentIDs() []UniqueID listNotFlushedSegmentIDs() []UniqueID + addSegment(req addSegmentReq) error listPartitionSegments(partID UniqueID) []UniqueID - addNewSegment(segID, collID, partitionID UniqueID, channelName string, startPos, endPos *internalpb.MsgPosition) error - addNormalSegment(segID, collID, partitionID UniqueID, channelName string, numOfRows int64, statsBinlog []*datapb.FieldBinlog, cp *segmentCheckPoint, recoverTs Timestamp) error filterSegments(channelName string, partitionID UniqueID) []*Segment - addFlushedSegment(segID, collID, partitionID UniqueID, channelName string, numOfRows int64, statsBinlog []*datapb.FieldBinlog, recoverTs Timestamp) error listNewSegmentsStartPositions() []*datapb.SegmentStartPosition listSegmentsCheckPoints() map[UniqueID]segmentCheckPoint updateSegmentEndPosition(segID UniqueID, endPos *internalpb.MsgPosition) @@ -120,6 +118,18 @@ type SegmentReplica struct { chunkManager storage.ChunkManager } +type addSegmentReq struct { + segType datapb.SegmentType + segID, collID, partitionID UniqueID + channelName string + numOfRows int64 + startPos, endPos *internalpb.MsgPosition + statsBinLogs []*datapb.FieldBinlog + cp *segmentCheckPoint + recoverTs Timestamp + importing bool +} + func (s *Segment) updatePk(pk primaryKey) error { if s.minPK == nil { s.minPK = pk @@ -307,49 +317,76 @@ func (replica *SegmentReplica) initSegmentBloomFilter(s *Segment) error { return nil } -// addNewSegment adds a *New* and *NotFlushed* new segment. Before add, please make sure there's no -// such segment by `hasSegment` -func (replica *SegmentReplica) addNewSegment(segID, collID, partitionID UniqueID, channelName string, - startPos, endPos *internalpb.MsgPosition) error { - - log := log.With( - zap.Int64("segment ID", segID), - zap.Int64("collection ID", collID), - zap.Int64("partition ID", partitionID), - zap.String("channel name", channelName)) - - if collID != replica.collectionID { - log.Warn("Mismatch collection", - zap.Int64("expected collectionID", replica.collectionID)) - return fmt.Errorf("mismatch collection, ID=%d", collID) - } - - log.Info("Add new segment") - +// addSegment adds the segment to current replica. Segments can be added as *new*, *normal* or *flushed*. +// Make sure to verify `replica.hasSegment(segID)` == false before calling `replica.addSegment()`. +func (replica *SegmentReplica) addSegment(req addSegmentReq) error { + if req.collID != replica.collectionID { + log.Warn("collection mismatch", + zap.Int64("current collection ID", req.collID), + zap.Int64("expected collection ID", replica.collectionID)) + return fmt.Errorf("mismatch collection, ID=%d", req.collID) + } + log.Info("adding segment", + zap.String("segment type", req.segType.String()), + zap.Int64("segment ID", req.segID), + zap.Int64("collection ID", req.collID), + zap.Int64("partition ID", req.partitionID), + zap.String("channel name", req.channelName), + zap.Any("start position", req.startPos), + zap.Any("end position", req.endPos), + zap.Any("checkpoints", req.cp), + zap.Uint64("recover ts", req.recoverTs), + zap.Bool("importing", req.importing), + ) seg := &Segment{ - collectionID: collID, - partitionID: partitionID, - segmentID: segID, - channelName: channelName, - - checkPoint: segmentCheckPoint{0, *startPos}, - startPos: startPos, - endPos: endPos, + collectionID: req.collID, + partitionID: req.partitionID, + segmentID: req.segID, + channelName: req.channelName, + numRows: req.numOfRows, // 0 if segType == NEW + } + if req.importing || req.segType == datapb.SegmentType_New { + seg.checkPoint = segmentCheckPoint{0, *req.startPos} + seg.startPos = req.startPos + seg.endPos = req.endPos + } + if req.segType == datapb.SegmentType_Normal { + if req.cp != nil { + seg.checkPoint = *req.cp + seg.endPos = &req.cp.pos + } } - - err := replica.initSegmentBloomFilter(seg) + // Set up bloom filter. + err := replica.initPKBloomFilter(seg, req.statsBinLogs, req.recoverTs) if err != nil { - log.Warn("failed to addNewSegment, init segment bf returns error", zap.Error(err)) + log.Error("failed to init bloom filter", + zap.Int64("segment ID", req.segID), + zap.Error(err)) return err } - - seg.isNew.Store(true) - seg.isFlushed.Store(false) - + // Please ignore `isNew` and `isFlushed` as they are for debugging only. + if req.segType == datapb.SegmentType_New { + seg.isNew.Store(true) + } else { + seg.isNew.Store(false) + } + if req.segType == datapb.SegmentType_Flushed { + seg.isFlushed.Store(true) + } else { + seg.isFlushed.Store(false) + } replica.segMu.Lock() - defer replica.segMu.Unlock() - replica.newSegments[segID] = seg - metrics.DataNodeNumUnflushedSegments.WithLabelValues(fmt.Sprint(Params.DataNodeCfg.GetNodeID())).Inc() + if req.segType == datapb.SegmentType_New { + replica.newSegments[req.segID] = seg + } else if req.segType == datapb.SegmentType_Normal { + replica.normalSegments[req.segID] = seg + } else if req.segType == datapb.SegmentType_Flushed { + replica.flushedSegments[req.segID] = seg + } + replica.segMu.Unlock() + if req.segType == datapb.SegmentType_New || req.segType == datapb.SegmentType_Normal { + metrics.DataNodeNumUnflushedSegments.WithLabelValues(fmt.Sprint(Params.DataNodeCfg.GetNodeID())).Inc() + } return nil } @@ -394,92 +431,6 @@ func (replica *SegmentReplica) filterSegments(channelName string, partitionID Un return results } -// addNormalSegment adds a *NotNew* and *NotFlushed* segment. Before add, please make sure there's no -// such segment by `hasSegment` -func (replica *SegmentReplica) addNormalSegment(segID, collID, partitionID UniqueID, channelName string, numOfRows int64, statsBinlogs []*datapb.FieldBinlog, cp *segmentCheckPoint, recoverTs Timestamp) error { - log := log.With( - zap.Int64("segment ID", segID), - zap.Int64("collection ID", collID), - zap.Int64("partition ID", partitionID), - zap.String("channel name", channelName)) - - if collID != replica.collectionID { - log.Warn("Mismatch collection", - zap.Int64("expected collectionID", replica.collectionID)) - return fmt.Errorf("mismatch collection, ID=%d", collID) - } - - log.Info("Add Normal segment") - - seg := &Segment{ - collectionID: collID, - partitionID: partitionID, - segmentID: segID, - channelName: channelName, - numRows: numOfRows, - } - - if cp != nil { - seg.checkPoint = *cp - seg.endPos = &cp.pos - } - err := replica.initPKBloomFilter(seg, statsBinlogs, recoverTs) - if err != nil { - return err - } - - seg.isNew.Store(false) - seg.isFlushed.Store(false) - - replica.segMu.Lock() - replica.normalSegments[segID] = seg - replica.segMu.Unlock() - metrics.DataNodeNumUnflushedSegments.WithLabelValues(fmt.Sprint(Params.DataNodeCfg.GetNodeID())).Inc() - - return nil -} - -// addFlushedSegment adds a *Flushed* segment. Before add, please make sure there's no -// such segment by `hasSegment` -func (replica *SegmentReplica) addFlushedSegment(segID, collID, partitionID UniqueID, channelName string, numOfRows int64, statsBinlogs []*datapb.FieldBinlog, recoverTs Timestamp) error { - - log := log.With( - zap.Int64("segment ID", segID), - zap.Int64("collection ID", collID), - zap.Int64("partition ID", partitionID), - zap.String("channel name", channelName)) - - if collID != replica.collectionID { - log.Warn("Mismatch collection", - zap.Int64("expected collectionID", replica.collectionID)) - return fmt.Errorf("mismatch collection, ID=%d", collID) - } - - log.Info("Add Flushed segment") - - seg := &Segment{ - collectionID: collID, - partitionID: partitionID, - segmentID: segID, - channelName: channelName, - numRows: numOfRows, - } - - err := replica.initPKBloomFilter(seg, statsBinlogs, recoverTs) - if err != nil { - return err - } - - seg.isNew.Store(false) - seg.isFlushed.Store(true) - - replica.segMu.Lock() - replica.flushedSegments[segID] = seg - replica.segMu.Unlock() - - return nil -} - func (replica *SegmentReplica) initPKBloomFilter(s *Segment, statsBinlogs []*datapb.FieldBinlog, ts Timestamp) error { log := log.With(zap.Int64("segmentID", s.segmentID)) log.Info("begin to init pk bloom filter", zap.Int("stats bin logs", len(statsBinlogs))) @@ -745,6 +696,8 @@ func (replica *SegmentReplica) getCollectionSchema(collID UniqueID, ts Timestamp return nil, fmt.Errorf("not supported collection %v", collID) } + replica.segMu.Lock() + defer replica.segMu.Unlock() if replica.collSchema == nil { sch, err := replica.metaService.getCollectionSchema(context.Background(), collID, ts) if err != nil { diff --git a/internal/datanode/segment_replica_test.go b/internal/datanode/segment_replica_test.go index 4447dc4f2c8c4..d3988a399ce99 100644 --- a/internal/datanode/segment_replica_test.go +++ b/internal/datanode/segment_replica_test.go @@ -185,8 +185,16 @@ func TestSegmentReplica(t *testing.T) { t.Run("Test coll mot match", func(t *testing.T) { replica, err := newReplica(context.Background(), rc, cm, collID) assert.Nil(t, err) - - err = replica.addNewSegment(1, collID+1, 0, "", nil, nil) + err = replica.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 1, + collID: collID + 1, + partitionID: 0, + channelName: "", + startPos: nil, + endPos: nil, + }) assert.NotNil(t, err) }) @@ -322,8 +330,16 @@ func TestSegmentReplica_InterfaceMethod(t *testing.T) { sr, err := newReplica(context.Background(), rc, cm, test.replicaCollID) assert.Nil(t, err) require.False(t, sr.hasSegment(test.inSegID, true)) - err = sr.addNewSegment(test.inSegID, - test.inCollID, 1, "", test.instartPos, &internalpb.MsgPosition{}) + err = sr.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: test.inSegID, + collID: test.inCollID, + partitionID: 1, + channelName: "", + startPos: test.instartPos, + endPos: &internalpb.MsgPosition{}, + }) if test.isValidCase { assert.NoError(t, err) assert.True(t, sr.hasSegment(test.inSegID, true)) @@ -358,7 +374,18 @@ func TestSegmentReplica_InterfaceMethod(t *testing.T) { sr, err := newReplica(context.Background(), rc, &mockDataCM{}, test.replicaCollID) assert.Nil(t, err) require.False(t, sr.hasSegment(test.inSegID, true)) - err = sr.addNormalSegment(test.inSegID, test.inCollID, 1, "", 0, []*datapb.FieldBinlog{getSimpleFieldBinlog()}, &segmentCheckPoint{}, 0) + err = sr.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_Normal, + segID: test.inSegID, + collID: test.inCollID, + partitionID: 1, + channelName: "", + numOfRows: 0, + statsBinLogs: []*datapb.FieldBinlog{getSimpleFieldBinlog()}, + cp: &segmentCheckPoint{}, + recoverTs: 0, + }) if test.isValidCase { assert.NoError(t, err) assert.True(t, sr.hasSegment(test.inSegID, true)) @@ -378,7 +405,18 @@ func TestSegmentReplica_InterfaceMethod(t *testing.T) { segID := int64(101) require.False(t, sr.hasSegment(segID, true)) assert.NotPanics(t, func() { - err = sr.addNormalSegment(segID, 1, 10, "empty_dml_chan", 0, []*datapb.FieldBinlog{}, nil, 0) + err = sr.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_Normal, + segID: segID, + collID: 1, + partitionID: 10, + channelName: "empty_dml_chan", + numOfRows: 0, + statsBinLogs: []*datapb.FieldBinlog{}, + cp: nil, + recoverTs: 0, + }) assert.NoError(t, err) }) }) @@ -617,9 +655,30 @@ func TestSegmentReplica_InterfaceMethod(t *testing.T) { cpPos := &internalpb.MsgPosition{ChannelName: "insert-01", Timestamp: Timestamp(10)} cp := &segmentCheckPoint{int64(10), *cpPos} - err = sr.addNormalSegment(1, 1, 2, "insert-01", int64(10), []*datapb.FieldBinlog{getSimpleFieldBinlog()}, cp, 0) + err = sr.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_Normal, + segID: 1, + collID: 1, + partitionID: 2, + channelName: "insert-01", + numOfRows: int64(10), + statsBinLogs: []*datapb.FieldBinlog{getSimpleFieldBinlog()}, + cp: cp, + recoverTs: 0, + }) assert.NotNil(t, err) - err = sr.addFlushedSegment(1, 1, 2, "insert-01", int64(0), []*datapb.FieldBinlog{getSimpleFieldBinlog()}, 0) + err = sr.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_Flushed, + segID: 1, + collID: 1, + partitionID: 2, + channelName: "insert-01", + numOfRows: int64(0), + statsBinLogs: []*datapb.FieldBinlog{getSimpleFieldBinlog()}, + recoverTs: 0, + }) assert.NotNil(t, err) }) @@ -630,9 +689,30 @@ func TestSegmentReplica_InterfaceMethod(t *testing.T) { cpPos := &internalpb.MsgPosition{ChannelName: "insert-01", Timestamp: Timestamp(10)} cp := &segmentCheckPoint{int64(10), *cpPos} - err = sr.addNormalSegment(1, 1, 2, "insert-01", int64(10), []*datapb.FieldBinlog{getSimpleFieldBinlog()}, cp, 0) + err = sr.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_Normal, + segID: 1, + collID: 1, + partitionID: 2, + channelName: "insert-01", + numOfRows: int64(10), + statsBinLogs: []*datapb.FieldBinlog{getSimpleFieldBinlog()}, + cp: cp, + recoverTs: 0, + }) assert.NotNil(t, err) - err = sr.addFlushedSegment(1, 1, 2, "insert-01", int64(0), []*datapb.FieldBinlog{getSimpleFieldBinlog()}, 0) + err = sr.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_Flushed, + segID: 1, + collID: 1, + partitionID: 2, + channelName: "insert-01", + numOfRows: int64(0), + statsBinLogs: []*datapb.FieldBinlog{getSimpleFieldBinlog()}, + recoverTs: 0, + }) assert.NotNil(t, err) }) @@ -643,9 +723,30 @@ func TestSegmentReplica_InterfaceMethod(t *testing.T) { cpPos := &internalpb.MsgPosition{ChannelName: "insert-01", Timestamp: Timestamp(10)} cp := &segmentCheckPoint{int64(10), *cpPos} - err = sr.addNormalSegment(1, 1, 2, "insert-01", int64(10), []*datapb.FieldBinlog{getSimpleFieldBinlog()}, cp, 0) + err = sr.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_Normal, + segID: 1, + collID: 1, + partitionID: 2, + channelName: "insert-01", + numOfRows: int64(10), + statsBinLogs: []*datapb.FieldBinlog{getSimpleFieldBinlog()}, + cp: cp, + recoverTs: 0, + }) assert.NotNil(t, err) - err = sr.addFlushedSegment(1, 1, 2, "insert-01", int64(0), []*datapb.FieldBinlog{getSimpleFieldBinlog()}, 0) + err = sr.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_Flushed, + segID: 1, + collID: 1, + partitionID: 2, + channelName: "insert-01", + numOfRows: int64(0), + statsBinLogs: []*datapb.FieldBinlog{getSimpleFieldBinlog()}, + recoverTs: 0, + }) assert.NotNil(t, err) }) @@ -698,7 +799,16 @@ func TestInnerFunctionSegment(t *testing.T) { startPos := &internalpb.MsgPosition{ChannelName: "insert-01", Timestamp: Timestamp(100)} endPos := &internalpb.MsgPosition{ChannelName: "insert-01", Timestamp: Timestamp(200)} - err = replica.addNewSegment(0, 1, 2, "insert-01", startPos, endPos) + err = replica.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 0, + collID: 1, + partitionID: 2, + channelName: "insert-01", + startPos: startPos, + endPos: endPos, + }) assert.NoError(t, err) assert.True(t, replica.hasSegment(0, true)) assert.Equal(t, 1, len(replica.newSegments)) @@ -723,7 +833,18 @@ func TestInnerFunctionSegment(t *testing.T) { cpPos := &internalpb.MsgPosition{ChannelName: "insert-01", Timestamp: Timestamp(10)} cp := &segmentCheckPoint{int64(10), *cpPos} - err = replica.addNormalSegment(1, 1, 2, "insert-01", int64(10), []*datapb.FieldBinlog{getSimpleFieldBinlog()}, cp, 0) + err = replica.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_Normal, + segID: 1, + collID: 1, + partitionID: 2, + channelName: "insert-01", + numOfRows: int64(10), + statsBinLogs: []*datapb.FieldBinlog{getSimpleFieldBinlog()}, + cp: cp, + recoverTs: 0, + }) assert.NoError(t, err) assert.True(t, replica.hasSegment(1, true)) assert.Equal(t, 1, len(replica.normalSegments)) @@ -740,7 +861,18 @@ func TestInnerFunctionSegment(t *testing.T) { assert.False(t, seg.isNew.Load().(bool)) assert.False(t, seg.isFlushed.Load().(bool)) - err = replica.addNormalSegment(1, 100000, 2, "invalid", int64(0), []*datapb.FieldBinlog{getSimpleFieldBinlog()}, &segmentCheckPoint{}, 0) + err = replica.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_Normal, + segID: 1, + collID: 100000, + partitionID: 2, + channelName: "invalid", + numOfRows: int64(0), + statsBinLogs: []*datapb.FieldBinlog{getSimpleFieldBinlog()}, + cp: &segmentCheckPoint{}, + recoverTs: 0, + }) assert.Error(t, err) replica.updateStatistics(1, 10) @@ -775,7 +907,17 @@ func TestInnerFunctionSegment(t *testing.T) { replica.updateSegmentCheckPoint(1) assert.Equal(t, int64(20), replica.normalSegments[UniqueID(1)].checkPoint.numRows) - err = replica.addFlushedSegment(1, 1, 2, "insert-01", int64(0), []*datapb.FieldBinlog{getSimpleFieldBinlog()}, 0) + err = replica.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_Flushed, + segID: 1, + collID: 1, + partitionID: 2, + channelName: "insert-01", + numOfRows: int64(0), + statsBinLogs: []*datapb.FieldBinlog{getSimpleFieldBinlog()}, + recoverTs: 0, + }) assert.Nil(t, err) totalSegments := replica.filterSegments("insert-01", common.InvalidPartitionID) @@ -871,9 +1013,29 @@ func TestReplica_UpdatePKRange(t *testing.T) { assert.Nil(t, err) replica.chunkManager = &mockDataCM{} - err = replica.addNewSegment(1, collID, partID, chanName, startPos, endPos) + err = replica.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 1, + collID: collID, + partitionID: partID, + channelName: chanName, + startPos: startPos, + endPos: endPos, + }) assert.Nil(t, err) - err = replica.addNormalSegment(2, collID, partID, chanName, 100, []*datapb.FieldBinlog{getSimpleFieldBinlog()}, cp, 0) + err = replica.addSegment( + addSegmentReq{ + segType: datapb.SegmentType_Normal, + segID: 2, + collID: collID, + partitionID: partID, + channelName: chanName, + numOfRows: 100, + statsBinLogs: []*datapb.FieldBinlog{getSimpleFieldBinlog()}, + cp: cp, + recoverTs: 0, + }) assert.Nil(t, err) segNew := replica.newSegments[1] diff --git a/internal/distributed/datacoord/client/client.go b/internal/distributed/datacoord/client/client.go index c5de5b2916335..7ee9428603fcc 100644 --- a/internal/distributed/datacoord/client/client.go +++ b/internal/distributed/datacoord/client/client.go @@ -569,13 +569,40 @@ func (c *Client) ReleaseSegmentLock(ctx context.Context, req *datapb.ReleaseSegm return ret.(*commonpb.Status), err } -// AddSegment is the DataCoord client side code for AddSegment call. -func (c *Client) AddSegment(ctx context.Context, req *datapb.AddSegmentRequest) (*commonpb.Status, error) { +// SaveImportSegment is the DataCoord client side code for SaveImportSegment call. +func (c *Client) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) (*commonpb.Status, error) { ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).AddSegment(ctx, req) + return client.(datapb.DataCoordClient).SaveImportSegment(ctx, req) + }) + if err != nil || ret == nil { + return nil, err + } + return ret.(*commonpb.Status), err +} + +// CompleteBulkLoad is the DataCoord client side code for CompleteBulkLoad call. +func (c *Client) CompleteBulkLoad(ctx context.Context, req *datapb.CompleteBulkLoadRequest) (*commonpb.Status, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + if !funcutil.CheckCtxValid(ctx) { + return nil, ctx.Err() + } + return client.(datapb.DataCoordClient).CompleteBulkLoad(ctx, req) + }) + if err != nil || ret == nil { + return nil, err + } + return ret.(*commonpb.Status), err +} + +func (c *Client) UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + if !funcutil.CheckCtxValid(ctx) { + return nil, ctx.Err() + } + return client.(datapb.DataCoordClient).UnsetIsImportingState(ctx, req) }) if err != nil || ret == nil { return nil, err diff --git a/internal/distributed/datacoord/client/client_test.go b/internal/distributed/datacoord/client/client_test.go index 164b9a824d789..12bd7b2bec03d 100644 --- a/internal/distributed/datacoord/client/client_test.go +++ b/internal/distributed/datacoord/client/client_test.go @@ -133,11 +133,17 @@ func Test_NewClient(t *testing.T) { r26, err := client.ReleaseSegmentLock(ctx, nil) retCheck(retNotNil, r26, err) - r27, err := client.AddSegment(ctx, nil) + r27, err := client.SaveImportSegment(ctx, nil) retCheck(retNotNil, r27, err) - r28, err := client.ShowConfigurations(ctx, nil) + r28, err := client.CompleteBulkLoad(ctx, nil) retCheck(retNotNil, r28, err) + + r29, err := client.UnsetIsImportingState(ctx, nil) + retCheck(retNotNil, r29, err) + + r30, err := client.ShowConfigurations(ctx, nil) + retCheck(retNotNil, r30, err) } client.grpcClient = &mock.GRPCClientBase{ diff --git a/internal/distributed/datacoord/service.go b/internal/distributed/datacoord/service.go index 8759e50ba359f..2c9058f78e417 100644 --- a/internal/distributed/datacoord/service.go +++ b/internal/distributed/datacoord/service.go @@ -377,6 +377,18 @@ func (s *Server) ReleaseSegmentLock(ctx context.Context, req *datapb.ReleaseSegm return s.dataCoord.ReleaseSegmentLock(ctx, req) } -func (s *Server) AddSegment(ctx context.Context, request *datapb.AddSegmentRequest) (*commonpb.Status, error) { - return s.dataCoord.AddSegment(ctx, request) +// SaveImportSegment saves the import segment binlog paths data and then looks for the right DataNode to add the +// segment to that DataNode. +func (s *Server) SaveImportSegment(ctx context.Context, request *datapb.SaveImportSegmentRequest) (*commonpb.Status, error) { + return s.dataCoord.SaveImportSegment(ctx, request) +} + +// CompleteBulkLoad is the distributed caller of CompleteBulkLoad. +func (s *Server) CompleteBulkLoad(ctx context.Context, request *datapb.CompleteBulkLoadRequest) (*commonpb.Status, error) { + return s.dataCoord.CompleteBulkLoad(ctx, request) +} + +// UnsetIsImportingState is the distributed caller of UnsetIsImportingState. +func (s *Server) UnsetIsImportingState(ctx context.Context, request *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { + return s.dataCoord.UnsetIsImportingState(ctx, request) } diff --git a/internal/distributed/datacoord/service_test.go b/internal/distributed/datacoord/service_test.go index c3ddd13188503..e3284add8cf68 100644 --- a/internal/distributed/datacoord/service_test.go +++ b/internal/distributed/datacoord/service_test.go @@ -33,37 +33,39 @@ import ( /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// type MockDataCoord struct { - states *internalpb.ComponentStates - status *commonpb.Status - err error - initErr error - startErr error - stopErr error - regErr error - strResp *milvuspb.StringResponse - infoResp *datapb.GetSegmentInfoResponse - flushResp *datapb.FlushResponse - assignResp *datapb.AssignSegmentIDResponse - segStateResp *datapb.GetSegmentStatesResponse - binResp *datapb.GetInsertBinlogPathsResponse - colStatResp *datapb.GetCollectionStatisticsResponse - partStatResp *datapb.GetPartitionStatisticsResponse - recoverResp *datapb.GetRecoveryInfoResponse - flushSegResp *datapb.GetFlushedSegmentsResponse + states *internalpb.ComponentStates + status *commonpb.Status + err error + initErr error + startErr error + stopErr error + regErr error + strResp *milvuspb.StringResponse + infoResp *datapb.GetSegmentInfoResponse + flushResp *datapb.FlushResponse + assignResp *datapb.AssignSegmentIDResponse + segStateResp *datapb.GetSegmentStatesResponse + binResp *datapb.GetInsertBinlogPathsResponse + colStatResp *datapb.GetCollectionStatisticsResponse + partStatResp *datapb.GetPartitionStatisticsResponse + recoverResp *datapb.GetRecoveryInfoResponse + flushSegResp *datapb.GetFlushedSegmentsResponse configResp *internalpb.ShowConfigurationsResponse - metricResp *milvuspb.GetMetricsResponse - compactionStateResp *milvuspb.GetCompactionStateResponse - manualCompactionResp *milvuspb.ManualCompactionResponse - compactionPlansResp *milvuspb.GetCompactionPlansResponse - watchChannelsResp *datapb.WatchChannelsResponse - getFlushStateResp *milvuspb.GetFlushStateResponse - dropVChanResp *datapb.DropVirtualChannelResponse - setSegmentStateResp *datapb.SetSegmentStateResponse - importResp *datapb.ImportTaskResponse - updateSegStatResp *commonpb.Status - acquireSegLockResp *commonpb.Status - releaseSegLockResp *commonpb.Status - addSegmentResp *commonpb.Status + metricResp *milvuspb.GetMetricsResponse + compactionStateResp *milvuspb.GetCompactionStateResponse + manualCompactionResp *milvuspb.ManualCompactionResponse + compactionPlansResp *milvuspb.GetCompactionPlansResponse + watchChannelsResp *datapb.WatchChannelsResponse + getFlushStateResp *milvuspb.GetFlushStateResponse + dropVChanResp *datapb.DropVirtualChannelResponse + setSegmentStateResp *datapb.SetSegmentStateResponse + importResp *datapb.ImportTaskResponse + updateSegStatResp *commonpb.Status + acquireSegLockResp *commonpb.Status + releaseSegLockResp *commonpb.Status + addSegmentResp *commonpb.Status + completeBulkLoadResp *commonpb.Status + unsetIsImportingStateResp *commonpb.Status } func (m *MockDataCoord) Init() error { @@ -200,10 +202,18 @@ func (m *MockDataCoord) ReleaseSegmentLock(ctx context.Context, req *datapb.Rele return m.releaseSegLockResp, m.err } -func (m *MockDataCoord) AddSegment(ctx context.Context, req *datapb.AddSegmentRequest) (*commonpb.Status, error) { +func (m *MockDataCoord) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) (*commonpb.Status, error) { return m.addSegmentResp, m.err } +func (m *MockDataCoord) CompleteBulkLoad(context.Context, *datapb.CompleteBulkLoadRequest) (*commonpb.Status, error) { + return m.completeBulkLoadResp, m.err +} + +func (m *MockDataCoord) UnsetIsImportingState(context.Context, *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { + return m.unsetIsImportingStateResp, m.err +} + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// func Test_NewServer(t *testing.T) { ctx := context.Background() @@ -471,13 +481,35 @@ func Test_NewServer(t *testing.T) { assert.NotNil(t, resp) }) - t.Run("add segment", func(t *testing.T) { + t.Run("save import segment", func(t *testing.T) { server.dataCoord = &MockDataCoord{ addSegmentResp: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, } - resp, err := server.AddSegment(ctx, nil) + resp, err := server.SaveImportSegment(ctx, nil) + assert.Nil(t, err) + assert.NotNil(t, resp) + }) + + t.Run("complete bulk load", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + completeBulkLoadResp: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + } + resp, err := server.CompleteBulkLoad(ctx, nil) + assert.Nil(t, err) + assert.NotNil(t, resp) + }) + + t.Run("unset isImporting state", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + unsetIsImportingStateResp: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + } + resp, err := server.UnsetIsImportingState(ctx, nil) assert.Nil(t, err) assert.NotNil(t, resp) }) diff --git a/internal/distributed/datanode/client/client.go b/internal/distributed/datanode/client/client.go index 2ea6fcb2804c6..baa8e0736acf0 100644 --- a/internal/distributed/datanode/client/client.go +++ b/internal/distributed/datanode/client/client.go @@ -245,13 +245,13 @@ func (c *Client) ResendSegmentStats(ctx context.Context, req *datapb.ResendSegme return ret.(*datapb.ResendSegmentStatsResponse), err } -// AddSegment is the DataNode client side code for AddSegment call. -func (c *Client) AddSegment(ctx context.Context, req *datapb.AddSegmentRequest) (*commonpb.Status, error) { +// AddImportSegment is the DataNode client side code for AddImportSegment call. +func (c *Client) AddImportSegment(ctx context.Context, req *datapb.AddImportSegmentRequest) (*commonpb.Status, error) { ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataNodeClient).AddSegment(ctx, req) + return client.(datapb.DataNodeClient).AddImportSegment(ctx, req) }) if err != nil || ret == nil { return nil, err diff --git a/internal/distributed/datanode/client/client_test.go b/internal/distributed/datanode/client/client_test.go index 850509c3e9591..54aec65e8a296 100644 --- a/internal/distributed/datanode/client/client_test.go +++ b/internal/distributed/datanode/client/client_test.go @@ -83,7 +83,7 @@ func Test_NewClient(t *testing.T) { r8, err := client.ResendSegmentStats(ctx, nil) retCheck(retNotNil, r8, err) - r9, err := client.AddSegment(ctx, nil) + r9, err := client.AddImportSegment(ctx, nil) retCheck(retNotNil, r9, err) r10, err := client.ShowConfigurations(ctx, nil) diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index 54700a7d72395..a53735efe1f89 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -378,6 +378,6 @@ func (s *Server) ResendSegmentStats(ctx context.Context, request *datapb.ResendS return s.datanode.ResendSegmentStats(ctx, request) } -func (s *Server) AddSegment(ctx context.Context, request *datapb.AddSegmentRequest) (*commonpb.Status, error) { - return s.datanode.AddSegment(ctx, request) +func (s *Server) AddImportSegment(ctx context.Context, request *datapb.AddImportSegmentRequest) (*commonpb.Status, error) { + return s.datanode.AddImportSegment(ctx, request) } diff --git a/internal/distributed/datanode/service_test.go b/internal/distributed/datanode/service_test.go index ca8a7e224f18e..10eee1bfa8de1 100644 --- a/internal/distributed/datanode/service_test.go +++ b/internal/distributed/datanode/service_test.go @@ -130,7 +130,7 @@ func (m *MockDataNode) ResendSegmentStats(ctx context.Context, req *datapb.Resen return m.resendResp, m.err } -func (m *MockDataNode) AddSegment(ctx context.Context, req *datapb.AddSegmentRequest) (*commonpb.Status, error) { +func (m *MockDataNode) AddImportSegment(ctx context.Context, req *datapb.AddImportSegmentRequest) (*commonpb.Status, error) { return m.status, m.err } @@ -300,7 +300,7 @@ func Test_NewServer(t *testing.T) { server.datanode = &MockDataNode{ status: &commonpb.Status{}, } - resp, err := server.AddSegment(ctx, nil) + resp, err := server.AddImportSegment(ctx, nil) assert.Nil(t, err) assert.NotNil(t, resp) }) diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index 2c3dccd199e5e..e01c946c7dbbe 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -29,16 +29,9 @@ import ( "sync" "time" - "google.golang.org/grpc/credentials" - - grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" - - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "github.com/gin-gonic/gin" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" "github.com/opentracing/opentracing-go" clientv3 "go.etcd.io/etcd/client/v3" @@ -47,8 +40,6 @@ import ( "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/keepalive" - "github.com/milvus-io/milvus/api/commonpb" - "github.com/milvus-io/milvus/api/milvuspb" dcc "github.com/milvus-io/milvus/internal/distributed/datacoord/client" icc "github.com/milvus-io/milvus/internal/distributed/indexcoord/client" "github.com/milvus-io/milvus/internal/distributed/proxy/httpserver" @@ -66,6 +57,15 @@ import ( "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/typeutil" + "github.com/opentracing/opentracing-go" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/status" ) var Params paramtable.GrpcServerConfig diff --git a/internal/distributed/proxy/service_test.go b/internal/distributed/proxy/service_test.go index e8cd41f75d6b5..f7b7b89f3e76f 100644 --- a/internal/distributed/proxy/service_test.go +++ b/internal/distributed/proxy/service_test.go @@ -262,6 +262,14 @@ func (m *MockRootCoord) ListPolicy(ctx context.Context, in *internalpb.ListPolic return nil, nil } +func (m *MockRootCoord) GetImportFailedSegmentIDs(ctx context.Context, req *internalpb.GetImportFailedSegmentIDsRequest) (*internalpb.GetImportFailedSegmentIDsResponse, error) { + return nil, nil +} + +func (m *MockRootCoord) CheckSegmentIndexReady(context.Context, *internalpb.CheckSegmentIndexReadyRequest) (*commonpb.Status, error) { + return nil, nil +} + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// type MockIndexCoord struct { MockBase @@ -466,7 +474,15 @@ func (m *MockDataCoord) Flush(ctx context.Context, req *datapb.FlushRequest) (*d return nil, nil } -func (m *MockDataCoord) AddSegment(ctx context.Context, req *datapb.AddSegmentRequest) (*commonpb.Status, error) { +func (m *MockDataCoord) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) (*commonpb.Status, error) { + return nil, nil +} + +func (m *MockDataCoord) CompleteBulkLoad(ctx context.Context, req *datapb.CompleteBulkLoadRequest) (*commonpb.Status, error) { + return nil, nil +} + +func (m *MockDataCoord) UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { return nil, nil } diff --git a/internal/distributed/rootcoord/client/client.go b/internal/distributed/rootcoord/client/client.go index 6d89979467b66..ebcca447f6771 100644 --- a/internal/distributed/rootcoord/client/client.go +++ b/internal/distributed/rootcoord/client/client.go @@ -652,3 +652,29 @@ func (c *Client) ListPolicy(ctx context.Context, req *internalpb.ListPolicyReque } return ret.(*internalpb.ListPolicyResponse), err } + +func (c *Client) GetImportFailedSegmentIDs(ctx context.Context, req *internalpb.GetImportFailedSegmentIDsRequest) (*internalpb.GetImportFailedSegmentIDsResponse, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + if !funcutil.CheckCtxValid(ctx) { + return nil, ctx.Err() + } + return client.(rootcoordpb.RootCoordClient).GetImportFailedSegmentIDs(ctx, req) + }) + if err != nil { + return nil, err + } + return ret.(*internalpb.GetImportFailedSegmentIDsResponse), err +} + +func (c *Client) CheckSegmentIndexReady(ctx context.Context, req *internalpb.CheckSegmentIndexReadyRequest) (*commonpb.Status, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + if !funcutil.CheckCtxValid(ctx) { + return nil, ctx.Err() + } + return client.(rootcoordpb.RootCoordClient).CheckSegmentIndexReady(ctx, req) + }) + if err != nil { + return nil, err + } + return ret.(*commonpb.Status), err +} diff --git a/internal/distributed/rootcoord/client/client_test.go b/internal/distributed/rootcoord/client/client_test.go index dc1b683d0fc81..8363c326ef70b 100644 --- a/internal/distributed/rootcoord/client/client_test.go +++ b/internal/distributed/rootcoord/client/client_test.go @@ -411,6 +411,12 @@ func Test_NewClient(t *testing.T) { retCheck(rTimeout, err) } + r37Timeout, err := client.GetImportFailedSegmentIDs(shortCtx, nil) + retCheck(r37Timeout, err) + + r38Timeout, err := client.CheckSegmentIndexReady(shortCtx, nil) + retCheck(r38Timeout, err) + // clean up err = client.Stop() assert.Nil(t, err) diff --git a/internal/distributed/rootcoord/service.go b/internal/distributed/rootcoord/service.go index 5cd472217ac0a..48ec7dd552d8f 100644 --- a/internal/distributed/rootcoord/service.go +++ b/internal/distributed/rootcoord/service.go @@ -79,6 +79,14 @@ type Server struct { closer io.Closer } +func (s *Server) GetImportFailedSegmentIDs(ctx context.Context, request *internalpb.GetImportFailedSegmentIDsRequest) (*internalpb.GetImportFailedSegmentIDsResponse, error) { + return s.rootCoord.GetImportFailedSegmentIDs(ctx, request) +} + +func (s *Server) CheckSegmentIndexReady(ctx context.Context, req *internalpb.CheckSegmentIndexReadyRequest) (*commonpb.Status, error) { + return s.rootCoord.CheckSegmentIndexReady(ctx, req) +} + // CreateAlias creates an alias for specified collection. func (s *Server) CreateAlias(ctx context.Context, request *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { return s.rootCoord.CreateAlias(ctx, request) diff --git a/internal/kv/mock_kv.go b/internal/kv/mock_kv.go index 90680c424ccd7..0a9a912dcf5c0 100644 --- a/internal/kv/mock_kv.go +++ b/internal/kv/mock_kv.go @@ -17,7 +17,9 @@ package kv import ( + "errors" "strings" + "sync" "github.com/milvus-io/milvus/internal/log" clientv3 "go.etcd.io/etcd/client/v3" @@ -25,12 +27,12 @@ import ( ) type MockBaseKV struct { - InMemKv map[string]string + InMemKv sync.Map } func (m *MockBaseKV) Load(key string) (string, error) { - if val, ok := m.InMemKv[key]; ok { - return val, nil + if val, ok := m.InMemKv.Load(key); ok { + return val.(string), nil } return "", nil } @@ -44,7 +46,9 @@ func (m *MockBaseKV) LoadWithPrefix(key string) ([]string, []string, error) { } func (m *MockBaseKV) Save(key string, value string) error { - panic("not implemented") // TODO: Implement + m.InMemKv.Store(key, value) + log.Debug("doing Save", zap.String("key", key)) + return nil } func (m *MockBaseKV) MultiSave(kvs map[string]string) error { @@ -52,7 +56,9 @@ func (m *MockBaseKV) MultiSave(kvs map[string]string) error { } func (m *MockBaseKV) Remove(key string) error { - panic("not implemented") // TODO: Implement + m.InMemKv.Delete(key) + log.Debug("doing Remove", zap.String("key", key)) + return nil } func (m *MockBaseKV) MultiRemove(keys []string) error { @@ -85,6 +91,9 @@ func (m *MockTxnKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, remova type MockMetaKV struct { MockTxnKV + + LoadWithPrefixMockErr bool + SaveMockErr bool } func (m *MockMetaKV) GetPath(key string) string { @@ -92,14 +101,18 @@ func (m *MockMetaKV) GetPath(key string) string { } func (m *MockMetaKV) LoadWithPrefix(prefix string) ([]string, []string, error) { - keys := make([]string, 0, len(m.InMemKv)) - values := make([]string, 0, len(m.InMemKv)) - for k, v := range m.InMemKv { - if strings.HasPrefix(k, prefix) { - keys = append(keys, k) - values = append(values, v) - } + if m.LoadWithPrefixMockErr { + return nil, nil, errors.New("mock err") } + keys := make([]string, 0) + values := make([]string, 0) + m.InMemKv.Range(func(key, value interface{}) bool { + if strings.HasPrefix(key.(string), prefix) { + keys = append(keys, key.(string)) + values = append(values, value.(string)) + } + return true + }) return keys, values, nil } @@ -128,13 +141,16 @@ func (m *MockMetaKV) WatchWithRevision(key string, revision int64) clientv3.Watc } func (m *MockMetaKV) SaveWithLease(key, value string, id clientv3.LeaseID) error { - m.InMemKv[key] = value + m.InMemKv.Store(key, value) log.Debug("Doing SaveWithLease", zap.String("key", key)) return nil } func (m *MockMetaKV) SaveWithIgnoreLease(key, value string) error { - m.InMemKv[key] = value + if m.SaveMockErr { + return errors.New("mock error") + } + m.InMemKv.Store(key, value) log.Debug("Doing SaveWithIgnoreLease", zap.String("key", key)) return nil } diff --git a/internal/kv/mock_kv_test.go b/internal/kv/mock_kv_test.go index a8cc00d9ad5e8..66ad2f9ca63c1 100644 --- a/internal/kv/mock_kv_test.go +++ b/internal/kv/mock_kv_test.go @@ -17,6 +17,7 @@ package kv import ( + "sync" "testing" "github.com/stretchr/testify/assert" @@ -28,7 +29,7 @@ const testValue = "value" func TestMockKV_MetaKV(t *testing.T) { mockKv := &MockMetaKV{} - mockKv.InMemKv = make(map[string]string) + mockKv.InMemKv = sync.Map{} var err error value, err := mockKv.Load(testKey) @@ -42,17 +43,13 @@ func TestMockKV_MetaKV(t *testing.T) { _, _, err = mockKv.LoadWithPrefix(testKey) assert.NoError(t, err) - assert.Panics(t, func() { - mockKv.Save(testKey, testValue) - }) + assert.NoError(t, mockKv.Save(testKey, testValue)) assert.Panics(t, func() { mockKv.MultiSave(map[string]string{testKey: testValue}) }) - assert.Panics(t, func() { - mockKv.Remove(testKey) - }) + assert.NoError(t, mockKv.Remove(testKey)) assert.Panics(t, func() { mockKv.MultiRemove([]string{testKey}) diff --git a/internal/proto/common.proto b/internal/proto/common.proto index a678013ce7ac5..81a91045a3e42 100644 --- a/internal/proto/common.proto +++ b/internal/proto/common.proto @@ -160,6 +160,7 @@ enum MsgType { HandoffSegments = 254; LoadBalanceSegments = 255; DescribeSegments = 256; + GetImportFailedSegmentIDs = 257; /* DEFINITION REQUESTS: INDEX */ CreateIndex = 300; @@ -269,14 +270,11 @@ enum ConsistencyLevel { } enum ImportState { - ImportPending = 0; - ImportFailed = 1; - ImportStarted = 2; - ImportDownloaded = 3; - ImportParsed = 4; - ImportPersisted = 5; - ImportCompleted = 6; - ImportAllocSegment = 10; + ImportPending = 0; // the task in in pending list of rootCoord, waiting to be executed + ImportFailed = 1; // the task failed for some reason, get detail reason from GetImportStateResponse.infos + ImportStarted = 2; // the task has been sent to datanode to execute + ImportPersisted = 5; // all data files have been parsed and data already persisted + ImportCompleted = 6; // all indexes are successfully built and segments are able to be compacted as normal. } enum ObjectType { diff --git a/internal/proto/data_coord.proto b/internal/proto/data_coord.proto index be11543c9b8fe..be473ad352527 100644 --- a/internal/proto/data_coord.proto +++ b/internal/proto/data_coord.proto @@ -12,6 +12,12 @@ import "schema.proto"; // TODO: import google/protobuf/empty.proto message Empty {} +enum SegmentType { + New = 0; + Normal = 1; + Flushed = 2; +} + service DataCoord { rpc GetComponentStates(internal.GetComponentStatesRequest) returns (internal.ComponentStates) {} rpc GetTimeTickChannel(internal.GetTimeTickChannelRequest) returns(milvus.StringResponse) {} @@ -53,7 +59,9 @@ service DataCoord { rpc AcquireSegmentLock(AcquireSegmentLockRequest) returns (common.Status) {} rpc ReleaseSegmentLock(ReleaseSegmentLockRequest) returns (common.Status) {} - rpc AddSegment(AddSegmentRequest) returns(common.Status) {} + rpc SaveImportSegment(SaveImportSegmentRequest) returns(common.Status) {} + rpc CompleteBulkLoad(CompleteBulkLoadRequest) returns(common.Status) {} + rpc UnsetIsImportingState(UnsetIsImportingStateRequest) returns(common.Status) {} } service DataNode { @@ -75,7 +83,7 @@ service DataNode { rpc ResendSegmentStats(ResendSegmentStatsRequest) returns(ResendSegmentStatsResponse) {} - rpc AddSegment(AddSegmentRequest) returns(common.Status) {} + rpc AddImportSegment(AddImportSegmentRequest) returns(common.Status) {} } message FlushRequest { @@ -520,8 +528,8 @@ message ImportTaskInfo { repeated string files = 9; // A list of files to import. int64 create_ts = 10; // Timestamp when the import task is created. ImportTaskState state = 11; // State of the import task. - bool data_queryable = 12; // A flag indicating whether import data are queryable (i.e. loaded in query nodes) - bool data_indexed = 13; // A flag indicating whether import data are indexed. + string collection_name = 12; // Collection name for the import task. + string partition_name = 13; // Partition name for the import task. } message ImportTaskResponse { @@ -549,13 +557,38 @@ message ResendSegmentStatsResponse { repeated int64 seg_resent = 2; } -message AddSegmentRequest { +message AddImportSegmentRequest { + common.MsgBase base = 1; + int64 segment_id = 2; + string channel_name = 3; + int64 collection_id = 4; + int64 partition_id = 5; + int64 row_num = 6; + repeated FieldBinlog stats_log = 7; + bytes dml_position_id = 8; +} + +message SaveImportSegmentRequest { common.MsgBase base = 1; int64 segment_id = 2; string channel_name = 3; int64 collection_id = 4; int64 partition_id = 5; int64 row_num = 6; + SaveBinlogPathsRequest save_binlog_path_req = 7; + bytes dml_position_id = 8; +} + +message CompleteBulkLoadRequest { + common.MsgBase base = 1; + int64 task_id = 2; // ID of the import task that needs to be completed. + int64 collection_id = 3; // Collection ID of the import task. + repeated int64 segment_ids = 4; // Segment IDs of the bulk load task. +} + +message UnsetIsImportingStateRequest { + common.MsgBase base = 1; + repeated int64 segment_ids = 2; // IDs of segments whose `isImport` states need to be unset. } message SegmentReferenceLock { diff --git a/internal/proto/internal.proto b/internal/proto/internal.proto index b860c96f0038e..6fdbb05d58e63 100644 --- a/internal/proto/internal.proto +++ b/internal/proto/internal.proto @@ -319,6 +319,22 @@ message ShowConfigurationsRequest { string pattern = 2; } +message GetImportFailedSegmentIDsRequest{ + common.MsgBase base = 1; +} + +message CheckSegmentIndexReadyRequest{ + common.MsgBase base = 1; + int64 taskID = 2; + int64 colID = 3; + repeated int64 segIDs = 4; +} + +message GetImportFailedSegmentIDsResponse{ + common.Status status = 1; + repeated int64 segmentIDs = 2; +} + message ShowConfigurationsResponse { common.Status status = 1; repeated common.KeyValuePair configuations = 2; diff --git a/internal/proto/milvus.proto b/internal/proto/milvus.proto index ccf4d616f49f6..1030b7d0d571c 100644 --- a/internal/proto/milvus.proto +++ b/internal/proto/milvus.proto @@ -968,11 +968,14 @@ message GetImportStateResponse { repeated int64 id_list = 4; // auto generated ids if the primary key is autoid repeated common.KeyValuePair infos = 5; // more information about the task, progress percent, file path, failed reason, etc. int64 id = 6; // id of an import task - bool data_queryable = 7; // A flag indicating whether import data are queryable (i.e. loaded in query nodes) - bool data_indexed = 8; // A flag indicating whether import data are indexed. + int64 collection_id = 7; // collection ID of the import task. + repeated int64 segment_ids = 8; // a list of segment IDs created by the import task. + int64 create_ts = 9; // timestamp when the import task is created. } message ListImportTasksRequest { + string collection_name = 1; // target collection, list all tasks if the name is empty + int64 limit = 2; // maximum number of tasks returned, list all tasks if the value is 0 } message ListImportTasksResponse { diff --git a/internal/proto/root_coord.proto b/internal/proto/root_coord.proto index 15f85b5e16b47..6d5e2d6b6995a 100644 --- a/internal/proto/root_coord.proto +++ b/internal/proto/root_coord.proto @@ -133,6 +133,10 @@ service RootCoord { rpc OperatePrivilege(milvus.OperatePrivilegeRequest) returns (common.Status) {} rpc SelectGrant(milvus.SelectGrantRequest) returns (milvus.SelectGrantResponse) {} rpc ListPolicy(internal.ListPolicyRequest) returns (internal.ListPolicyResponse) {} + + // TODO: move import manager to datacoord to remove this rpc + rpc GetImportFailedSegmentIDs(internal.GetImportFailedSegmentIDsRequest) returns (internal.GetImportFailedSegmentIDsResponse) {} + rpc CheckSegmentIndexReady(internal.CheckSegmentIndexReadyRequest) returns (common.Status) {} } message AllocTimestampRequest { diff --git a/internal/proxy/data_coord_mock_test.go b/internal/proxy/data_coord_mock_test.go index 3bf84db49afdd..874c89bd46e81 100644 --- a/internal/proxy/data_coord_mock_test.go +++ b/internal/proxy/data_coord_mock_test.go @@ -114,7 +114,15 @@ func (coord *DataCoordMock) Flush(ctx context.Context, req *datapb.FlushRequest) panic("implement me") } -func (coord *DataCoordMock) AddSegment(ctx context.Context, req *datapb.AddSegmentRequest) (*commonpb.Status, error) { +func (coord *DataCoordMock) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) (*commonpb.Status, error) { + panic("implement me") +} + +func (coord *DataCoordMock) CompleteBulkLoad(context.Context, *datapb.CompleteBulkLoadRequest) (*commonpb.Status, error) { + panic("implement me") +} + +func (coord *DataCoordMock) UnsetIsImportingState(context.Context, *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { panic("implement me") } diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 13fd7d2af1333..76094035dd1a4 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -26,32 +26,36 @@ import ( "github.com/milvus-io/milvus/internal/util/errorutil" "github.com/milvus-io/milvus/internal/util" + "time" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" - + "github.com/golang/protobuf/proto" "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/golang/protobuf/proto" - "github.com/milvus-io/milvus/api/commonpb" - "github.com/milvus-io/milvus/api/milvuspb" + "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util" "github.com/milvus-io/milvus/internal/util/crypto" "github.com/milvus-io/milvus/internal/util/logutil" "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/timerecord" "github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/typeutil" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) const moduleName = "Proxy" +var CheckTaskPersistedInterval = 5 * time.Second +var CheckTaskPersistedWaitLimit = 300 * time.Second + // UpdateStateCode updates the state code of Proxy. func (node *Proxy) UpdateStateCode(code internalpb.StateCode) { node.stateCode.Store(code) @@ -3834,10 +3838,35 @@ func (node *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*mi resp.Status.Reason = err.Error() return resp, nil } + log.Info("import complete, now completing import", zap.Int64s("task IDs", respFromRC.GetTasks())) + for _, taskID := range respFromRC.GetTasks() { + go node.completeImport(taskID) + } return respFromRC, nil } -// GetImportState checks import task state from datanode +func (node *Proxy) completeImport(taskID int64) { + // First check if the import task has turned into persisted state. Returns an error status if not after retrying. + // This could take a few or tens of seconds. + getImportResp, err := node.checkImportTaskPersisted(taskID) + if err != nil { + log.Error("task not persisted yet after wait limit", + zap.Int64("wait limit (seconds)", int64(CheckTaskPersistedWaitLimit.Seconds())), + zap.Int64("task ID", taskID), + zap.Any("current task state", getImportResp.GetState())) + return + } + + // Start background process of complete bulk load operation. + // Ignoring complete bulk load return status. + node.dataCoord.CompleteBulkLoad(node.ctx, &datapb.CompleteBulkLoadRequest{ + TaskId: taskID, + CollectionId: getImportResp.GetCollectionId(), + SegmentIds: getImportResp.GetSegmentIds(), + }) +} + +// GetImportState checks import task state from RootCoord. func (node *Proxy) GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) { log.Info("received get import state request", zap.Int64("taskID", req.GetTask())) resp := &milvuspb.GetImportStateResponse{} @@ -4387,6 +4416,48 @@ func (node *Proxy) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.Refr }, nil } +// checkImportTaskPersisted starts a loop to periodically check if the import task becomes ImportState_ImportPersisted state. +// A non-nil error is returned if the import task was not in ImportState_ImportPersisted state. +func (node *Proxy) checkImportTaskPersisted(taskID int64) (*milvuspb.GetImportStateResponse, error) { + ticker := time.NewTicker(CheckTaskPersistedInterval) + defer ticker.Stop() + expireTicker := time.NewTicker(CheckTaskPersistedWaitLimit) + defer expireTicker.Stop() + var getImportResp *milvuspb.GetImportStateResponse + for { + select { + case <-node.ctx.Done(): + log.Info("(in check task persisted loop) context done, exiting CheckSegmentIndexReady loop") + return nil, errors.New("proxy node context done") + case <-ticker.C: + var err error + getImportResp, err = node.rootCoord.GetImportState(node.ctx, &milvuspb.GetImportStateRequest{Task: taskID}) + if err != nil { + log.Warn(fmt.Sprintf("an error occurred while completing bulk load %s", err.Error())) + return nil, err + } + if getImportResp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + log.Warn(fmt.Sprintf("an error occurred while completing bulk load %s", getImportResp.GetStatus().GetReason())) + return nil, errors.New(getImportResp.GetStatus().GetReason()) + } + if getImportResp.GetState() == commonpb.ImportState_ImportPersisted || + getImportResp.GetState() == commonpb.ImportState_ImportCompleted { + log.Info("import task persisted or already complete", + zap.Int64("task ID", getImportResp.GetId()), + zap.Any("task state", getImportResp.GetState())) + return getImportResp, nil + } + log.Debug("checking import task state", + zap.Int64("task ID", getImportResp.GetId()), + zap.Any("current state", getImportResp.GetState())) + case <-expireTicker.C: + log.Warn("(in check task persisted loop) task still not persisted", + zap.Int64("task ID", taskID)) + return nil, errors.New("task still not persisted, please try again later") + } + } +} + // SetRates limits the rates of requests. func (node *Proxy) SetRates(ctx context.Context, request *proxypb.SetRatesRequest) (*commonpb.Status, error) { log.Debug("SetRates", zap.String("role", typeutil.ProxyRole), zap.Any("rates", request.GetRates())) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 5f49057f1cfc9..7bdbfcd8a1216 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -421,6 +421,8 @@ func TestProxy(t *testing.T) { localMsg := true factory := dependency.NewDefaultFactory(localMsg) alias := "TestProxy" + CheckTaskPersistedInterval = 10 * time.Millisecond + CheckTaskPersistedWaitLimit = 100 * time.Millisecond Params.InitOnce() log.Info("Initialize parameter table of Proxy") @@ -1663,6 +1665,8 @@ func TestProxy(t *testing.T) { resp, err := proxy.Import(context.TODO(), req) assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) assert.Nil(t, err) + // Wait a bit for complete import to start. + time.Sleep(2 * time.Second) }) wg.Add(1) diff --git a/internal/proxy/rootcoord_mock_test.go b/internal/proxy/rootcoord_mock_test.go index 20f40af14fc8a..3a70e79f98daf 100644 --- a/internal/proxy/rootcoord_mock_test.go +++ b/internal/proxy/rootcoord_mock_test.go @@ -111,6 +111,16 @@ type RootCoordMock struct { lastTsMtx sync.Mutex } +func (coord *RootCoordMock) CheckSegmentIndexReady(ctx context.Context, req *internalpb.CheckSegmentIndexReadyRequest) (*commonpb.Status, error) { + //TODO implement me + panic("implement me") +} + +func (coord *RootCoordMock) GetImportFailedSegmentIDs(ctx context.Context, req *internalpb.GetImportFailedSegmentIDsRequest) (*internalpb.GetImportFailedSegmentIDsResponse, error) { + //TODO implement me + panic("implement me") +} + func (coord *RootCoordMock) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { code := coord.state.Load().(internalpb.StateCode) if code != internalpb.StateCode_Healthy { @@ -1129,6 +1139,7 @@ type ImportFunc func(ctx context.Context, req *milvuspb.ImportRequest) (*milvusp type DropCollectionFunc func(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error) type GetGetCredentialFunc func(ctx context.Context, req *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error) +type CheckSegmentIndexReadyFunc func(ctx context.Context, request *internalpb.CheckSegmentIndexReadyRequest) (*commonpb.Status, error) type mockRootCoord struct { types.RootCoord @@ -1140,6 +1151,7 @@ type mockRootCoord struct { ImportFunc DropCollectionFunc GetGetCredentialFunc + CheckSegmentIndexReadyFunc } func (m *mockRootCoord) GetCredential(ctx context.Context, request *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error) { @@ -1189,6 +1201,13 @@ func (m *mockRootCoord) ListPolicy(ctx context.Context, in *internalpb.ListPolic return &internalpb.ListPolicyResponse{}, nil } +func (m *mockRootCoord) CheckSegmentIndexReady(ctx context.Context, req *internalpb.CheckSegmentIndexReadyRequest) (*commonpb.Status, error) { + if m.CheckSegmentIndexReadyFunc != nil { + return m.CheckSegmentIndexReadyFunc(ctx, req) + } + return nil, errors.New("mock") +} + func newMockRootCoord() *mockRootCoord { return &mockRootCoord{} } diff --git a/internal/rootcoord/import_manager.go b/internal/rootcoord/import_manager.go index 6fcec6eb7abd2..c263cfe31190e 100644 --- a/internal/rootcoord/import_manager.go +++ b/internal/rootcoord/import_manager.go @@ -33,7 +33,6 @@ import ( "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" - "github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/typeutil" "go.uber.org/zap" ) @@ -82,29 +81,32 @@ type importManager struct { startOnce sync.Once - idAllocator func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) - callImportService func(ctx context.Context, req *datapb.ImportTaskRequest) *datapb.ImportTaskResponse - getCollectionName func(collID, partitionID typeutil.UniqueID) (string, string, error) + idAllocator func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) + callImportService func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) + getCollectionName func(collID, partitionID typeutil.UniqueID) (string, string, error) + callUnsetIsImportState func(taskID int64) error } // newImportManager helper function to create a importManager func newImportManager(ctx context.Context, client kv.MetaKv, idAlloc func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error), - importService func(ctx context.Context, req *datapb.ImportTaskRequest) *datapb.ImportTaskResponse, + importService func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error), + unsetIsImportState func(taskID int64) error, getCollectionName func(collID, partitionID typeutil.UniqueID) (string, string, error)) *importManager { mgr := &importManager{ - ctx: ctx, - taskStore: client, - pendingTasks: make([]*datapb.ImportTaskInfo, 0, MaxPendingCount), // currently task queue max size is 32 - workingTasks: make(map[int64]*datapb.ImportTaskInfo), - busyNodes: make(map[int64]bool), - pendingLock: sync.RWMutex{}, - workingLock: sync.RWMutex{}, - busyNodesLock: sync.RWMutex{}, - lastReqID: 0, - idAllocator: idAlloc, - callImportService: importService, - getCollectionName: getCollectionName, + ctx: ctx, + taskStore: client, + pendingTasks: make([]*datapb.ImportTaskInfo, 0, MaxPendingCount), // currently task queue max size is 32 + workingTasks: make(map[int64]*datapb.ImportTaskInfo), + busyNodes: make(map[int64]bool), + pendingLock: sync.RWMutex{}, + workingLock: sync.RWMutex{}, + busyNodesLock: sync.RWMutex{}, + lastReqID: 0, + idAllocator: idAlloc, + callImportService: importService, + callUnsetIsImportState: unsetIsImportState, + getCollectionName: getCollectionName, } return mgr } @@ -112,9 +114,14 @@ func newImportManager(ctx context.Context, client kv.MetaKv, func (m *importManager) init(ctx context.Context) { m.startOnce.Do(func() { // Read tasks from Etcd and save them as pending tasks or working tasks. - m.loadFromTaskStore() + if err := m.loadFromTaskStore(); err != nil { + log.Error("importManager init failed, read tasks from Etcd failed, about to panic") + panic(err) + } // Send out tasks to dataCoord. - m.sendOutTasks(ctx) + if err := m.sendOutTasks(ctx); err != nil { + log.Error("importManager init failed, send out tasks to dataCoord failed") + } }) } @@ -129,13 +136,19 @@ func (m *importManager) sendOutTasksLoop(wg *sync.WaitGroup) { log.Debug("import manager context done, exit check sendOutTasksLoop") return case <-ticker.C: - m.sendOutTasks(m.ctx) + if err := m.sendOutTasks(m.ctx); err != nil { + log.Error("importManager sendOutTasksLoop fail to send out tasks") + } } } } -// expireOldTasksLoop starts a loop that checks and expires old tasks every `ImportTaskExpiration` seconds. -func (m *importManager) expireOldTasksLoop(wg *sync.WaitGroup, releaseLockFunc func(context.Context, int64, []int64) error) { +// expireOldTasksLoop starts a loop that checks and expires old tasks every `expireOldTasksInterval` seconds. +// There are two types of tasks to clean up: +// (1) pending tasks or working tasks that existed for over `ImportTaskExpiration` seconds, these tasks will be +// removed from memory. +// (2) any import tasks that has been created over `ImportTaskRetention` seconds ago, these tasks will be removed from Etcd. +func (m *importManager) expireOldTasksLoop(wg *sync.WaitGroup) { defer wg.Done() ticker := time.NewTicker(time.Duration(expireOldTasksInterval) * time.Millisecond) defer ticker.Stop() @@ -145,9 +158,8 @@ func (m *importManager) expireOldTasksLoop(wg *sync.WaitGroup, releaseLockFunc f log.Info("(in loop) import manager context done, exit expireOldTasksLoop") return case <-ticker.C: - log.Debug("(in loop) starting expiring old tasks...", - zap.Duration("cleaning up interval", time.Duration(expireOldTasksInterval)*time.Millisecond)) - m.expireOldTasks(releaseLockFunc) + m.expireOldTasksFromMem() + m.expireOldTasksFromEtcd() } } } @@ -161,6 +173,7 @@ func (m *importManager) sendOutTasks(ctx context.Context) error { // Trigger Import() action to DataCoord. for len(m.pendingTasks) > 0 { + log.Debug("try to send out pending tasks", zap.Int("task_number", len(m.pendingTasks))) task := m.pendingTasks[0] // TODO: Use ImportTaskInfo directly. it := &datapb.ImportTask{ @@ -185,7 +198,7 @@ func (m *importManager) sendOutTasks(ctx context.Context) error { } // Send import task to dataCoord, which will then distribute the import task to dataNode. - resp := m.callImportService(ctx, &datapb.ImportTaskRequest{ + resp, err := m.callImportService(ctx, &datapb.ImportTaskRequest{ ImportTask: it, WorkingNodes: busyNodeList, }) @@ -196,6 +209,10 @@ func (m *importManager) sendOutTasks(ctx context.Context) error { zap.String("cause", resp.GetStatus().GetReason())) break } + if err != nil { + log.Error("import task get error", zap.Error(err)) + break + } // Successfully assigned dataNode for the import task. Add task to working task list and update task store. task.DatanodeId = resp.GetDatanodeId() @@ -204,18 +221,25 @@ func (m *importManager) sendOutTasks(ctx context.Context) error { zap.Int64("dataNode ID", task.GetDatanodeId())) // Add new working dataNode to busyNodes. m.busyNodes[resp.GetDatanodeId()] = true - - func() { + err = func() error { m.workingLock.Lock() defer m.workingLock.Unlock() - log.Debug("import task added as working task", zap.Int64("task ID", it.TaskId)) - task.State.StateCode = commonpb.ImportState_ImportPending + task.State.StateCode = commonpb.ImportState_ImportStarted + // first update the import task into meta store and then put it into working tasks + if err := m.persistTaskInfo(task); err != nil { + log.Error("failed to update import task", + zap.Int64("task ID", task.GetId()), + zap.Error(err)) + return err + } m.workingTasks[task.GetId()] = task - m.updateImportTaskStore(task) + return nil }() - - // Erase this task from head of pending list. + if err != nil { + return err + } + // Remove this task from head of pending list. m.pendingTasks = append(m.pendingTasks[:0], m.pendingTasks[1:]...) } @@ -254,7 +278,7 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque zap.String("collection name", req.GetCollectionName()), zap.Int64("collection ID", cID), zap.Int64("partition ID", pID)) - err := func() (err error) { + err := func() error { m.pendingLock.Lock() defer m.pendingLock.Unlock() @@ -268,7 +292,7 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque // task queue size has a limit, return error if import request contains too many data files, and skip entire job if capacity-length < taskCount { - err = fmt.Errorf("import task queue max size is %v, currently there are %v tasks is pending. Not able to execute this request with %v tasks", capacity, length, taskCount) + err := fmt.Errorf("import task queue max size is %v, currently there are %v tasks is pending. Not able to execute this request with %v tasks", capacity, length, taskCount) log.Error(err.Error()) return err } @@ -302,16 +326,26 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque State: &datapb.ImportTaskState{ StateCode: commonpb.ImportState_ImportPending, }, - DataQueryable: false, - DataIndexed: false, } + + // Here no need to check error returned by setCollectionPartitionName(), + // since here we always return task list to client no matter something missed. + // We make the method setCollectionPartitionName() returns error + // because we need to make sure coverage all the code branch in unittest case. + m.setCollectionPartitionName(cID, pID, newTask) resp.Tasks = append(resp.Tasks, newTask.GetId()) taskList[i] = newTask.GetId() - log.Info("new task created as pending task", zap.Int64("task ID", newTask.GetId())) + log.Info("new task created as pending task", + zap.Int64("task ID", newTask.GetId())) + if err := m.persistTaskInfo(newTask); err != nil { + log.Error("failed to update import task", + zap.Int64("task ID", newTask.GetId()), + zap.Error(err)) + return err + } m.pendingTasks = append(m.pendingTasks, newTask) - m.storeImportTask(newTask) } - log.Info("row-based import request processed", zap.Any("taskIDs", taskList)) + log.Info("row-based import request processed", zap.Any("task IDs", taskList)) } else { // TODO: Merge duplicated code :( // for column-based, all files is a task @@ -331,14 +365,24 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque State: &datapb.ImportTaskState{ StateCode: commonpb.ImportState_ImportPending, }, - DataQueryable: false, - DataIndexed: false, } + // Here no need to check error returned by setCollectionPartitionName(), + // since here we always return task list to client no matter something missed. + // We make the method setCollectionPartitionName() returns error + // because we need to make sure coverage all the code branch in unittest case. + m.setCollectionPartitionName(cID, pID, newTask) resp.Tasks = append(resp.Tasks, newTask.GetId()) - log.Info("new task created as pending task", zap.Int64("task ID", newTask.GetId())) + log.Info("new task created as pending task", + zap.Int64("task ID", newTask.GetId())) + if err := m.persistTaskInfo(newTask); err != nil { + log.Error("failed to update import task", + zap.Int64("task ID", newTask.GetId()), + zap.Error(err)) + return err + } m.pendingTasks = append(m.pendingTasks, newTask) - m.storeImportTask(newTask) - log.Info("column-based import request processed", zap.Int64("taskID", newTask.GetId())) + log.Info("column-based import request processed", + zap.Int64("task ID", newTask.GetId())) } return nil }() @@ -350,30 +394,10 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque }, } } - m.sendOutTasks(ctx) - return resp -} - -// setTaskDataQueryable sets task's DataQueryable flag to true. -func (m *importManager) setTaskDataQueryable(taskID int64) { - m.workingLock.Lock() - defer m.workingLock.Unlock() - if v, ok := m.workingTasks[taskID]; ok { - v.DataQueryable = true - } else { - log.Error("task ID not found", zap.Int64("task ID", taskID)) - } -} - -// setTaskDataIndexed sets task's DataIndexed flag to true. -func (m *importManager) setTaskDataIndexed(taskID int64) { - m.workingLock.Lock() - defer m.workingLock.Unlock() - if v, ok := m.workingTasks[taskID]; ok { - v.DataIndexed = true - } else { - log.Error("task ID not found", zap.Int64("task ID", taskID)) + if sendOutTasksErr := m.sendOutTasks(ctx); sendOutTasksErr != nil { + log.Error("fail to send out tasks", zap.Error(sendOutTasksErr)) } + return resp } // updateTaskState updates the task's state in in-memory working tasks list and in task store, given ImportResult @@ -396,18 +420,26 @@ func (m *importManager) updateTaskState(ir *rootcoordpb.ImportResult) (*datapb.I return nil, errors.New("trying to update an already failed task " + strconv.FormatInt(ir.GetTaskId(), 10)) } found = true - v.State.StateCode = ir.GetState() - v.State.Segments = ir.GetSegments() - v.State.RowCount = ir.GetRowCount() - v.State.RowIds = ir.AutoIds + // Meta persist should be done before memory objs change. + toPersistImportTaskInfo := cloneImportTaskInfo(v) + toPersistImportTaskInfo.State.StateCode = ir.GetState() + toPersistImportTaskInfo.State.Segments = ir.GetSegments() + toPersistImportTaskInfo.State.RowCount = ir.GetRowCount() + toPersistImportTaskInfo.State.RowIds = ir.AutoIds for _, kv := range ir.GetInfos() { if kv.GetKey() == FailedReason { - v.State.ErrorMessage = kv.GetValue() + toPersistImportTaskInfo.State.ErrorMessage = kv.GetValue() break } } // Update task in task store. - m.updateImportTaskStore(v) + if err := m.persistTaskInfo(toPersistImportTaskInfo); err != nil { + log.Error("failed to update import task", + zap.Int64("task ID", v.GetId()), + zap.Error(err)) + return nil, err + } + m.workingTasks[ir.GetTaskId()] = toPersistImportTaskInfo } if !found { @@ -417,39 +449,81 @@ func (m *importManager) updateTaskState(ir *rootcoordpb.ImportResult) (*datapb.I return v, nil } -func (m *importManager) getCollectionPartitionName(task *datapb.ImportTaskInfo, resp *milvuspb.GetImportStateResponse) { - if m.getCollectionName != nil { - colName, partName, err := m.getCollectionName(task.GetCollectionId(), task.GetPartitionId()) - if err == nil { - resp.Infos = append(resp.Infos, &commonpb.KeyValuePair{Key: CollectionName, Value: colName}) - resp.Infos = append(resp.Infos, &commonpb.KeyValuePair{Key: PartitionName, Value: partName}) - } - } -} - -// appendTaskSegments updates the task's segment lists by adding `segIDs` to it. -func (m *importManager) appendTaskSegments(taskID int64, segIDs []int64) error { - log.Debug("import manager appending task segments", - zap.Int64("task ID", taskID), - zap.Int64s("segment ID", segIDs)) +// setCompleteImportState set the task state as `ImportState_ImportCompleted`. +func (m *importManager) setCompleteImportState(taskID int64) error { + log.Debug("trying to set import task as ImportState_ImportCompleted", zap.Int64("taskID", taskID)) + found := false var v *datapb.ImportTaskInfo m.workingLock.Lock() + defer m.workingLock.Unlock() ok := false if v, ok = m.workingTasks[taskID]; ok { - v.State.Segments = append(v.GetState().GetSegments(), segIDs...) + // If the task has already been marked failed. Prevent further state updating and return an error. + if v.GetState().GetStateCode() == commonpb.ImportState_ImportFailed { + return errors.New("trying to complete an already failed task " + strconv.FormatInt(taskID, 10)) + } + found = true + // Meta persist should be done before memory objs change. + toPersistImportTaskInfo := cloneImportTaskInfo(v) + toPersistImportTaskInfo.State.StateCode = commonpb.ImportState_ImportCompleted // Update task in task store. - m.updateImportTaskStore(v) + if err := m.persistTaskInfo(toPersistImportTaskInfo); err != nil { + return err + } + m.workingTasks[taskID] = toPersistImportTaskInfo } - m.workingLock.Unlock() - if !ok { - log.Debug("import manager appending task segments failed", zap.Int64("task ID", taskID)) - return errors.New("failed to update import task, ID not found: " + strconv.FormatInt(taskID, 10)) + if !found { + return errors.New("failed to complete import task, ID not found: " + strconv.FormatInt(taskID, 10)) } return nil } +func (m *importManager) setCollectionPartitionName(colID, partID int64, task *datapb.ImportTaskInfo) error { + if m.getCollectionName != nil { + colName, partName, err := m.getCollectionName(colID, partID) + if err == nil { + task.CollectionName = colName + task.PartitionName = partName + return nil + } else { + log.Error("failed to setCollectionPartitionName", + zap.Int64("collection ID", colID), + zap.Int64("partition ID", partID), + zap.Error(err)) + } + } + return errors.New("failed to setCollectionPartitionName for import task") +} + +func (m *importManager) copyTaskInfo(input *datapb.ImportTaskInfo, output *milvuspb.GetImportStateResponse) error { + if input == nil || output == nil { + log.Error("ImportTaskInfo or ImprtStateResponse object should not be null") + return errors.New("ImportTaskInfo or ImprtStateResponse object should not be null") + } + + output.Status = &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + } + output.Id = input.GetId() + output.CollectionId = input.GetCollectionId() + output.State = input.GetState().GetStateCode() + output.RowCount = input.GetState().GetRowCount() + output.IdList = input.GetState().GetRowIds() + output.SegmentIds = input.GetState().GetSegments() + output.CreateTs = input.GetCreateTs() + output.Infos = append(output.Infos, &commonpb.KeyValuePair{Key: Files, Value: strings.Join(input.GetFiles(), ",")}) + output.Infos = append(output.Infos, &commonpb.KeyValuePair{Key: CollectionName, Value: input.GetCollectionName()}) + output.Infos = append(output.Infos, &commonpb.KeyValuePair{Key: PartitionName, Value: input.GetPartitionName()}) + output.Infos = append(output.Infos, &commonpb.KeyValuePair{ + Key: FailedReason, + Value: input.GetState().GetErrorMessage(), + }) + + return nil +} + // getTaskState looks for task with the given ID and returns its import state. func (m *importManager) getTaskState(tID int64) *milvuspb.GetImportStateResponse { resp := &milvuspb.GetImportStateResponse{ @@ -464,22 +538,21 @@ func (m *importManager) getTaskState(tID int64) *milvuspb.GetImportStateResponse found := false func() { m.pendingLock.Lock() - defer m.pendingLock.Unlock() for _, t := range m.pendingTasks { if tID == t.Id { - resp.Status = &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - resp.Id = tID - resp.State = commonpb.ImportState_ImportPending - resp.Infos = append(resp.Infos, &commonpb.KeyValuePair{Key: Files, Value: strings.Join(t.GetFiles(), ",")}) - resp.DataQueryable = t.GetDataQueryable() - resp.DataIndexed = t.GetDataIndexed() - m.getCollectionPartitionName(t, resp) + m.copyTaskInfo(t, resp) + + // Release lock early to prevent deadlock. + m.pendingLock.Unlock() + found = true break } } + if !found { + // Release the lock. + m.pendingLock.Unlock() + } }() if found { return resp @@ -487,24 +560,15 @@ func (m *importManager) getTaskState(tID int64) *milvuspb.GetImportStateResponse func() { m.workingLock.Lock() - defer m.workingLock.Unlock() if v, ok := m.workingTasks[tID]; ok { found = true - resp.Status = &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - resp.Id = tID - resp.State = v.GetState().GetStateCode() - resp.RowCount = v.GetState().GetRowCount() - resp.IdList = v.GetState().GetRowIds() - resp.Infos = append(resp.Infos, &commonpb.KeyValuePair{Key: Files, Value: strings.Join(v.GetFiles(), ",")}) - resp.Infos = append(resp.Infos, &commonpb.KeyValuePair{ - Key: FailedReason, - Value: v.GetState().GetErrorMessage(), - }) - resp.DataQueryable = v.GetDataQueryable() - resp.DataIndexed = v.GetDataIndexed() - m.getCollectionPartitionName(v, resp) + m.copyTaskInfo(v, resp) + + // Release lock early to prevent deadlock. + m.workingLock.Unlock() + } + if !found { + m.workingLock.Unlock() } }() if found { @@ -545,71 +609,57 @@ func (m *importManager) loadFromTaskStore() error { return nil } -// storeImportTask signs a lease and saves import task info into Etcd with this lease. -func (m *importManager) storeImportTask(task *datapb.ImportTaskInfo) error { - log.Debug("saving import task to Etcd", zap.Int64("task ID", task.GetId())) - // Sign a lease. Tasks will be stored for at least `ImportTaskRetention` seconds. - leaseID, err := m.taskStore.Grant(int64(Params.RootCoordCfg.ImportTaskRetention)) - if err != nil { - log.Error("failed to grant lease from Etcd for data import", - zap.Int64("task ID", task.GetId()), +// persistTaskInfo stores or updates the import task info in Etcd. +func (m *importManager) persistTaskInfo(ti *datapb.ImportTaskInfo) error { + log.Info("updating import task info in Etcd", zap.Int64("task ID", ti.GetId())) + if taskInfo, err := proto.Marshal(ti); err != nil { + log.Error("failed to marshall task info proto", + zap.Int64("task ID", ti.GetId()), zap.Error(err)) return err - } - log.Debug("lease granted for task", zap.Int64("task ID", task.GetId())) - var taskInfo []byte - if taskInfo, err = proto.Marshal(task); err != nil { - log.Error("failed to marshall task proto", zap.Int64("task ID", task.GetId()), zap.Error(err)) - return err - } else if err = m.taskStore.SaveWithLease(BuildImportTaskKey(task.GetId()), string(taskInfo), leaseID); err != nil { - log.Error("failed to save import task info into Etcd", - zap.Int64("task ID", task.GetId()), + } else if err = m.taskStore.Save(BuildImportTaskKey(ti.GetId()), string(taskInfo)); err != nil { + log.Error("failed to update import task info in Etcd", + zap.Int64("task ID", ti.GetId()), zap.Error(err)) return err } - log.Debug("task info successfully saved", zap.Int64("task ID", task.GetId())) return nil } -// updateImportTaskStore updates the task info in Etcd according to task ID. It won't change the lease on the key. -func (m *importManager) updateImportTaskStore(ti *datapb.ImportTaskInfo) error { - log.Debug("updating import task info in Etcd", zap.Int64("Task ID", ti.GetId())) - if taskInfo, err := proto.Marshal(ti); err != nil { - log.Error("failed to marshall task info proto", zap.Int64("Task ID", ti.GetId()), zap.Error(err)) - return err - } else if err = m.taskStore.SaveWithIgnoreLease(BuildImportTaskKey(ti.GetId()), string(taskInfo)); err != nil { - log.Error("failed to update import task info info in Etcd", zap.Int64("Task ID", ti.GetId()), zap.Error(err)) +// yieldTaskInfo removes the task info from Etcd. +func (m *importManager) yieldTaskInfo(tID int64) error { + log.Info("removing import task info from Etcd", + zap.Int64("task ID", tID)) + if err := m.taskStore.Remove(BuildImportTaskKey(tID)); err != nil { + log.Error("failed to update import task info in Etcd", + zap.Int64("task ID", tID), + zap.Error(err)) return err } - log.Debug("task info successfully updated in Etcd", zap.Int64("Task ID", ti.GetId())) return nil } -// expireOldTasks marks expires tasks as failed. -func (m *importManager) expireOldTasks(releaseLockFunc func(context.Context, int64, []int64) error) { +// expireOldTasks removes expired tasks from memory. +func (m *importManager) expireOldTasksFromMem() { // Expire old pending tasks, if any. func() { m.pendingLock.Lock() defer m.pendingLock.Unlock() + index := 0 for _, t := range m.pendingTasks { if taskExpired(t) { - // Mark this expired task as failed. log.Info("a pending task has expired", zap.Int64("task ID", t.GetId())) - t.State.StateCode = commonpb.ImportState_ImportFailed - t.State.ErrorMessage = taskExpiredMsgPrefix + - (time.Duration(Params.RootCoordCfg.ImportTaskExpiration*1000) * time.Millisecond).String() - log.Info("releasing seg ref locks on expired import task", - zap.Int64s("segment IDs", t.GetState().GetSegments())) - err := retry.Do(m.ctx, func() error { - return releaseLockFunc(m.ctx, t.GetId(), t.GetState().GetSegments()) - }, retry.Attempts(100)) - if err != nil { - log.Error("failed to release lock, about to panic!") - panic(err) - } - m.updateImportTaskStore(t) + } else { + // Only keep non-expired tasks in memory. + m.pendingTasks[index] = t + index++ } } + // To prevent memory leak. + for i := index; i < len(m.pendingTasks); i++ { + m.pendingTasks[i] = nil + } + m.pendingTasks = m.pendingTasks[:index] }() // Expire old working tasks. func() { @@ -617,26 +667,51 @@ func (m *importManager) expireOldTasks(releaseLockFunc func(context.Context, int defer m.workingLock.Unlock() for _, v := range m.workingTasks { if taskExpired(v) { - // Mark this expired task as failed. log.Info("a working task has expired", zap.Int64("task ID", v.GetId())) - v.State.StateCode = commonpb.ImportState_ImportFailed - v.State.ErrorMessage = taskExpiredMsgPrefix + - (time.Duration(Params.RootCoordCfg.ImportTaskExpiration*1000) * time.Millisecond).String() - log.Info("releasing seg ref locks on expired import task", - zap.Int64s("segment IDs", v.GetState().GetSegments())) - err := retry.Do(m.ctx, func() error { - return releaseLockFunc(m.ctx, v.GetId(), v.GetState().GetSegments()) - }, retry.Attempts(100)) - if err != nil { - log.Error("failed to release lock, about to panic!") - panic(err) - } - m.updateImportTaskStore(v) + // Unset `isImport` flag of the bulk load segments. + taskID := v.GetId() + m.workingLock.Unlock() + m.callUnsetIsImportState(taskID) + // Re-lock. + m.workingLock.Lock() + // Remove this task from memory. + delete(m.workingTasks, v.GetId()) } } }() } +// expireOldTasksFromEtcd removes tasks from Etcd that are over `ImportTaskRetention` seconds old. +func (m *importManager) expireOldTasksFromEtcd() { + var vs []string + var err error + // Collect all import task records. + if _, vs, err = m.taskStore.LoadWithPrefix(Params.RootCoordCfg.ImportTaskSubPath); err != nil { + log.Error("failed to load import tasks from Etcd during task cleanup") + return + } + // Loop through all import tasks in Etcd and look for the ones that have passed retention period. + for _, val := range vs { + ti := &datapb.ImportTaskInfo{} + if err := proto.Unmarshal([]byte(val), ti); err != nil { + log.Error("failed to unmarshal proto", zap.String("taskInfo", val), zap.Error(err)) + // Ignore bad protos. This is just a cleanup task, so we are not panicking. + continue + } + if taskPastRetention(ti) { + log.Info("an import task has passed retention period and will be removed from Etcd", + zap.Int64("task ID", ti.GetId())) + // Unset `isImport` flag of the bulk load segments. + m.callUnsetIsImportState(ti.GetId()) + if err = m.yieldTaskInfo(ti.GetId()); err != nil { + log.Error("failed to remove import task from Etcd", + zap.Int64("task ID", ti.GetId()), + zap.Error(err)) + } + } + } +} + func rearrangeTasks(tasks []*milvuspb.GetImportStateResponse) { sort.Slice(tasks, func(i, j int) bool { return tasks[i].GetId() < tasks[j].GetId() @@ -650,18 +725,14 @@ func (m *importManager) listAllTasks() []*milvuspb.GetImportStateResponse { m.pendingLock.Lock() defer m.pendingLock.Unlock() for _, t := range m.pendingTasks { - resp := &milvuspb.GetImportStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - Infos: make([]*commonpb.KeyValuePair, 0), - Id: t.GetId(), - State: commonpb.ImportState_ImportPending, - DataQueryable: t.GetDataQueryable(), - DataIndexed: t.GetDataIndexed(), - } - resp.Infos = append(resp.Infos, &commonpb.KeyValuePair{Key: Files, Value: strings.Join(t.GetFiles(), ",")}) - m.getCollectionPartitionName(t, resp) + resp := &milvuspb.GetImportStateResponse{} + m.copyTaskInfo(t, resp) + + // Release lock early. + m.pendingLock.Unlock() + + // Re-lock. + m.pendingLock.Lock() tasks = append(tasks, resp) } log.Info("tasks in pending list", zap.Int("count", len(m.pendingTasks))) @@ -671,24 +742,14 @@ func (m *importManager) listAllTasks() []*milvuspb.GetImportStateResponse { m.workingLock.Lock() defer m.workingLock.Unlock() for _, v := range m.workingTasks { - resp := &milvuspb.GetImportStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - Infos: make([]*commonpb.KeyValuePair, 0), - Id: v.GetId(), - State: v.GetState().GetStateCode(), - RowCount: v.GetState().GetRowCount(), - IdList: v.GetState().GetRowIds(), - DataQueryable: v.GetDataQueryable(), - DataIndexed: v.GetDataIndexed(), - } - resp.Infos = append(resp.Infos, &commonpb.KeyValuePair{Key: Files, Value: strings.Join(v.GetFiles(), ",")}) - resp.Infos = append(resp.Infos, &commonpb.KeyValuePair{ - Key: FailedReason, - Value: v.GetState().GetErrorMessage(), - }) - m.getCollectionPartitionName(v, resp) + resp := &milvuspb.GetImportStateResponse{} + m.copyTaskInfo(v, resp) + + // Release lock early. + m.workingLock.Unlock() + + // Re-lock. + m.workingLock.Lock() tasks = append(tasks, resp) } log.Info("tasks in working list", zap.Int("count", len(m.workingTasks))) @@ -703,10 +764,49 @@ func BuildImportTaskKey(taskID int64) string { return fmt.Sprintf("%s%s%d", Params.RootCoordCfg.ImportTaskSubPath, delimiter, taskID) } -// taskExpired returns true if the task is considered expired. +// taskExpired returns true if the in-mem task is considered expired. func taskExpired(ti *datapb.ImportTaskInfo) bool { - return ti.GetState().GetStateCode() != commonpb.ImportState_ImportFailed && - ti.GetState().GetStateCode() != commonpb.ImportState_ImportPersisted && - ti.GetState().GetStateCode() != commonpb.ImportState_ImportCompleted && - Params.RootCoordCfg.ImportTaskExpiration <= float64(time.Now().Unix()-ti.GetCreateTs()) + return Params.RootCoordCfg.ImportTaskExpiration <= float64(time.Now().Unix()-ti.GetCreateTs()) +} + +// taskPastRetention returns true if the task is considered expired in Etcd. +func taskPastRetention(ti *datapb.ImportTaskInfo) bool { + return Params.RootCoordCfg.ImportTaskRetention <= float64(time.Now().Unix()-ti.GetCreateTs()) +} + +func (m *importManager) GetImportFailedSegmentIDs() ([]int64, error) { + ret := make([]int64, 0) + m.pendingLock.RLock() + for _, importTaskInfo := range m.pendingTasks { + if importTaskInfo.State.StateCode == commonpb.ImportState_ImportFailed { + ret = append(ret, importTaskInfo.State.Segments...) + } + } + m.pendingLock.RUnlock() + m.workingLock.RLock() + for _, importTaskInfo := range m.workingTasks { + if importTaskInfo.State.StateCode == commonpb.ImportState_ImportFailed { + ret = append(ret, importTaskInfo.State.Segments...) + } + } + m.workingLock.RUnlock() + return ret, nil +} + +func cloneImportTaskInfo(taskInfo *datapb.ImportTaskInfo) *datapb.ImportTaskInfo { + cloned := &datapb.ImportTaskInfo{ + Id: taskInfo.GetId(), + DatanodeId: taskInfo.GetDatanodeId(), + CollectionId: taskInfo.GetCollectionId(), + PartitionId: taskInfo.GetPartitionId(), + ChannelNames: taskInfo.GetChannelNames(), + Bucket: taskInfo.GetBucket(), + RowBased: taskInfo.GetRowBased(), + Files: taskInfo.GetFiles(), + CreateTs: taskInfo.GetCreateTs(), + State: taskInfo.GetState(), + CollectionName: taskInfo.GetCollectionName(), + PartitionName: taskInfo.GetPartitionName(), + } + return cloned } diff --git a/internal/rootcoord/import_manager_test.go b/internal/rootcoord/import_manager_test.go index 77857abdf1ce6..bce9e7fbb1d4f 100644 --- a/internal/rootcoord/import_manager_test.go +++ b/internal/rootcoord/import_manager_test.go @@ -49,54 +49,66 @@ func TestImportManager_NewImportManager(t *testing.T) { return globalCount, 0, nil } Params.RootCoordCfg.ImportTaskSubPath = "test_import_task" - Params.RootCoordCfg.ImportTaskExpiration = 100 + Params.RootCoordCfg.ImportTaskExpiration = 50 + Params.RootCoordCfg.ImportTaskRetention = 200 checkPendingTasksInterval = 100 expireOldTasksInterval = 100 mockKv := &kv.MockMetaKV{} - mockKv.InMemKv = make(map[string]string) + mockKv.InMemKv = sync.Map{} ti1 := &datapb.ImportTaskInfo{ Id: 100, State: &datapb.ImportTaskState{ - StateCode: commonpb.ImportState_ImportPending, + StateCode: commonpb.ImportState_ImportStarted, }, + CreateTs: time.Now().Unix() - 100, } ti2 := &datapb.ImportTaskInfo{ Id: 200, State: &datapb.ImportTaskState{ StateCode: commonpb.ImportState_ImportPersisted, }, + CreateTs: time.Now().Unix() - 100, } taskInfo1, err := proto.Marshal(ti1) assert.NoError(t, err) taskInfo2, err := proto.Marshal(ti2) assert.NoError(t, err) - mockKv.SaveWithLease(BuildImportTaskKey(1), "value", 1) - mockKv.SaveWithLease(BuildImportTaskKey(2), string(taskInfo1), 2) - mockKv.SaveWithLease(BuildImportTaskKey(3), string(taskInfo2), 3) - fn := func(ctx context.Context, req *datapb.ImportTaskRequest) *datapb.ImportTaskResponse { + mockKv.Save(BuildImportTaskKey(1), "value") + mockKv.Save(BuildImportTaskKey(100), string(taskInfo1)) + mockKv.Save(BuildImportTaskKey(200), string(taskInfo2)) + + mockCallImportServiceErr := false + callImportServiceFn := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { + if mockCallImportServiceErr { + return &datapb.ImportTaskResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + }, errors.New("mock err") + } return &datapb.ImportTaskResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, - } + }, nil + } + callUnsetImportingState := func(taskID int64) error { + return nil } - - time.Sleep(1 * time.Second) - var wg sync.WaitGroup wg.Add(1) t.Run("working task expired", func(t *testing.T) { defer wg.Done() ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - mgr := newImportManager(ctx, mockKv, idAlloc, fn, nil) + mgr := newImportManager(ctx, mockKv, idAlloc, callImportServiceFn, callUnsetImportingState, nil) assert.NotNil(t, mgr) - mgr.init(ctx) + assert.NoError(t, mgr.loadFromTaskStore()) var wgLoop sync.WaitGroup wgLoop.Add(2) - mgr.expireOldTasksLoop(&wgLoop, func(ctx context.Context, int64 int64, int64s []int64) error { - return nil - }) + assert.Equal(t, 2, len(mgr.workingTasks)) + mgr.expireOldTasksLoop(&wgLoop) + assert.Equal(t, 0, len(mgr.workingTasks)) mgr.sendOutTasksLoop(&wgLoop) wgLoop.Wait() }) @@ -106,38 +118,98 @@ func TestImportManager_NewImportManager(t *testing.T) { defer wg.Done() ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) defer cancel() - mgr := newImportManager(ctx, mockKv, idAlloc, fn, nil) + mgr := newImportManager(ctx, mockKv, idAlloc, callImportServiceFn, callUnsetImportingState, nil) assert.NotNil(t, mgr) mgr.init(context.TODO()) var wgLoop sync.WaitGroup wgLoop.Add(2) - mgr.expireOldTasksLoop(&wgLoop, func(ctx context.Context, int64 int64, int64s []int64) error { - return nil - }) + mgr.expireOldTasksLoop(&wgLoop) mgr.sendOutTasksLoop(&wgLoop) wgLoop.Wait() }) + wg.Add(1) + t.Run("importManager init fail because of loadFromTaskStore fail", func(t *testing.T) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + mgr := newImportManager(ctx, mockKv, idAlloc, callImportServiceFn, callUnsetImportingState, nil) + mockKv.LoadWithPrefixMockErr = true + defer func() { + mockKv.LoadWithPrefixMockErr = false + }() + assert.NotNil(t, mgr) + assert.Panics(t, func() { + mgr.init(context.TODO()) + }) + }) + + wg.Add(1) + t.Run("sendOutTasks fail", func(t *testing.T) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + mgr := newImportManager(ctx, mockKv, idAlloc, callImportServiceFn, callUnsetImportingState, nil) + mockKv.SaveMockErr = true + defer func() { + mockKv.SaveMockErr = false + }() + assert.NotNil(t, mgr) + mgr.init(context.TODO()) + }) + + wg.Add(1) + t.Run("sendOutTasks fail", func(t *testing.T) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + mgr := newImportManager(ctx, mockKv, idAlloc, callImportServiceFn, callUnsetImportingState, nil) + assert.NotNil(t, mgr) + mgr.init(context.TODO()) + func() { + mockKv.SaveMockErr = true + defer func() { + mockKv.SaveMockErr = false + }() + mgr.sendOutTasks(context.TODO()) + }() + + func() { + mockCallImportServiceErr = true + defer func() { + mockKv.SaveMockErr = false + }() + mgr.sendOutTasks(context.TODO()) + }() + }) + wg.Add(1) t.Run("pending task expired", func(t *testing.T) { defer wg.Done() ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - mgr := newImportManager(ctx, mockKv, idAlloc, fn, nil) + mgr := newImportManager(ctx, mockKv, idAlloc, callImportServiceFn, callUnsetImportingState, nil) assert.NotNil(t, mgr) mgr.pendingTasks = append(mgr.pendingTasks, &datapb.ImportTaskInfo{ Id: 300, State: &datapb.ImportTaskState{ StateCode: commonpb.ImportState_ImportPending, }, - CreateTs: time.Now().Unix() - 10, + CreateTs: time.Now().Unix() + 1, }) - mgr.loadFromTaskStore() + mgr.pendingTasks = append(mgr.pendingTasks, &datapb.ImportTaskInfo{ + Id: 400, + State: &datapb.ImportTaskState{ + StateCode: commonpb.ImportState_ImportPending, + }, + CreateTs: time.Now().Unix() - 100, + }) + assert.NoError(t, mgr.loadFromTaskStore()) var wgLoop sync.WaitGroup wgLoop.Add(2) - mgr.expireOldTasksLoop(&wgLoop, func(ctx context.Context, int64 int64, int64s []int64) error { - return nil - }) + assert.Equal(t, 2, len(mgr.pendingTasks)) + mgr.expireOldTasksLoop(&wgLoop) + assert.Equal(t, 1, len(mgr.pendingTasks)) mgr.sendOutTasksLoop(&wgLoop) wgLoop.Wait() }) @@ -147,22 +219,104 @@ func TestImportManager_NewImportManager(t *testing.T) { defer wg.Done() ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - mgr := newImportManager(ctx, mockKv, idAlloc, fn, nil) + mgr := newImportManager(ctx, mockKv, idAlloc, callImportServiceFn, callUnsetImportingState, nil) assert.NotNil(t, mgr) mgr.init(ctx) var wgLoop sync.WaitGroup wgLoop.Add(2) - mgr.expireOldTasksLoop(&wgLoop, func(ctx context.Context, int64 int64, int64s []int64) error { - return nil - }) + mgr.expireOldTasksLoop(&wgLoop) mgr.sendOutTasksLoop(&wgLoop) - time.Sleep(500 * time.Millisecond) + time.Sleep(100 * time.Millisecond) wgLoop.Wait() }) wg.Wait() } +func TestImportManager_TestEtcdCleanUp(t *testing.T) { + var countLock sync.RWMutex + var globalCount = typeutil.UniqueID(0) + + var idAlloc = func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { + countLock.Lock() + defer countLock.Unlock() + globalCount++ + return globalCount, 0, nil + } + Params.RootCoordCfg.ImportTaskSubPath = "test_import_task" + Params.RootCoordCfg.ImportTaskExpiration = 50 + Params.RootCoordCfg.ImportTaskRetention = 200 + checkPendingTasksInterval = 100 + expireOldTasksInterval = 100 + mockKv := &kv.MockMetaKV{} + mockKv.InMemKv = sync.Map{} + ti1 := &datapb.ImportTaskInfo{ + Id: 100, + State: &datapb.ImportTaskState{ + StateCode: commonpb.ImportState_ImportPending, + }, + CreateTs: time.Now().Unix() - 500, + } + ti2 := &datapb.ImportTaskInfo{ + Id: 200, + State: &datapb.ImportTaskState{ + StateCode: commonpb.ImportState_ImportPersisted, + }, + CreateTs: time.Now().Unix() - 500, + } + ti3 := &datapb.ImportTaskInfo{ + Id: 300, + State: &datapb.ImportTaskState{ + StateCode: commonpb.ImportState_ImportPersisted, + }, + CreateTs: time.Now().Unix() - 100, + } + taskInfo3, err := proto.Marshal(ti3) + assert.NoError(t, err) + taskInfo1, err := proto.Marshal(ti1) + assert.NoError(t, err) + taskInfo2, err := proto.Marshal(ti2) + assert.NoError(t, err) + mockKv.Save(BuildImportTaskKey(100), string(taskInfo1)) + mockKv.Save(BuildImportTaskKey(200), string(taskInfo2)) + mockKv.Save(BuildImportTaskKey(300), string(taskInfo3)) + + mockCallImportServiceErr := false + callImportServiceFn := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { + if mockCallImportServiceErr { + return &datapb.ImportTaskResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + }, errors.New("mock err") + } + return &datapb.ImportTaskResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + }, nil + } + + callUnsetImportingState := func(taskID int64) error { + return nil + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + mgr := newImportManager(ctx, mockKv, idAlloc, callImportServiceFn, callUnsetImportingState, nil) + assert.NotNil(t, mgr) + assert.NoError(t, mgr.loadFromTaskStore()) + var wgLoop sync.WaitGroup + wgLoop.Add(2) + keys, _, _ := mockKv.LoadWithPrefix("") + // All 3 tasks are stored in Etcd. + assert.Equal(t, 3, len(keys)) + mgr.expireOldTasksLoop(&wgLoop) + keys, _, _ = mockKv.LoadWithPrefix("") + // task 1 and task 2 have passed retention period. + assert.Equal(t, 1, len(keys)) + mgr.sendOutTasksLoop(&wgLoop) +} + func TestImportManager_ImportJob(t *testing.T) { var countLock sync.RWMutex var globalCount = typeutil.UniqueID(0) @@ -176,8 +330,11 @@ func TestImportManager_ImportJob(t *testing.T) { Params.RootCoordCfg.ImportTaskSubPath = "test_import_task" colID := int64(100) mockKv := &kv.MockMetaKV{} - mockKv.InMemKv = make(map[string]string) - mgr := newImportManager(context.TODO(), mockKv, idAlloc, nil, nil) + mockKv.InMemKv = sync.Map{} + callUnsetImportingState := func(taskID int64) error { + return nil + } + mgr := newImportManager(context.TODO(), mockKv, idAlloc, nil, callUnsetImportingState, nil) resp := mgr.importJob(context.TODO(), nil, colID, 0) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) @@ -204,60 +361,60 @@ func TestImportManager_ImportJob(t *testing.T) { }, } - fn := func(ctx context.Context, req *datapb.ImportTaskRequest) *datapb.ImportTaskResponse { + fn := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { return &datapb.ImportTaskResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, }, - } + }, nil } - mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, nil) + mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callUnsetImportingState, nil) resp = mgr.importJob(context.TODO(), rowReq, colID, 0) assert.Equal(t, len(rowReq.Files), len(mgr.pendingTasks)) assert.Equal(t, 0, len(mgr.workingTasks)) - mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, nil) + mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callUnsetImportingState, nil) resp = mgr.importJob(context.TODO(), colReq, colID, 0) assert.Equal(t, 1, len(mgr.pendingTasks)) assert.Equal(t, 0, len(mgr.workingTasks)) - fn = func(ctx context.Context, req *datapb.ImportTaskRequest) *datapb.ImportTaskResponse { + fn = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { return &datapb.ImportTaskResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, - } + }, nil } - mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, nil) + mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callUnsetImportingState, nil) resp = mgr.importJob(context.TODO(), rowReq, colID, 0) assert.Equal(t, 0, len(mgr.pendingTasks)) assert.Equal(t, len(rowReq.Files), len(mgr.workingTasks)) - mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, nil) + mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callUnsetImportingState, nil) resp = mgr.importJob(context.TODO(), colReq, colID, 0) assert.Equal(t, 0, len(mgr.pendingTasks)) assert.Equal(t, 1, len(mgr.workingTasks)) count := 0 - fn = func(ctx context.Context, req *datapb.ImportTaskRequest) *datapb.ImportTaskResponse { + fn = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { if count >= 2 { return &datapb.ImportTaskResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, }, - } + }, nil } count++ return &datapb.ImportTaskResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, - } + }, nil } - mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, nil) + mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callUnsetImportingState, nil) resp = mgr.importJob(context.TODO(), rowReq, colID, 0) assert.Equal(t, len(rowReq.Files)-2, len(mgr.pendingTasks)) assert.Equal(t, 2, len(mgr.workingTasks)) @@ -267,6 +424,10 @@ func TestImportManager_ImportJob(t *testing.T) { } resp = mgr.importJob(context.TODO(), rowReq, colID, 0) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + + segIDs, err := mgr.GetImportFailedSegmentIDs() + assert.True(t, len(segIDs) >= 0) + assert.Nil(t, err) } func TestImportManager_AllDataNodesBusy(t *testing.T) { @@ -282,7 +443,7 @@ func TestImportManager_AllDataNodesBusy(t *testing.T) { Params.RootCoordCfg.ImportTaskSubPath = "test_import_task" colID := int64(100) mockKv := &kv.MockMetaKV{} - mockKv.InMemKv = make(map[string]string) + mockKv.InMemKv = sync.Map{} rowReq := &milvuspb.ImportRequest{ CollectionName: "c1", PartitionName: "p1", @@ -304,7 +465,7 @@ func TestImportManager_AllDataNodesBusy(t *testing.T) { dnList := []int64{1, 2, 3} count := 0 - fn := func(ctx context.Context, req *datapb.ImportTaskRequest) *datapb.ImportTaskResponse { + fn := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { if count < len(dnList) { count++ return &datapb.ImportTaskResponse{ @@ -312,28 +473,31 @@ func TestImportManager_AllDataNodesBusy(t *testing.T) { ErrorCode: commonpb.ErrorCode_Success, }, DatanodeId: dnList[count-1], - } + }, nil } return &datapb.ImportTaskResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, }, - } + }, nil } - mgr := newImportManager(context.TODO(), mockKv, idAlloc, fn, nil) + callUnsetImportingState := func(taskID int64) error { + return nil + } + mgr := newImportManager(context.TODO(), mockKv, idAlloc, fn, callUnsetImportingState, nil) mgr.importJob(context.TODO(), rowReq, colID, 0) assert.Equal(t, 0, len(mgr.pendingTasks)) assert.Equal(t, len(rowReq.Files), len(mgr.workingTasks)) - mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, nil) + mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callUnsetImportingState, nil) mgr.importJob(context.TODO(), rowReq, colID, 0) assert.Equal(t, len(rowReq.Files), len(mgr.pendingTasks)) assert.Equal(t, 0, len(mgr.workingTasks)) // Reset count. count = 0 - mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, nil) + mgr = newImportManager(context.TODO(), mockKv, idAlloc, fn, callUnsetImportingState, nil) mgr.importJob(context.TODO(), colReq, colID, 0) assert.Equal(t, 0, len(mgr.pendingTasks)) assert.Equal(t, 1, len(mgr.workingTasks)) @@ -364,13 +528,13 @@ func TestImportManager_TaskState(t *testing.T) { Params.RootCoordCfg.ImportTaskSubPath = "test_import_task" colID := int64(100) mockKv := &kv.MockMetaKV{} - mockKv.InMemKv = make(map[string]string) - fn := func(ctx context.Context, req *datapb.ImportTaskRequest) *datapb.ImportTaskResponse { + mockKv.InMemKv = sync.Map{} + fn := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { return &datapb.ImportTaskResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, - } + }, nil } rowReq := &milvuspb.ImportRequest{ @@ -380,7 +544,11 @@ func TestImportManager_TaskState(t *testing.T) { Files: []string{"f1", "f2", "f3"}, } - mgr := newImportManager(context.TODO(), mockKv, idAlloc, fn, nil) + callUnsetImportingState := func(taskID int64) error { + return nil + } + + mgr := newImportManager(context.TODO(), mockKv, idAlloc, fn, callUnsetImportingState, nil) mgr.importJob(context.TODO(), rowReq, colID, 0) state := &rootcoordpb.ImportResult{ @@ -392,7 +560,7 @@ func TestImportManager_TaskState(t *testing.T) { state = &rootcoordpb.ImportResult{ TaskId: 2, RowCount: 1000, - State: commonpb.ImportState_ImportCompleted, + State: commonpb.ImportState_ImportPersisted, Infos: []*commonpb.KeyValuePair{ { Key: "key1", @@ -412,7 +580,7 @@ func TestImportManager_TaskState(t *testing.T) { assert.Equal(t, int64(0), ti.GetPartitionId()) assert.Equal(t, true, ti.GetRowBased()) assert.Equal(t, []string{"f2"}, ti.GetFiles()) - assert.Equal(t, commonpb.ImportState_ImportCompleted, ti.GetState().GetStateCode()) + assert.Equal(t, commonpb.ImportState_ImportPersisted, ti.GetState().GetStateCode()) assert.Equal(t, int64(1000), ti.GetState().GetRowCount()) resp := mgr.getTaskState(10000) @@ -420,11 +588,11 @@ func TestImportManager_TaskState(t *testing.T) { resp = mgr.getTaskState(2) assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) - assert.Equal(t, commonpb.ImportState_ImportCompleted, resp.State) + assert.Equal(t, commonpb.ImportState_ImportPersisted, resp.State) resp = mgr.getTaskState(1) assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) - assert.Equal(t, commonpb.ImportState_ImportPending, resp.State) + assert.Equal(t, commonpb.ImportState_ImportStarted, resp.State) } func TestImportManager_AllocFail(t *testing.T) { @@ -434,13 +602,13 @@ func TestImportManager_AllocFail(t *testing.T) { Params.RootCoordCfg.ImportTaskSubPath = "test_import_task" colID := int64(100) mockKv := &kv.MockMetaKV{} - mockKv.InMemKv = make(map[string]string) - fn := func(ctx context.Context, req *datapb.ImportTaskRequest) *datapb.ImportTaskResponse { + mockKv.InMemKv = sync.Map{} + fn := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { return &datapb.ImportTaskResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, - } + }, nil } rowReq := &milvuspb.ImportRequest{ @@ -450,7 +618,10 @@ func TestImportManager_AllocFail(t *testing.T) { Files: []string{"f1", "f2", "f3"}, } - mgr := newImportManager(context.TODO(), mockKv, idAlloc, fn, nil) + callUnsetImportingState := func(taskID int64) error { + return nil + } + mgr := newImportManager(context.TODO(), mockKv, idAlloc, fn, callUnsetImportingState, nil) mgr.importJob(context.TODO(), rowReq, colID, 0) } @@ -468,15 +639,15 @@ func TestImportManager_ListAllTasks(t *testing.T) { Params.RootCoordCfg.ImportTaskSubPath = "test_import_task" colID := int64(100) mockKv := &kv.MockMetaKV{} - mockKv.InMemKv = make(map[string]string) + mockKv.InMemKv = sync.Map{} // reject some tasks so there are 3 tasks left in pending list - fn := func(ctx context.Context, req *datapb.ImportTaskRequest) *datapb.ImportTaskResponse { + fn := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { return &datapb.ImportTaskResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, }, - } + }, nil } rowReq := &milvuspb.ImportRequest{ @@ -485,8 +656,10 @@ func TestImportManager_ListAllTasks(t *testing.T) { RowBased: true, Files: []string{"f1", "f2", "f3"}, } - - mgr := newImportManager(context.TODO(), mockKv, idAlloc, fn, nil) + callUnsetImportingState := func(taskID int64) error { + return nil + } + mgr := newImportManager(context.TODO(), mockKv, idAlloc, fn, callUnsetImportingState, nil) mgr.importJob(context.TODO(), rowReq, colID, 0) tasks := mgr.listAllTasks() @@ -498,12 +671,12 @@ func TestImportManager_ListAllTasks(t *testing.T) { assert.Equal(t, int64(1), resp.Id) // accept tasks to working list - mgr.callImportService = func(ctx context.Context, req *datapb.ImportTaskRequest) *datapb.ImportTaskResponse { + mgr.callImportService = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { return &datapb.ImportTaskResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, - } + }, nil } mgr.importJob(context.TODO(), rowReq, colID, 0) @@ -521,23 +694,31 @@ func TestImportManager_ListAllTasks(t *testing.T) { assert.Equal(t, 0, len(ids)) } -func TestImportManager_getCollectionPartitionName(t *testing.T) { +func TestImportManager_setCollectionPartitionName(t *testing.T) { mgr := &importManager{ getCollectionName: func(collID, partitionID typeutil.UniqueID) (string, string, error) { - return "c1", "p1", nil + if collID == 1 && partitionID == 2 { + return "c1", "p1", nil + } else { + return "", "", errors.New("Error") + } }, } - task := &datapb.ImportTaskInfo{ - CollectionId: 1, - PartitionId: 2, - } - resp := &milvuspb.GetImportStateResponse{ - Infos: make([]*commonpb.KeyValuePair, 0), + info := &datapb.ImportTaskInfo{ + Id: 100, + State: &datapb.ImportTaskState{ + StateCode: commonpb.ImportState_ImportStarted, + }, + CreateTs: time.Now().Unix() - 100, } - mgr.getCollectionPartitionName(task, resp) - assert.Equal(t, "c1", resp.Infos[0].Value) - assert.Equal(t, "p1", resp.Infos[1].Value) + err := mgr.setCollectionPartitionName(1, 2, info) + assert.Nil(t, err) + assert.Equal(t, "c1", info.GetCollectionName()) + assert.Equal(t, "p1", info.GetPartitionName()) + + err = mgr.setCollectionPartitionName(0, 0, info) + assert.Error(t, err) } func TestImportManager_rearrangeTasks(t *testing.T) { diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index b574e5f3ea7ce..effd2a0579b39 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -223,6 +223,183 @@ func (c *Core) tsLoop() { } } +func (c *Core) checkFlushedSegmentsLoop() { + defer c.wg.Done() + ticker := time.NewTicker(10 * time.Minute) + for { + select { + case <-c.ctx.Done(): + log.Debug("RootCoord context done, exit check FlushedSegmentsLoop") + return + case <-ticker.C: + log.Debug("check flushed segments") + c.checkFlushedSegments(c.ctx) + } + } +} + +func (c *Core) recycleDroppedIndex() { + defer c.wg.Done() + ticker := time.NewTicker(3 * time.Second) + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + droppedIndex := c.MetaTable.GetDroppedIndex() + for collID, fieldIndexes := range droppedIndex { + for _, fieldIndex := range fieldIndexes { + indexID := fieldIndex.GetIndexID() + fieldID := fieldIndex.GetFiledID() + if err := c.CallDropIndexService(c.ctx, indexID); err != nil { + log.Warn("Notify IndexCoord to drop index failed, wait to retry", zap.Int64("collID", collID), + zap.Int64("fieldID", fieldID), zap.Int64("indexID", indexID)) + } + } + } + err := c.MetaTable.RecycleDroppedIndex() + if err != nil { + log.Warn("Remove index meta failed, wait to retry", zap.Error(err)) + } + } + } +} + +func (c *Core) createIndexForSegment(ctx context.Context, collID, partID, segID UniqueID, numRows int64, binlogs []*datapb.FieldBinlog) error { + log.Info("creating index for segment", + zap.Int64("collection ID", collID), + zap.Int64("partition ID", partID), + zap.Int64("segment ID", segID), + zap.Int64("# of rows", numRows)) + collID2Meta, _, indexID2Meta := c.MetaTable.dupMeta() + collMeta, ok := collID2Meta[collID] + if !ok { + log.Error("collection meta is not exist", zap.Int64("collID", collID)) + return fmt.Errorf("collection meta is not exist with ID = %d", collID) + } + if len(collMeta.FieldIndexes) == 0 { + log.Info("collection has no index, no need to build index on segment", zap.Int64("collID", collID), + zap.Int64("segID", segID)) + return nil + } + for _, fieldIndex := range collMeta.FieldIndexes { + indexMeta, ok := indexID2Meta[fieldIndex.IndexID] + if !ok { + log.Warn("index has no meta", zap.Int64("collID", collID), zap.Int64("indexID", fieldIndex.IndexID)) + return fmt.Errorf("index has no meta with ID = %d in collection %d", fieldIndex.IndexID, collID) + } + if indexMeta.Deleted { + log.Info("index has been deleted, no need to build index on segment") + continue + } + + field, err := GetFieldSchemaByID(&collMeta, fieldIndex.FiledID) + if err != nil { + log.Error("GetFieldSchemaByID failed", + zap.Int64("collectionID", collID), + zap.Int64("fieldID", fieldIndex.FiledID), zap.Error(err)) + return err + } + if c.MetaTable.IsSegmentIndexed(segID, field, indexMeta.IndexParams) { + continue + } + createTS, err := c.TSOAllocator(1) + if err != nil { + log.Error("RootCoord alloc timestamp failed", zap.Int64("collectionID", collID), zap.Error(err)) + return err + } + + segIndexInfo := etcdpb.SegmentIndexInfo{ + CollectionID: collMeta.ID, + PartitionID: partID, + SegmentID: segID, + FieldID: fieldIndex.FiledID, + IndexID: fieldIndex.IndexID, + EnableIndex: false, + CreateTime: createTS, + } + buildID, err := c.BuildIndex(ctx, segID, numRows, binlogs, field, &indexMeta, false) + if err != nil { + log.Debug("build index failed", + zap.Int64("segmentID", segID), + zap.Int64("fieldID", field.FieldID), + zap.Int64("indexID", indexMeta.IndexID)) + return err + } + // if buildID == 0, means it's no need to build index. + if buildID != 0 { + log.Debug("no need to build index for segment", zap.Int64("segment ID", segID)) + segIndexInfo.BuildID = buildID + segIndexInfo.EnableIndex = true + } + + if err := c.MetaTable.AddIndex(&segIndexInfo); err != nil { + log.Error("Add index into meta table failed, need remove index with buildID", + zap.Int64("collectionID", collID), zap.Int64("indexID", fieldIndex.IndexID), + zap.Int64("buildID", buildID), zap.Error(err)) + if err = retry.Do(ctx, func() error { + return c.CallRemoveIndexService(ctx, []UniqueID{buildID}) + }); err != nil { + log.Error("remove index failed, need to be resolved manually", zap.Int64("collectionID", collID), + zap.Int64("indexID", fieldIndex.IndexID), zap.Int64("buildID", buildID), zap.Error(err)) + return err + } + return err + } + } + log.Info("successfully build index on segment", + zap.Int64("collection ID", collID), + zap.Int64("partition ID", partID), + zap.Int64("segment ID", segID), + zap.Int64("# of rows", numRows)) + return nil +} + +func (c *Core) checkFlushedSegments(ctx context.Context) { + collID2Meta := c.MetaTable.dupCollectionMeta() + for collID, collMeta := range collID2Meta { + if len(collMeta.FieldIndexes) == 0 { + continue + } + for _, partID := range collMeta.PartitionIDs { + segBinlogs, err := c.CallGetRecoveryInfoService(ctx, collMeta.ID, partID) + if err != nil { + log.Debug("failed to get flushed segments from dataCoord", + zap.Int64("collection ID", collMeta.GetID()), + zap.Int64("partition ID", partID), + zap.Error(err)) + continue + } + segIDs := make(map[UniqueID]struct{}) + for _, segBinlog := range segBinlogs { + segIDs[segBinlog.GetSegmentID()] = struct{}{} + err = c.createIndexForSegment(ctx, collID, partID, segBinlog.GetSegmentID(), segBinlog.GetNumOfRows(), segBinlog.GetFieldBinlogs()) + if err != nil { + log.Error("createIndexForSegment failed, wait to retry", zap.Int64("collID", collID), + zap.Int64("partID", partID), zap.Int64("segID", segBinlog.GetSegmentID()), zap.Error(err)) + continue + } + } + recycledSegIDs, recycledBuildIDs := c.MetaTable.AlignSegmentsMeta(collID, partID, segIDs) + log.Info("there buildIDs should be remove index", zap.Int64s("buildIDs", recycledBuildIDs)) + if len(recycledBuildIDs) > 0 { + if err := c.CallRemoveIndexService(ctx, recycledBuildIDs); err != nil { + log.Error("CallRemoveIndexService remove indexes on segments failed", + zap.Int64s("need dropped buildIDs", recycledBuildIDs), zap.Error(err)) + continue + } + } + + if err := c.MetaTable.RemoveSegments(collID, partID, recycledSegIDs); err != nil { + log.Warn("remove segments failed, wait to retry", zap.Int64("collID", collID), zap.Int64("partID", partID), + zap.Int64s("segIDs", recycledSegIDs), zap.Error(err)) + continue + } + } + } +} + func (c *Core) SetNewProxyClient(f func(sess *sessionutil.Session) (types.Proxy, error)) { c.proxyCreator = f } @@ -611,11 +788,13 @@ func (c *Core) startInternal() error { } c.wg.Add(5) - go c.tsLoop() go c.startTimeTickLoop() + go c.tsLoop() go c.chanTimeTick.startWatch(&c.wg) - go c.importManager.expireOldTasksLoop(&c.wg, c.broker.ReleaseSegRefLock) + go c.importManager.expireOldTasksLoop(&c.wg) go c.importManager.sendOutTasksLoop(&c.wg) + Params.RootCoordCfg.CreatedTime = time.Now() + Params.RootCoordCfg.UpdatedTime = time.Now() if Params.QuotaConfig.EnableQuotaAndLimits { go c.quotaCenter.run() @@ -1586,21 +1765,16 @@ func (c *Core) ReportImport(ctx context.Context, ir *rootcoordpb.ImportResult) ( if code, ok := c.checkHealthy(); !ok { return failStatus(commonpb.ErrorCode_UnexpectedError, "StateCode="+internalpb.StateCode_name[int32(code)]), nil } - // Special case for ImportState_ImportAllocSegment state, where we shall only add segment ref lock and do no other - // operations. - // TODO: This is inelegant and must get re-structured. - if ir.GetState() == commonpb.ImportState_ImportAllocSegment { - // Lock the segments, so we don't lose track of them when compaction happens. - // Note that these locks will be unlocked in c.postImportPersistLoop() -> checkSegmentLoadedLoop(). - if err := c.broker.AddSegRefLock(ctx, ir.GetTaskId(), ir.GetSegments()); err != nil { - log.Error("failed to acquire segment ref lock", zap.Error(err)) + // If setting ImportState_ImportCompleted, simply update the state and return directly. + if ir.GetState() == commonpb.ImportState_ImportCompleted { + if err := c.importManager.setCompleteImportState(ir.GetTaskId()); err != nil { + errMsg := "failed to set import task as ImportState_ImportCompleted" + log.Error(errMsg, zap.Error(err)) return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: fmt.Sprintf("failed to acquire segment ref lock %s", err.Error()), + Reason: fmt.Sprintf("%s %s", errMsg, err.Error()), }, nil } - // Update task store with new segments. - c.importManager.appendTaskSegments(ir.GetTaskId(), ir.GetSegments()) return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, nil @@ -1625,20 +1799,15 @@ func (c *Core) ReportImport(ctx context.Context, ir *rootcoordpb.ImportResult) ( zap.Int64("task ID", ir.GetTaskId())) }() - c.importManager.sendOutTasks(c.importManager.ctx) + err := c.importManager.sendOutTasks(c.importManager.ctx) + if err != nil { + log.Error("fail to send out import task to datanodes") + } } // If task failed, send task to idle datanode if ir.GetState() == commonpb.ImportState_ImportFailed { - // Release segments when task fails. - log.Info("task failed, release segment ref locks") - err := retry.Do(ctx, func() error { - return c.broker.ReleaseSegRefLock(ctx, ir.GetTaskId(), ir.GetSegments()) - }, retry.Attempts(100)) - if err != nil { - log.Error("failed to release lock, about to panic!") - panic(err) - } + log.Info("task failed, resending import task") resendTaskFunc() } @@ -1652,168 +1821,17 @@ func (c *Core) ReportImport(ctx context.Context, ir *rootcoordpb.ImportResult) ( }, nil } - // Look up collection name on collection ID. - var colName string - var colMeta *model.Collection - if colMeta, err = c.meta.GetCollectionByID(ctx, ti.GetCollectionId(), typeutil.MaxTimestamp); err != nil { - log.Error("failed to get collection name", - zap.Int64("collection ID", ti.GetCollectionId()), - zap.Error(err)) - // In some unexpected cases, user drop collection when bulkload task still in pending list, the datanode become idle. - // If we directly return, the pending tasks will remain in pending list. So we call resendTaskFunc() to push next pending task to idle datanode. - resendTaskFunc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_CollectionNameNotFound, - Reason: "failed to get collection name for collection ID" + strconv.FormatInt(ti.GetCollectionId(), 10), - }, nil - } - colName = colMeta.Name - // When DataNode has done its thing, remove it from the busy node list. And send import task again resendTaskFunc() // Flush all import data segments. c.broker.Flush(ctx, ti.GetCollectionId(), ir.GetSegments()) - // Check if data are "queryable" and if indices are built on all segments. - go c.postImportPersistLoop(c.ctx, ir.GetTaskId(), ti.GetCollectionId(), colName, ir.GetSegments()) return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, nil } -// CountCompleteIndex checks indexing status of the given segments. -// It returns an error if error occurs. It also returns a boolean indicating whether indexing is done (or if no index -// is needed). -func (c *Core) CountCompleteIndex(ctx context.Context, collectionName string, collectionID UniqueID, - allSegmentIDs []UniqueID) (bool, error) { - // Note: Index name is always Params.CommonCfg.DefaultIndexName in current Milvus designs as of today. - indexName := Params.CommonCfg.DefaultIndexName - - states, err := c.broker.GetSegmentIndexState(ctx, collectionID, indexName, allSegmentIDs) - if err != nil { - log.Error("failed to get index state in checkSegmentIndexStates", zap.Error(err)) - return false, err - } - - // Count the # of segments with finished index. - ct := 0 - for _, s := range states { - if s.State == commonpb.IndexState_Finished { - ct++ - } - } - log.Info("segment indexing state checked", - //zap.Int64s("segments checked", seg2Check), - //zap.Int("# of checked segment", len(seg2Check)), - zap.Int("# of segments with complete index", ct), - zap.String("collection name", collectionName), - zap.Int64("collection ID", collectionID), - ) - return len(allSegmentIDs) == ct, nil -} - -func (c *Core) postImportPersistLoop(ctx context.Context, taskID int64, colID int64, colName string, segIDs []UniqueID) { - // Loop and check if segments are loaded in queryNodes. - c.wg.Add(1) - go c.checkSegmentLoadedLoop(ctx, taskID, colID, segIDs) - // Check if collection has any indexed fields. If so, start a loop to check segments' index states. - if _, err := c.meta.GetCollectionByID(ctx, colID, typeutil.MaxTimestamp); err != nil { - log.Error("failed to find meta for collection", - zap.Int64("collection ID", colID), - zap.Error(err)) - } else { - log.Info("start checking index state", zap.Int64("collection ID", colID)) - c.wg.Add(1) - go c.checkCompleteIndexLoop(ctx, taskID, colID, colName, segIDs) - } -} - -// checkSegmentLoadedLoop loops and checks if all segments in `segIDs` are loaded in queryNodes. -func (c *Core) checkSegmentLoadedLoop(ctx context.Context, taskID int64, colID int64, segIDs []UniqueID) { - defer c.wg.Done() - ticker := time.NewTicker(time.Duration(Params.RootCoordCfg.ImportSegmentStateCheckInterval*1000) * time.Millisecond) - defer ticker.Stop() - expireTicker := time.NewTicker(time.Duration(Params.RootCoordCfg.ImportSegmentStateWaitLimit*1000) * time.Millisecond) - defer expireTicker.Stop() - defer func() { - log.Info("we are done checking segment loading state, release segment ref locks") - err := retry.Do(ctx, func() error { - return c.broker.ReleaseSegRefLock(ctx, taskID, segIDs) - }, retry.Attempts(100)) - if err != nil { - log.Error("failed to release lock, about to panic!") - panic(err) - } - }() - for { - select { - case <-c.ctx.Done(): - log.Info("(in check segment loaded loop) context done, exiting checkSegmentLoadedLoop") - return - case <-ticker.C: - resp, err := c.broker.GetQuerySegmentInfo(ctx, colID, segIDs) - log.Debug("(in check segment loaded loop)", - zap.Int64("task ID", taskID), - zap.Int64("collection ID", colID), - zap.Int64s("segment IDs expected", segIDs), - zap.Int("# of segments found", len(resp.GetInfos()))) - if err != nil { - log.Warn("(in check segment loaded loop) failed to call get segment info on queryCoord", - zap.Int64("task ID", taskID), - zap.Int64("collection ID", colID), - zap.Int64s("segment IDs", segIDs)) - } else if len(resp.GetInfos()) == len(segIDs) { - // Check if all segment info are loaded in queryNodes. - log.Info("(in check segment loaded loop) all import data segments loaded in queryNodes", - zap.Int64("task ID", taskID), - zap.Int64("collection ID", colID), - zap.Int64s("segment IDs", segIDs)) - c.importManager.setTaskDataQueryable(taskID) - return - } - case <-expireTicker.C: - log.Warn("(in check segment loaded loop) segments still not loaded after max wait time", - zap.Int64("task ID", taskID), - zap.Int64("collection ID", colID), - zap.Int64s("segment IDs", segIDs)) - return - } - } -} - -// checkCompleteIndexLoop loops and checks if all indices are built for an import task's segments. -func (c *Core) checkCompleteIndexLoop(ctx context.Context, taskID int64, colID int64, colName string, segIDs []UniqueID) { - defer c.wg.Done() - ticker := time.NewTicker(time.Duration(Params.RootCoordCfg.ImportIndexCheckInterval*1000) * time.Millisecond) - defer ticker.Stop() - expireTicker := time.NewTicker(time.Duration(Params.RootCoordCfg.ImportIndexWaitLimit*1000) * time.Millisecond) - defer expireTicker.Stop() - for { - select { - case <-c.ctx.Done(): - log.Info("(in check complete index loop) context done, exiting checkCompleteIndexLoop") - return - case <-ticker.C: - if done, err := c.CountCompleteIndex(ctx, colName, colID, segIDs); err == nil && done { - log.Info("(in check complete index loop) indices are built or no index needed", - zap.Int64("task ID", taskID)) - c.importManager.setTaskDataIndexed(taskID) - return - } else if err != nil { - log.Error("(in check complete index loop) an error occurs", - zap.Error(err)) - } - case <-expireTicker.C: - log.Warn("(in check complete index loop) indexing is taken too long", - zap.Int64("task ID", taskID), - zap.Int64("collection ID", colID), - zap.Int64s("segment IDs", segIDs)) - return - } - } -} - // ExpireCredCache will call invalidate credential cache func (c *Core) ExpireCredCache(ctx context.Context, username string) error { req := proxypb.InvalidateCredCacheRequest{ @@ -1992,6 +2010,196 @@ func (c *Core) ListCredUsers(ctx context.Context, in *milvuspb.ListCredUsersRequ }, nil } +func (c *Core) GetImportFailedSegmentIDs(ctx context.Context, req *internalpb.GetImportFailedSegmentIDsRequest) (*internalpb.GetImportFailedSegmentIDsResponse, error) { + segmentIds, err := c.importManager.GetImportFailedSegmentIDs() + if err != nil { + return &internalpb.GetImportFailedSegmentIDsResponse{ + Status: failStatus(commonpb.ErrorCode_GetImportFailedSegmentsFailure, "GetImportFailedSegmentsIDFailed"+err.Error()), + }, err + } + return &internalpb.GetImportFailedSegmentIDsResponse{ + Status: succStatus(), + SegmentIDs: segmentIds, + }, nil +} + +func (c *Core) CheckSegmentIndexReady(ctx context.Context, req *internalpb.CheckSegmentIndexReadyRequest) (*commonpb.Status, error) { + log.Info("start checking segments index ready states", + zap.Int64("task ID", req.GetTaskID()), + zap.Int64("col ID", req.GetColID()), + zap.Int64s("segment IDs", req.GetSegIDs())) + // Look up collection name on collection ID. + var colName string + var colMeta *etcdpb.CollectionInfo + var err error + if colMeta, err = c.MetaTable.GetCollectionByID(req.GetColID(), 0); err != nil { + log.Error("failed to get collection name", + zap.Int64("collection ID", req.GetColID()), + zap.Error(err)) + // In some unexpected cases, user drop collection when bulk load task still in pending list, the datanode become idle. + // If we directly return, the pending tasks will remain in pending list. So we call resendTaskFunc() to push next pending task to idle datanode. + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_CollectionNameNotFound, + Reason: "failed to get collection name for collection ID" + strconv.FormatInt(req.GetColID(), 10), + }, nil + } + colName = colMeta.GetSchema().GetName() + // Check if collection has any indexed fields. If so, start a loop to check segments' index states. + if len(colMeta.GetFieldIndexes()) == 0 { + log.Info("no index field found for collection", zap.Int64("collection ID", req.GetColID())) + } else { + log.Info("start checking index state", zap.Int64("collection ID", req.GetColID())) + ticker := time.NewTicker(time.Duration(Params.RootCoordCfg.ImportIndexCheckInterval*1000) * time.Millisecond) + defer ticker.Stop() + expireTicker := time.NewTicker(time.Duration(Params.RootCoordCfg.ImportIndexWaitLimit*1000) * time.Millisecond) + defer expireTicker.Stop() + for { + select { + case <-c.ctx.Done(): + log.Info("(in check segment index ready loop) context done, exiting CheckSegmentIndexReady loop") + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "check complete index context done", + }, nil + case <-ticker.C: + if done, err := c.countCompleteIndex(ctx, colName, req.GetColID(), req.GetSegIDs()); err == nil && done { + log.Info("(in check segment index ready loop) indexes are built or no index needed", + zap.Int64("task ID", req.GetTaskID())) + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, nil + } else if err != nil { + log.Error("(in check segment index ready loop) an error occurs", + zap.Error(err)) + } + case <-expireTicker.C: + log.Warn("(in check segment index ready loop) indexing is taken too long", + zap.Int64("task ID", req.GetTaskID()), + zap.Int64("collection ID", req.GetColID()), + zap.Int64s("segment IDs", req.GetSegIDs())) + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "index building is taking too long", + }, nil + } + } + } + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "unexpected return when checking segment index ready states", + }, nil +} + +// CountCompleteIndex checks indexing status of the given segments. +// It returns an error if error occurs. It also returns a boolean indicating whether all indexes are built on the given +// segments. +func (c *Core) countCompleteIndex(ctx context.Context, collectionName string, collectionID UniqueID, + allSegmentIDs []UniqueID) (bool, error) { + // Note: Index name is always Params.CommonCfg.DefaultIndexName in current Milvus designs as of today. + indexName := Params.CommonCfg.DefaultIndexName + + // Retrieve index status and detailed index information. + describeIndexReq := &milvuspb.DescribeIndexRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DescribeIndex, + }, + CollectionName: collectionName, + IndexName: indexName, + } + indexDescriptionResp, err := c.DescribeIndex(ctx, describeIndexReq) + if err != nil { + return false, err + } + if len(indexDescriptionResp.GetIndexDescriptions()) == 0 { + log.Info("no index needed for collection, consider indexing done", + zap.Int64("collection ID", collectionID)) + return true, nil + } + log.Debug("got index description", + zap.Any("index description", indexDescriptionResp)) + + // Check if the target index name exists. + matchIndexID := int64(-1) + foundIndexID := false + for _, desc := range indexDescriptionResp.GetIndexDescriptions() { + if desc.GetIndexName() == indexName { + matchIndexID = desc.GetIndexID() + foundIndexID = true + break + } + } + if !foundIndexID { + return false, fmt.Errorf("no index is created") + } + log.Debug("found match index ID", + zap.Int64("match index ID", matchIndexID)) + + getIndexStatesRequest := &indexpb.GetIndexStatesRequest{ + IndexBuildIDs: make([]UniqueID, 0), + } + + // Fetch index build IDs from segments. + var seg2Check []UniqueID + for _, segmentID := range allSegmentIDs { + describeSegmentRequest := &milvuspb.DescribeSegmentRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DescribeSegment, + }, + CollectionID: collectionID, + SegmentID: segmentID, + } + segmentDesc, err := c.DescribeSegment(ctx, describeSegmentRequest) + if err != nil { + log.Error("failed to describe segment", zap.Error(err)) + return false, nil + } + if segmentDesc.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + log.Error("failed to describe segment", + zap.Int64("collection ID", collectionID), + zap.Int64("segment ID", segmentID), + zap.String("error", segmentDesc.GetStatus().GetReason())) + return false, nil + } + if segmentDesc.GetIndexID() == matchIndexID { + if segmentDesc.GetEnableIndex() { + seg2Check = append(seg2Check, segmentID) + getIndexStatesRequest.IndexBuildIDs = append(getIndexStatesRequest.GetIndexBuildIDs(), segmentDesc.GetBuildID()) + } + } + } + if len(getIndexStatesRequest.GetIndexBuildIDs()) == 0 { + log.Info("no index build IDs returned, perhaps no index is needed", + zap.String("collection name", collectionName), + zap.Int64("collection ID", collectionID)) + return true, nil + } + + log.Debug("working on GetIndexState", + zap.Int("# of IndexBuildIDs", len(getIndexStatesRequest.GetIndexBuildIDs()))) + + states, err := c.CallGetIndexStatesService(ctx, getIndexStatesRequest.GetIndexBuildIDs()) + if err != nil { + log.Error("failed to get index state in checkSegmentIndexStates", zap.Error(err)) + return false, err + } + + // Count the # of segments with finished index. + ct := 0 + for _, s := range states { + if s.State == commonpb.IndexState_Finished { + ct++ + } + } + log.Info("segment indexing state checked", + zap.Int64s("segments checked", seg2Check), + zap.Int("# of checked segment", len(seg2Check)), + zap.Int("# of segments with complete index", ct), + zap.String("collection name", collectionName), + zap.Int64("collection ID", collectionID), + ) + return len(seg2Check) == ct, nil +} + // CreateRole create role // - check the node health // - check if the role is existed diff --git a/internal/storage/minio_chunk_manager.go b/internal/storage/minio_chunk_manager.go index 74f3e3a60523a..8952943f07613 100644 --- a/internal/storage/minio_chunk_manager.go +++ b/internal/storage/minio_chunk_manager.go @@ -36,6 +36,8 @@ import ( "golang.org/x/exp/mmap" ) +var CheckBucketRetryAttempts uint = 20 + // MinioChunkManager is responsible for read and write data stored in minio. type MinioChunkManager struct { *minio.Client @@ -95,7 +97,7 @@ func newMinioChunkManagerWithConfig(ctx context.Context, c *config) (*MinioChunk } return nil } - err = retry.Do(ctx, checkBucketFn, retry.Attempts(20)) + err = retry.Do(ctx, checkBucketFn, retry.Attempts(CheckBucketRetryAttempts)) if err != nil { return nil, err } diff --git a/internal/types/types.go b/internal/types/types.go index 8af8d5fa91d61..dcfb9df660b46 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -94,8 +94,8 @@ type DataNode interface { // It returns a list of segments to be sent. ResendSegmentStats(ctx context.Context, req *datapb.ResendSegmentStatsRequest) (*datapb.ResendSegmentStatsResponse, error) - // AddSegment puts the given segment to current DataNode's flow graph. - AddSegment(ctx context.Context, req *datapb.AddSegmentRequest) (*commonpb.Status, error) + // AddImportSegment puts the given import segment to current DataNode's flow graph. + AddImportSegment(ctx context.Context, req *datapb.AddImportSegmentRequest) (*commonpb.Status, error) } // DataNodeComponent is used by grpc server of DataNode @@ -304,9 +304,15 @@ type DataCoord interface { AcquireSegmentLock(ctx context.Context, req *datapb.AcquireSegmentLockRequest) (*commonpb.Status, error) ReleaseSegmentLock(ctx context.Context, req *datapb.ReleaseSegmentLockRequest) (*commonpb.Status, error) - // AddSegment looks for the right DataNode given channel name, and triggers AddSegment call on that DataNode to - // add the segment into this DataNode. - AddSegment(ctx context.Context, req *datapb.AddSegmentRequest) (*commonpb.Status, error) + // SaveImportSegment saves the import segment binlog paths data and then looks for the right DataNode to add the + // segment to that DataNode. + SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) (*commonpb.Status, error) + + // CompleteBulkLoad is the DataCoord side work for a complete bulk load operation. + CompleteBulkLoad(ctx context.Context, req *datapb.CompleteBulkLoadRequest) (*commonpb.Status, error) + + // UnsetIsImportingState unsets the `isImport` state of the given segments so that they can get compacted normally. + UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) } // DataCoordComponent defines the interface of DataCoord component. @@ -732,6 +738,10 @@ type RootCoord interface { OperatePrivilege(ctx context.Context, req *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) + // GetImportFailedSegmentIDs get import failed segment IDs + GetImportFailedSegmentIDs(ctx context.Context, req *internalpb.GetImportFailedSegmentIDsRequest) (*internalpb.GetImportFailedSegmentIDsResponse, error) + // CheckSegmentIndexReady checks if indexes have been successfully built on the given segments. + CheckSegmentIndexReady(ctx context.Context, req *internalpb.CheckSegmentIndexReadyRequest) (*commonpb.Status, error) } // RootCoordComponent is used by grpc server of RootCoord diff --git a/internal/util/importutil/import_wrapper.go b/internal/util/importutil/import_wrapper.go index eadb8f51d9d00..4351f5e14b5b5 100644 --- a/internal/util/importutil/import_wrapper.go +++ b/internal/util/importutil/import_wrapper.go @@ -19,6 +19,7 @@ import ( "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/timerecord" "github.com/milvus-io/milvus/internal/util/typeutil" ) @@ -171,57 +172,13 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b log.Info("import wrapper: row-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType)) if fileType == JSONFileExt { - err := func() error { - tr := timerecord.NewTimeRecorder("json row-based parser: " + filePath) - - // for minio storage, chunkManager will download file into local memory - // for local storage, chunkManager open the file directly - file, err := p.chunkManager.Reader(filePath) - if err != nil { - return err - } - defer file.Close() - tr.Record("open reader") - - // report file process state - p.importResult.State = commonpb.ImportState_ImportDownloaded - p.reportFunc(p.importResult) - - // parse file - reader := bufio.NewReader(file) - parser := NewJSONParser(p.ctx, p.collectionSchema) - var consumer *JSONRowConsumer - if !onlyValidate { - flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error { - p.printFieldsDataInfo(fields, "import wrapper: prepare to flush segment", filePaths) - return p.callFlushFunc(fields, shardNum) - } - consumer = NewJSONRowConsumer(p.collectionSchema, p.rowIDAllocator, p.shardNum, p.segmentSize, flushFunc) - } - validator := NewJSONRowValidator(p.collectionSchema, consumer) - err = parser.ParseRows(reader, validator) - if err != nil { - return err - } - - // for row-based files, auto-id is generated within JSONRowConsumer - if consumer != nil { - p.importResult.AutoIds = append(p.importResult.AutoIds, consumer.IDRange()...) - } - - // report file process state - p.importResult.State = commonpb.ImportState_ImportParsed - p.reportFunc(p.importResult) - - tr.Record("parsed") - return nil - }() - + err = p.parseRowBasedJSON(filePath, onlyValidate) if err != nil { log.Error("import error: "+err.Error(), zap.String("filePath", filePath)) return err } } + // no need to check else, since the fileValidation() already do this } } else { // parse and consume column-based files @@ -269,103 +226,24 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b // parse/validate/consume data for i := 0; i < len(filePaths); i++ { filePath := filePaths[i] - fileName, fileType := getFileNameAndExt(filePath) + _, fileType := getFileNameAndExt(filePath) log.Info("import wrapper: column-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType)) if fileType == JSONFileExt { - err := func() error { - tr := timerecord.NewTimeRecorder("json column-based parser: " + filePath) - - // for minio storage, chunkManager will download file into local memory - // for local storage, chunkManager open the file directly - file, err := p.chunkManager.Reader(filePath) - if err != nil { - return err - } - defer file.Close() - tr.Record("open reader") - - // report file process state - p.importResult.State = commonpb.ImportState_ImportDownloaded - p.reportFunc(p.importResult) - - // parse file - reader := bufio.NewReader(file) - parser := NewJSONParser(p.ctx, p.collectionSchema) - var consumer *JSONColumnConsumer - if !onlyValidate { - consumer = NewJSONColumnConsumer(p.collectionSchema, combineFunc) - } - validator := NewJSONColumnValidator(p.collectionSchema, consumer) - - err = parser.ParseColumns(reader, validator) - if err != nil { - return err - } - - // report file process state - p.importResult.State = commonpb.ImportState_ImportParsed - p.reportFunc(p.importResult) - - tr.Record("parsed") - return nil - }() - + err = p.parseColumnBasedJSON(filePath, onlyValidate, combineFunc) if err != nil { log.Error("import error: "+err.Error(), zap.String("filePath", filePath)) return err } } else if fileType == NumpyFileExt { - err := func() error { - tr := timerecord.NewTimeRecorder("numpy parser: " + filePath) - - // for minio storage, chunkManager will download file into local memory - // for local storage, chunkManager open the file directly - file, err := p.chunkManager.Reader(filePath) - if err != nil { - return err - } - defer file.Close() - tr.Record("open reader") - - // report file process state - p.importResult.State = commonpb.ImportState_ImportDownloaded - p.reportFunc(p.importResult) - - var id storage.FieldID - for _, field := range p.collectionSchema.Fields { - if field.GetName() == fileName { - id = field.GetFieldID() - } - } - - // the numpy parser return a storage.FieldData, here construct a map[string]storage.FieldData to combine - flushFunc := func(field storage.FieldData) error { - fields := make(map[storage.FieldID]storage.FieldData) - fields[id] = field - return combineFunc(fields) - } - - // for numpy file, we say the file name(without extension) is the filed name - parser := NewNumpyParser(p.ctx, p.collectionSchema, flushFunc) - err = parser.Parse(file, fileName, onlyValidate) - if err != nil { - return err - } - - // report file process state - p.importResult.State = commonpb.ImportState_ImportParsed - p.reportFunc(p.importResult) - - tr.Record("parsed") - return nil - }() + err = p.parseColumnBasedNumpy(filePath, onlyValidate, combineFunc) if err != nil { log.Error("import error: "+err.Error(), zap.String("filePath", filePath)) return err } } + // no need to check else, since the fileValidation() already do this } // split fields data into segments @@ -379,7 +257,130 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b debug.FreeOSMemory() // report file process state p.importResult.State = commonpb.ImportState_ImportPersisted - return p.reportFunc(p.importResult) + // persist state task is valuable, retry more times in case fail this task only because of network error + reportErr := retry.Do(p.ctx, func() error { + return p.reportFunc(p.importResult) + }, retry.Attempts(10)) + if reportErr != nil { + log.Warn("fail to report import state to root coord", zap.Error(err)) + return reportErr + } + return nil +} + +func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) error { + tr := timerecord.NewTimeRecorder("json row-based parser: " + filePath) + + // for minio storage, chunkManager will download file into local memory + // for local storage, chunkManager open the file directly + file, err := p.chunkManager.Reader(filePath) + if err != nil { + return err + } + defer file.Close() + + // parse file + reader := bufio.NewReader(file) + parser := NewJSONParser(p.ctx, p.collectionSchema) + var consumer *JSONRowConsumer + if !onlyValidate { + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardID int) error { + var filePaths = []string{filePath} + p.printFieldsDataInfo(fields, "import wrapper: prepare to flush segment", filePaths) + return p.callFlushFunc(fields, shardID) + } + consumer = NewJSONRowConsumer(p.collectionSchema, p.rowIDAllocator, p.shardNum, p.segmentSize, flushFunc) + } + validator := NewJSONRowValidator(p.collectionSchema, consumer) + err = parser.ParseRows(reader, validator) + if err != nil { + return err + } + + // for row-based files, auto-id is generated within JSONRowConsumer + if consumer != nil { + p.importResult.AutoIds = append(p.importResult.AutoIds, consumer.IDRange()...) + } + + tr.Elapse("parsed") + return nil +} + +func (p *ImportWrapper) parseColumnBasedJSON(filePath string, onlyValidate bool, + combineFunc func(fields map[storage.FieldID]storage.FieldData) error) error { + tr := timerecord.NewTimeRecorder("json column-based parser: " + filePath) + + // for minio storage, chunkManager will download file into local memory + // for local storage, chunkManager open the file directly + file, err := p.chunkManager.Reader(filePath) + if err != nil { + return err + } + defer file.Close() + + // parse file + reader := bufio.NewReader(file) + parser := NewJSONParser(p.ctx, p.collectionSchema) + var consumer *JSONColumnConsumer + if !onlyValidate { + consumer = NewJSONColumnConsumer(p.collectionSchema, combineFunc) + } + validator := NewJSONColumnValidator(p.collectionSchema, consumer) + + err = parser.ParseColumns(reader, validator) + if err != nil { + return err + } + + tr.Elapse("parsed") + return nil +} + +func (p *ImportWrapper) parseColumnBasedNumpy(filePath string, onlyValidate bool, + combineFunc func(fields map[storage.FieldID]storage.FieldData) error) error { + tr := timerecord.NewTimeRecorder("numpy parser: " + filePath) + + fileName, _ := getFileNameAndExt(filePath) + + // for minio storage, chunkManager will download file into local memory + // for local storage, chunkManager open the file directly + file, err := p.chunkManager.Reader(filePath) + if err != nil { + return err + } + defer file.Close() + + var id storage.FieldID + var found = false + for _, field := range p.collectionSchema.Fields { + if field.GetName() == fileName { + id = field.GetFieldID() + found = true + break + } + } + + // if the numpy file name is not mapping to a field name, ignore it + if !found { + return nil + } + + // the numpy parser return a storage.FieldData, here construct a map[string]storage.FieldData to combine + flushFunc := func(field storage.FieldData) error { + fields := make(map[storage.FieldID]storage.FieldData) + fields[id] = field + return combineFunc(fields) + } + + // for numpy file, we say the file name(without extension) is the filed name + parser := NewNumpyParser(p.ctx, p.collectionSchema, flushFunc) + err = parser.Parse(file, fileName, onlyValidate) + if err != nil { + return err + } + + tr.Elapse("parsed") + return nil } func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storage.FieldData, n int, target storage.FieldData) error { @@ -544,11 +545,11 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.F appendFunctions := make(map[string]func(src storage.FieldData, n int, target storage.FieldData) error) for i := 0; i < len(p.collectionSchema.Fields); i++ { schema := p.collectionSchema.Fields[i] - appendFunc := p.appendFunc(schema) - if appendFunc == nil { + appendFuncErr := p.appendFunc(schema) + if appendFuncErr == nil { return errors.New("import error: unsupported field data type") } - appendFunctions[schema.GetName()] = appendFunc + appendFunctions[schema.GetName()] = appendFuncErr } // split data into segments diff --git a/internal/util/importutil/import_wrapper_test.go b/internal/util/importutil/import_wrapper_test.go index a85659bdef373..fa283f6bea15f 100644 --- a/internal/util/importutil/import_wrapper_test.go +++ b/internal/util/importutil/import_wrapper_test.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "strconv" "testing" "time" @@ -473,6 +474,7 @@ func Test_ImportColumnBased_numpy(t *testing.T) { err = cm.Write(filePath, content) assert.NoError(t, err) + importResult.State = commonpb.ImportState_ImportStarted wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc, importResult, reportFunc) files = make([]string, 0) files = append(files, filePath) @@ -802,3 +804,238 @@ func Test_FileValidation(t *testing.T) { err = wrapper.fileValidation(files[:], false) assert.NotNil(t, err) } + +func Test_ReportImportFailRowBased(t *testing.T) { + f := dependency.NewDefaultFactory(true) + ctx := context.Background() + cm, err := f.NewVectorStorageChunkManager(ctx) + assert.NoError(t, err) + + idAllocator := newIDAllocator(ctx, t) + + content := []byte(`{ + "rows":[ + {"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 10001, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]}, + {"field_bool": false, "field_int8": 11, "field_int16": 102, "field_int32": 1002, "field_int64": 10002, "field_float": 3.15, "field_double": 2.56, "field_string": "hello world", "field_binary_vector": [253, 0], "field_float_vector": [2.1, 2.2, 2.3, 2.4]}, + {"field_bool": true, "field_int8": 12, "field_int16": 103, "field_int32": 1003, "field_int64": 10003, "field_float": 3.16, "field_double": 3.56, "field_string": "hello world", "field_binary_vector": [252, 0], "field_float_vector": [3.1, 3.2, 3.3, 3.4]}, + {"field_bool": false, "field_int8": 13, "field_int16": 104, "field_int32": 1004, "field_int64": 10004, "field_float": 3.17, "field_double": 4.56, "field_string": "hello world", "field_binary_vector": [251, 0], "field_float_vector": [4.1, 4.2, 4.3, 4.4]}, + {"field_bool": true, "field_int8": 14, "field_int16": 105, "field_int32": 1005, "field_int64": 10005, "field_float": 3.18, "field_double": 5.56, "field_string": "hello world", "field_binary_vector": [250, 0], "field_float_vector": [5.1, 5.2, 5.3, 5.4]} + ] + }`) + + filePath := TempFilesPath + "rows_1.json" + err = cm.Write(filePath, content) + assert.NoError(t, err) + defer cm.RemoveWithPrefix("") + + rowCount := 0 + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error { + count := 0 + for _, data := range fields { + assert.Less(t, 0, data.RowNum()) + if count == 0 { + count = data.RowNum() + } else { + assert.Equal(t, count, data.RowNum()) + } + } + rowCount += count + return nil + } + + // success case + importResult := &rootcoordpb.ImportResult{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + TaskId: 1, + DatanodeId: 1, + State: commonpb.ImportState_ImportStarted, + Segments: make([]int64, 0), + AutoIds: make([]int64, 0), + RowCount: 0, + } + reportFunc := func(res *rootcoordpb.ImportResult) error { + return nil + } + wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc, importResult, reportFunc) + files := make([]string, 0) + files = append(files, filePath) + + wrapper.reportFunc = func(res *rootcoordpb.ImportResult) error { + return errors.New("mock error") + } + err = wrapper.Import(files, true, false) + assert.NotNil(t, err) + assert.Equal(t, 5, rowCount) + assert.Equal(t, commonpb.ImportState_ImportPersisted, importResult.State) +} + +func Test_ReportImportFailColumnBased_json(t *testing.T) { + f := dependency.NewDefaultFactory(true) + ctx := context.Background() + cm, err := f.NewVectorStorageChunkManager(ctx) + assert.NoError(t, err) + defer cm.RemoveWithPrefix("") + + idAllocator := newIDAllocator(ctx, t) + + content := []byte(`{ + "field_bool": [true, false, true, true, true], + "field_int8": [10, 11, 12, 13, 14], + "field_int16": [100, 101, 102, 103, 104], + "field_int32": [1000, 1001, 1002, 1003, 1004], + "field_int64": [10000, 10001, 10002, 10003, 10004], + "field_float": [3.14, 3.15, 3.16, 3.17, 3.18], + "field_double": [5.1, 5.2, 5.3, 5.4, 5.5], + "field_string": ["a", "b", "c", "d", "e"], + "field_binary_vector": [ + [254, 1], + [253, 2], + [252, 3], + [251, 4], + [250, 5] + ], + "field_float_vector": [ + [1.1, 1.2, 1.3, 1.4], + [2.1, 2.2, 2.3, 2.4], + [3.1, 3.2, 3.3, 3.4], + [4.1, 4.2, 4.3, 4.4], + [5.1, 5.2, 5.3, 5.4] + ] + }`) + + filePath := TempFilesPath + "columns_1.json" + err = cm.Write(filePath, content) + assert.NoError(t, err) + + rowCount := 0 + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error { + count := 0 + for _, data := range fields { + assert.Less(t, 0, data.RowNum()) + if count == 0 { + count = data.RowNum() + } else { + assert.Equal(t, count, data.RowNum()) + } + } + rowCount += count + return nil + } + + // success case + importResult := &rootcoordpb.ImportResult{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + TaskId: 1, + DatanodeId: 1, + State: commonpb.ImportState_ImportStarted, + Segments: make([]int64, 0), + AutoIds: make([]int64, 0), + RowCount: 0, + } + reportFunc := func(res *rootcoordpb.ImportResult) error { + return nil + } + wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc, importResult, reportFunc) + files := make([]string, 0) + files = append(files, filePath) + + wrapper.reportFunc = func(res *rootcoordpb.ImportResult) error { + return errors.New("mock error") + } + err = wrapper.Import(files, false, false) + assert.NotNil(t, err) + assert.Equal(t, 5, rowCount) + assert.Equal(t, commonpb.ImportState_ImportPersisted, importResult.State) +} + +func Test_ReportImportFailColumnBased_numpy(t *testing.T) { + f := dependency.NewDefaultFactory(true) + ctx := context.Background() + cm, err := f.NewVectorStorageChunkManager(ctx) + assert.NoError(t, err) + defer cm.RemoveWithPrefix("") + + idAllocator := newIDAllocator(ctx, t) + + content := []byte(`{ + "field_bool": [true, false, true, true, true], + "field_int8": [10, 11, 12, 13, 14], + "field_int16": [100, 101, 102, 103, 104], + "field_int32": [1000, 1001, 1002, 1003, 1004], + "field_int64": [10000, 10001, 10002, 10003, 10004], + "field_float": [3.14, 3.15, 3.16, 3.17, 3.18], + "field_double": [5.1, 5.2, 5.3, 5.4, 5.5], + "field_string": ["a", "b", "c", "d", "e"] + }`) + + files := make([]string, 0) + + filePath := TempFilesPath + "scalar_fields.json" + err = cm.Write(filePath, content) + assert.NoError(t, err) + files = append(files, filePath) + + filePath = TempFilesPath + "field_binary_vector.npy" + bin := [][2]uint8{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}} + content, err = CreateNumpyData(bin) + assert.Nil(t, err) + log.Debug("content", zap.Any("c", content)) + err = cm.Write(filePath, content) + assert.NoError(t, err) + files = append(files, filePath) + + filePath = TempFilesPath + "field_float_vector.npy" + flo := [][4]float32{{1, 2, 3, 4}, {3, 4, 5, 6}, {5, 6, 7, 8}, {7, 8, 9, 10}, {9, 10, 11, 12}} + content, err = CreateNumpyData(flo) + assert.Nil(t, err) + log.Debug("content", zap.Any("c", content)) + err = cm.Write(filePath, content) + assert.NoError(t, err) + files = append(files, filePath) + + rowCount := 0 + flushFunc := func(fields map[storage.FieldID]storage.FieldData, shardNum int) error { + count := 0 + for _, data := range fields { + assert.Less(t, 0, data.RowNum()) + if count == 0 { + count = data.RowNum() + } else { + assert.Equal(t, count, data.RowNum()) + } + } + rowCount += count + return nil + } + + // success case + importResult := &rootcoordpb.ImportResult{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + TaskId: 1, + DatanodeId: 1, + State: commonpb.ImportState_ImportStarted, + Segments: make([]int64, 0), + AutoIds: make([]int64, 0), + RowCount: 0, + } + reportFunc := func(res *rootcoordpb.ImportResult) error { + return nil + } + schema := sampleSchema() + schema.Fields[4].AutoID = true + wrapper := NewImportWrapper(ctx, schema, 2, 1, idAllocator, cm, flushFunc, importResult, reportFunc) + + wrapper.reportFunc = func(res *rootcoordpb.ImportResult) error { + return errors.New("mock error") + } + err = wrapper.Import(files, false, false) + assert.NotNil(t, err) + assert.Equal(t, 5, rowCount) + assert.Equal(t, commonpb.ImportState_ImportPersisted, importResult.State) +} diff --git a/internal/util/importutil/json_handler.go b/internal/util/importutil/json_handler.go index 3036ed18250d8..93c94fa5314b9 100644 --- a/internal/util/importutil/json_handler.go +++ b/internal/util/importutil/json_handler.go @@ -268,8 +268,10 @@ func NewJSONRowValidator(collectionSchema *schemapb.CollectionSchema, downstream downstream: downstream, rowCounter: 0, } - initValidators(collectionSchema, v.validators) - + err := initValidators(collectionSchema, v.validators) + if err != nil { + log.Error("fail to initialize validator", zap.Error(err)) + } return v } @@ -332,8 +334,10 @@ func NewJSONColumnValidator(schema *schemapb.CollectionSchema, downstream JSONCo downstream: downstream, rowCounter: make(map[string]int64), } - initValidators(schema, v.validators) - + err := initValidators(schema, v.validators) + if err != nil { + log.Error("fail to initialize validator", zap.Error(err)) + } return v } @@ -390,7 +394,7 @@ func (v *JSONColumnValidator) Handle(columns map[storage.FieldID][]interface{}) return nil } -type ImportFlushFunc func(fields map[storage.FieldID]storage.FieldData, shardNum int) error +type ImportFlushFunc func(fields map[storage.FieldID]storage.FieldData, shardID int) error // row-based json format consumer class type JSONRowConsumer struct { @@ -501,7 +505,10 @@ func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *al callFlushFunc: flushFunc, } - initValidators(collectionSchema, v.validators) + err := initValidators(collectionSchema, v.validators) + if err != nil { + log.Error("fail to initialize validator", zap.Error(err)) + } v.segmentsData = make([]map[storage.FieldID]storage.FieldData, 0, shardNum) for i := 0; i < int(shardNum); i++ { @@ -682,7 +689,10 @@ func NewJSONColumnConsumer(collectionSchema *schemapb.CollectionSchema, flushFun validators: make(map[storage.FieldID]*Validator), callFlushFunc: flushFunc, } - initValidators(collectionSchema, v.validators) + err := initValidators(collectionSchema, v.validators) + if err != nil { + log.Error("fail to initialize validator", zap.Error(err)) + } v.fieldsData = initSegmentData(collectionSchema) for i := 0; i < len(collectionSchema.Fields); i++ { diff --git a/internal/util/importutil/numpy_adapter_test.go b/internal/util/importutil/numpy_adapter_test.go index f91eb04385f31..4497260ab79a4 100644 --- a/internal/util/importutil/numpy_adapter_test.go +++ b/internal/util/importutil/numpy_adapter_test.go @@ -182,7 +182,8 @@ func Test_Read(t *testing.T) { { filePath := TempFilesPath + "bool.npy" data := []bool{true, false, true, false} - CreateNumpyFile(filePath, data) + err := CreateNumpyFile(filePath, data) + assert.Nil(t, err) file, err := os.Open(filePath) assert.Nil(t, err) @@ -241,7 +242,8 @@ func Test_Read(t *testing.T) { { filePath := TempFilesPath + "uint8.npy" data := []uint8{1, 2, 3, 4, 5, 6} - CreateNumpyFile(filePath, data) + err := CreateNumpyFile(filePath, data) + assert.Nil(t, err) file, err := os.Open(filePath) assert.Nil(t, err) @@ -276,7 +278,8 @@ func Test_Read(t *testing.T) { { filePath := TempFilesPath + "int8.npy" data := []int8{1, 2, 3, 4, 5, 6} - CreateNumpyFile(filePath, data) + err := CreateNumpyFile(filePath, data) + assert.Nil(t, err) file, err := os.Open(filePath) assert.Nil(t, err) @@ -306,7 +309,8 @@ func Test_Read(t *testing.T) { { filePath := TempFilesPath + "int16.npy" data := []int16{1, 2, 3, 4, 5, 6} - CreateNumpyFile(filePath, data) + err := CreateNumpyFile(filePath, data) + assert.Nil(t, err) file, err := os.Open(filePath) assert.Nil(t, err) @@ -336,7 +340,8 @@ func Test_Read(t *testing.T) { { filePath := TempFilesPath + "int32.npy" data := []int32{1, 2, 3, 4, 5, 6} - CreateNumpyFile(filePath, data) + err := CreateNumpyFile(filePath, data) + assert.Nil(t, err) file, err := os.Open(filePath) assert.Nil(t, err) @@ -366,7 +371,8 @@ func Test_Read(t *testing.T) { { filePath := TempFilesPath + "int64.npy" data := []int64{1, 2, 3, 4, 5, 6} - CreateNumpyFile(filePath, data) + err := CreateNumpyFile(filePath, data) + assert.Nil(t, err) file, err := os.Open(filePath) assert.Nil(t, err) @@ -396,7 +402,8 @@ func Test_Read(t *testing.T) { { filePath := TempFilesPath + "float.npy" data := []float32{1, 2, 3, 4, 5, 6} - CreateNumpyFile(filePath, data) + err := CreateNumpyFile(filePath, data) + assert.Nil(t, err) file, err := os.Open(filePath) assert.Nil(t, err) @@ -426,7 +433,8 @@ func Test_Read(t *testing.T) { { filePath := TempFilesPath + "double.npy" data := []float64{1, 2, 3, 4, 5, 6} - CreateNumpyFile(filePath, data) + err := CreateNumpyFile(filePath, data) + assert.Nil(t, err) file, err := os.Open(filePath) assert.Nil(t, err) diff --git a/internal/util/importutil/numpy_parser_test.go b/internal/util/importutil/numpy_parser_test.go index 62f0ec8405a74..87fbfdda691e8 100644 --- a/internal/util/importutil/numpy_parser_test.go +++ b/internal/util/importutil/numpy_parser_test.go @@ -83,7 +83,8 @@ func Test_Validate(t *testing.T) { func() { filePath := TempFilesPath + "scalar_1.npy" data1 := []float64{0, 1, 2, 3, 4, 5} - CreateNumpyFile(filePath, data1) + err := CreateNumpyFile(filePath, data1) + assert.Nil(t, err) file1, err := os.Open(filePath) assert.Nil(t, err) @@ -102,7 +103,8 @@ func Test_Validate(t *testing.T) { // data type mismatch filePath = TempFilesPath + "scalar_2.npy" data2 := []int64{0, 1, 2, 3, 4, 5} - CreateNumpyFile(filePath, data2) + err = CreateNumpyFile(filePath, data2) + assert.Nil(t, err) file2, err := os.Open(filePath) assert.Nil(t, err) @@ -117,7 +119,8 @@ func Test_Validate(t *testing.T) { // shape mismatch filePath = TempFilesPath + "scalar_2.npy" data3 := [][2]float64{{1, 1}} - CreateNumpyFile(filePath, data3) + err = CreateNumpyFile(filePath, data3) + assert.Nil(t, err) file3, err := os.Open(filePath) assert.Nil(t, err) @@ -134,7 +137,8 @@ func Test_Validate(t *testing.T) { func() { filePath := TempFilesPath + "binary_vector_1.npy" data1 := [][2]uint8{{0, 1}, {2, 3}, {4, 5}} - CreateNumpyFile(filePath, data1) + err := CreateNumpyFile(filePath, data1) + assert.Nil(t, err) file1, err := os.Open(filePath) assert.Nil(t, err) @@ -150,7 +154,8 @@ func Test_Validate(t *testing.T) { // data type mismatch filePath = TempFilesPath + "binary_vector_2.npy" data2 := [][2]uint16{{0, 1}, {2, 3}, {4, 5}} - CreateNumpyFile(filePath, data2) + err = CreateNumpyFile(filePath, data2) + assert.Nil(t, err) file2, err := os.Open(filePath) assert.Nil(t, err) @@ -165,7 +170,8 @@ func Test_Validate(t *testing.T) { // shape mismatch filePath = TempFilesPath + "binary_vector_3.npy" data3 := []uint8{1, 2, 3} - CreateNumpyFile(filePath, data3) + err = CreateNumpyFile(filePath, data3) + assert.Nil(t, err) file3, err := os.Open(filePath) assert.Nil(t, err) @@ -180,7 +186,8 @@ func Test_Validate(t *testing.T) { // shape[1] mismatch filePath = TempFilesPath + "binary_vector_4.npy" data4 := [][3]uint8{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}} - CreateNumpyFile(filePath, data4) + err = CreateNumpyFile(filePath, data4) + assert.Nil(t, err) file4, err := os.Open(filePath) assert.Nil(t, err) @@ -211,7 +218,8 @@ func Test_Validate(t *testing.T) { func() { filePath := TempFilesPath + "float_vector.npy" data1 := [][4]float32{{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}, {3, 3, 3, 3}} - CreateNumpyFile(filePath, data1) + err := CreateNumpyFile(filePath, data1) + assert.Nil(t, err) file1, err := os.Open(filePath) assert.Nil(t, err) @@ -227,7 +235,8 @@ func Test_Validate(t *testing.T) { // data type mismatch filePath = TempFilesPath + "float_vector_2.npy" data2 := [][4]int32{{0, 1, 2, 3}} - CreateNumpyFile(filePath, data2) + err = CreateNumpyFile(filePath, data2) + assert.Nil(t, err) file2, err := os.Open(filePath) assert.Nil(t, err) @@ -242,7 +251,8 @@ func Test_Validate(t *testing.T) { // shape mismatch filePath = TempFilesPath + "float_vector_3.npy" data3 := []float32{1, 2, 3} - CreateNumpyFile(filePath, data3) + err = CreateNumpyFile(filePath, data3) + assert.Nil(t, err) file3, err := os.Open(filePath) assert.Nil(t, err) @@ -257,7 +267,8 @@ func Test_Validate(t *testing.T) { // shape[1] mismatch filePath = TempFilesPath + "float_vector_4.npy" data4 := [][3]float32{{0, 0, 0}, {1, 1, 1}} - CreateNumpyFile(filePath, data4) + err = CreateNumpyFile(filePath, data4) + assert.Nil(t, err) file4, err := os.Open(filePath) assert.Nil(t, err) @@ -296,7 +307,8 @@ func Test_Parse(t *testing.T) { checkFunc := func(data interface{}, fieldName string, callback func(field storage.FieldData) error) { filePath := TempFilesPath + fieldName + ".npy" - CreateNumpyFile(filePath, data) + err := CreateNumpyFile(filePath, data) + assert.Nil(t, err) func() { file, err := os.Open(filePath) @@ -510,7 +522,8 @@ func Test_Parse_perf(t *testing.T) { } filePath := TempFilesPath + "perf.npy" - CreateNumpyFile(filePath, data) + err = CreateNumpyFile(filePath, data) + assert.Nil(t, err) tr.Record("generate large numpy file " + filePath) diff --git a/internal/util/mock/grpc_datacoord_client.go b/internal/util/mock/grpc_datacoord_client.go index fd8967c2d166a..d5cd5bc8684b4 100644 --- a/internal/util/mock/grpc_datacoord_client.go +++ b/internal/util/mock/grpc_datacoord_client.go @@ -145,6 +145,14 @@ func (m *GrpcDataCoordClient) ReleaseSegmentLock(ctx context.Context, req *datap return &commonpb.Status{}, m.Err } -func (m *GrpcDataCoordClient) AddSegment(ctx context.Context, in *datapb.AddSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { +func (m *GrpcDataCoordClient) SaveImportSegment(ctx context.Context, in *datapb.SaveImportSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *GrpcDataCoordClient) CompleteBulkLoad(context.Context, *datapb.CompleteBulkLoadRequest, ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *GrpcDataCoordClient) UnsetIsImportingState(context.Context, *datapb.UnsetIsImportingStateRequest, ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, m.Err } diff --git a/internal/util/mock/grpc_datanode_client.go b/internal/util/mock/grpc_datanode_client.go index ece1f4017b73d..dee60fa9640c9 100644 --- a/internal/util/mock/grpc_datanode_client.go +++ b/internal/util/mock/grpc_datanode_client.go @@ -73,6 +73,6 @@ func (m *GrpcDataNodeClient) ResendSegmentStats(ctx context.Context, req *datapb return &datapb.ResendSegmentStatsResponse{}, m.Err } -func (m *GrpcDataNodeClient) AddSegment(ctx context.Context, in *datapb.AddSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { +func (m *GrpcDataNodeClient) AddImportSegment(ctx context.Context, in *datapb.AddImportSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, m.Err } diff --git a/internal/util/mock/grpc_rootcoord_client.go b/internal/util/mock/grpc_rootcoord_client.go index 05aab2f16844c..9db75e530a0a0 100644 --- a/internal/util/mock/grpc_rootcoord_client.go +++ b/internal/util/mock/grpc_rootcoord_client.go @@ -35,6 +35,14 @@ type GrpcRootCoordClient struct { Err error } +func (m *GrpcRootCoordClient) GetImportFailedSegmentIDs(ctx context.Context, in *internalpb.GetImportFailedSegmentIDsRequest, opts ...grpc.CallOption) (*internalpb.GetImportFailedSegmentIDsResponse, error) { + return &internalpb.GetImportFailedSegmentIDsResponse{}, m.Err +} + +func (m *GrpcRootCoordClient) CheckSegmentIndexReady(ctx context.Context, in *internalpb.CheckSegmentIndexReadyRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + func (m *GrpcRootCoordClient) CreateRole(ctx context.Context, in *milvuspb.CreateRoleRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, m.Err } diff --git a/internal/util/paramtable/component_param.go b/internal/util/paramtable/component_param.go index 6d5abe02e32ca..9609b81fd6c8d 100644 --- a/internal/util/paramtable/component_param.go +++ b/internal/util/paramtable/component_param.go @@ -390,15 +390,13 @@ type rootCoordConfig struct { Address string Port int - DmlChannelNum int64 - MaxPartitionNum int64 - MinSegmentSizeToEnableIndex int64 - ImportTaskExpiration float64 - ImportTaskRetention float64 - ImportSegmentStateCheckInterval float64 - ImportSegmentStateWaitLimit float64 - ImportIndexCheckInterval float64 - ImportIndexWaitLimit float64 + DmlChannelNum int64 + MaxPartitionNum int64 + MinSegmentSizeToEnableIndex int64 + ImportTaskExpiration float64 + ImportTaskRetention float64 + ImportIndexCheckInterval float64 + ImportIndexWaitLimit float64 // --- ETCD Path --- ImportTaskSubPath string @@ -414,11 +412,9 @@ func (p *rootCoordConfig) init(base *BaseTable) { p.MinSegmentSizeToEnableIndex = p.Base.ParseInt64WithDefault("rootCoord.minSegmentSizeToEnableIndex", 1024) p.ImportTaskExpiration = p.Base.ParseFloatWithDefault("rootCoord.importTaskExpiration", 15*60) p.ImportTaskRetention = p.Base.ParseFloatWithDefault("rootCoord.importTaskRetention", 24*60*60) - p.ImportSegmentStateCheckInterval = p.Base.ParseFloatWithDefault("rootCoord.importSegmentStateCheckInterval", 10) - p.ImportSegmentStateWaitLimit = p.Base.ParseFloatWithDefault("rootCoord.importSegmentStateWaitLimit", 60) + p.ImportTaskSubPath = "importtask" p.ImportIndexCheckInterval = p.Base.ParseFloatWithDefault("rootCoord.importIndexCheckInterval", 10) p.ImportIndexWaitLimit = p.Base.ParseFloatWithDefault("rootCoord.importIndexWaitLimit", 10*60) - p.ImportTaskSubPath = "importtask" } /////////////////////////////////////////////////////////////////////////////// diff --git a/internal/util/paramtable/component_param_test.go b/internal/util/paramtable/component_param_test.go index 213db544de92a..e369bfc50c034 100644 --- a/internal/util/paramtable/component_param_test.go +++ b/internal/util/paramtable/component_param_test.go @@ -124,8 +124,6 @@ func TestComponentParam(t *testing.T) { assert.NotEqual(t, Params.MinSegmentSizeToEnableIndex, 0) t.Logf("master MinSegmentSizeToEnableIndex = %d", Params.MinSegmentSizeToEnableIndex) assert.NotEqual(t, Params.ImportTaskExpiration, 0) - t.Logf("master ImportTaskExpiration = %f", Params.ImportTaskExpiration) - assert.NotEqual(t, Params.ImportTaskRetention, 0) t.Logf("master ImportTaskRetention = %f", Params.ImportTaskRetention) assert.NotEqual(t, Params.ImportIndexCheckInterval, 0) t.Logf("master ImportIndexCheckInterval = %f", Params.ImportIndexCheckInterval)