From b7ffa8383c78108c49c956163605f4f9473a13ea Mon Sep 17 00:00:00 2001 From: congqixia Date: Fri, 18 Oct 2024 11:19:24 +0800 Subject: [PATCH] enhance: [GoSDK] write back auto id value to row based input (#36964) Related to #33460 --------- Signed-off-by: Congqi Xia --- client/example/rowbase/main.go | 87 ++++++++++++++++++++++++++++++++++ client/row/data.go | 19 ++++++++ client/write.go | 4 +- client/write_options.go | 28 +++++++++++ 4 files changed, 137 insertions(+), 1 deletion(-) create mode 100644 client/example/rowbase/main.go diff --git a/client/example/rowbase/main.go b/client/example/rowbase/main.go new file mode 100644 index 0000000000000..2b43378bb8071 --- /dev/null +++ b/client/example/rowbase/main.go @@ -0,0 +1,87 @@ +package main + +import ( + "context" + "log" + "math/rand" + + "github.com/samber/lo" + + milvusclient "github.com/milvus-io/milvus/client/v2" + "github.com/milvus-io/milvus/client/v2/row" +) + +type Data struct { + ID int64 `milvus:"name:id;primary_key;auto_id"` + Vector []float32 `milvus:"name:vector;dim:128"` +} + +const ( + milvusAddr = `localhost:19530` + nEntities, dim = 10, 128 + collectionName = "hello_row_base" + + msgFmt = "==== %s ====\n" + idCol, randomCol, embeddingCol = "id", "random", "vector" + topK = 3 +) + +func main() { + schema, err := row.ParseSchema(&Data{}) + if err != nil { + log.Fatal("failed to parse schema from struct", err.Error()) + } + + for _, field := range schema.Fields { + log.Printf("Field name: %s, FieldType %s, IsPrimaryKey: %t", field.Name, field.DataType, field.PrimaryKey) + } + schema.WithName(collectionName) + + ctx := context.Background() + + log.Printf(msgFmt, "start connecting to Milvus") + c, err := milvusclient.New(ctx, &milvusclient.ClientConfig{ + Address: milvusAddr, + }) + if err != nil { + log.Fatal("failed to connect to milvus, err: ", err.Error()) + } + defer c.Close(ctx) + + if has, err := c.HasCollection(ctx, milvusclient.NewHasCollectionOption(collectionName)); err != nil { + log.Fatal("failed to check collection exists or not", err.Error()) + } else if has { + log.Printf("collection %s alread exists, dropping it now\n", collectionName) + c.DropCollection(ctx, milvusclient.NewDropCollectionOption(collectionName)) + } + + err = c.CreateCollection(ctx, milvusclient.NewCreateCollectionOption(collectionName, schema)) + if err != nil { + log.Fatal("failed to create collection", err.Error()) + } + + var rows []*Data + for i := 0; i < nEntities; i++ { + vec := make([]float32, 0, dim) + for j := 0; j < dim; j++ { + vec = append(vec, rand.Float32()) + } + rows = append(rows, &Data{ + Vector: vec, + }) + } + + insertResult, err := c.Insert(ctx, milvusclient.NewRowBasedInsertOption(collectionName, lo.Map(rows, func(data *Data, _ int) any { + return data + })...)) + if err != nil { + log.Fatal("failed to insert data: ", err.Error()) + } + log.Println(insertResult.IDs) + for _, row := range rows { + // id shall be written back + log.Println(row.ID) + } + + c.DropCollection(ctx, milvusclient.NewDropCollectionOption(collectionName)) +} diff --git a/client/row/data.go b/client/row/data.go index e774a55b8605e..448509be911e4 100644 --- a/client/row/data.go +++ b/client/row/data.go @@ -269,6 +269,25 @@ func NewArrayColumn(f *entity.Field) column.Column { } } +func SetField(receiver any, fieldName string, value any) error { + candidates, err := reflectValueCandi(reflect.ValueOf(receiver)) + if err != nil { + return err + } + + candidate, ok := candidates[fieldName] + // if field not found, just return + if !ok { + return nil + } + + if candidate.v.CanSet() { + candidate.v.Set(reflect.ValueOf(value)) + } + + return nil +} + type fieldCandi struct { name string v reflect.Value diff --git a/client/write.go b/client/write.go index d358fc0982264..c4dc07a13f515 100644 --- a/client/write.go +++ b/client/write.go @@ -56,7 +56,9 @@ func (c *Client) Insert(ctx context.Context, option InsertOption, callOptions .. return err } - return nil + // write back pks if needed + // pks values shall be written back to struct if receiver field exists + return option.WriteBackPKs(collection.Schema, result.IDs) }) return result, err } diff --git a/client/write_options.go b/client/write_options.go index dba1b864471f2..a122b385246c2 100644 --- a/client/write_options.go +++ b/client/write_options.go @@ -34,6 +34,7 @@ import ( type InsertOption interface { InsertRequest(coll *entity.Collection) (*milvuspb.InsertRequest, error) CollectionName() string + WriteBackPKs(schema *entity.Schema, pks column.Column) error } type UpsertOption interface { @@ -52,6 +53,11 @@ type columnBasedDataOption struct { columns []column.Column } +func (opt *columnBasedDataOption) WriteBackPKs(_ *entity.Schema, _ column.Column) error { + // column based data option need not write back pk + return nil +} + func (opt *columnBasedDataOption) processInsertColumns(colSchema *entity.Schema, columns ...column.Column) ([]*schemapb.FieldData, int, error) { // setup dynamic related var isDynamic := colSchema.EnableDynamicField @@ -296,6 +302,28 @@ func (opt *rowBasedDataOption) UpsertRequest(coll *entity.Collection) (*milvuspb }, nil } +func (opt *rowBasedDataOption) WriteBackPKs(sch *entity.Schema, pks column.Column) error { + pkField := sch.PKField() + // not auto id, return + if pkField == nil || !pkField.AutoID { + return nil + } + if len(opt.rows) != pks.Len() { + return errors.New("input row count is not equal to result pk length") + } + + for i, r := range opt.rows { + // index range checked + v, _ := pks.Get(i) + err := row.SetField(r, pkField.Name, v) + if err != nil { + return err + } + } + + return nil +} + type DeleteOption interface { Request() *milvuspb.DeleteRequest }