Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support unmarshal resultset into orm receiver #827

Merged
merged 3 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions client/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func (c *GrpcClient) handleSearchResult(schema *entity.Schema, outputFields []st
for i := 0; i < int(results.GetNumQueries()); i++ {
rc := int(results.GetTopks()[i]) // result entry count for current query
entry := SearchResult{
sch: schema,
ResultCount: rc,
Scores: results.GetScores()[offset : offset+rc],
}
Expand Down
151 changes: 151 additions & 0 deletions client/results.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package client

import (
"go/ast"
"reflect"

"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)

Expand All @@ -9,6 +13,9 @@ import (
// Fields contains the data of `outputFieleds` specified or all columns if non
// Scores is actually the distance between the vector current record contains and the search target vector
type SearchResult struct {
// internal schema for unmarshaling
sch *entity.Schema

ResultCount int // the returning entry count
GroupByValue entity.Column
IDs entity.Column // auto generated id, can be mapped to the columns from `Insert` API
Expand Down Expand Up @@ -44,6 +51,66 @@ func (sr *SearchResult) Slice(start, end int) *SearchResult {
return result
}

func (sr *SearchResult) Unmarshal(receiver interface{}) (err error) {
err = sr.Fields.Unmarshal(receiver)
if err != nil {
return err
}
return sr.fillPKEntry(receiver)
}

func (sr *SearchResult) fillPKEntry(receiver interface{}) (err error) {
defer func() {
if x := recover(); x != nil {
err = errors.Newf("failed to unmarshal result set: %v", x)
}
}()
rr := reflect.ValueOf(receiver)

if rr.Kind() == reflect.Ptr {
if rr.IsNil() && rr.CanAddr() {
rr.Set(reflect.New(rr.Type().Elem()))
}
rr = rr.Elem()
}

rt := rr.Type()
rv := rr

switch rt.Kind() {
case reflect.Slice:
pkField := sr.sch.PKField()

et := rt.Elem()
for et.Kind() == reflect.Ptr {
et = et.Elem()
}

candidates := parseCandidates(et)
candi, ok := candidates[pkField.Name]
if !ok {
// pk field not found in struct, skip
return nil
}
for i := 0; i < sr.IDs.Len(); i++ {
row := rv.Index(i)
for row.Kind() == reflect.Ptr {
row = row.Elem()
}

val, err := sr.IDs.Get(i)
if err != nil {
return err
}
row.Field(candi).Set(reflect.ValueOf(val))
}
rr.Set(rv)
default:
return errors.Newf("receiver need to be slice or array but get %v", rt.Kind())
}
return nil
}

// ResultSet is an alias type for column slice.
type ResultSet []entity.Column

Expand Down Expand Up @@ -71,3 +138,87 @@ func (rs ResultSet) GetColumn(fieldName string) entity.Column {
}
return nil
}

func (rs ResultSet) Unmarshal(receiver interface{}) (err error) {
defer func() {
if x := recover(); x != nil {
err = errors.Newf("failed to unmarshal result set: %v", x)
}
}()
rr := reflect.ValueOf(receiver)

if rr.Kind() == reflect.Ptr {
if rr.IsNil() && rr.CanAddr() {
rr.Set(reflect.New(rr.Type().Elem()))
}
rr = rr.Elem()
}

rt := rr.Type()
rv := rr

switch rt.Kind() {
// TODO maybe support Array and just fill data
// case reflect.Array:
case reflect.Slice:
et := rt.Elem()
if et.Kind() != reflect.Ptr {
return errors.Newf("receiver must be slice of pointers but get: %v", et.Kind())
}
for et.Kind() == reflect.Ptr {
et = et.Elem()
}
for i := 0; i < rs.Len(); i++ {
data := reflect.New(et)
err := rs.fillData(data.Elem(), et, i)
if err != nil {
return err
}
rv = reflect.Append(rv, data)
}
rr.Set(rv)
default:
return errors.Newf("receiver need to be slice or array but get %v", rt.Kind())
}
return nil
}

func parseCandidates(dataType reflect.Type) map[string]int {
result := make(map[string]int)
for i := 0; i < dataType.NumField(); i++ {
f := dataType.Field(i)
// ignore anonymous field for now
if f.Anonymous || !ast.IsExported(f.Name) {
continue
}

name := f.Name
tag := f.Tag.Get(entity.MilvusTag)
tagSettings := entity.ParseTagSetting(tag, entity.MilvusTagSep)
if tagName, has := tagSettings[entity.MilvusTagName]; has {
name = tagName
}

result[name] = i
}
return result
}

func (rs ResultSet) fillData(data reflect.Value, dataType reflect.Type, idx int) error {
m := parseCandidates(dataType)
for i := 0; i < len(rs); i++ {
name := rs[i].Name()
fidx, ok := m[name]
if !ok {
// maybe return error
continue
}
val, err := rs[i].Get(idx)
if err != nil {
return err
}
// TODO check datatype
data.Field(fidx).Set(reflect.ValueOf(val))
}
return nil
}
109 changes: 109 additions & 0 deletions client/results_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package client

import (
"testing"

"github.com/milvus-io/milvus-sdk-go/v2/entity"
"github.com/stretchr/testify/suite"
)

type ResultSetSuite struct {
suite.Suite
}

func (s *ResultSetSuite) TestResultsetUnmarshal() {
type MyData struct {
A int64 `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
}
type OtherData struct {
A string `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
}

var (
idData = []int64{1, 2, 3}
vectorData = [][]float32{
{0.1, 0.2},
{0.1, 0.2},
{0.1, 0.2},
}
)

rs := ResultSet([]entity.Column{
entity.NewColumnInt64("id", idData),
entity.NewColumnFloatVector("vector", 2, vectorData),
})
err := rs.Unmarshal([]MyData{})
s.Error(err)

receiver := []MyData{}
err = rs.Unmarshal(&receiver)
s.Error(err)

var ptrReceiver []*MyData
err = rs.Unmarshal(&ptrReceiver)
s.NoError(err)

for idx, row := range ptrReceiver {
s.Equal(row.A, idData[idx])
s.Equal(row.V, vectorData[idx])
}

var otherReceiver []*OtherData
err = rs.Unmarshal(&otherReceiver)
s.Error(err)
}

func (s *ResultSetSuite) TestSearchResultUnmarshal() {
type MyData struct {
A int64 `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
}
type OtherData struct {
A string `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
}

var (
idData = []int64{1, 2, 3}
vectorData = [][]float32{
{0.1, 0.2},
{0.1, 0.2},
{0.1, 0.2},
}
)

sr := SearchResult{
sch: entity.NewSchema().
WithField(entity.NewField().WithName("id").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeInt64)).
WithField(entity.NewField().WithName("vector").WithDim(2).WithDataType(entity.FieldTypeFloatVector)),
IDs: entity.NewColumnInt64("id", idData),
Fields: ResultSet([]entity.Column{
entity.NewColumnFloatVector("vector", 2, vectorData),
}),
}
err := sr.Unmarshal([]MyData{})
s.Error(err)

receiver := []MyData{}
err = sr.Unmarshal(&receiver)
s.Error(err)

var ptrReceiver []*MyData
err = sr.Unmarshal(&ptrReceiver)
s.NoError(err)

for idx, row := range ptrReceiver {
s.Equal(row.A, idData[idx])
s.Equal(row.V, vectorData[idx])
}

var otherReceiver []*OtherData
err = sr.Unmarshal(&otherReceiver)
s.Error(err)
}

func TestResults(t *testing.T) {
suite.Run(t, new(ResultSetSuite))
}
Loading
Loading