Skip to content

Commit

Permalink
Support bulkinsert
Browse files Browse the repository at this point in the history
Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
  • Loading branch information
junjiejiangjjj committed Nov 4, 2024
1 parent 4e4ecd0 commit 38a573d
Show file tree
Hide file tree
Showing 11 changed files with 278 additions and 93 deletions.
99 changes: 99 additions & 0 deletions internal/datanode/importv2/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/internal/util/importutilv2"
"github.com/milvus-io/milvus/internal/util/testutil"
"github.com/milvus-io/milvus/pkg/common"
Expand Down Expand Up @@ -435,6 +436,104 @@ func (s *SchedulerSuite) TestScheduler_ImportFile() {
s.NoError(err)
}

func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() {
s.syncMgr.EXPECT().SyncData(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, task syncmgr.Task, callbacks ...func(error) error) *conc.Future[struct{}] {
future := conc.Go(func() (struct{}, error) {
return struct{}{}, nil
})
return future
})
ts := function.CreateEmbeddingServer()
defer ts.Close()
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "pk",
IsPrimaryKey: true,
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.MaxLengthKey, Value: "128"},
},
},
{
FieldID: 101,
Name: "vec",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "4",
},
},
},
{
FieldID: 102,
Name: "int64",
DataType: schemapb.DataType_Int64,
},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "test",
Type: schemapb.FunctionType_OpenAIEmbedding,
InputFieldIds: []int64{100},
OutputFieldIds: []int64{101},
Params: []*commonpb.KeyValuePair{
{Key: function.ModelNameParamKey, Value: "text-embedding-ada-002"},
{Key: function.OpenaiApiKeyParamKey, Value: "mock"},
{Key: function.OpenaiEmbeddingUrlParamKey, Value: ts.URL},
{Key: function.DimParamKey, Value: "4"},
},
},
},
}

var once sync.Once
data, err := testutil.CreateInsertData(schema, s.numRows)
s.NoError(err)
s.reader = importutilv2.NewMockReader(s.T())
s.reader.EXPECT().Read().RunAndReturn(func() (*storage.InsertData, error) {
var res *storage.InsertData
once.Do(func() {
res = data
})
if res != nil {
return res, nil
}
return nil, io.EOF
})
importReq := &datapb.ImportRequest{
JobID: 10,
TaskID: 11,
CollectionID: 12,
PartitionIDs: []int64{13},
Vchannels: []string{"v0"},
Schema: schema,
Files: []*internalpb.ImportFile{
{
Paths: []string{"dummy.json"},
},
},
Ts: 1000,
IDRange: &datapb.IDRange{
Begin: 0,
End: int64(s.numRows),
},
RequestSegments: []*datapb.ImportRequestSegment{
{
SegmentID: 14,
PartitionID: 13,
Vchannel: "v0",
},
},
}
importTask := NewImportTask(importReq, s.manager, s.syncMgr, s.cm)
s.manager.Add(importTask)
err = importTask.(*ImportTask).importFile(s.reader)
s.NoError(err)
}

func TestScheduler(t *testing.T) {
suite.Run(t, new(SchedulerSuite))
}
27 changes: 27 additions & 0 deletions internal/datanode/importv2/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,39 @@ func AppendSystemFieldsData(task *ImportTask, data *storage.InsertData) error {
}

func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error {
if err := RunBm25Function(task, data); err != nil {
return err
}
if err := RunDenseEmbedding(task, data); err != nil {
return err
}
return nil
}

func RunDenseEmbedding(task *ImportTask, data *storage.InsertData) error {
schema := task.GetSchema()
if function.HasFunctions(schema.Functions, []int64{}) {
exec, err := function.NewFunctionExecutor(schema)
if err != nil {
return err
}
if err := exec.ProcessBulkInsert(data); err != nil {
return err
}
}
return nil
}

func RunBm25Function(task *ImportTask, data *storage.InsertData) error {
fns := task.GetSchema().GetFunctions()
for _, fn := range fns {
runner, err := function.NewFunctionRunner(task.GetSchema(), fn)
if err != nil {
return err
}
if runner == nil {
continue
}
inputDatas := make([]any, 0, len(fn.InputFieldIds))
for _, inputFieldID := range fn.InputFieldIds {
inputDatas = append(inputDatas, data.Data[inputFieldID].GetDataRows())
Expand Down
3 changes: 3 additions & 0 deletions internal/flushcommon/pipeline/flow_graph_embedding_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ func newEmbeddingNode(channelName string, schema *schemapb.CollectionSchema) (*e
if err != nil {
return nil, err
}
if functionRunner == nil {
continue
}
node.functionRunners[tf.GetId()] = functionRunner
}
return node, nil
Expand Down
31 changes: 1 addition & 30 deletions internal/proxy/task_insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@ package proxy

import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -16,7 +12,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/models"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
Expand Down Expand Up @@ -319,31 +314,7 @@ func TestMaxInsertSize(t *testing.T) {
}

func TestInsertTask_Function(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req models.EmbeddingRequest
body, _ := io.ReadAll(r.Body)
defer r.Body.Close()
json.Unmarshal(body, &req)

var res models.EmbeddingResponse
res.Object = "list"
res.Model = "text-embedding-3-small"
for i := 0; i < len(req.Input); i++ {
res.Data = append(res.Data, models.EmbeddingData{
Object: "embedding",
Embedding: make([]float32, req.Dimensions),
Index: i,
})
}

res.Usage = models.Usage{
PromptTokens: 1,
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
ts := function.CreateEmbeddingServer()
defer ts.Close()
data := []*schemapb.FieldData{}
f := schemapb.FieldData{
Expand Down
32 changes: 1 addition & 31 deletions internal/proxy/task_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ package proxy

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
Expand All @@ -40,7 +36,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/models"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
Expand Down Expand Up @@ -364,33 +359,8 @@ func TestSearchTask_PreExecute(t *testing.T) {
}

func TestSearchTask_WithFunctions(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req models.EmbeddingRequest
body, _ := io.ReadAll(r.Body)
defer r.Body.Close()
json.Unmarshal(body, &req)

var res models.EmbeddingResponse
res.Object = "list"
res.Model = "text-embedding-3-small"
for i := 0; i < len(req.Input); i++ {
res.Data = append(res.Data, models.EmbeddingData{
Object: "embedding",
Embedding: make([]float32, req.Dimensions),
Index: i,
})
}

res.Usage = models.Usage{
PromptTokens: 1,
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
ts := function.CreateEmbeddingServer()
defer ts.Close()

collectionName := "TestInsertTask_function"
schema := &schemapb.CollectionSchema{
Name: collectionName,
Expand Down
3 changes: 3 additions & 0 deletions internal/querynodev2/pipeline/embedding_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ func newEmbeddingNode(collectionID int64, channelName string, manager *DataManag
if err != nil {
return nil, err
}
if functionRunner == nil {
continue
}
node.functionRunners = append(node.functionRunners, functionRunner)
}
return node, nil
Expand Down
2 changes: 2 additions & 0 deletions internal/util/function/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ func NewFunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.Functio
switch schema.GetType() {
case schemapb.FunctionType_BM25:
return NewBM25FunctionRunner(coll, schema)
case schemapb.FunctionType_OpenAIEmbedding:
return nil, nil
default:
return nil, fmt.Errorf("unknown functionRunner type %s", schema.GetType().String())
}
Expand Down
33 changes: 33 additions & 0 deletions internal/util/function/function_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/merr"
)
Expand All @@ -38,6 +39,7 @@ type Runner interface {
MaxBatch() int
ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error)
ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error)
ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error)
}

type FunctionExecutor struct {
Expand Down Expand Up @@ -210,3 +212,34 @@ func (executor *FunctionExecutor) ProcessSearch(req *internalpb.SearchRequest) e
return executor.prcessAdvanceSearch(req)
}
}

func (executor *FunctionExecutor) processSingleBulkInsert(runner Runner, data *storage.InsertData) (map[storage.FieldID]storage.FieldData, error) {
inputs := make([]storage.FieldData, 0, len(runner.GetSchema().InputFieldIds))
for idx, id := range runner.GetSchema().InputFieldIds {
field, exist := data.Data[id]
if !exist {
return nil, fmt.Errorf("Can not find input field: [%s]", runner.GetSchema().GetInputFieldNames()[idx])
}
inputs = append(inputs, field)
}

outputs, err := runner.ProcessBulkInsert(inputs)
if err != nil {
return nil, err
}
return outputs, nil
}

func (executor *FunctionExecutor) ProcessBulkInsert(data *storage.InsertData) error {
// Since concurrency has already been used in the outer layer, only a serial logic access model is used here.
for _, runner := range executor.runners {
output, err := executor.processSingleBulkInsert(runner, data)
if err != nil {
return nil
}
for k, v := range output {
data.Data[k] = v
}
}
return nil
}
29 changes: 1 addition & 28 deletions internal/util/function/function_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,34 +128,7 @@ func (s *FunctionExecutorSuite) createEmbedding(texts []string, dim int) [][]flo
}

func (s *FunctionExecutorSuite) TestExecutor() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req models.EmbeddingRequest
body, _ := io.ReadAll(r.Body)
defer r.Body.Close()
json.Unmarshal(body, &req)

var res models.EmbeddingResponse
res.Object = "list"
res.Model = "text-embedding-3-small"
embs := s.createEmbedding(req.Input, req.Dimensions)
for i := 0; i < len(req.Input); i++ {
res.Data = append(res.Data, models.EmbeddingData{
Object: "embedding",
Embedding: embs[i],
Index: i,
})
}

res.Usage = models.Usage{
PromptTokens: 1,
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)

}))

ts := CreateEmbeddingServer()
defer ts.Close()
schema := s.creataSchema(ts.URL)
exec, err := NewFunctionExecutor(schema)
Expand Down
Loading

0 comments on commit 38a573d

Please sign in to comment.