Skip to content

Commit

Permalink
Guard bad distance/score from knowhere (#21034)
Browse files Browse the repository at this point in the history
/kind bug

issue: #21138
Signed-off-by: Yuchen Gao <yuchen.gao@zilliz.com>

Signed-off-by: Yuchen Gao <yuchen.gao@zilliz.com>
  • Loading branch information
soothing-rain authored Dec 13, 2022
1 parent fc10c74 commit b01546e
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 45 deletions.
11 changes: 7 additions & 4 deletions internal/proxy/task_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -648,10 +648,13 @@ func selectHighestScoreIndex(subSearchResultData []*schemapb.SearchResultData, s
resultDataIdx = sIdx
maxScore = sScore
} else if sScore == maxScore {
sID := typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx)
tmpID := typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)

if typeutil.ComparePK(sID, tmpID) {
if subSearchIdx == -1 {
// A bad case happens where Knowhere returns distance/score == +/-maxFloat32
// by mistake.
log.Error("a bad score is returned, something is wrong here!", zap.Float32("score", sScore))
} else if typeutil.ComparePK(
typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx),
typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)) {
subSearchIdx = i
resultDataIdx = sIdx
maxScore = sScore
Expand Down
63 changes: 63 additions & 0 deletions internal/proxy/task_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"strconv"
"strings"
"testing"
Expand Down Expand Up @@ -1270,6 +1271,68 @@ func TestTaskSearch_selectHighestScoreIndex(t *testing.T) {
}
})

t.Run("Integer ID with bad score", func(t *testing.T) {
type args struct {
subSearchResultData []*schemapb.SearchResultData
subSearchNqOffset [][]int64
cursors []int64
topk int64
nq int64
}
tests := []struct {
description string
args args

expectedIdx []int
expectedDataIdx []int
}{
{
description: "reduce 2 subSearchResultData",
args: args{
subSearchResultData: []*schemapb.SearchResultData{
{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{11, 9, 8, 5, 3, 1},
},
},
},
Scores: []float32{-math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32},
Topks: []int64{2, 2, 2},
},
{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{12, 10, 7, 6, 4, 2},
},
},
},
Scores: []float32{-math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32},
Topks: []int64{2, 2, 2},
},
},
subSearchNqOffset: [][]int64{{0, 2, 4}, {0, 2, 4}},
cursors: []int64{0, 0},
topk: 2,
nq: 3,
},
expectedIdx: []int{-1, -1, -1},
expectedDataIdx: []int{-1, -1, -1},
},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
for nqNum := int64(0); nqNum < test.args.nq; nqNum++ {
idx, dataIdx := selectHighestScoreIndex(test.args.subSearchResultData, test.args.subSearchNqOffset, test.args.cursors, nqNum)
assert.Equal(t, test.expectedIdx[nqNum], idx)
assert.Equal(t, test.expectedDataIdx[nqNum], int(dataIdx))
}
})
}
})

t.Run("String ID", func(t *testing.T) {
type args struct {
subSearchResultData []*schemapb.SearchResultData
Expand Down
11 changes: 7 additions & 4 deletions internal/querynode/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,13 @@ func selectSearchResultData(dataArray []*schemapb.SearchResultData, resultOffset
maxDistance = distance
resultDataIdx = idx
} else if distance == maxDistance {
sID := typeutil.GetPK(dataArray[i].GetIds(), idx)
tmpID := typeutil.GetPK(dataArray[sel].GetIds(), resultDataIdx)

if typeutil.ComparePK(sID, tmpID) {
if sel == -1 {
// A bad case happens where knowhere returns distance == +/-maxFloat32
// by mistake.
log.Error("a bad distance is found, something is wrong here!", zap.Float32("score", distance))
} else if typeutil.ComparePK(
typeutil.GetPK(dataArray[i].GetIds(), idx),
typeutil.GetPK(dataArray[sel].GetIds(), resultDataIdx)) {
sel = i
maxDistance = distance
resultDataIdx = idx
Expand Down
128 changes: 91 additions & 37 deletions internal/querynode/result_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package querynode

import (
"context"
"math"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -424,51 +425,104 @@ func TestResult_selectSearchResultData_int(t *testing.T) {
nq int64
qi int64
}
tests := []struct {
name string
args args
want int
}{
{
args: args{
dataArray: []*schemapb.SearchResultData{
{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{11, 9, 7, 5, 3, 1},
t.Run("Integer ID", func(t *testing.T) {
tests := []struct {
name string
args args
want int
}{
{
args: args{
dataArray: []*schemapb.SearchResultData{
{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{11, 9, 7, 5, 3, 1},
},
},
},
Scores: []float32{1.1, 0.9, 0.7, 0.5, 0.3, 0.1},
Topks: []int64{2, 2, 2},
},
{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{12, 10, 8, 6, 4, 2},
},
},
},
Scores: []float32{1.2, 1.0, 0.8, 0.6, 0.4, 0.2},
Topks: []int64{2, 2, 2},
},
Scores: []float32{1.1, 0.9, 0.7, 0.5, 0.3, 0.1},
Topks: []int64{2, 2, 2},
},
{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{12, 10, 8, 6, 4, 2},
resultOffsets: [][]int64{{0, 2, 4}, {0, 2, 4}},
offsets: []int64{0, 1},
topk: 2,
nq: 3,
qi: 0,
},
want: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := selectSearchResultData(tt.args.dataArray, tt.args.resultOffsets, tt.args.offsets, tt.args.qi); got != tt.want {
t.Errorf("selectSearchResultData() = %v, want %v", got, tt.want)
}
})
}
})

t.Run("Integer ID with bad score", func(t *testing.T) {
tests := []struct {
name string
args args
want int
}{
{
args: args{
dataArray: []*schemapb.SearchResultData{
{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{11, 9, 7, 5, 3, 1},
},
},
},
Scores: []float32{-math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32},
Topks: []int64{2, 2, 2},
},
{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{12, 10, 8, 6, 4, 2},
},
},
},
Scores: []float32{-math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32},
Topks: []int64{2, 2, 2},
},
Scores: []float32{1.2, 1.0, 0.8, 0.6, 0.4, 0.2},
Topks: []int64{2, 2, 2},
},
resultOffsets: [][]int64{{0, 2, 4}, {0, 2, 4}},
offsets: []int64{0, 1},
topk: 2,
nq: 3,
qi: 0,
},
resultOffsets: [][]int64{{0, 2, 4}, {0, 2, 4}},
offsets: []int64{0, 1},
topk: 2,
nq: 3,
qi: 0,
want: -1,
},
want: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := selectSearchResultData(tt.args.dataArray, tt.args.resultOffsets, tt.args.offsets, tt.args.qi); got != tt.want {
t.Errorf("selectSearchResultData() = %v, want %v", got, tt.want)
}
})
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := selectSearchResultData(tt.args.dataArray, tt.args.resultOffsets, tt.args.offsets, tt.args.qi); got != tt.want {
t.Errorf("selectSearchResultData() = %v, want %v", got, tt.want)
}
})
}
})

}

0 comments on commit b01546e

Please sign in to comment.