Skip to content

Commit

Permalink
feat:add new gpu index:GPU_BRUTE_FORCE and limit gpu index metric type (
Browse files Browse the repository at this point in the history
#29590)

issue: #29230
this pr do these things:
1. add gpu brute force;
2. limit gpu index only support l2 / ip;

Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>
  • Loading branch information
cqy123456 authored Jan 5, 2024
1 parent c8db36a commit 22bb84f
Show file tree
Hide file tree
Showing 10 changed files with 281 additions and 8 deletions.
11 changes: 9 additions & 2 deletions internal/core/src/config/ConfigKnowhere.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,16 @@ KnowhereInitGPUMemoryPool(const uint32_t init_size, const uint32_t max_size) {
if (init_size == 0 && max_size == 0) {
knowhere::KnowhereConfig::SetRaftMemPool();
return;
} else if (init_size > max_size) {
PanicInfo(ConfigInvalid,
"Error Gpu memory pool params: init_size {} can't not large "
"than max_size {}.",
init_size,
max_size);
} else {
knowhere::KnowhereConfig::SetRaftMemPool(size_t{init_size},
size_t{max_size});
}
knowhere::KnowhereConfig::SetRaftMemPool(size_t{init_size},
size_t{max_size});
}

int32_t
Expand Down
3 changes: 2 additions & 1 deletion pkg/util/indexparamcheck/conf_adapter_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ func (mgr *indexCheckerMgrImpl) GetChecker(indexType string) (IndexChecker, erro
}

func (mgr *indexCheckerMgrImpl) registerIndexChecker() {
mgr.checkers[IndexRaftIvfFlat] = newIVFBaseChecker()
mgr.checkers[IndexRaftIvfFlat] = newRaftIVFFlatChecker()
mgr.checkers[IndexRaftIvfPQ] = newRaftIVFPQChecker()
mgr.checkers[IndexRaftCagra] = newCagraChecker()
mgr.checkers[IndexRaftBruteForce] = newRaftBruteForceChecker()
mgr.checkers[IndexFaissIDMap] = newFlatChecker()
mgr.checkers[IndexFaissIvfFlat] = newIVFBaseChecker()
mgr.checkers[IndexFaissIvfPQ] = newIVFPQChecker()
Expand Down
7 changes: 4 additions & 3 deletions pkg/util/indexparamcheck/constraints.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ var (
BinIDMapMetrics = []string{metric.HAMMING, metric.JACCARD, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE} // const
BinIvfMetrics = []string{metric.HAMMING, metric.JACCARD} // const
HnswMetrics = []string{metric.L2, metric.IP, metric.COSINE, metric.HAMMING, metric.JACCARD} // const
CagraMetrics = []string{metric.L2} // const
supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
RaftMetrics = []string{metric.L2, metric.IP}
CagraMetrics = []string{metric.L2} // const
supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
)

const (
Expand Down
1 change: 1 addition & 0 deletions pkg/util/indexparamcheck/index_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (
IndexRaftIvfFlat IndexType = "GPU_IVF_FLAT"
IndexRaftIvfPQ IndexType = "GPU_IVF_PQ"
IndexRaftCagra IndexType = "GPU_CAGRA"
IndexRaftBruteForce IndexType = "GPU_BRUTE_FORCE"
IndexFaissIDMap IndexType = "FLAT" // no index is built.
IndexFaissIvfFlat IndexType = "IVF_FLAT"
IndexFaissIvfPQ IndexType = "IVF_PQ"
Expand Down
22 changes: 22 additions & 0 deletions pkg/util/indexparamcheck/raft_brute_force_checker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package indexparamcheck

import "fmt"

type raftBruteForceChecker struct {
floatVectorBaseChecker
}

// raftBrustForceChecker checks if a Brute_Force index can be built.
func (c raftBruteForceChecker) CheckTrain(params map[string]string) error {
if err := c.floatVectorBaseChecker.CheckTrain(params); err != nil {
return err
}
if !CheckStrByValues(params, Metric, RaftMetrics) {
return fmt.Errorf("metric type not found or not supported, supported: %v", RaftMetrics)
}
return nil
}

func newRaftBruteForceChecker() IndexChecker {
return &raftBruteForceChecker{}
}
64 changes: 64 additions & 0 deletions pkg/util/indexparamcheck/raft_brute_force_checker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package indexparamcheck

import (
"strconv"
"testing"

"github.com/stretchr/testify/assert"

"github.com/milvus-io/milvus/pkg/util/metric"
)

func Test_raftbfChecker_CheckTrain(t *testing.T) {
p1 := map[string]string{
DIM: strconv.Itoa(128),
Metric: metric.L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
Metric: metric.IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
Metric: metric.COSINE,
}

p4 := map[string]string{
DIM: strconv.Itoa(128),
Metric: metric.HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
Metric: metric.JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
Metric: metric.SUBSTRUCTURE,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
Metric: metric.SUPERSTRUCTURE,
}
cases := []struct {
params map[string]string
errIsNil bool
}{
{p1, true},
{p2, true},
{p3, false},
{p4, false},
{p5, false},
{p6, false},
{p7, false},
}

c := newRaftBruteForceChecker()
for _, test := range cases {
err := c.CheckTrain(test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}
23 changes: 23 additions & 0 deletions pkg/util/indexparamcheck/raft_ivf_flat_checker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package indexparamcheck

import "fmt"

// raftIVFChecker checks if a RAFT_IVF_Flat index can be built.
type raftIVFFlatChecker struct {
ivfBaseChecker
}

// CheckTrain checks if ivf-flat index can be built with the specific index parameters.
func (c *raftIVFFlatChecker) CheckTrain(params map[string]string) error {
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
return err
}
if !CheckStrByValues(params, Metric, RaftMetrics) {
return fmt.Errorf("metric type not found or not supported, supported: %v", RaftMetrics)
}
return nil
}

func newRaftIVFFlatChecker() IndexChecker {
return &raftIVFFlatChecker{}
}
152 changes: 152 additions & 0 deletions pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package indexparamcheck

import (
"strconv"
"testing"

"github.com/stretchr/testify/assert"

"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/metric"
)

func Test_raftIvfFlatChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: metric.L2,
}

p1 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: metric.L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: metric.IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: metric.COSINE,
}

p4 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: metric.HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: metric.JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: metric.SUBSTRUCTURE,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: metric.SUPERSTRUCTURE,
}

cases := []struct {
params map[string]string
errIsNil bool
}{
{validParams, true},
{invalidIVFParamsMin(), false},
{invalidIVFParamsMax(), false},
{p1, true},
{p2, true},
{p3, false},
{p4, false},
{p5, false},
{p6, false},
{p7, false},
}

c := newRaftIVFFlatChecker()
for _, test := range cases {
err := c.CheckTrain(test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}

func Test_raftIvfFlatChecker_CheckValidDataType(t *testing.T) {
cases := []struct {
dType schemapb.DataType
errIsNil bool
}{
{
dType: schemapb.DataType_Bool,
errIsNil: false,
},
{
dType: schemapb.DataType_Int8,
errIsNil: false,
},
{
dType: schemapb.DataType_Int16,
errIsNil: false,
},
{
dType: schemapb.DataType_Int32,
errIsNil: false,
},
{
dType: schemapb.DataType_Int64,
errIsNil: false,
},
{
dType: schemapb.DataType_Float,
errIsNil: false,
},
{
dType: schemapb.DataType_Double,
errIsNil: false,
},
{
dType: schemapb.DataType_String,
errIsNil: false,
},
{
dType: schemapb.DataType_VarChar,
errIsNil: false,
},
{
dType: schemapb.DataType_Array,
errIsNil: false,
},
{
dType: schemapb.DataType_JSON,
errIsNil: false,
},
{
dType: schemapb.DataType_FloatVector,
errIsNil: true,
},
{
dType: schemapb.DataType_BinaryVector,
errIsNil: false,
},
}

c := newRaftIVFFlatChecker()
for _, test := range cases {
err := c.CheckValidDataType(test.dType)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}
4 changes: 3 additions & 1 deletion pkg/util/indexparamcheck/raft_ivf_pq_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ func (c *raftIVFPQChecker) CheckTrain(params map[string]string) error {
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
return err
}

if !CheckStrByValues(params, Metric, RaftMetrics) {
return fmt.Errorf("metric type not found or not supported, supported: %v", RaftMetrics)
}
return c.checkPQParams(params)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) {
{validParamsMzero, true},
{p1, true},
{p2, true},
{p3, true},
{p3, false},
{p4, false},
{p5, false},
{p6, false},
Expand Down

0 comments on commit 22bb84f

Please sign in to comment.