Skip to content

Commit

Permalink
Enable sort for segcore RetrieveResults (milvus-io#18894)
Browse files Browse the repository at this point in the history
See also: milvus-io#18893

Signed-off-by: yangxuan <xuan.yang@zilliz.com>

Signed-off-by: yangxuan <xuan.yang@zilliz.com>
  • Loading branch information
XuanYang-cn authored Sep 14, 2022
1 parent 65fffa6 commit 086eb92
Show file tree
Hide file tree
Showing 12 changed files with 475 additions and 119 deletions.
16 changes: 16 additions & 0 deletions internal/querynode/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1826,6 +1826,22 @@ func genFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType,
},
FieldId: fieldID,
}
case schemapb.DataType_VarChar:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldName: fieldName,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: fieldValue.([]string),
},
},
},
},
FieldId: fieldID,
}

case schemapb.DataType_BinaryVector:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_BinaryVector,
Expand Down
92 changes: 0 additions & 92 deletions internal/querynode/query_shard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
)

Expand Down Expand Up @@ -143,93 +141,3 @@ func genSearchResultData(nq int64, topk int64, ids []int64, scores []float32, to
Topks: topks,
}
}

func TestReduceSearchResultData(t *testing.T) {
const (
nq = 1
topk = 4
metricType = "L2"
)
t.Run("case1", func(t *testing.T) {
ids := []int64{1, 2, 3, 4}
scores := []float32{-1.0, -2.0, -3.0, -4.0}
topks := []int64{int64(len(ids))}
data1 := genSearchResultData(nq, topk, ids, scores, topks)
data2 := genSearchResultData(nq, topk, ids, scores, topks)
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
res, err := reduceSearchResultData(context.TODO(), dataArray, nq, topk)
assert.Nil(t, err)
assert.Equal(t, ids, res.Ids.GetIntId().Data)
assert.Equal(t, scores, res.Scores)
})
t.Run("case2", func(t *testing.T) {
ids1 := []int64{1, 2, 3, 4}
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
topks1 := []int64{int64(len(ids1))}
ids2 := []int64{5, 1, 3, 4}
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
topks2 := []int64{int64(len(ids2))}
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
res, err := reduceSearchResultData(context.TODO(), dataArray, nq, topk)
assert.Nil(t, err)
assert.ElementsMatch(t, []int64{1, 5, 2, 3}, res.Ids.GetIntId().Data)
})
}

func TestMergeInternalRetrieveResults(t *testing.T) {
const (
Dim = 8
Int64FieldName = "Int64Field"
FloatVectorFieldName = "FloatVectorField"
Int64FieldID = common.StartOfUserFieldID + 1
FloatVectorFieldID = common.StartOfUserFieldID + 2
)
Int64Array := []int64{11, 22}
FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0}

var fieldDataArray1 []*schemapb.FieldData
fieldDataArray1 = append(fieldDataArray1, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))

var fieldDataArray2 []*schemapb.FieldData
fieldDataArray2 = append(fieldDataArray2, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))

result1 := &internalpb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{0, 1},
},
},
},
// Offset: []int64{0, 1},
FieldsData: fieldDataArray1,
}
result2 := &internalpb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{0, 1},
},
},
},
// Offset: []int64{0, 1},
FieldsData: fieldDataArray2,
}
ctx := context.TODO()

result, err := mergeInternalRetrieveResults(ctx, []*internalpb.RetrieveResults{result1, result2})
assert.NoError(t, err)
assert.Equal(t, 2, len(result.FieldsData[0].GetScalars().GetLongData().Data))
assert.Equal(t, 2*Dim, len(result.FieldsData[1].GetVectors().GetFloatVector().Data))

_, err = mergeInternalRetrieveResults(ctx, nil)
assert.NoError(t, err)
}
14 changes: 10 additions & 4 deletions internal/querynode/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,12 @@ func mergeInternalRetrieveResults(ctx context.Context, retrieveResults []*intern
}

func mergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) {
var ret *segcorepb.RetrieveResults
var skipDupCnt int64
var idSet = make(map[interface{}]struct{})

var (
ret *segcorepb.RetrieveResults
skipDupCnt int64
idSet = make(map[interface{}]struct{})
)

// merge results and remove duplicates
for _, rr := range retrieveResults {
Expand Down Expand Up @@ -320,7 +323,10 @@ func mergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
}
}
}
log.Ctx(ctx).Debug("skip duplicated query result", zap.Int64("count", skipDupCnt))

if skipDupCnt > 0 {
log.Ctx(ctx).Debug("skip duplicated query result", zap.Int64("count", skipDupCnt))
}

// not found, return default values indicating not result found
if ret == nil {
Expand Down
87 changes: 87 additions & 0 deletions internal/querynode/result_sorter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package querynode

import (
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/util/typeutil"
)

type byPK struct {
r *segcorepb.RetrieveResults
}

func (s *byPK) Len() int {
if s.r == nil {
return 0
}

switch id := s.r.GetIds().GetIdField().(type) {
case *schemapb.IDs_IntId:
return len(id.IntId.GetData())
case *schemapb.IDs_StrId:
return len(id.StrId.GetData())
}
return 0
}

func (s *byPK) Swap(i, j int) {
s.r.Offset[i], s.r.Offset[j] = s.r.Offset[j], s.r.Offset[i]

typeutil.SwapPK(s.r.GetIds(), i, j)

for _, field := range s.r.GetFieldsData() {
swapFieldData(field, i, j)
}
}

func (s *byPK) Less(i, j int) bool {
return typeutil.ComparePK(s.r.GetIds(), i, j)
}

func swapFieldData(field *schemapb.FieldData, i int, j int) {
switch field.GetField().(type) {
case *schemapb.FieldData_Scalars:
switch sd := field.GetScalars().GetData().(type) {
case *schemapb.ScalarField_BoolData:
data := sd.BoolData.Data
data[i], data[j] = data[j], data[i]
case *schemapb.ScalarField_IntData:
data := sd.IntData.Data
data[i], data[j] = data[j], data[i]
case *schemapb.ScalarField_LongData:
data := sd.LongData.Data
data[i], data[j] = data[j], data[i]
case *schemapb.ScalarField_FloatData:
data := sd.FloatData.Data
data[i], data[j] = data[j], data[i]
case *schemapb.ScalarField_DoubleData:
data := sd.DoubleData.Data
data[i], data[j] = data[j], data[i]
case *schemapb.ScalarField_StringData:
data := sd.StringData.Data
data[i], data[j] = data[j], data[i]
}
case *schemapb.FieldData_Vectors:
dim := int(field.GetVectors().GetDim())
switch vd := field.GetVectors().GetData().(type) {
case *schemapb.VectorField_BinaryVector:
steps := dim / 8 // dim for binary vector must be multiplier of 8
srcToSwap := vd.BinaryVector[i*steps : (i+1)*steps]
dstToSwap := vd.BinaryVector[j*steps : (j+1)*steps]

for k := range srcToSwap {
srcToSwap[k], dstToSwap[k] = dstToSwap[k], srcToSwap[k]
}
case *schemapb.VectorField_FloatVector:
srcToSwap := vd.FloatVector.Data[i*dim : (i+1)*dim]
dstToSwap := vd.FloatVector.Data[j*dim : (j+1)*dim]

for k := range srcToSwap {
srcToSwap[k], dstToSwap[k] = dstToSwap[k], srcToSwap[k]
}
}
default:
errMsg := "undefined data type " + field.Type.String()
panic(errMsg)
}
}
54 changes: 54 additions & 0 deletions internal/querynode/result_sorter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package querynode

import (
"sort"
"testing"

"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"

"github.com/stretchr/testify/assert"
)

func TestResultSorter_ByIntPK(t *testing.T) {
result := &segcorepb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{5, 4, 3, 2, 9, 8, 7, 6},
}},
},
Offset: []int64{5, 4, 3, 2, 9, 8, 7, 6},
FieldsData: []*schemapb.FieldData{
genFieldData("int64 field", 100, schemapb.DataType_Int64,
[]int64{5, 4, 3, 2, 9, 8, 7, 6}, 1),
genFieldData("double field", 101, schemapb.DataType_Double,
[]float64{5, 4, 3, 2, 9, 8, 7, 6}, 1),
genFieldData("string field", 102, schemapb.DataType_VarChar,
[]string{"5", "4", "3", "2", "9", "8", "7", "6"}, 1),
genFieldData("bool field", 103, schemapb.DataType_Bool,
[]bool{false, true, false, true, false, true, false, true}, 1),
genFieldData("float field", 104, schemapb.DataType_Float,
[]float32{5, 4, 3, 2, 9, 8, 7, 6}, 1),
genFieldData("int field", 105, schemapb.DataType_Int32,
[]int32{5, 4, 3, 2, 9, 8, 7, 6}, 1),
genFieldData("float vector field", 106, schemapb.DataType_FloatVector,
[]float32{5, 4, 3, 2, 9, 8, 7, 6}, 1),
genFieldData("binary vector field", 107, schemapb.DataType_BinaryVector,
[]byte{5, 4, 3, 2, 9, 8, 7, 6}, 8),
},
}

sort.Sort(&byPK{result})

assert.Equal(t, []int64{2, 3, 4, 5, 6, 7, 8, 9}, result.GetIds().GetIntId().GetData())
assert.Equal(t, []int64{2, 3, 4, 5, 6, 7, 8, 9}, result.GetOffset())
assert.Equal(t, []int64{2, 3, 4, 5, 6, 7, 8, 9}, result.FieldsData[0].GetScalars().GetLongData().Data)
assert.InDeltaSlice(t, []float64{2, 3, 4, 5, 6, 7, 8, 9}, result.FieldsData[1].GetScalars().GetDoubleData().Data, 10e-10)
assert.Equal(t, []string{"2", "3", "4", "5", "6", "7", "8", "9"}, result.FieldsData[2].GetScalars().GetStringData().Data)
assert.Equal(t, []bool{true, false, true, false, true, false, true, false}, result.FieldsData[3].GetScalars().GetBoolData().Data)
assert.InDeltaSlice(t, []float32{2, 3, 4, 5, 6, 7, 8, 9}, result.FieldsData[4].GetScalars().GetFloatData().Data, 10e-10)
assert.Equal(t, []int32{2, 3, 4, 5, 6, 7, 8, 9}, result.FieldsData[5].GetScalars().GetIntData().Data)
assert.InDeltaSlice(t, []float32{2, 3, 4, 5, 6, 7, 8, 9}, result.FieldsData[6].GetVectors().GetFloatVector().GetData(), 10e-10)
assert.Equal(t, []byte{2, 3, 4, 5, 6, 7, 8, 9}, result.FieldsData[7].GetVectors().GetBinaryVector())
}
Loading

0 comments on commit 086eb92

Please sign in to comment.