Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: use GPU pool for gpu tasks #29678

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions internal/querynodev2/segments/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,17 @@
"unsafe"

"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"go.uber.org/atomic"
"go.uber.org/zap"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

Expand Down Expand Up @@ -123,6 +127,7 @@
loadType querypb.LoadType
metricType atomic.String
schema atomic.Pointer[schemapb.CollectionSchema]
isGpuIndex bool

refCount *atomic.Uint32
}
Expand All @@ -137,6 +142,11 @@
return c.schema.Load()
}

// IsGpuIndex returns a boolean value indicating whether the collection is using a GPU index.
func (c *Collection) IsGpuIndex() bool {
return c.isGpuIndex
}

// getPartitionIDs return partitionIDs of collection
func (c *Collection) GetPartitions() []int64 {
return c.partitions.Collect()
Expand Down Expand Up @@ -205,13 +215,23 @@

collection := C.NewCollection(unsafe.Pointer(&schemaBlob[0]), (C.int64_t)(len(schemaBlob)))

isGpuIndex := false
if indexMeta != nil && len(indexMeta.GetIndexMetas()) > 0 && indexMeta.GetMaxIndexRowCount() > 0 {
indexMetaBlob, err := proto.Marshal(indexMeta)
if err != nil {
log.Warn("marshal index meta failed", zap.Error(err))
return nil
}
C.SetIndexMeta(collection, unsafe.Pointer(&indexMetaBlob[0]), (C.int64_t)(len(indexMetaBlob)))

for _, indexMeta := range indexMeta.GetIndexMetas() {
isGpuIndex = lo.ContainsBy(indexMeta.GetIndexParams(), func(param *commonpb.KeyValuePair) bool {
return param.Key == common.IndexTypeKey && indexparamcheck.IsGpuIndex(param.Value)
})
if isGpuIndex {
break

Check warning on line 232 in internal/querynodev2/segments/collection.go

View check run for this annotation

Codecov / codecov/patch

internal/querynodev2/segments/collection.go#L232

Added line #L232 was not covered by tests
}
}
}

coll := &Collection{
Expand All @@ -220,6 +240,7 @@
partitions: typeutil.NewConcurrentSet[int64](),
loadType: loadType,
refCount: atomic.NewUint32(0),
isGpuIndex: isGpuIndex,
}
coll.schema.Store(schema)

Expand Down
12 changes: 11 additions & 1 deletion internal/querynodev2/tasks/concurrent_safe_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
receiveChan: make(chan addTaskReq, maxReceiveChanSize),
execChan: make(chan Task),
pool: conc.NewPool[any](maxReadConcurrency, conc.WithPreAlloc(true)),
gpuPool: conc.NewPool[any](paramtable.Get().QueryNodeCfg.MaxGpuReadConcurrency.GetAsInt(), conc.WithPreAlloc(true)),
schedulerCounter: schedulerCounter{},
lifetime: lifetime.NewLifetime(lifetime.Initializing),
}
Expand All @@ -46,6 +47,7 @@
receiveChan chan addTaskReq
execChan chan Task
pool *conc.Pool[any]
gpuPool *conc.Pool[any]

// wg is the waitgroup for internal worker goroutine
wg sync.WaitGroup
Expand Down Expand Up @@ -227,7 +229,7 @@
continue
}

s.pool.Submit(func() (any, error) {
s.getPool(t).Submit(func() (any, error) {
// Update concurrency metric and notify task done.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
collector.Counter.Inc(metricsinfo.ExecuteQueueType, 1)
Expand All @@ -245,6 +247,14 @@
}
}

func (s *scheduler) getPool(t Task) *conc.Pool[any] {
if t.IsGpuIndex() {
return s.gpuPool
}

Check warning on line 253 in internal/querynodev2/tasks/concurrent_safe_scheduler.go

View check run for this annotation

Codecov / codecov/patch

internal/querynodev2/tasks/concurrent_safe_scheduler.go#L252-L253

Added lines #L252 - L253 were not covered by tests

return s.pool
}

// setupExecListener setup the execChan and next task to run.
func (s *scheduler) setupExecListener(lastWaitingTask Task) (Task, int64, chan Task) {
var execChan chan Task
Expand Down
4 changes: 4 additions & 0 deletions internal/querynodev2/tasks/mock_task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ func (t *MockTask) Username() string {
return t.username
}

func (t *MockTask) IsGpuIndex() bool {
return false
}

func (t *MockTask) TimeRecorder() *timerecord.TimeRecorder {
return t.tr
}
Expand Down
4 changes: 4 additions & 0 deletions internal/querynodev2/tasks/query_stream_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ func (t *QueryStreamTask) Username() string {
return t.req.Req.GetUsername()
}

func (t *QueryStreamTask) IsGpuIndex() bool {
return false
}

// PreExecute the task, only call once.
func (t *QueryStreamTask) PreExecute() error {
return nil
Expand Down
4 changes: 4 additions & 0 deletions internal/querynodev2/tasks/query_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ func (t *QueryTask) Username() string {
return t.req.Req.GetUsername()
}

func (t *QueryTask) IsGpuIndex() bool {
return false
}

// PreExecute the task, only call once.
func (t *QueryTask) PreExecute() error {
// Update task wait time metric before execute
Expand Down
4 changes: 4 additions & 0 deletions internal/querynodev2/tasks/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ func (t *SearchTask) Username() string {
return t.req.Req.GetUsername()
}

func (t *SearchTask) IsGpuIndex() bool {
return t.collection.IsGpuIndex()
}

func (t *SearchTask) PreExecute() error {
// Update task wait time metric before execute
nodeID := strconv.FormatInt(paramtable.GetNodeID(), 10)
Expand Down
3 changes: 3 additions & 0 deletions internal/querynodev2/tasks/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ type Task interface {
// Return "" if the task do not contain any user info.
Username() string

// Return whether the task would be running on GPU.
IsGpuIndex() bool

// PreExecute the task, only call once.
PreExecute() error

Expand Down
8 changes: 8 additions & 0 deletions pkg/util/indexparamcheck/index_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type IndexType = string

// IndexType definitions
const (
IndexGpuBF IndexType = "GPU_BRUTE_FORCE"
IndexRaftIvfFlat IndexType = "GPU_IVF_FLAT"
IndexRaftIvfPQ IndexType = "GPU_IVF_PQ"
IndexRaftCagra IndexType = "GPU_CAGRA"
Expand All @@ -29,3 +30,10 @@ const (
IndexHNSW IndexType = "HNSW"
IndexDISKANN IndexType = "DISKANN"
)

func IsGpuIndex(indexType IndexType) bool {
return indexType == IndexGpuBF ||
indexType == IndexRaftIvfFlat ||
indexType == IndexRaftIvfPQ ||
indexType == IndexRaftCagra
}
26 changes: 17 additions & 9 deletions pkg/util/paramtable/component_param.go
Original file line number Diff line number Diff line change
Expand Up @@ -1754,15 +1754,16 @@ type queryNodeConfig struct {
// chunk cache
ReadAheadPolicy ParamItem `refreshable:"false"`

GroupEnabled ParamItem `refreshable:"true"`
MaxReceiveChanSize ParamItem `refreshable:"false"`
MaxUnsolvedQueueSize ParamItem `refreshable:"true"`
MaxReadConcurrency ParamItem `refreshable:"true"`
MaxGroupNQ ParamItem `refreshable:"true"`
TopKMergeRatio ParamItem `refreshable:"true"`
CPURatio ParamItem `refreshable:"true"`
MaxTimestampLag ParamItem `refreshable:"true"`
GCEnabled ParamItem `refreshable:"true"`
GroupEnabled ParamItem `refreshable:"true"`
MaxReceiveChanSize ParamItem `refreshable:"false"`
MaxUnsolvedQueueSize ParamItem `refreshable:"true"`
MaxReadConcurrency ParamItem `refreshable:"true"`
MaxGpuReadConcurrency ParamItem `refreshable:"false"`
MaxGroupNQ ParamItem `refreshable:"true"`
TopKMergeRatio ParamItem `refreshable:"true"`
CPURatio ParamItem `refreshable:"true"`
MaxTimestampLag ParamItem `refreshable:"true"`
GCEnabled ParamItem `refreshable:"true"`

GCHelperEnabled ParamItem `refreshable:"false"`
MinimumGOGCConfig ParamItem `refreshable:"false"`
Expand Down Expand Up @@ -1999,6 +2000,13 @@ Max read concurrency must greater than or equal to 1, and less than or equal to
}
p.MaxReadConcurrency.Init(base.mgr)

p.MaxGpuReadConcurrency = ParamItem{
Key: "queryNode.scheduler.maGpuReadConcurrency",
Version: "2.0.0",
DefaultValue: "8",
}
p.MaxGpuReadConcurrency.Init(base.mgr)

p.MaxUnsolvedQueueSize = ParamItem{
Key: "queryNode.scheduler.unsolvedQueueSize",
Version: "2.0.0",
Expand Down
Loading