diff --git a/internal/querynode/result.go b/internal/querynode/result.go index edcf7dc04fbc7..4410010fdc80f 100644 --- a/internal/querynode/result.go +++ b/internal/querynode/result.go @@ -35,6 +35,10 @@ import ( "github.com/milvus-io/milvus/internal/util/typeutil" ) +const ( + unlimited int = -1 +) + func reduceStatisticResponse(results []*internalpb.GetStatisticsResponse) (*internalpb.GetStatisticsResponse, error) { mergedResults := map[string]interface{}{ "row_count": int64(0), @@ -234,6 +238,63 @@ func encodeSearchResultData(searchResultData *schemapb.SearchResultData, nq int6 return } +func mergeInternalRetrieveResultsV2(ctx context.Context, retrieveResults []*internalpb.RetrieveResults, limit int) (*internalpb.RetrieveResults, error) { + log.Ctx(ctx).Debug("reduceInternelRetrieveResults", zap.Int("len(retrieveResults)", len(retrieveResults))) + var ( + ret = &internalpb.RetrieveResults{ + Ids: &schemapb.IDs{}, + } + + skipDupCnt int64 + loopEnd int + ) + + validRetrieveResults := []*internalpb.RetrieveResults{} + for _, r := range retrieveResults { + size := typeutil.GetSizeOfIDs(r.GetIds()) + if r == nil || len(r.GetFieldsData()) == 0 || size == 0 { + continue + } + validRetrieveResults = append(validRetrieveResults, r) + loopEnd += size + } + + if len(validRetrieveResults) == 0 { + return ret, nil + } + + if limit != unlimited { + loopEnd = limit + } + + ret.FieldsData = make([]*schemapb.FieldData, len(validRetrieveResults[0].GetFieldsData())) + idSet := make(map[interface{}]struct{}) + cursors := make([]int64, len(validRetrieveResults)) + for j := 0; j < loopEnd; j++ { + sel := selectMinPK(validRetrieveResults, cursors) + if sel == -1 { + break + } + + pk := typeutil.GetPK(validRetrieveResults[sel].GetIds(), cursors[sel]) + if _, ok := idSet[pk]; !ok { + typeutil.AppendPKs(ret.Ids, pk) + typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel]) + idSet[pk] = struct{}{} + } else { + // primary keys duplicate + skipDupCnt++ + } + cursors[sel]++ + } + + if skipDupCnt > 0 { + log.Ctx(ctx).Debug("skip duplicated query result while reducing internal.RetrieveResults", zap.Int64("count", skipDupCnt)) + } + + return ret, nil +} + // TODO: largely based on function mergeSegcoreRetrieveResults, need rewriting func mergeInternalRetrieveResults(ctx context.Context, retrieveResults []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) { var ret *internalpb.RetrieveResults @@ -284,6 +345,105 @@ func mergeInternalRetrieveResults(ctx context.Context, retrieveResults []*intern return ret, nil } +func mergeSegcoreRetrieveResultsV2(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults, limit int) (*segcorepb.RetrieveResults, error) { + log.Ctx(ctx).Debug("reduceSegcoreRetrieveResults", zap.Int("len(retrieveResults)", len(retrieveResults))) + var ( + ret = &segcorepb.RetrieveResults{ + Ids: &schemapb.IDs{}, + } + + skipDupCnt int64 + loopEnd int + ) + + validRetrieveResults := []*segcorepb.RetrieveResults{} + for _, r := range retrieveResults { + size := typeutil.GetSizeOfIDs(r.GetIds()) + if r == nil || len(r.GetOffset()) == 0 || size == 0 { + continue + } + validRetrieveResults = append(validRetrieveResults, r) + loopEnd += size + } + + if len(validRetrieveResults) == 0 { + return ret, nil + } + + if limit != unlimited { + loopEnd = limit + } + + ret.FieldsData = make([]*schemapb.FieldData, len(validRetrieveResults[0].GetFieldsData())) + idSet := make(map[interface{}]struct{}) + cursors := make([]int64, len(validRetrieveResults)) + for j := 0; j < loopEnd; j++ { + sel := selectMinPK(validRetrieveResults, cursors) + if sel == -1 { + break + } + + pk := typeutil.GetPK(validRetrieveResults[sel].GetIds(), cursors[sel]) + if _, ok := idSet[pk]; !ok { + typeutil.AppendPKs(ret.Ids, pk) + typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel]) + idSet[pk] = struct{}{} + } else { + // primary keys duplicate + skipDupCnt++ + } + cursors[sel]++ + } + + if skipDupCnt > 0 { + log.Ctx(ctx).Debug("skip duplicated query result while reducing segcore.RetrieveResults", zap.Int64("count", skipDupCnt)) + } + + return ret, nil +} + +type ResultWithID interface { + GetIds() *schemapb.IDs +} + +var _ ResultWithID = &internalpb.RetrieveResults{} +var _ ResultWithID = &segcorepb.RetrieveResults{} + +func selectMinPK[T ResultWithID](results []T, cursors []int64) int { + var ( + sel = -1 + minIntPK int64 = math.MaxInt64 + + firstStr = true + minStrPK = "" + ) + + for i, cursor := range cursors { + if int(cursor) >= typeutil.GetSizeOfIDs(results[i].GetIds()) { + continue + } + + pkInterface := typeutil.GetPK(results[i].GetIds(), cursor) + switch pk := pkInterface.(type) { + case string: + if firstStr || pk < minStrPK { + firstStr = false + minStrPK = pk + sel = i + } + case int64: + if pk < minIntPK { + minIntPK = pk + sel = i + } + default: + continue + } + } + + return sel +} + func mergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) { var ( diff --git a/internal/querynode/result_test.go b/internal/querynode/result_test.go index 4434961073570..5074164ca9786 100644 --- a/internal/querynode/result_test.go +++ b/internal/querynode/result_test.go @@ -25,8 +25,174 @@ import ( "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/proto/segcorepb" ) +func TestResult_mergeSegcoreRetrieveResults(t *testing.T) { + const ( + Dim = 8 + Int64FieldName = "Int64Field" + FloatVectorFieldName = "FloatVectorField" + Int64FieldID = common.StartOfUserFieldID + 1 + FloatVectorFieldID = common.StartOfUserFieldID + 2 + ) + Int64Array := []int64{11, 22} + FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} + + var fieldDataArray1 []*schemapb.FieldData + fieldDataArray1 = append(fieldDataArray1, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) + fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) + + var fieldDataArray2 []*schemapb.FieldData + fieldDataArray2 = append(fieldDataArray2, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) + fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) + + t.Run("test skip dupPK 2", func(t *testing.T) { + result1 := &segcorepb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{0, 1}, + }, + }, + }, + Offset: []int64{0, 1}, + FieldsData: fieldDataArray1, + } + result2 := &segcorepb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{0, 1}, + }, + }, + }, + Offset: []int64{0, 1}, + FieldsData: fieldDataArray2, + } + + result, err := mergeSegcoreRetrieveResultsV2(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, unlimited) + assert.NoError(t, err) + assert.Equal(t, 2, len(result.GetFieldsData())) + assert.Equal(t, []int64{0, 1}, result.GetIds().GetIntId().GetData()) + assert.Equal(t, Int64Array, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + assert.InDeltaSlice(t, FloatVector, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + }) + + t.Run("test nil results", func(t *testing.T) { + ret, err := mergeSegcoreRetrieveResultsV2(context.Background(), nil, unlimited) + assert.NoError(t, err) + assert.Empty(t, ret.GetIds()) + assert.Empty(t, ret.GetFieldsData()) + }) + + t.Run("test no offset", func(t *testing.T) { + r := &segcorepb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{0, 1}, + }, + }, + }, + FieldsData: fieldDataArray1, + } + + ret, err := mergeSegcoreRetrieveResultsV2(context.Background(), []*segcorepb.RetrieveResults{r}, unlimited) + assert.NoError(t, err) + assert.Empty(t, ret.GetIds()) + assert.Empty(t, ret.GetFieldsData()) + }) + + t.Run("test merge", func(t *testing.T) { + r1 := &segcorepb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 3}, + }, + }, + }, + Offset: []int64{0, 1}, + FieldsData: fieldDataArray1, + } + r2 := &segcorepb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{2, 4}, + }, + }, + }, + Offset: []int64{0, 1}, + FieldsData: fieldDataArray2, + } + + resultFloat := []float32{ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0, + 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} + + t.Run("test limited", func(t *testing.T) { + tests := []struct { + description string + limit int + }{ + {"limit 1", 1}, + {"limit 2", 2}, + {"limit 3", 3}, + {"limit 4", 4}, + } + resultIDs := []int64{1, 2, 3, 4} + resultField0 := []int64{11, 11, 22, 22} + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + result, err := mergeSegcoreRetrieveResultsV2(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, test.limit) + assert.Equal(t, 2, len(result.GetFieldsData())) + assert.Equal(t, test.limit, len(result.GetIds().GetIntId().GetData())) + assert.Equal(t, resultIDs[0:test.limit], result.GetIds().GetIntId().GetData()) + assert.Equal(t, resultField0[0:test.limit], result.GetFieldsData()[0].GetScalars().GetLongData().Data) + assert.InDeltaSlice(t, resultFloat[0:test.limit*Dim], result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + assert.NoError(t, err) + }) + } + }) + + t.Run("test int ID", func(t *testing.T) { + result, err := mergeSegcoreRetrieveResultsV2(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, unlimited) + assert.Equal(t, 2, len(result.GetFieldsData())) + assert.Equal(t, []int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData()) + assert.Equal(t, []int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + assert.InDeltaSlice(t, resultFloat, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + assert.NoError(t, err) + }) + + t.Run("test string ID", func(t *testing.T) { + r1.Ids = &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"a", "c"}, + }}} + + r2.Ids = &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"b", "d"}, + }}} + + result, err := mergeSegcoreRetrieveResultsV2(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, unlimited) + assert.NoError(t, err) + assert.Equal(t, 2, len(result.GetFieldsData())) + assert.Equal(t, []string{"a", "b", "c", "d"}, result.GetIds().GetStrId().GetData()) + assert.Equal(t, []int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + assert.InDeltaSlice(t, resultFloat, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + assert.NoError(t, err) + }) + + }) +} + func TestResult_mergeInternalRetrieveResults(t *testing.T) { const ( Dim = 8 @@ -46,37 +212,128 @@ func TestResult_mergeInternalRetrieveResults(t *testing.T) { fieldDataArray2 = append(fieldDataArray2, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) - result1 := &internalpb.RetrieveResults{ - Ids: &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: []int64{0, 1}, + t.Run("test skip dupPK 2", func(t *testing.T) { + result1 := &internalpb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{0, 1}, + }, }, }, - }, - // Offset: []int64{0, 1}, - FieldsData: fieldDataArray1, - } - result2 := &internalpb.RetrieveResults{ - Ids: &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: []int64{0, 1}, + FieldsData: fieldDataArray1, + } + result2 := &internalpb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{0, 1}, + }, }, }, - }, - // Offset: []int64{0, 1}, - FieldsData: fieldDataArray2, - } + FieldsData: fieldDataArray2, + } + + result, err := mergeInternalRetrieveResultsV2(context.Background(), []*internalpb.RetrieveResults{result1, result2}, unlimited) + assert.NoError(t, err) + assert.Equal(t, 2, len(result.GetFieldsData())) + assert.Equal(t, []int64{0, 1}, result.GetIds().GetIntId().GetData()) + assert.Equal(t, Int64Array, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + assert.InDeltaSlice(t, FloatVector, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + }) - result, err := mergeInternalRetrieveResults(context.Background(), []*internalpb.RetrieveResults{result1, result2}) - assert.NoError(t, err) - assert.Equal(t, 2, len(result.FieldsData[0].GetScalars().GetLongData().Data)) - assert.Equal(t, 2*Dim, len(result.FieldsData[1].GetVectors().GetFloatVector().Data)) - assert.InDeltaSlice(t, FloatVector, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + t.Run("test nil results", func(t *testing.T) { + ret, err := mergeInternalRetrieveResultsV2(context.Background(), nil, unlimited) + assert.NoError(t, err) + assert.Empty(t, ret.GetIds()) + assert.Empty(t, ret.GetFieldsData()) + }) - _, err = mergeInternalRetrieveResults(context.Background(), nil) - assert.NoError(t, err) + t.Run("test merge", func(t *testing.T) { + r1 := &internalpb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 3}, + }, + }, + }, + FieldsData: fieldDataArray1, + } + r2 := &internalpb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{2, 4}, + }, + }, + }, + FieldsData: fieldDataArray2, + } + + resultFloat := []float32{ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0, + 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} + + t.Run("test limited", func(t *testing.T) { + tests := []struct { + description string + limit int + }{ + {"limit 1", 1}, + {"limit 2", 2}, + {"limit 3", 3}, + {"limit 4", 4}, + } + resultIDs := []int64{1, 2, 3, 4} + resultField0 := []int64{11, 11, 22, 22} + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + result, err := mergeInternalRetrieveResultsV2(context.Background(), []*internalpb.RetrieveResults{r1, r2}, test.limit) + assert.Equal(t, 2, len(result.GetFieldsData())) + assert.Equal(t, test.limit, len(result.GetIds().GetIntId().GetData())) + assert.Equal(t, resultIDs[0:test.limit], result.GetIds().GetIntId().GetData()) + assert.Equal(t, resultField0[0:test.limit], result.GetFieldsData()[0].GetScalars().GetLongData().Data) + assert.InDeltaSlice(t, resultFloat[0:test.limit*Dim], result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + assert.NoError(t, err) + }) + } + }) + + t.Run("test int ID", func(t *testing.T) { + result, err := mergeInternalRetrieveResultsV2(context.Background(), []*internalpb.RetrieveResults{r1, r2}, unlimited) + assert.Equal(t, 2, len(result.GetFieldsData())) + assert.Equal(t, []int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData()) + assert.Equal(t, []int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + assert.InDeltaSlice(t, resultFloat, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + assert.NoError(t, err) + }) + + t.Run("test string ID", func(t *testing.T) { + r1.Ids = &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"a", "c"}, + }}} + + r2.Ids = &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"b", "d"}, + }}} + + result, err := mergeInternalRetrieveResultsV2(context.Background(), []*internalpb.RetrieveResults{r1, r2}, unlimited) + assert.NoError(t, err) + assert.Equal(t, 2, len(result.GetFieldsData())) + assert.Equal(t, []string{"a", "b", "c", "d"}, result.GetIds().GetStrId().GetData()) + assert.Equal(t, []int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + assert.InDeltaSlice(t, resultFloat, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + assert.NoError(t, err) + }) + + }) } func TestResult_reduceSearchResultData(t *testing.T) {