diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 69402f9f1be9e..1cdc27c53f995 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -964,9 +964,13 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN searchParams := generateSearchParams(ctx, c, httpReq.SearchParams) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}) - searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField}) - searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)}) - searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupStrictSize, Value: strconv.FormatBool(httpReq.GroupStrictSize)}) + if httpReq.GroupByField != "" { + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField}) + } + if httpReq.GroupByField != "" && httpReq.GroupSize > 0 { + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)}) + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupStrictSize, Value: strconv.FormatBool(httpReq.GroupStrictSize)}) + } searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField}) body, _ := c.Get(gin.BodyBytesKey) placeholderGroup, err := generatePlaceholderGroup(ctx, string(body.([]byte)), collSchema, httpReq.AnnsField) @@ -1064,9 +1068,13 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq {Key: proxy.RankParamsKey, Value: string(bs)}, {Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}, {Key: ParamRoundDecimal, Value: "-1"}, - {Key: ParamGroupByField, Value: httpReq.GroupByField}, - {Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)}, - {Key: ParamGroupStrictSize, Value: strconv.FormatBool(httpReq.GroupStrictSize)}, + } + if httpReq.GroupByField != "" { + req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField}) + } + if httpReq.GroupByField != "" && httpReq.GroupSize > 0 { + req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)}) + req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamGroupStrictSize, Value: strconv.FormatBool(httpReq.GroupStrictSize)}) } resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/HybridSearch", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.HybridSearch(reqCtx, req.(*milvuspb.HybridSearchRequest)) diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index 116d8b656a4af..2d873a23dafca 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -370,8 +370,15 @@ func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemap groupSize = 1 } else { groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64) - if err != nil || groupSize <= 0 { - groupSize = 1 + if err != nil { + ret.err = merr.WrapErrParameterInvalidMsg( + fmt.Sprintf("failed to parse input group size:%s", groupSizeStr)) + return ret + } + if groupSize <= 0 { + ret.err = merr.WrapErrParameterInvalidMsg( + fmt.Sprintf("input group size:%d is negative, failed to do search_groupby", groupSize)) + return ret } } if groupSize > Params.QuotaConfig.MaxGroupSize.GetAsInt64() { diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index b0b0e77d954d7..9ed4a03f27dec 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -2538,7 +2538,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), searchInfo.planInfo.GetTopk()) }) - t.Run("check max group size", func(t *testing.T) { + t.Run("check correctness of group size", func(t *testing.T) { normalParam := getValidSearchParams() normalParam = append(normalParam, &commonpb.KeyValuePair{ Key: GroupSizeKey, @@ -2553,14 +2553,26 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { Fields: fields, } searchInfo := parseSearchInfo(normalParam, schema, nil) - assert.Nil(t, searchInfo.planInfo) assert.Error(t, searchInfo.parseError) assert.True(t, strings.Contains(searchInfo.parseError.Error(), "exceeds configured max group size")) - - resetSearchParamsValue(normalParam, GroupSizeKey, `10`) - searchInfo = parseSearchInfo(normalParam, schema, nil) - assert.NotNil(t, searchInfo.planInfo) - assert.NoError(t, searchInfo.parseError) + { + resetSearchParamsValue(normalParam, GroupSizeKey, `10`) + searchInfo = parseSearchInfo(normalParam, schema, nil) + assert.NoError(t, searchInfo.parseError) + assert.Equal(t, int64(10), searchInfo.planInfo.GroupSize) + } + { + resetSearchParamsValue(normalParam, GroupSizeKey, `-1`) + searchInfo = parseSearchInfo(normalParam, schema, nil) + assert.Error(t, searchInfo.parseError) + assert.True(t, strings.Contains(searchInfo.parseError.Error(), "is negative")) + } + { + resetSearchParamsValue(normalParam, GroupSizeKey, `xxx`) + searchInfo = parseSearchInfo(normalParam, schema, nil) + assert.Error(t, searchInfo.parseError) + assert.True(t, strings.Contains(searchInfo.parseError.Error(), "failed to parse input group size")) + } }) }