diff --git a/internal/datanode/compaction/clustering_compactor.go b/internal/datanode/compaction/clustering_compactor.go index 830800809b1ad..7c04ab4b01169 100644 --- a/internal/datanode/compaction/clustering_compactor.go +++ b/internal/datanode/compaction/clustering_compactor.go @@ -526,17 +526,6 @@ func (t *clusteringCompactionTask) mappingSegment( remained int64 = 0 ) - isDeletedValue := func(v *storage.Value) bool { - ts, ok := delta[v.PK.GetValue()] - // insert task and delete task has the same ts when upsert - // here should be < instead of <= - // to avoid the upsert data to be deleted after compact - if ok && uint64(v.Timestamp) < ts { - return true - } - return false - } - mappingStats := &clusteringpb.ClusteringCentroidIdMappingStats{} if t.isVectorClusteringKey { offSetPath := t.segmentIDOffsetMapping[segment.SegmentID] @@ -603,7 +592,7 @@ func (t *clusteringCompactionTask) mappingSegment( offset++ // Filtering deleted entity - if isDeletedValue(v) { + if isDeletedEntity(v, delta) { deleted++ continue } diff --git a/internal/datanode/compaction/compactor_common.go b/internal/datanode/compaction/compactor_common.go index 23e28357c7d8e..f3de55d27b666 100644 --- a/internal/datanode/compaction/compactor_common.go +++ b/internal/datanode/compaction/compactor_common.go @@ -47,6 +47,17 @@ func isExpiredEntity(ttl int64, now, ts typeutil.Timestamp) bool { return expireTime.Before(pnow) } +func isDeletedEntity(v *storage.Value, delta map[interface{}]typeutil.Timestamp) bool { + ts, ok := delta[v.PK.GetValue()] + // insert task and delete task has the same ts when upsert + // here should be < instead of <= + // to avoid the upsert data to be deleted after compact + if ok && uint64(v.Timestamp) < ts { + return true + } + return false +} + func mergeDeltalogs(ctx context.Context, io io.BinlogIO, dpaths map[typeutil.UniqueID][]string) (map[interface{}]typeutil.Timestamp, error) { pk2ts := make(map[interface{}]typeutil.Timestamp) diff --git a/internal/datanode/compaction/merge_sort.go b/internal/datanode/compaction/merge_sort.go index c0c3da817ec4e..b34ac9ae83db7 100644 --- a/internal/datanode/compaction/merge_sort.go +++ b/internal/datanode/compaction/merge_sort.go @@ -40,7 +40,7 @@ func mergeSortMultipleSegments(ctx context.Context, segIDAlloc := allocator.NewLocalAllocator(plan.GetPreAllocatedSegmentIDs().GetBegin(), plan.GetPreAllocatedSegmentIDs().GetEnd()) logIDAlloc := allocator.NewLocalAllocator(plan.GetBeginLogID(), math.MaxInt64) compAlloc := NewCompactionAllocator(segIDAlloc, logIDAlloc) - mWriter := NewMultiSegmentWriter(binlogIO, compAlloc, plan, maxRows, partitionID, collectionID, bm25FieldIds) + mWriter := NewMultiSegmentWriter(binlogIO, compAlloc, plan.GetSchema(), plan.GetChannel(), plan.GetMaxSize(), maxRows, partitionID, collectionID, bm25FieldIds, false) var ( expiredRowCount int64 // the number of expired entities diff --git a/internal/datanode/compaction/mix_compactor.go b/internal/datanode/compaction/mix_compactor.go index 3949ee27a2820..56ca147cfc160 100644 --- a/internal/datanode/compaction/mix_compactor.go +++ b/internal/datanode/compaction/mix_compactor.go @@ -143,7 +143,7 @@ func (t *mixCompactionTask) mergeSplit( segIDAlloc := allocator.NewLocalAllocator(t.plan.GetPreAllocatedSegmentIDs().GetBegin(), t.plan.GetPreAllocatedSegmentIDs().GetEnd()) logIDAlloc := allocator.NewLocalAllocator(t.plan.GetBeginLogID(), math.MaxInt64) compAlloc := NewCompactionAllocator(segIDAlloc, logIDAlloc) - mWriter := NewMultiSegmentWriter(t.binlogIO, compAlloc, t.plan, t.maxRows, t.partitionID, t.collectionID, t.bm25FieldIDs) + mWriter := NewMultiSegmentWriter(t.binlogIO, compAlloc, t.plan.GetSchema(), t.plan.GetChannel(), t.plan.GetMaxSize(), t.maxRows, t.partitionID, t.collectionID, t.bm25FieldIDs, false) deletedRowCount := int64(0) expiredRowCount := int64(0) diff --git a/internal/datanode/compaction/segment_writer.go b/internal/datanode/compaction/segment_writer.go index 6ae6a6aa6a75d..805e6689191ba 100644 --- a/internal/datanode/compaction/segment_writer.go +++ b/internal/datanode/compaction/segment_writer.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "math" + "sync" "github.com/samber/lo" "go.uber.org/atomic" @@ -25,7 +26,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) -// Not concurrent safe. +// concurrent safe type MultiSegmentWriter struct { binlogIO io.BinlogIO allocator *compactionAlloactor @@ -48,6 +49,12 @@ type MultiSegmentWriter struct { // segID -> fieldID -> binlogs res []*datapb.CompactionSegment + + supportConcurrent bool + writeLock sync.Mutex + + rowCount *atomic.Int64 + // DONOT leave it empty of all segments are deleted, just return a segment with zero meta for datacoord bm25Fields []int64 } @@ -72,7 +79,8 @@ func (alloc *compactionAlloactor) getLogIDAllocator() allocator.Interface { return alloc.logIDAlloc } -func NewMultiSegmentWriter(binlogIO io.BinlogIO, allocator *compactionAlloactor, plan *datapb.CompactionPlan, maxRows int64, partitionID, collectionID int64, bm25Fields []int64) *MultiSegmentWriter { +func NewMultiSegmentWriter(binlogIO io.BinlogIO, allocator *compactionAlloactor, schema *schemapb.CollectionSchema, channel string, segmentSize int64, + maxRows int64, partitionID, collectionID int64, bm25Fields []int64, supportConcurrent bool) *MultiSegmentWriter { return &MultiSegmentWriter{ binlogIO: binlogIO, allocator: allocator, @@ -81,16 +89,18 @@ func NewMultiSegmentWriter(binlogIO io.BinlogIO, allocator *compactionAlloactor, current: -1, maxRows: maxRows, // For bloomfilter only - segmentSize: plan.GetMaxSize(), + segmentSize: segmentSize, - schema: plan.GetSchema(), + schema: schema, partitionID: partitionID, collectionID: collectionID, - channel: plan.GetChannel(), + channel: channel, - cachedMeta: make(map[typeutil.UniqueID]map[typeutil.UniqueID]*datapb.FieldBinlog), - res: make([]*datapb.CompactionSegment, 0), - bm25Fields: bm25Fields, + cachedMeta: make(map[typeutil.UniqueID]map[typeutil.UniqueID]*datapb.FieldBinlog), + res: make([]*datapb.CompactionSegment, 0), + bm25Fields: bm25Fields, + supportConcurrent: supportConcurrent, + rowCount: atomic.NewInt64(0), } } @@ -184,6 +194,10 @@ func (w *MultiSegmentWriter) getWriter() (*SegmentWriter, error) { } func (w *MultiSegmentWriter) Write(v *storage.Value) error { + if w.supportConcurrent { + w.writeLock.Lock() + defer w.writeLock.Unlock() + } writer, err := w.getWriter() if err != nil { return err @@ -194,20 +208,32 @@ func (w *MultiSegmentWriter) Write(v *storage.Value) error { if _, ok := w.cachedMeta[writer.segmentID]; !ok { w.cachedMeta[writer.segmentID] = make(map[typeutil.UniqueID]*datapb.FieldBinlog) } - - kvs, partialBinlogs, err := serializeWrite(context.TODO(), w.allocator.getLogIDAllocator(), writer) + err = w.flushBinlog(writer) if err != nil { return err } + } - if err := w.binlogIO.Upload(context.TODO(), kvs); err != nil { - return err - } + err = writer.Write(v) + if err != nil { + return err + } + w.rowCount.Inc() + return nil +} - mergeFieldBinlogs(w.cachedMeta[writer.segmentID], partialBinlogs) +func (w *MultiSegmentWriter) flushBinlog(writer *SegmentWriter) error { + kvs, partialBinlogs, err := serializeWrite(context.TODO(), w.allocator.getLogIDAllocator(), writer) + if err != nil { + return err } - return writer.Write(v) + if err := w.binlogIO.Upload(context.TODO(), kvs); err != nil { + return err + } + + mergeFieldBinlogs(w.cachedMeta[writer.segmentID], partialBinlogs) + return nil } func (w *MultiSegmentWriter) appendEmptySegment() error { @@ -215,7 +241,6 @@ func (w *MultiSegmentWriter) appendEmptySegment() error { if err != nil { return err } - w.res = append(w.res, &datapb.CompactionSegment{ SegmentID: writer.GetSegmentID(), NumOfRows: 0, @@ -243,6 +268,40 @@ func (w *MultiSegmentWriter) Finish() ([]*datapb.CompactionSegment, error) { return w.res, nil } +func (w *MultiSegmentWriter) Flush() error { + if w.current == -1 { + return nil + } + if w.supportConcurrent { + w.writeLock.Lock() + defer w.writeLock.Unlock() + } + writer, err := w.getWriter() + if err != nil { + return err + } + // init segment fieldBinlogs if it is not exist + if _, ok := w.cachedMeta[writer.segmentID]; !ok { + w.cachedMeta[writer.segmentID] = make(map[typeutil.UniqueID]*datapb.FieldBinlog) + } + return w.flushBinlog(writer) +} + +func (w *MultiSegmentWriter) GetRowNum() int64 { + return w.rowCount.Load() +} + +func (w *MultiSegmentWriter) WrittenMemorySize() uint64 { + if w.current == -1 { + return 0 + } + writer, err := w.getWriter() + if err != nil { + return 0 + } + return writer.WrittenMemorySize() +} + func NewSegmentDeltaWriter(segmentID, partitionID, collectionID int64) *SegmentDeltaWriter { return &SegmentDeltaWriter{ deleteData: &storage.DeleteData{}, diff --git a/internal/datanode/compaction/segment_writer_test.go b/internal/datanode/compaction/segment_writer_test.go index c93062ab214b1..d0cc43711f396 100644 --- a/internal/datanode/compaction/segment_writer_test.go +++ b/internal/datanode/compaction/segment_writer_test.go @@ -17,16 +17,174 @@ package compaction import ( + "math" "testing" + "time" + "github.com/samber/lo" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/flushcommon/io" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/util/conc" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/tsoutil" ) func TestSegmentWriterSuite(t *testing.T) { - suite.Run(t, new(SegmentWriteSuite)) + suite.Run(t, new(SegmentWriterSuite)) +} + +type SegmentWriterSuite struct { + suite.Suite + mockBinlogIO *io.MockBinlogIO + allocator allocator.Interface + meta *etcdpb.CollectionMeta +} + +func (s *SegmentWriterSuite) SetupSuite() { + paramtable.Get().Init(paramtable.NewBaseTable()) +} + +func (s *SegmentWriterSuite) SetupTest() { + s.mockBinlogIO = io.NewMockBinlogIO(s.T()) + s.meta = genTestCollectionMeta() + s.allocator = allocator.NewLocalAllocator(time.Now().UnixMilli(), math.MaxInt64) + + paramtable.Get().Save(paramtable.Get().CommonCfg.EntityExpirationTTL.Key, "0") +} + +func (s *SegmentWriterSuite) SetupSubTest() { + s.SetupTest() +} + +func (s *SegmentWriterSuite) TearDownTest() { + paramtable.Get().Reset(paramtable.Get().CommonCfg.EntityExpirationTTL.Key) +} + +func (s *SegmentWriterSuite) TestFlushBinlog() { + s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(nil) + + allocator := NewCompactionAllocator(s.allocator, s.allocator) + writer := NewMultiSegmentWriter(s.mockBinlogIO, allocator, s.meta.GetSchema(), "ch-1", 1000000, 1000, 2, 1, true) + + for i := int64(0); i < 500; i++ { + err := writer.Write(generateInt64PKEntitiy(i)) + s.NoError(err) + } + writer.Flush() + + for i := int64(500); i < 1000; i++ { + err := writer.Write(generateInt64PKEntitiy(i)) + s.NoError(err) + } + + compactionSegments, err := writer.Finish() + s.NoError(err) + s.Equal(2, len(compactionSegments[0].InsertLogs[0].Binlogs)) + s.Equal(int64(1000), compactionSegments[0].GetNumOfRows()) +} + +func (s *SegmentWriterSuite) TestGetRowNum() { + s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(nil) + + allocator := NewCompactionAllocator(s.allocator, s.allocator) + writer := NewMultiSegmentWriter(s.mockBinlogIO, allocator, s.meta.GetSchema(), "ch-1", 200000, 1000, 2, 1, true) + + for i := int64(0); i < 500; i++ { + err := writer.Write(generateInt64PKEntitiy(i)) + s.NoError(err) + } + s.Equal(int64(500), writer.GetRowNum()) + s.Equal(int64(65000), int64(writer.WrittenMemorySize())) + + writer.Flush() + s.Equal(int64(500), writer.GetRowNum()) + s.Equal(int64(0), int64(writer.WrittenMemorySize())) + + for i := int64(500); i < 1000; i++ { + err := writer.Write(generateInt64PKEntitiy(i)) + s.NoError(err) + } + s.Equal(int64(1000), writer.GetRowNum()) + s.Equal(int64(65000), int64(writer.WrittenMemorySize())) + + compactionSegments, err := writer.Finish() + s.NoError(err) + + s.Equal(int64(1000), writer.GetRowNum()) + s.Equal(int64(0), int64(writer.WrittenMemorySize())) + + s.Equal(2, len(compactionSegments[0].InsertLogs[0].Binlogs)) + s.Equal(int64(1000), compactionSegments[0].GetNumOfRows()) +} + +func (s *SegmentWriterSuite) TestWriteMultiSegments() { + s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(nil) + + allocator := NewCompactionAllocator(s.allocator, s.allocator) + writer := NewMultiSegmentWriter(s.mockBinlogIO, allocator, s.meta.GetSchema(), "ch-1", 100000, 500, 2, 1, nil, true) + + for i := int64(0); i < 1000; i++ { + err := writer.Write(generateInt64PKEntitiy(i)) + s.NoError(err) + } + + compactionSegments, err := writer.Finish() + s.NoError(err) + s.Equal(2, len(compactionSegments)) +} + +func (s *SegmentWriterSuite) TestConcurrentWriteMultiSegments() { + s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(nil) + + allocator := NewCompactionAllocator(s.allocator, s.allocator) + writer := NewMultiSegmentWriter(s.mockBinlogIO, allocator, s.meta.GetSchema(), "ch-1", 100000, 500, 2, 1, nil, true) + + pool := conc.NewPool[any](10) + futures := make([]*conc.Future[any], 0) + for i := 0; i < 10; i++ { + j := int64(i * 1000) + future := pool.Submit(func() (any, error) { + for i := j + int64(0); i < j+1000; i++ { + err := writer.Write(generateInt64PKEntitiy(i)) + if i == j+100 { + writer.Flush() + } + s.NoError(err) + } + return struct{}{}, nil + }) + futures = append(futures, future) + } + err := conc.AwaitAll(futures...) + s.NoError(err) + + compactionSegments, err := writer.Finish() + s.NoError(err) + + totalRows := lo.Reduce(lo.Map(compactionSegments, func(segment *datapb.CompactionSegment, i int) int64 { + return segment.GetNumOfRows() + }), func(i int64, j int64, x int) int64 { + return i + j + }, 0) + s.Equal(int64(10000), totalRows) + s.Equal(13, len(compactionSegments)) + + s.Equal(1, len(compactionSegments[0].GetField2StatslogPaths())) + s.Equal(1, len(compactionSegments[1].GetField2StatslogPaths())) +} + +func generateInt64PKEntitiy(magic int64) *storage.Value { + return &storage.Value{ + PK: storage.NewInt64PrimaryKey(magic), + Timestamp: int64(tsoutil.ComposeTSByTime(getMilvusBirthday(), 0)), + Value: getRow(magic), + } } type SegmentWriteSuite struct { diff --git a/internal/datanode/compaction/split_cluster_writer.go b/internal/datanode/compaction/split_cluster_writer.go new file mode 100644 index 0000000000000..0f30a5be80ce8 --- /dev/null +++ b/internal/datanode/compaction/split_cluster_writer.go @@ -0,0 +1,277 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compaction + +import ( + "fmt" + "sort" + "sync" + + "github.com/cockroachdb/errors" + "go.uber.org/atomic" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/flushcommon/io" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/conc" +) + +type SplitClusterWriter struct { + // construct field + binlogIO io.BinlogIO + allocator *compactionAlloactor + collectionID int64 + partitionID int64 + channel string + schema *schemapb.CollectionSchema + segmentMaxRowCount int64 + segmentMaxSize int64 + clusterNum int32 + mappingFunc func(*storage.Value) (string, error) + memoryBufferSize int64 + workerPoolSize int + + // inner field + clusterWriters map[string]*MultiSegmentWriter + writtenRowNum *atomic.Int64 + flushPool *conc.Pool[any] + flushMutex sync.Mutex +} + +func (c *SplitClusterWriter) Write(value *storage.Value) error { + clusterKey, err := c.mappingFunc(value) + if err != nil { + return err + } + // c.clusterLocks.Lock(clusterKey) + // defer c.clusterLocks.Unlock(clusterKey) + _, exist := c.clusterWriters[clusterKey] + if !exist { + return errors.New(fmt.Sprintf("cluster key=%s not exist", clusterKey)) + } + err = c.clusterWriters[clusterKey].Write(value) + if err != nil { + return err + } + c.writtenRowNum.Inc() + return nil +} + +func (c *SplitClusterWriter) Finish() (map[string][]*datapb.CompactionSegment, error) { + resultSegments := make(map[string][]*datapb.CompactionSegment, 0) + for id, writer := range c.clusterWriters { + log.Info("Finish", zap.String("id", id), zap.Int("current", writer.current), zap.Int64("current", writer.GetRowNum())) + segments, err := writer.Finish() + if err != nil { + return nil, err + } + //for _, segment := range segments { + // segment.VshardId = id + //} + resultSegments[id] = segments + } + return resultSegments, nil +} + +func (c *SplitClusterWriter) GetRowNum() int64 { + return c.writtenRowNum.Load() +} + +func (c *SplitClusterWriter) FlushLargest() error { + // only one flushLargest or flushAll should do at the same time + getLock := c.flushMutex.TryLock() + if !getLock { + return nil + } + defer c.flushMutex.Unlock() + currentMemorySize := c.getTotalUsedMemorySize() + if currentMemorySize <= c.getMemoryBufferLowWatermark() { + log.Info("memory low water mark", zap.Int64("memoryBufferSize", c.getTotalUsedMemorySize())) + return nil + } + bufferIDs := make([]string, 0) + bufferRowNums := make([]int64, 0) + for id, writer := range c.clusterWriters { + bufferIDs = append(bufferIDs, id) + // c.clusterLocks.RLock(id) + bufferRowNums = append(bufferRowNums, writer.GetRowNum()) + // c.clusterLocks.RUnlock(id) + } + sort.Slice(bufferIDs, func(i, j int) bool { + return bufferRowNums[i] > bufferRowNums[j] + }) + log.Info("start flushLargestBuffers", zap.Strings("bufferIDs", bufferIDs), zap.Int64("currentMemorySize", currentMemorySize)) + + futures := make([]*conc.Future[any], 0) + for _, bufferId := range bufferIDs { + writer := c.clusterWriters[bufferId] + log.Info("currentMemorySize after flush writer binlog", + zap.Int64("currentMemorySize", currentMemorySize), + zap.String("bufferID", bufferId), + zap.Uint64("writtenMemorySize", writer.WrittenMemorySize()), + zap.Int64("RowNum", writer.GetRowNum())) + future := c.flushPool.Submit(func() (any, error) { + err := writer.Flush() + if err != nil { + return nil, err + } + return struct{}{}, nil + }) + futures = append(futures, future) + + if currentMemorySize <= c.getMemoryBufferLowWatermark() { + log.Info("reach memory low water mark", zap.Int64("memoryBufferSize", c.getTotalUsedMemorySize())) + break + } + } + if err := conc.AwaitAll(futures...); err != nil { + return err + } + return nil +} + +func (c *SplitClusterWriter) getTotalUsedMemorySize() int64 { + var totalBufferSize int64 = 0 + for _, writer := range c.clusterWriters { + totalBufferSize = totalBufferSize + int64(writer.WrittenMemorySize()) + } + return totalBufferSize +} + +func (c *SplitClusterWriter) getMemoryBufferLowWatermark() int64 { + return int64(float64(c.memoryBufferSize) * 0.3) +} + +func (c *SplitClusterWriter) getMemoryBufferHighWatermark() int64 { + return int64(float64(c.memoryBufferSize) * 0.7) +} + +// Builder for SplitClusterWriter +type SplitClusterWriterBuilder struct { + binlogIO io.BinlogIO + allocator *compactionAlloactor + collectionID int64 + partitionID int64 + channel string + schema *schemapb.CollectionSchema + segmentMaxRowCount int64 + segmentMaxSize int64 + splitKeys []string + mappingFunc func(*storage.Value) (string, error) + memoryBufferSize int64 + workerPoolSize int +} + +// NewSplitClusterWriterBuilder creates a new builder instance +func NewSplitClusterWriterBuilder() *SplitClusterWriterBuilder { + return &SplitClusterWriterBuilder{} +} + +func (b *SplitClusterWriterBuilder) SetBinlogIO(binlogIO io.BinlogIO) *SplitClusterWriterBuilder { + b.binlogIO = binlogIO + return b +} + +func (b *SplitClusterWriterBuilder) SetAllocator(allocator *compactionAlloactor) *SplitClusterWriterBuilder { + b.allocator = allocator + return b +} + +// SetCollectionID sets the collectionID field +func (b *SplitClusterWriterBuilder) SetCollectionID(collectionID int64) *SplitClusterWriterBuilder { + b.collectionID = collectionID + return b +} + +// SetPartitionID sets the partitionID field +func (b *SplitClusterWriterBuilder) SetPartitionID(partitionID int64) *SplitClusterWriterBuilder { + b.partitionID = partitionID + return b +} + +func (b *SplitClusterWriterBuilder) SetChannel(channel string) *SplitClusterWriterBuilder { + b.channel = channel + return b +} + +// SetSchema sets the schema field +func (b *SplitClusterWriterBuilder) SetSchema(schema *schemapb.CollectionSchema) *SplitClusterWriterBuilder { + b.schema = schema + return b +} + +func (b *SplitClusterWriterBuilder) SetSegmentMaxSize(segmentMaxSize int64) *SplitClusterWriterBuilder { + b.segmentMaxSize = segmentMaxSize + return b +} + +// SetSegmentMaxRowCount sets the segmentMaxRowCount field +func (b *SplitClusterWriterBuilder) SetSegmentMaxRowCount(segmentMaxRowCount int64) *SplitClusterWriterBuilder { + b.segmentMaxRowCount = segmentMaxRowCount + return b +} + +// SetSplitKeys sets the splitKeys field +func (b *SplitClusterWriterBuilder) SetSplitKeys(keys []string) *SplitClusterWriterBuilder { + b.splitKeys = keys + return b +} + +// SetMappingFunc sets the mappingFunc field +func (b *SplitClusterWriterBuilder) SetMappingFunc(mappingFunc func(*storage.Value) (string, error)) *SplitClusterWriterBuilder { + b.mappingFunc = mappingFunc + return b +} + +func (b *SplitClusterWriterBuilder) SetMemoryBufferSize(memoryBufferSize int64) *SplitClusterWriterBuilder { + b.memoryBufferSize = memoryBufferSize + return b +} + +func (b *SplitClusterWriterBuilder) SetWorkerPoolSize(workerPoolSize int) *SplitClusterWriterBuilder { + b.workerPoolSize = workerPoolSize + return b +} + +// Build creates the final SplitClusterWriter instance +func (b *SplitClusterWriterBuilder) Build() (*SplitClusterWriter, error) { + writer := &SplitClusterWriter{ + binlogIO: b.binlogIO, + allocator: b.allocator, + collectionID: b.collectionID, + partitionID: b.partitionID, + schema: b.schema, + channel: b.channel, + segmentMaxSize: b.segmentMaxSize, + segmentMaxRowCount: b.segmentMaxRowCount, + clusterWriters: make(map[string]*MultiSegmentWriter, len(b.splitKeys)), + mappingFunc: b.mappingFunc, + workerPoolSize: b.workerPoolSize, + flushPool: conc.NewPool[any](b.workerPoolSize), + memoryBufferSize: b.memoryBufferSize, + writtenRowNum: atomic.NewInt64(0), + } + bm25FieldIds := GetBM25FieldIDs(b.schema) + for _, key := range b.splitKeys { + writer.clusterWriters[key] = NewMultiSegmentWriter(writer.binlogIO, writer.allocator, writer.schema, writer.channel, writer.segmentMaxSize, writer.segmentMaxRowCount, writer.partitionID, writer.collectionID, bm25FieldIds, true) + } + + return writer, nil +} diff --git a/internal/datanode/compaction/split_cluster_writer_test.go b/internal/datanode/compaction/split_cluster_writer_test.go new file mode 100644 index 0000000000000..ae968f65b0272 --- /dev/null +++ b/internal/datanode/compaction/split_cluster_writer_test.go @@ -0,0 +1,246 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compaction + +import ( + "fmt" + "math" + "testing" + "time" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/flushcommon/io" + "github.com/milvus-io/milvus/internal/proto/etcdpb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/util/conc" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +func TestSplitClusterWriterSuite(t *testing.T) { + suite.Run(t, new(SplitClusterWriterSuite)) +} + +type SplitClusterWriterSuite struct { + suite.Suite + mockBinlogIO *io.MockBinlogIO + allocator allocator.Interface + meta *etcdpb.CollectionMeta +} + +func (s *SplitClusterWriterSuite) SetupSuite() { + paramtable.Get().Init(paramtable.NewBaseTable()) +} + +func (s *SplitClusterWriterSuite) SetupTest() { + s.mockBinlogIO = io.NewMockBinlogIO(s.T()) + s.meta = genTestCollectionMeta() + s.allocator = allocator.NewLocalAllocator(time.Now().UnixMilli(), math.MaxInt64) + + paramtable.Get().Save(paramtable.Get().CommonCfg.EntityExpirationTTL.Key, "0") +} + +func (s *SplitClusterWriterSuite) SetupSubTest() { + s.SetupTest() +} + +func (s *SplitClusterWriterSuite) TearDownTest() { + paramtable.Get().Reset(paramtable.Get().CommonCfg.EntityExpirationTTL.Key) +} + +func (s *SplitClusterWriterSuite) TestSplitByHash() { + s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(nil) + + mappingFunc := func(value *storage.Value) (string, error) { + pkHash, err := value.PK.Hash() + if err != nil { + return "", err + } + return fmt.Sprint(pkHash % uint32(2)), nil + } + + splitWriter, err := NewSplitClusterWriterBuilder(). + SetBinlogIO(s.mockBinlogIO). + SetAllocator(&compactionAlloactor{ + segmentAlloc: s.allocator, + logIDAlloc: s.allocator, + }). + SetCollectionID(1). + SetPartitionID(2). + SetSchema(s.meta.Schema). + SetChannel("ch-1"). + SetSegmentMaxSize(1000000). + SetSegmentMaxRowCount(1000). + SetSplitKeys([]string{"0", "1"}). + SetMemoryBufferSize(math.MaxInt64). + SetMappingFunc(mappingFunc). + SetWorkerPoolSize(1). + Build() + s.NoError(err) + + for i := int64(0); i < 1000; i++ { + err := splitWriter.Write(generateInt64PKEntitiy(i)) + s.NoError(err) + } + + s.Equal(int64(500), splitWriter.clusterWriters["1"].GetRowNum()) + s.Equal(int64(500), splitWriter.clusterWriters["0"].GetRowNum()) + + segments, err := splitWriter.Finish() + s.NoError(err) + s.Equal(2, len(segments)) + s.Equal(1, len(segments["0"])) + s.Equal(1, len(segments["0"][0].GetField2StatslogPaths())) + s.Equal(1, len(segments["1"])) + s.Equal(1, len(segments["1"][0].GetField2StatslogPaths())) +} + +func (s *SplitClusterWriterSuite) TestSplitByRange() { + s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(nil) + + mappingFunc := func(value *storage.Value) (string, error) { + if value.PK.LT(storage.NewInt64PrimaryKey(200)) { + return "[0,200)", nil + } else if value.PK.LT(storage.NewInt64PrimaryKey(400)) { + return "[200,400)", nil + } else if value.PK.LT(storage.NewInt64PrimaryKey(600)) { + return "[400,600)", nil + } else if value.PK.LT(storage.NewInt64PrimaryKey(800)) { + return "[600,800)", nil + } else { + return "[800,1000)", nil + } + } + + splitWriter, err := NewSplitClusterWriterBuilder(). + SetBinlogIO(s.mockBinlogIO). + SetAllocator(&compactionAlloactor{ + segmentAlloc: s.allocator, + logIDAlloc: s.allocator, + }). + SetCollectionID(1). + SetPartitionID(2). + SetSchema(s.meta.Schema). + SetChannel("ch-1"). + SetSegmentMaxSize(1000000). + SetSegmentMaxRowCount(10000). + SetSplitKeys([]string{"[0,200)", "[200,400)", "[400,600)", "[600,800)", "[800,1000)"}). + SetMemoryBufferSize(math.MaxInt64). + SetMappingFunc(mappingFunc). + SetWorkerPoolSize(1). + Build() + s.NoError(err) + + for i := int64(0); i < 1000; i++ { + err := splitWriter.Write(generateInt64PKEntitiy(i)) + s.NoError(err) + } + + s.Equal(5, len(splitWriter.clusterWriters)) + s.Equal(int64(200), splitWriter.clusterWriters["[0,200)"].GetRowNum()) + s.Equal(int64(200), splitWriter.clusterWriters["[200,400)"].GetRowNum()) + s.Equal(int64(200), splitWriter.clusterWriters["[400,600)"].GetRowNum()) + s.Equal(int64(200), splitWriter.clusterWriters["[600,800)"].GetRowNum()) + s.Equal(int64(200), splitWriter.clusterWriters["[800,1000)"].GetRowNum()) + + result, err := splitWriter.Finish() + s.NoError(err) + s.Equal(5, len(result)) + + var totalRows int64 = 0 + for _, segments := range result { + for _, seg := range segments { + totalRows = totalRows + seg.GetNumOfRows() + } + } + s.Equal(int64(1000), totalRows) + + s.Equal(1, len(result["[0,200)"][0].GetField2StatslogPaths())) + s.Equal(1, len(result["[200,400)"][0].GetField2StatslogPaths())) + s.Equal(1, len(result["[400,600)"][0].GetField2StatslogPaths())) + s.Equal(1, len(result["[600,800)"][0].GetField2StatslogPaths())) + s.Equal(1, len(result["[800,1000)"][0].GetField2StatslogPaths())) +} + +func (s *SplitClusterWriterSuite) TestConcurrentSplitByHash() { + s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(nil) + + mappingFunc := func(value *storage.Value) (string, error) { + pkHash, err := value.PK.Hash() + if err != nil { + return "", err + } + return fmt.Sprint(pkHash % uint32(2)), nil + } + + splitWriter, err := NewSplitClusterWriterBuilder(). + SetBinlogIO(s.mockBinlogIO). + SetAllocator(&compactionAlloactor{ + segmentAlloc: s.allocator, + logIDAlloc: s.allocator, + }). + SetCollectionID(1). + SetPartitionID(2). + SetSchema(s.meta.Schema). + SetChannel("ch-1"). + SetSegmentMaxRowCount(1000). + SetSplitKeys([]string{"0", "1"}). + SetMemoryBufferSize(math.MaxInt64). + SetMappingFunc(mappingFunc). + SetWorkerPoolSize(1). + Build() + s.NoError(err) + + pool := conc.NewPool[any](10) + futures := make([]*conc.Future[any], 0) + for i := 0; i < 10; i++ { + j := int64(i * 1000) + future := pool.Submit(func() (any, error) { + for i := j + int64(0); i < j+1000; i++ { + err := splitWriter.Write(generateInt64PKEntitiy(i)) + if i == j+100 { + splitWriter.FlushLargest() + } + s.NoError(err) + } + return struct{}{}, nil + }) + futures = append(futures, future) + } + err = conc.AwaitAll(futures...) + s.NoError(err) + + for id, buffer := range splitWriter.clusterWriters { + println(id) + println(buffer.GetRowNum()) + } + + result, err := splitWriter.Finish() + s.NoError(err) + + var totalRows int64 = 0 + for _, segments := range result { + for _, seg := range segments { + totalRows = totalRows + seg.GetNumOfRows() + } + } + + s.Equal(1, len(result["0"][0].GetField2StatslogPaths())) + s.Equal(1, len(result["1"][0].GetField2StatslogPaths())) +} diff --git a/internal/storage/primary_key.go b/internal/storage/primary_key.go index 640ee2226a48c..fbbda1bb2ce92 100644 --- a/internal/storage/primary_key.go +++ b/internal/storage/primary_key.go @@ -23,6 +23,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type PrimaryKey interface { @@ -37,6 +38,7 @@ type PrimaryKey interface { GetValue() interface{} Type() schemapb.DataType Size() int64 + Hash() (uint32, error) } type Int64PrimaryKey struct { @@ -158,6 +160,10 @@ func (ip *Int64PrimaryKey) Size() int64 { return 16 } +func (ip *Int64PrimaryKey) Hash() (uint32, error) { + return typeutil.Hash32Int64(ip.Value) +} + type VarCharPrimaryKey struct { Value string } @@ -258,6 +264,10 @@ func (vcp *VarCharPrimaryKey) Size() int64 { return int64(8*len(vcp.Value) + 8) } +func (ip *VarCharPrimaryKey) Hash() (uint32, error) { + return typeutil.HashString2Uint32(ip.Value), nil +} + func GenPrimaryKeyByRawData(data interface{}, pkType schemapb.DataType) (PrimaryKey, error) { var result PrimaryKey switch pkType { diff --git a/internal/storage/primary_key_test.go b/internal/storage/primary_key_test.go index ff70531914ed7..309338f54aeef 100644 --- a/internal/storage/primary_key_test.go +++ b/internal/storage/primary_key_test.go @@ -177,3 +177,15 @@ func TestParsePrimaryKeysAndIDs(t *testing.T) { assert.ElementsMatch(t, c.pks, testPks) } } + +func TestPrimaryKeyHash(t *testing.T) { + pk := NewInt64PrimaryKey(1) + hash, err := pk.Hash() + assert.NoError(t, err) + assert.Equal(t, uint32(1392991556), hash) + + vcharPK := NewVarCharPrimaryKey("1") + vcharHash, err2 := vcharPK.Hash() + assert.NoError(t, err2) + assert.Equal(t, uint32(2212294583), vcharHash) +}