Skip to content

Commit

Permalink
Move TopK check inside parseQueryInfo (milvus-io#18892)
Browse files Browse the repository at this point in the history
Signed-off-by: yangxuan <xuan.yang@zilliz.com>

Signed-off-by: yangxuan <xuan.yang@zilliz.com>
  • Loading branch information
XuanYang-cn authored Aug 30, 2022
1 parent 1527aee commit 867ea63
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 26 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
**/cmake-build-release/*
**/cmake_build_release/*
**/cmake_build/*
.cache

internal/core/output/*
internal/core/build/*
Expand Down Expand Up @@ -87,4 +88,4 @@ deployments/docker/*/volumes

# rocksdb
cwrapper_rocksdb_build/
internal/kv/rocksdb/cwrapper/
internal/kv/rocksdb/cwrapper/
12 changes: 7 additions & 5 deletions internal/proxy/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,19 @@ import (
)

const (
AnnsFieldKey = "anns_field"
TopKKey = "topk"
MetricTypeKey = "metric_type"
SearchParamsKey = "params"
RoundDecimalKey = "round_decimal"
OffsetKey = "offset"

InsertTaskName = "InsertTask"
CreateCollectionTaskName = "CreateCollectionTask"
DropCollectionTaskName = "DropCollectionTask"
SearchTaskName = "SearchTask"
RetrieveTaskName = "RetrieveTask"
QueryTaskName = "QueryTask"
AnnsFieldKey = "anns_field"
TopKKey = "topk"
MetricTypeKey = "metric_type"
SearchParamsKey = "params"
RoundDecimalKey = "round_decimal"
HasCollectionTaskName = "HasCollectionTask"
DescribeCollectionTaskName = "DescribeCollectionTask"
GetCollectionStatisticsTaskName = "GetCollectionStatisticsTask"
Expand Down
28 changes: 8 additions & 20 deletions internal/proxy/task_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,13 @@ func parseQueryInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInf
if err != nil {
return nil, errors.New(TopKKey + " not found in search_params")
}
topK, err := strconv.Atoi(topKStr)
topK, err := strconv.ParseInt(topKStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
}
if err := validateTopK(topK); err != nil {
return nil, fmt.Errorf("invalid limit, %w", err)
}

metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(MetricTypeKey, searchParamsPair)
if err != nil {
Expand All @@ -112,7 +115,7 @@ func parseQueryInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInf
if err != nil {
roundDecimalStr = "-1"
}
roundDecimal, err := strconv.Atoi(roundDecimalStr)
roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}
Expand All @@ -122,10 +125,10 @@ func parseQueryInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInf
}

return &planpb.QueryInfo{
Topk: int64(topK),
Topk: topK,
MetricType: metricType,
SearchParams: searchParams,
RoundDecimal: int64(roundDecimal),
RoundDecimal: roundDecimal,
}, nil
}

Expand Down Expand Up @@ -242,17 +245,14 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
t.SearchRequest.OutputFieldsId = outputFieldIDs
plan.OutputFieldIds = outputFieldIDs

t.SearchRequest.Topk = queryInfo.GetTopk()
t.SearchRequest.MetricType = queryInfo.GetMetricType()
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
if err != nil {
return err
}

t.SearchRequest.Topk = queryInfo.GetTopk()
if err := validateTopK(queryInfo.GetTopk()); err != nil {
return err
}
log.Ctx(ctx).Debug("Proxy::searchTask::PreExecute", zap.Int64("msgID", t.ID()),
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
zap.String("plan", plan.String())) // may be very large if large term passed.
Expand Down Expand Up @@ -647,18 +647,6 @@ func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
// }
//}

// func printSearchResult(partialSearchResult *internalpb.SearchResults) {
// for i := 0; i < len(partialSearchResult.Hits); i++ {
// testHits := milvuspb.Hits{}
// err := proto.Unmarshal(partialSearchResult.Hits[i], &testHits)
// if err != nil {
// panic(err)
// }
// fmt.Println(testHits.IDs)
// fmt.Println(testHits.Scores)
// }
// }

func (t *searchTask) TraceCtx() context.Context {
return t.ctx
}
Expand Down
6 changes: 6 additions & 0 deletions internal/proxy/task_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1697,6 +1697,11 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
Value: "invalid",
})

spInvalidTopk65536 := append(spNoTopk, &commonpb.KeyValuePair{
Key: TopKKey,
Value: "65536",
})

spNoMetricType := append(spNoTopk, &commonpb.KeyValuePair{
Key: TopKKey,
Value: "10",
Expand Down Expand Up @@ -1727,6 +1732,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
}{
{"No_topk", spNoTopk},
{"Invalid_topk", spInvalidTopk},
{"Invalid_topk_65536", spInvalidTopk65536},
{"No_Metric_type", spNoMetricType},
{"No_search_params", spNoSearchParams},
{"Invalid_round_decimal", spInvalidRoundDecimal},
Expand Down

0 comments on commit 867ea63

Please sign in to comment.