From 22bb84fa9df4fbb3598ea21f701707293dc4cf2c Mon Sep 17 00:00:00 2001 From: cqy123456 <39671710+cqy123456@users.noreply.github.com> Date: Fri, 5 Jan 2024 15:24:48 +0800 Subject: [PATCH] feat:add new gpu index:GPU_BRUTE_FORCE and limit gpu index metric type (#29590) issue: https://github.com/milvus-io/milvus/issues/29230 this pr do these things: 1. add gpu brute force; 2. limit gpu index only support l2 / ip; Signed-off-by: cqy123456 --- internal/core/src/config/ConfigKnowhere.cpp | 11 +- pkg/util/indexparamcheck/conf_adapter_mgr.go | 3 +- pkg/util/indexparamcheck/constraints.go | 7 +- pkg/util/indexparamcheck/index_type.go | 1 + .../raft_brute_force_checker.go | 22 +++ .../raft_brute_force_checker_test.go | 64 ++++++++ .../indexparamcheck/raft_ivf_flat_checker.go | 23 +++ .../raft_ivf_flat_checker_test.go | 152 ++++++++++++++++++ .../indexparamcheck/raft_ivf_pq_checker.go | 4 +- .../raft_ivf_pq_checker_test.go | 2 +- 10 files changed, 281 insertions(+), 8 deletions(-) create mode 100644 pkg/util/indexparamcheck/raft_brute_force_checker.go create mode 100644 pkg/util/indexparamcheck/raft_brute_force_checker_test.go create mode 100644 pkg/util/indexparamcheck/raft_ivf_flat_checker.go create mode 100644 pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go diff --git a/internal/core/src/config/ConfigKnowhere.cpp b/internal/core/src/config/ConfigKnowhere.cpp index c10cd366b3af7..47fa65ad3ba3d 100644 --- a/internal/core/src/config/ConfigKnowhere.cpp +++ b/internal/core/src/config/ConfigKnowhere.cpp @@ -93,9 +93,16 @@ KnowhereInitGPUMemoryPool(const uint32_t init_size, const uint32_t max_size) { if (init_size == 0 && max_size == 0) { knowhere::KnowhereConfig::SetRaftMemPool(); return; + } else if (init_size > max_size) { + PanicInfo(ConfigInvalid, + "Error Gpu memory pool params: init_size {} can't not large " + "than max_size {}.", + init_size, + max_size); + } else { + knowhere::KnowhereConfig::SetRaftMemPool(size_t{init_size}, + size_t{max_size}); } - knowhere::KnowhereConfig::SetRaftMemPool(size_t{init_size}, - size_t{max_size}); } int32_t diff --git a/pkg/util/indexparamcheck/conf_adapter_mgr.go b/pkg/util/indexparamcheck/conf_adapter_mgr.go index 6a8e2a8407189..ca2d53a3f39b9 100644 --- a/pkg/util/indexparamcheck/conf_adapter_mgr.go +++ b/pkg/util/indexparamcheck/conf_adapter_mgr.go @@ -43,9 +43,10 @@ func (mgr *indexCheckerMgrImpl) GetChecker(indexType string) (IndexChecker, erro } func (mgr *indexCheckerMgrImpl) registerIndexChecker() { - mgr.checkers[IndexRaftIvfFlat] = newIVFBaseChecker() + mgr.checkers[IndexRaftIvfFlat] = newRaftIVFFlatChecker() mgr.checkers[IndexRaftIvfPQ] = newRaftIVFPQChecker() mgr.checkers[IndexRaftCagra] = newCagraChecker() + mgr.checkers[IndexRaftBruteForce] = newRaftBruteForceChecker() mgr.checkers[IndexFaissIDMap] = newFlatChecker() mgr.checkers[IndexFaissIvfFlat] = newIVFBaseChecker() mgr.checkers[IndexFaissIvfPQ] = newIVFPQChecker() diff --git a/pkg/util/indexparamcheck/constraints.go b/pkg/util/indexparamcheck/constraints.go index f3f8d64c6ba35..9d4ffd64986aa 100644 --- a/pkg/util/indexparamcheck/constraints.go +++ b/pkg/util/indexparamcheck/constraints.go @@ -50,9 +50,10 @@ var ( BinIDMapMetrics = []string{metric.HAMMING, metric.JACCARD, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE} // const BinIvfMetrics = []string{metric.HAMMING, metric.JACCARD} // const HnswMetrics = []string{metric.L2, metric.IP, metric.COSINE, metric.HAMMING, metric.JACCARD} // const - CagraMetrics = []string{metric.L2} // const - supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const - supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const + RaftMetrics = []string{metric.L2, metric.IP} + CagraMetrics = []string{metric.L2} // const + supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const + supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const ) const ( diff --git a/pkg/util/indexparamcheck/index_type.go b/pkg/util/indexparamcheck/index_type.go index ebef1bc7a699c..4b8291ed9d021 100644 --- a/pkg/util/indexparamcheck/index_type.go +++ b/pkg/util/indexparamcheck/index_type.go @@ -20,6 +20,7 @@ const ( IndexRaftIvfFlat IndexType = "GPU_IVF_FLAT" IndexRaftIvfPQ IndexType = "GPU_IVF_PQ" IndexRaftCagra IndexType = "GPU_CAGRA" + IndexRaftBruteForce IndexType = "GPU_BRUTE_FORCE" IndexFaissIDMap IndexType = "FLAT" // no index is built. IndexFaissIvfFlat IndexType = "IVF_FLAT" IndexFaissIvfPQ IndexType = "IVF_PQ" diff --git a/pkg/util/indexparamcheck/raft_brute_force_checker.go b/pkg/util/indexparamcheck/raft_brute_force_checker.go new file mode 100644 index 0000000000000..38872da7ec773 --- /dev/null +++ b/pkg/util/indexparamcheck/raft_brute_force_checker.go @@ -0,0 +1,22 @@ +package indexparamcheck + +import "fmt" + +type raftBruteForceChecker struct { + floatVectorBaseChecker +} + +// raftBrustForceChecker checks if a Brute_Force index can be built. +func (c raftBruteForceChecker) CheckTrain(params map[string]string) error { + if err := c.floatVectorBaseChecker.CheckTrain(params); err != nil { + return err + } + if !CheckStrByValues(params, Metric, RaftMetrics) { + return fmt.Errorf("metric type not found or not supported, supported: %v", RaftMetrics) + } + return nil +} + +func newRaftBruteForceChecker() IndexChecker { + return &raftBruteForceChecker{} +} diff --git a/pkg/util/indexparamcheck/raft_brute_force_checker_test.go b/pkg/util/indexparamcheck/raft_brute_force_checker_test.go new file mode 100644 index 0000000000000..ce037bc4dcb9c --- /dev/null +++ b/pkg/util/indexparamcheck/raft_brute_force_checker_test.go @@ -0,0 +1,64 @@ +package indexparamcheck + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/util/metric" +) + +func Test_raftbfChecker_CheckTrain(t *testing.T) { + p1 := map[string]string{ + DIM: strconv.Itoa(128), + Metric: metric.L2, + } + p2 := map[string]string{ + DIM: strconv.Itoa(128), + Metric: metric.IP, + } + p3 := map[string]string{ + DIM: strconv.Itoa(128), + Metric: metric.COSINE, + } + + p4 := map[string]string{ + DIM: strconv.Itoa(128), + Metric: metric.HAMMING, + } + p5 := map[string]string{ + DIM: strconv.Itoa(128), + Metric: metric.JACCARD, + } + p6 := map[string]string{ + DIM: strconv.Itoa(128), + Metric: metric.SUBSTRUCTURE, + } + p7 := map[string]string{ + DIM: strconv.Itoa(128), + Metric: metric.SUPERSTRUCTURE, + } + cases := []struct { + params map[string]string + errIsNil bool + }{ + {p1, true}, + {p2, true}, + {p3, false}, + {p4, false}, + {p5, false}, + {p6, false}, + {p7, false}, + } + + c := newRaftBruteForceChecker() + for _, test := range cases { + err := c.CheckTrain(test.params) + if test.errIsNil { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + } +} diff --git a/pkg/util/indexparamcheck/raft_ivf_flat_checker.go b/pkg/util/indexparamcheck/raft_ivf_flat_checker.go new file mode 100644 index 0000000000000..5846e2616825c --- /dev/null +++ b/pkg/util/indexparamcheck/raft_ivf_flat_checker.go @@ -0,0 +1,23 @@ +package indexparamcheck + +import "fmt" + +// raftIVFChecker checks if a RAFT_IVF_Flat index can be built. +type raftIVFFlatChecker struct { + ivfBaseChecker +} + +// CheckTrain checks if ivf-flat index can be built with the specific index parameters. +func (c *raftIVFFlatChecker) CheckTrain(params map[string]string) error { + if err := c.ivfBaseChecker.CheckTrain(params); err != nil { + return err + } + if !CheckStrByValues(params, Metric, RaftMetrics) { + return fmt.Errorf("metric type not found or not supported, supported: %v", RaftMetrics) + } + return nil +} + +func newRaftIVFFlatChecker() IndexChecker { + return &raftIVFFlatChecker{} +} diff --git a/pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go b/pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go new file mode 100644 index 0000000000000..949dd6c552c45 --- /dev/null +++ b/pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go @@ -0,0 +1,152 @@ +package indexparamcheck + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/metric" +) + +func Test_raftIvfFlatChecker_CheckTrain(t *testing.T) { + validParams := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.L2, + } + + p1 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.L2, + } + p2 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.IP, + } + p3 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.COSINE, + } + + p4 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.HAMMING, + } + p5 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.JACCARD, + } + p6 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.SUBSTRUCTURE, + } + p7 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.SUPERSTRUCTURE, + } + + cases := []struct { + params map[string]string + errIsNil bool + }{ + {validParams, true}, + {invalidIVFParamsMin(), false}, + {invalidIVFParamsMax(), false}, + {p1, true}, + {p2, true}, + {p3, false}, + {p4, false}, + {p5, false}, + {p6, false}, + {p7, false}, + } + + c := newRaftIVFFlatChecker() + for _, test := range cases { + err := c.CheckTrain(test.params) + if test.errIsNil { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + } +} + +func Test_raftIvfFlatChecker_CheckValidDataType(t *testing.T) { + cases := []struct { + dType schemapb.DataType + errIsNil bool + }{ + { + dType: schemapb.DataType_Bool, + errIsNil: false, + }, + { + dType: schemapb.DataType_Int8, + errIsNil: false, + }, + { + dType: schemapb.DataType_Int16, + errIsNil: false, + }, + { + dType: schemapb.DataType_Int32, + errIsNil: false, + }, + { + dType: schemapb.DataType_Int64, + errIsNil: false, + }, + { + dType: schemapb.DataType_Float, + errIsNil: false, + }, + { + dType: schemapb.DataType_Double, + errIsNil: false, + }, + { + dType: schemapb.DataType_String, + errIsNil: false, + }, + { + dType: schemapb.DataType_VarChar, + errIsNil: false, + }, + { + dType: schemapb.DataType_Array, + errIsNil: false, + }, + { + dType: schemapb.DataType_JSON, + errIsNil: false, + }, + { + dType: schemapb.DataType_FloatVector, + errIsNil: true, + }, + { + dType: schemapb.DataType_BinaryVector, + errIsNil: false, + }, + } + + c := newRaftIVFFlatChecker() + for _, test := range cases { + err := c.CheckValidDataType(test.dType) + if test.errIsNil { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + } +} diff --git a/pkg/util/indexparamcheck/raft_ivf_pq_checker.go b/pkg/util/indexparamcheck/raft_ivf_pq_checker.go index 65f6d1d1b7503..dd502bccb0128 100644 --- a/pkg/util/indexparamcheck/raft_ivf_pq_checker.go +++ b/pkg/util/indexparamcheck/raft_ivf_pq_checker.go @@ -15,7 +15,9 @@ func (c *raftIVFPQChecker) CheckTrain(params map[string]string) error { if err := c.ivfBaseChecker.CheckTrain(params); err != nil { return err } - + if !CheckStrByValues(params, Metric, RaftMetrics) { + return fmt.Errorf("metric type not found or not supported, supported: %v", RaftMetrics) + } return c.checkPQParams(params) } diff --git a/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go b/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go index f1b743359727f..d21ed166ccb05 100644 --- a/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go +++ b/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go @@ -123,7 +123,7 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) { {validParamsMzero, true}, {p1, true}, {p2, true}, - {p3, true}, + {p3, false}, {p4, false}, {p5, false}, {p6, false},