Skip to content

Commit

Permalink
enhance: use GPU pool for gpu tasks (milvus-io#29678)
Browse files Browse the repository at this point in the history
- this much improve the performance for GPU index

Signed-off-by: yah01 <yang.cen@zilliz.com>
  • Loading branch information
yah01 committed Jan 8, 2024
1 parent 0c83440 commit 417d456
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 10 deletions.
21 changes: 21 additions & 0 deletions internal/querynodev2/segments/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,17 @@ import (
"unsafe"

"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"go.uber.org/atomic"
"go.uber.org/zap"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

Expand Down Expand Up @@ -123,6 +127,7 @@ type Collection struct {
loadType querypb.LoadType
metricType atomic.String
schema atomic.Pointer[schemapb.CollectionSchema]
isGpuIndex bool

refCount *atomic.Uint32
}
Expand All @@ -137,6 +142,11 @@ func (c *Collection) Schema() *schemapb.CollectionSchema {
return c.schema.Load()
}

// IsGpuIndex returns a boolean value indicating whether the collection is using a GPU index.
func (c *Collection) IsGpuIndex() bool {
return c.isGpuIndex
}

// getPartitionIDs return partitionIDs of collection
func (c *Collection) GetPartitions() []int64 {
return c.partitions.Collect()
Expand Down Expand Up @@ -203,10 +213,20 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM

collection := C.NewCollection(cSchemaBlob)

isGpuIndex := false
if indexMeta != nil && len(indexMeta.GetIndexMetas()) > 0 && indexMeta.GetMaxIndexRowCount() > 0 {
indexMetaBlob := proto.MarshalTextString(indexMeta)
cIndexMetaBlob := C.CString(indexMetaBlob)
C.SetIndexMeta(collection, cIndexMetaBlob)

for _, indexMeta := range indexMeta.GetIndexMetas() {
isGpuIndex = lo.ContainsBy(indexMeta.GetIndexParams(), func(param *commonpb.KeyValuePair) bool {
return param.Key == common.IndexTypeKey && indexparamcheck.IsGpuIndex(param.Value)
})
if isGpuIndex {
break
}
}
}

coll := &Collection{
Expand All @@ -215,6 +235,7 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM
partitions: typeutil.NewConcurrentSet[int64](),
loadType: loadType,
refCount: atomic.NewUint32(0),
isGpuIndex: isGpuIndex,
}
coll.schema.Store(schema)

Expand Down
12 changes: 11 additions & 1 deletion internal/querynodev2/tasks/concurrent_safe_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func newScheduler(policy schedulePolicy) Scheduler {
receiveChan: make(chan addTaskReq, maxReceiveChanSize),
execChan: make(chan Task),
pool: conc.NewPool[any](maxReadConcurrency, conc.WithPreAlloc(true)),
gpuPool: conc.NewPool[any](paramtable.Get().QueryNodeCfg.MaxGpuReadConcurrency.GetAsInt(), conc.WithPreAlloc(true)),
schedulerCounter: schedulerCounter{},
lifetime: lifetime.NewLifetime(lifetime.Initializing),
}
Expand All @@ -46,6 +47,7 @@ type scheduler struct {
receiveChan chan addTaskReq
execChan chan Task
pool *conc.Pool[any]
gpuPool *conc.Pool[any]

// wg is the waitgroup for internal worker goroutine
wg sync.WaitGroup
Expand Down Expand Up @@ -227,7 +229,7 @@ func (s *scheduler) exec() {
continue
}

s.pool.Submit(func() (any, error) {
s.getPool(t).Submit(func() (any, error) {
// Update concurrency metric and notify task done.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
collector.Counter.Inc(metricsinfo.ExecuteQueueType, 1)
Expand All @@ -245,6 +247,14 @@ func (s *scheduler) exec() {
}
}

func (s *scheduler) getPool(t Task) *conc.Pool[any] {
if t.IsGpuIndex() {
return s.gpuPool
}

return s.pool
}

// setupExecListener setup the execChan and next task to run.
func (s *scheduler) setupExecListener(lastWaitingTask Task) (Task, int64, chan Task) {
var execChan chan Task
Expand Down
4 changes: 4 additions & 0 deletions internal/querynodev2/tasks/mock_task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ func (t *MockTask) Username() string {
return t.username
}

func (t *MockTask) IsGpuIndex() bool {
return false
}

func (t *MockTask) TimeRecorder() *timerecord.TimeRecorder {
return t.tr
}
Expand Down
4 changes: 4 additions & 0 deletions internal/querynodev2/tasks/query_stream_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ func (t *QueryStreamTask) Username() string {
return t.req.Req.GetUsername()
}

func (t *QueryStreamTask) IsGpuIndex() bool {
return false
}

// PreExecute the task, only call once.
func (t *QueryStreamTask) PreExecute() error {
return nil
Expand Down
4 changes: 4 additions & 0 deletions internal/querynodev2/tasks/query_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ func (t *QueryTask) Username() string {
return t.req.Req.GetUsername()
}

func (t *QueryTask) IsGpuIndex() bool {
return false
}

// PreExecute the task, only call once.
func (t *QueryTask) PreExecute() error {
// Update task wait time metric before execute
Expand Down
4 changes: 4 additions & 0 deletions internal/querynodev2/tasks/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ func (t *SearchTask) Username() string {
return t.req.Req.GetUsername()
}

func (t *SearchTask) IsGpuIndex() bool {
return t.collection.IsGpuIndex()
}

func (t *SearchTask) PreExecute() error {
// Update task wait time metric before execute
nodeID := strconv.FormatInt(paramtable.GetNodeID(), 10)
Expand Down
3 changes: 3 additions & 0 deletions internal/querynodev2/tasks/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ type Task interface {
// Return "" if the task do not contain any user info.
Username() string

// Return whether the task would be running on GPU.
IsGpuIndex() bool

// PreExecute the task, only call once.
PreExecute() error

Expand Down
9 changes: 9 additions & 0 deletions pkg/util/indexparamcheck/index_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ type IndexType = string

// IndexType definitions
const (
IndexGpuBF IndexType = "GPU_BRUTE_FORCE"
IndexRaftIvfFlat IndexType = "GPU_IVF_FLAT"
IndexRaftIvfPQ IndexType = "GPU_IVF_PQ"
IndexRaftCagra IndexType = "GPU_CAGRA"
IndexFaissIDMap IndexType = "FLAT" // no index is built.
IndexFaissIvfFlat IndexType = "IVF_FLAT"
IndexFaissIvfPQ IndexType = "IVF_PQ"
Expand All @@ -28,3 +30,10 @@ const (
IndexHNSW IndexType = "HNSW"
IndexDISKANN IndexType = "DISKANN"
)

func IsGpuIndex(indexType IndexType) bool {
return indexType == IndexGpuBF ||
indexType == IndexRaftIvfFlat ||
indexType == IndexRaftIvfPQ ||
indexType == IndexRaftCagra
}
26 changes: 17 additions & 9 deletions pkg/util/paramtable/component_param.go
Original file line number Diff line number Diff line change
Expand Up @@ -1693,15 +1693,16 @@ type queryNodeConfig struct {
// chunk cache
ReadAheadPolicy ParamItem `refreshable:"false"`

GroupEnabled ParamItem `refreshable:"true"`
MaxReceiveChanSize ParamItem `refreshable:"false"`
MaxUnsolvedQueueSize ParamItem `refreshable:"true"`
MaxReadConcurrency ParamItem `refreshable:"true"`
MaxGroupNQ ParamItem `refreshable:"true"`
TopKMergeRatio ParamItem `refreshable:"true"`
CPURatio ParamItem `refreshable:"true"`
MaxTimestampLag ParamItem `refreshable:"true"`
GCEnabled ParamItem `refreshable:"true"`
GroupEnabled ParamItem `refreshable:"true"`
MaxReceiveChanSize ParamItem `refreshable:"false"`
MaxUnsolvedQueueSize ParamItem `refreshable:"true"`
MaxReadConcurrency ParamItem `refreshable:"true"`
MaxGpuReadConcurrency ParamItem `refreshable:"false"`
MaxGroupNQ ParamItem `refreshable:"true"`
TopKMergeRatio ParamItem `refreshable:"true"`
CPURatio ParamItem `refreshable:"true"`
MaxTimestampLag ParamItem `refreshable:"true"`
GCEnabled ParamItem `refreshable:"true"`

GCHelperEnabled ParamItem `refreshable:"false"`
MinimumGOGCConfig ParamItem `refreshable:"false"`
Expand Down Expand Up @@ -1936,6 +1937,13 @@ Max read concurrency must greater than or equal to 1, and less than or equal to
}
p.MaxReadConcurrency.Init(base.mgr)

p.MaxGpuReadConcurrency = ParamItem{
Key: "queryNode.scheduler.maGpuReadConcurrency",
Version: "2.0.0",
DefaultValue: "8",
}
p.MaxGpuReadConcurrency.Init(base.mgr)

p.MaxUnsolvedQueueSize = ParamItem{
Key: "queryNode.scheduler.unsolvedQueueSize",
Version: "2.0.0",
Expand Down

0 comments on commit 417d456

Please sign in to comment.