Skip to content

Commit

Permalink
feat: fp32 vector to fp16/bf16 vector conversion for RESTful API
Browse files Browse the repository at this point in the history
RESTful API has 3 handlers. The influenced API are as follows:

- Handler. insert
- HandlerV1. insert/upsert
- HandlerV2. insert/upsert/search

issue: #37448

Signed-off-by: Yinzuo Jiang <yinzuo.jiang@zilliz.com>
Signed-off-by: Yinzuo Jiang <jiangyinzuo@foxmail.com>
  • Loading branch information
jiangyinzuo committed Nov 10, 2024
1 parent f42869c commit a900400
Show file tree
Hide file tree
Showing 11 changed files with 732 additions and 261 deletions.
261 changes: 194 additions & 67 deletions internal/distributed/proxy/httpserver/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1577,6 +1577,29 @@ func TestMethodPost(t *testing.T) {
}
}

func validateTestCases(t *testing.T, testEngine *gin.Engine, queryTestCases []requestBodyTestCase, allowInt64 bool) {
for i, testcase := range queryTestCases {
t.Run(testcase.path, func(t *testing.T) {
bodyReader := bytes.NewReader(testcase.requestBody)
req := httptest.NewRequest(http.MethodPost, versionalV2(EntityCategory, testcase.path), bodyReader)
if allowInt64 {
req.Header.Set(HTTPHeaderAllowInt64, "true")
}
w := httptest.NewRecorder()
testEngine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code, "case %d: ", i, string(testcase.requestBody))
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err, "case %d: ", i)
assert.Equal(t, testcase.errCode, returnBody.Code, "case %d: ", i, string(testcase.requestBody))
if testcase.errCode != 0 {
assert.Equal(t, testcase.errMsg, returnBody.Message, "case %d: ", i, string(testcase.requestBody))
}
fmt.Println(w.Body.String())
})
}
}

func TestDML(t *testing.T) {
paramtable.Init()
// disable rate limit
Expand Down Expand Up @@ -1692,23 +1715,7 @@ func TestDML(t *testing.T) {
requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "book_intro": [0.11825, 0.6]}]}`),
})

for _, testcase := range queryTestCases {
t.Run(testcase.path, func(t *testing.T) {
bodyReader := bytes.NewReader(testcase.requestBody)
req := httptest.NewRequest(http.MethodPost, versionalV2(EntityCategory, testcase.path), bodyReader)
w := httptest.NewRecorder()
testEngine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
if testcase.errCode != 0 {
assert.Equal(t, testcase.errMsg, returnBody.Message)
}
fmt.Println(w.Body.String())
})
}
validateTestCases(t, testEngine, queryTestCases, false)
}

func TestAllowInt64(t *testing.T) {
Expand Down Expand Up @@ -1736,24 +1743,164 @@ func TestAllowInt64(t *testing.T) {
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once()
mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, UpsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once()

for _, testcase := range queryTestCases {
t.Run(testcase.path, func(t *testing.T) {
bodyReader := bytes.NewReader(testcase.requestBody)
req := httptest.NewRequest(http.MethodPost, versionalV2(EntityCategory, testcase.path), bodyReader)
req.Header.Set(HTTPHeaderAllowInt64, "true")
w := httptest.NewRecorder()
testEngine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
if testcase.errCode != 0 {
assert.Equal(t, testcase.errMsg, returnBody.Message)
}
fmt.Println(w.Body.String())
})
validateTestCases(t, testEngine, queryTestCases, true)
}

func generateCollectionSchemaWithVectorFields() *schemapb.CollectionSchema {
collSchema := generateCollectionSchema(schemapb.DataType_Int64, false)
binaryVectorField := generateVectorFieldSchema(schemapb.DataType_BinaryVector)
binaryVectorField.Name = "binaryVector"
float16VectorField := generateVectorFieldSchema(schemapb.DataType_Float16Vector)
float16VectorField.Name = "float16Vector"
bfloat16VectorField := generateVectorFieldSchema(schemapb.DataType_BFloat16Vector)
bfloat16VectorField.Name = "bfloat16Vector"
sparseFloatVectorField := generateVectorFieldSchema(schemapb.DataType_SparseFloatVector)
sparseFloatVectorField.Name = "sparseFloatVector"
collSchema.Fields = append(collSchema.Fields, binaryVectorField)
collSchema.Fields = append(collSchema.Fields, float16VectorField)
collSchema.Fields = append(collSchema.Fields, bfloat16VectorField)
collSchema.Fields = append(collSchema.Fields, sparseFloatVectorField)
return collSchema
}

func TestFp16Bf16Vectors(t *testing.T) {
paramtable.Init()
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
mp := mocks.NewMockProxy(t)
collSchema := generateCollectionSchemaWithVectorFields()
testEngine := initHTTPServerV2(mp, false)
queryTestCases := []requestBodyTestCase{}
for _, path := range []string{InsertAction, UpsertAction} {
queryTestCases = append(queryTestCases,
requestBodyTestCase{
path: path,
requestBody: []byte(
`{
"collectionName": "book",
"data": [
{
"book_id": 0,
"word_count": 0,
"book_intro": [0.11825, 0.6],
"binaryVector": "AQ==",
"float16Vector": [3.0],
"bfloat16Vector": [4.4, 442],
"sparseFloatVector": {"1": 0.1, "2": 0.44}
}
]
}`),
errCode: 1804,
errMsg: "fail to deal the insert data, error: []byte size 2 doesn't equal to vector dimension 2 of Float16Vector",
}, requestBodyTestCase{
path: path,
requestBody: []byte(
`{
"collectionName": "book",
"data": [
{
"book_id": 0,
"word_count": 0,
"book_intro": [0.11825, 0.6],
"binaryVector": "AQ==",
"float16Vector": [3, 3.0],
"bfloat16Vector": [4.4, 442],
"sparseFloatVector": {"1": 0.1, "2": 0.44}
}
]
}`),
}, requestBodyTestCase{
path: path,
// [3, 3] shouble be converted to [float(3), float(3)]
requestBody: []byte(
`{
"collectionName": "book",
"data": [
{
"book_id": 0,
"word_count": 0,
"book_intro": [0.11825, 0.6],
"binaryVector": "AQ==",
"float16Vector": [3, 3],
"bfloat16Vector": [4.4, 442],
"sparseFloatVector": {"1": 0.1, "2": 0.44}
}
]
}`),
}, requestBodyTestCase{
path: path,
requestBody: []byte(
`{
"collectionName": "book",
"data": [
{
"book_id": 0,
"word_count": 0,
"book_intro": [0.11825, 0.6],
"binaryVector": "AQ==",
"float16Vector": "AQIDBA==",
"bfloat16Vector": "AQIDBA==",
"sparseFloatVector": {"1": 0.1, "2": 0.44}
}
]
}`),
}, requestBodyTestCase{
path: path,
requestBody: []byte(
`{
"collectionName": "book",
"data": [
{
"book_id": 0,
"word_count": 0,
"book_intro": [0.11825, 0.6],
"binaryVector": "AQ==",
"float16Vector": [3, 3.0, 3],
"bfloat16Vector": [4.4, 442, 44],
"sparseFloatVector": {"1": 0.1, "2": 0.44}
}
]
}`),
errMsg: "fail to deal the insert data, error: []byte size 6 doesn't equal to vector dimension 2 of Float16Vector",
errCode: 1804,
}, requestBodyTestCase{
path: path,
requestBody: []byte(
`{
"collectionName": "book",
"data": [
{
"book_id": 0,
"word_count": 0,
"book_intro": [0.11825, 0.6],
"binaryVector": "AQ==",
"float16Vector": "AQIDBA==",
"bfloat16Vector": [4.4, 442],
"sparseFloatVector": {"1": 0.1, "2": 0.44}
},
{
"book_id": 1,
"word_count": 0,
"book_intro": [0.11825, 0.6],
"binaryVector": "AQ==",
"float16Vector": [3.1, 3.1],
"bfloat16Vector": "AQIDBA==",
"sparseFloatVector": {"3": 1.1, "2": 0.44}
}
]
}`),
})
}
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
CollectionName: DefaultCollectionName,
Schema: collSchema,
ShardsNum: ShardNumDefault,
Status: &StatusSuccess,
}, nil).Times(len(queryTestCases))
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Times(4)
mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Times(4)
validateTestCases(t, testEngine, queryTestCases, false)
}

func TestSearchV2(t *testing.T) {
Expand Down Expand Up @@ -1788,26 +1935,14 @@ func TestSearchV2(t *testing.T) {
Ids: generateIDs(schemapb.DataType_Int64, 3),
Scores: DefaultScores,
}}, nil).Once()
mp.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
collSchema := generateCollectionSchema(schemapb.DataType_Int64, false)
binaryVectorField := generateVectorFieldSchema(schemapb.DataType_BinaryVector)
binaryVectorField.Name = "binaryVector"
float16VectorField := generateVectorFieldSchema(schemapb.DataType_Float16Vector)
float16VectorField.Name = "float16Vector"
bfloat16VectorField := generateVectorFieldSchema(schemapb.DataType_BFloat16Vector)
bfloat16VectorField.Name = "bfloat16Vector"
sparseFloatVectorField := generateVectorFieldSchema(schemapb.DataType_SparseFloatVector)
sparseFloatVectorField.Name = "sparseFloatVector"
collSchema.Fields = append(collSchema.Fields, binaryVectorField)
collSchema.Fields = append(collSchema.Fields, float16VectorField)
collSchema.Fields = append(collSchema.Fields, bfloat16VectorField)
collSchema.Fields = append(collSchema.Fields, sparseFloatVectorField)
mp.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(4)
collSchema := generateCollectionSchemaWithVectorFields()
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
CollectionName: DefaultCollectionName,
Schema: collSchema,
ShardsNum: ShardNumDefault,
Status: &StatusSuccess,
}, nil).Times(10)
}, nil).Times(11)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: &commonpb.Status{
Expand Down Expand Up @@ -1906,6 +2041,15 @@ func TestSearchV2(t *testing.T) {
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
`{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
`{"data": [[0.1, 0.23]], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
`{"data": [[0.1, 0.43]], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
Expand Down Expand Up @@ -1983,22 +2127,5 @@ func TestSearchV2(t *testing.T) {
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})

for _, testcase := range queryTestCases {
t.Run(testcase.path, func(t *testing.T) {
bodyReader := bytes.NewReader(testcase.requestBody)
req := httptest.NewRequest(http.MethodPost, versionalV2(EntityCategory, testcase.path), bodyReader)
w := httptest.NewRecorder()
testEngine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
if testcase.errCode != 0 {
assert.Equal(t, testcase.errMsg, returnBody.Message)
}
fmt.Println(w.Body.String())
})
}
validateTestCases(t, testEngine, queryTestCases, false)
}
Loading

0 comments on commit a900400

Please sign in to comment.